107 lines
3.7 KiB
Python
107 lines
3.7 KiB
Python
import torch
|
||
from PIL import Image
|
||
from ultralytics import YOLO
|
||
from torchvision import transforms
|
||
from model import CRNN
|
||
import cv2
|
||
|
||
from config import CTC_BLANK, IDX_TO_CHAR
|
||
|
||
# ------------------------------------------------------------
|
||
# Основной класс, объединяющий YOLO + CRNN
|
||
# ------------------------------------------------------------
|
||
class LicensePlateRecognizer:
|
||
def __init__(self,
|
||
yolo_model_path: str,
|
||
crnn_model_path: str,
|
||
num_classes: int,
|
||
device: str = "cpu"):
|
||
"""
|
||
yolo_model_path: путь к файлу весов YOLO (напр. "yolo_plate.pt")
|
||
crnn_model_path: путь к файлу весов CRNN (напр. "best_accuracy_model_2.pth")
|
||
"""
|
||
self.yolo_model = YOLO(yolo_model_path)
|
||
|
||
self.crnn_model = CRNN(num_classes=num_classes).to(device)
|
||
self.crnn_model.load_state_dict(torch.load(crnn_model_path, map_location=device))
|
||
self.crnn_model.eval()
|
||
|
||
self.transform = transforms.Compose([
|
||
transforms.Resize((32, 128)),
|
||
transforms.ToTensor(),
|
||
transforms.Normalize((0.5,), (0.5,))
|
||
])
|
||
|
||
self.device = device
|
||
|
||
def detect_and_recognize_frame(self, frame, padding: int = 5):
|
||
"""
|
||
Принимает кадр (BGR, np.array) напрямую,
|
||
возвращает список словарей:
|
||
{
|
||
"bbox": (x1, y1, x2, y2),
|
||
"text": распознанный_номер
|
||
}
|
||
"""
|
||
results = self.yolo_model.predict(frame)
|
||
detections_info = []
|
||
|
||
if not results:
|
||
return detections_info
|
||
|
||
for result in results:
|
||
boxes = result.boxes
|
||
if boxes is None or len(boxes) == 0:
|
||
continue
|
||
|
||
for box in boxes:
|
||
x1, y1, x2, y2 = map(int, box.xyxy[0].tolist())
|
||
|
||
x1_padded = max(x1 - padding, 0)
|
||
y1_padded = max(y1 - padding, 0)
|
||
x2_padded = min(x2 + padding, frame.shape[1])
|
||
y2_padded = min(y2 + padding, frame.shape[0])
|
||
|
||
plate_crop_bgr = frame[y1_padded:y2_padded, x1_padded:x2_padded]
|
||
if plate_crop_bgr.size == 0:
|
||
continue
|
||
|
||
# Конвертация в PIL (grayscale)
|
||
plate_gray = cv2.cvtColor(plate_crop_bgr, cv2.COLOR_BGR2GRAY)
|
||
plate_pil = Image.fromarray(plate_gray)
|
||
|
||
# Подготовка к CRNN
|
||
plate_tensor = self.transform(plate_pil).unsqueeze(0).to(self.device)
|
||
|
||
with torch.no_grad():
|
||
logits = self.crnn_model(plate_tensor)
|
||
decoded_texts = decode_predictions(logits)
|
||
recognized_text = decoded_texts[0] if len(decoded_texts) > 0 else ""
|
||
|
||
detections_info.append({
|
||
"bbox": (x1, y1, x2, y2),
|
||
"text": recognized_text
|
||
})
|
||
|
||
return detections_info
|
||
|
||
|
||
# ------------------------------------------------------------
|
||
# Расшифровка результатов CRNN
|
||
# ------------------------------------------------------------
|
||
def decode_predictions(preds, blank=CTC_BLANK):
|
||
# preds: (seq_len, batch, num_classes)
|
||
preds = preds.argmax(2) # (seq_len, batch)
|
||
preds = preds.permute(1, 0) # (batch, seq_len)
|
||
|
||
decoded = []
|
||
for pred in preds:
|
||
pred = pred.tolist()
|
||
decoded_seq = []
|
||
previous = blank
|
||
for p in pred:
|
||
if p != previous and p != blank:
|
||
decoded_seq.append(IDX_TO_CHAR.get(p, ''))
|
||
previous = p
|
||
decoded.append(''.join(decoded_seq))
|
||
return decoded |