number-plate/license_plate_recognizer.py

107 lines
3.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
from PIL import Image
from ultralytics import YOLO
from torchvision import transforms
from model import CRNN
import cv2
from config import CTC_BLANK, IDX_TO_CHAR
# ------------------------------------------------------------
# Основной класс, объединяющий YOLO + CRNN
# ------------------------------------------------------------
class LicensePlateRecognizer:
def __init__(self,
yolo_model_path: str,
crnn_model_path: str,
num_classes: int,
device: str = "cpu"):
"""
yolo_model_path: путь к файлу весов YOLO (напр. "yolo_plate.pt")
crnn_model_path: путь к файлу весов CRNN (напр. "best_accuracy_model_2.pth")
"""
self.yolo_model = YOLO(yolo_model_path)
self.crnn_model = CRNN(num_classes=num_classes).to(device)
self.crnn_model.load_state_dict(torch.load(crnn_model_path, map_location=device))
self.crnn_model.eval()
self.transform = transforms.Compose([
transforms.Resize((32, 128)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
self.device = device
def detect_and_recognize_frame(self, frame, padding: int = 5):
"""
Принимает кадр (BGR, np.array) напрямую,
возвращает список словарей:
{
"bbox": (x1, y1, x2, y2),
"text": распознанный_номер
}
"""
results = self.yolo_model.predict(frame)
detections_info = []
if not results:
return detections_info
for result in results:
boxes = result.boxes
if boxes is None or len(boxes) == 0:
continue
for box in boxes:
x1, y1, x2, y2 = map(int, box.xyxy[0].tolist())
x1_padded = max(x1 - padding, 0)
y1_padded = max(y1 - padding, 0)
x2_padded = min(x2 + padding, frame.shape[1])
y2_padded = min(y2 + padding, frame.shape[0])
plate_crop_bgr = frame[y1_padded:y2_padded, x1_padded:x2_padded]
if plate_crop_bgr.size == 0:
continue
# Конвертация в PIL (grayscale)
plate_gray = cv2.cvtColor(plate_crop_bgr, cv2.COLOR_BGR2GRAY)
plate_pil = Image.fromarray(plate_gray)
# Подготовка к CRNN
plate_tensor = self.transform(plate_pil).unsqueeze(0).to(self.device)
with torch.no_grad():
logits = self.crnn_model(plate_tensor)
decoded_texts = decode_predictions(logits)
recognized_text = decoded_texts[0] if len(decoded_texts) > 0 else ""
detections_info.append({
"bbox": (x1, y1, x2, y2),
"text": recognized_text
})
return detections_info
# ------------------------------------------------------------
# Расшифровка результатов CRNN
# ------------------------------------------------------------
def decode_predictions(preds, blank=CTC_BLANK):
# preds: (seq_len, batch, num_classes)
preds = preds.argmax(2) # (seq_len, batch)
preds = preds.permute(1, 0) # (batch, seq_len)
decoded = []
for pred in preds:
pred = pred.tolist()
decoded_seq = []
previous = blank
for p in pred:
if p != previous and p != blank:
decoded_seq.append(IDX_TO_CHAR.get(p, ''))
previous = p
decoded.append(''.join(decoded_seq))
return decoded