import torch from PIL import Image from ultralytics import YOLO from torchvision import transforms from model import CRNN import cv2 from config import CTC_BLANK, IDX_TO_CHAR # ------------------------------------------------------------ # Основной класс, объединяющий YOLO + CRNN # ------------------------------------------------------------ class LicensePlateRecognizer: def __init__(self, yolo_model_path: str, crnn_model_path: str, num_classes: int, device: str = "cpu"): """ yolo_model_path: путь к файлу весов YOLO (напр. "yolo_plate.pt") crnn_model_path: путь к файлу весов CRNN (напр. "best_accuracy_model_2.pth") """ self.yolo_model = YOLO(yolo_model_path) self.crnn_model = CRNN(num_classes=num_classes).to(device) self.crnn_model.load_state_dict(torch.load(crnn_model_path, map_location=device)) self.crnn_model.eval() self.transform = transforms.Compose([ transforms.Resize((32, 128)), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) self.device = device def detect_and_recognize_frame(self, frame, padding: int = 5): """ Принимает кадр (BGR, np.array) напрямую, возвращает список словарей: { "bbox": (x1, y1, x2, y2), "text": распознанный_номер } """ results = self.yolo_model.predict(frame) detections_info = [] if not results: return detections_info for result in results: boxes = result.boxes if boxes is None or len(boxes) == 0: continue for box in boxes: x1, y1, x2, y2 = map(int, box.xyxy[0].tolist()) x1_padded = max(x1 - padding, 0) y1_padded = max(y1 - padding, 0) x2_padded = min(x2 + padding, frame.shape[1]) y2_padded = min(y2 + padding, frame.shape[0]) plate_crop_bgr = frame[y1_padded:y2_padded, x1_padded:x2_padded] if plate_crop_bgr.size == 0: continue # Конвертация в PIL (grayscale) plate_gray = cv2.cvtColor(plate_crop_bgr, cv2.COLOR_BGR2GRAY) plate_pil = Image.fromarray(plate_gray) # Подготовка к CRNN plate_tensor = self.transform(plate_pil).unsqueeze(0).to(self.device) with torch.no_grad(): logits = self.crnn_model(plate_tensor) decoded_texts = decode_predictions(logits) recognized_text = decoded_texts[0] if len(decoded_texts) > 0 else "" detections_info.append({ "bbox": (x1, y1, x2, y2), "text": recognized_text }) return detections_info # ------------------------------------------------------------ # Расшифровка результатов CRNN # ------------------------------------------------------------ def decode_predictions(preds, blank=CTC_BLANK): # preds: (seq_len, batch, num_classes) preds = preds.argmax(2) # (seq_len, batch) preds = preds.permute(1, 0) # (batch, seq_len) decoded = [] for pred in preds: pred = pred.tolist() decoded_seq = [] previous = blank for p in pred: if p != previous and p != blank: decoded_seq.append(IDX_TO_CHAR.get(p, '')) previous = p decoded.append(''.join(decoded_seq)) return decoded