107 lines
3.7 KiB
Python
107 lines
3.7 KiB
Python
|
import torch
|
|||
|
from PIL import Image
|
|||
|
from ultralytics import YOLO
|
|||
|
from torchvision import transforms
|
|||
|
from model import CRNN
|
|||
|
import cv2
|
|||
|
|
|||
|
from config import CTC_BLANK, IDX_TO_CHAR
|
|||
|
|
|||
|
# ------------------------------------------------------------
|
|||
|
# Основной класс, объединяющий YOLO + CRNN
|
|||
|
# ------------------------------------------------------------
|
|||
|
class LicensePlateRecognizer:
|
|||
|
def __init__(self,
|
|||
|
yolo_model_path: str,
|
|||
|
crnn_model_path: str,
|
|||
|
num_classes: int,
|
|||
|
device: str = "cpu"):
|
|||
|
"""
|
|||
|
yolo_model_path: путь к файлу весов YOLO (напр. "yolo_plate.pt")
|
|||
|
crnn_model_path: путь к файлу весов CRNN (напр. "best_accuracy_model_2.pth")
|
|||
|
"""
|
|||
|
self.yolo_model = YOLO(yolo_model_path)
|
|||
|
|
|||
|
self.crnn_model = CRNN(num_classes=num_classes).to(device)
|
|||
|
self.crnn_model.load_state_dict(torch.load(crnn_model_path, map_location=device))
|
|||
|
self.crnn_model.eval()
|
|||
|
|
|||
|
self.transform = transforms.Compose([
|
|||
|
transforms.Resize((32, 128)),
|
|||
|
transforms.ToTensor(),
|
|||
|
transforms.Normalize((0.5,), (0.5,))
|
|||
|
])
|
|||
|
|
|||
|
self.device = device
|
|||
|
|
|||
|
def detect_and_recognize_frame(self, frame, padding: int = 5):
|
|||
|
"""
|
|||
|
Принимает кадр (BGR, np.array) напрямую,
|
|||
|
возвращает список словарей:
|
|||
|
{
|
|||
|
"bbox": (x1, y1, x2, y2),
|
|||
|
"text": распознанный_номер
|
|||
|
}
|
|||
|
"""
|
|||
|
results = self.yolo_model.predict(frame)
|
|||
|
detections_info = []
|
|||
|
|
|||
|
if not results:
|
|||
|
return detections_info
|
|||
|
|
|||
|
for result in results:
|
|||
|
boxes = result.boxes
|
|||
|
if boxes is None or len(boxes) == 0:
|
|||
|
continue
|
|||
|
|
|||
|
for box in boxes:
|
|||
|
x1, y1, x2, y2 = map(int, box.xyxy[0].tolist())
|
|||
|
|
|||
|
x1_padded = max(x1 - padding, 0)
|
|||
|
y1_padded = max(y1 - padding, 0)
|
|||
|
x2_padded = min(x2 + padding, frame.shape[1])
|
|||
|
y2_padded = min(y2 + padding, frame.shape[0])
|
|||
|
|
|||
|
plate_crop_bgr = frame[y1_padded:y2_padded, x1_padded:x2_padded]
|
|||
|
if plate_crop_bgr.size == 0:
|
|||
|
continue
|
|||
|
|
|||
|
# Конвертация в PIL (grayscale)
|
|||
|
plate_gray = cv2.cvtColor(plate_crop_bgr, cv2.COLOR_BGR2GRAY)
|
|||
|
plate_pil = Image.fromarray(plate_gray)
|
|||
|
|
|||
|
# Подготовка к CRNN
|
|||
|
plate_tensor = self.transform(plate_pil).unsqueeze(0).to(self.device)
|
|||
|
|
|||
|
with torch.no_grad():
|
|||
|
logits = self.crnn_model(plate_tensor)
|
|||
|
decoded_texts = decode_predictions(logits)
|
|||
|
recognized_text = decoded_texts[0] if len(decoded_texts) > 0 else ""
|
|||
|
|
|||
|
detections_info.append({
|
|||
|
"bbox": (x1, y1, x2, y2),
|
|||
|
"text": recognized_text
|
|||
|
})
|
|||
|
|
|||
|
return detections_info
|
|||
|
|
|||
|
|
|||
|
# ------------------------------------------------------------
|
|||
|
# Расшифровка результатов CRNN
|
|||
|
# ------------------------------------------------------------
|
|||
|
def decode_predictions(preds, blank=CTC_BLANK):
|
|||
|
# preds: (seq_len, batch, num_classes)
|
|||
|
preds = preds.argmax(2) # (seq_len, batch)
|
|||
|
preds = preds.permute(1, 0) # (batch, seq_len)
|
|||
|
|
|||
|
decoded = []
|
|||
|
for pred in preds:
|
|||
|
pred = pred.tolist()
|
|||
|
decoded_seq = []
|
|||
|
previous = blank
|
|||
|
for p in pred:
|
|||
|
if p != previous and p != blank:
|
|||
|
decoded_seq.append(IDX_TO_CHAR.get(p, ''))
|
|||
|
previous = p
|
|||
|
decoded.append(''.join(decoded_seq))
|
|||
|
return decoded
|