number-plate/license_plate_recognizer.py

107 lines
3.7 KiB
Python
Raw Normal View History

import torch
from PIL import Image
from ultralytics import YOLO
from torchvision import transforms
from model import CRNN
import cv2
from config import CTC_BLANK, IDX_TO_CHAR
# ------------------------------------------------------------
# Основной класс, объединяющий YOLO + CRNN
# ------------------------------------------------------------
class LicensePlateRecognizer:
def __init__(self,
yolo_model_path: str,
crnn_model_path: str,
num_classes: int,
device: str = "cpu"):
"""
yolo_model_path: путь к файлу весов YOLO (напр. "yolo_plate.pt")
crnn_model_path: путь к файлу весов CRNN (напр. "best_accuracy_model_2.pth")
"""
self.yolo_model = YOLO(yolo_model_path)
self.crnn_model = CRNN(num_classes=num_classes).to(device)
self.crnn_model.load_state_dict(torch.load(crnn_model_path, map_location=device))
self.crnn_model.eval()
self.transform = transforms.Compose([
transforms.Resize((32, 128)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
self.device = device
def detect_and_recognize_frame(self, frame, padding: int = 5):
"""
Принимает кадр (BGR, np.array) напрямую,
возвращает список словарей:
{
"bbox": (x1, y1, x2, y2),
"text": распознанный_номер
}
"""
results = self.yolo_model.predict(frame)
detections_info = []
if not results:
return detections_info
for result in results:
boxes = result.boxes
if boxes is None or len(boxes) == 0:
continue
for box in boxes:
x1, y1, x2, y2 = map(int, box.xyxy[0].tolist())
x1_padded = max(x1 - padding, 0)
y1_padded = max(y1 - padding, 0)
x2_padded = min(x2 + padding, frame.shape[1])
y2_padded = min(y2 + padding, frame.shape[0])
plate_crop_bgr = frame[y1_padded:y2_padded, x1_padded:x2_padded]
if plate_crop_bgr.size == 0:
continue
# Конвертация в PIL (grayscale)
plate_gray = cv2.cvtColor(plate_crop_bgr, cv2.COLOR_BGR2GRAY)
plate_pil = Image.fromarray(plate_gray)
# Подготовка к CRNN
plate_tensor = self.transform(plate_pil).unsqueeze(0).to(self.device)
with torch.no_grad():
logits = self.crnn_model(plate_tensor)
decoded_texts = decode_predictions(logits)
recognized_text = decoded_texts[0] if len(decoded_texts) > 0 else ""
detections_info.append({
"bbox": (x1, y1, x2, y2),
"text": recognized_text
})
return detections_info
# ------------------------------------------------------------
# Расшифровка результатов CRNN
# ------------------------------------------------------------
def decode_predictions(preds, blank=CTC_BLANK):
# preds: (seq_len, batch, num_classes)
preds = preds.argmax(2) # (seq_len, batch)
preds = preds.permute(1, 0) # (batch, seq_len)
decoded = []
for pred in preds:
pred = pred.tolist()
decoded_seq = []
previous = blank
for p in pred:
if p != previous and p != blank:
decoded_seq.append(IDX_TO_CHAR.get(p, ''))
previous = p
decoded.append(''.join(decoded_seq))
return decoded