Initial commit: Add complete project files to the repository
This commit is contained in:
parent
74a8baefc9
commit
a4e6544db1
|
@ -0,0 +1,9 @@
|
||||||
|
ALPHABET = '-ABEKMHOPCTYX0123456789'
|
||||||
|
CHAR_TO_IDX = {char: idx + 1 for idx, char in enumerate(ALPHABET)}
|
||||||
|
IDX_TO_CHAR = {idx + 1: char for idx, char in enumerate(ALPHABET)}
|
||||||
|
|
||||||
|
NUM_CLASSES = len(ALPHABET) + 1
|
||||||
|
CTC_BLANK = 0
|
||||||
|
|
||||||
|
YOLO_WEIGHTS_PATH = "models/yolo_plate.pt"
|
||||||
|
CRNN_WEIGHTS_PATH = "models/best_accuracy_model_3.pth"
|
|
@ -0,0 +1,107 @@
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from ultralytics import YOLO
|
||||||
|
from torchvision import transforms
|
||||||
|
from model import CRNN
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
from config import CTC_BLANK, IDX_TO_CHAR
|
||||||
|
|
||||||
|
# ------------------------------------------------------------
|
||||||
|
# Основной класс, объединяющий YOLO + CRNN
|
||||||
|
# ------------------------------------------------------------
|
||||||
|
class LicensePlateRecognizer:
|
||||||
|
def __init__(self,
|
||||||
|
yolo_model_path: str,
|
||||||
|
crnn_model_path: str,
|
||||||
|
num_classes: int,
|
||||||
|
device: str = "cpu"):
|
||||||
|
"""
|
||||||
|
yolo_model_path: путь к файлу весов YOLO (напр. "yolo_plate.pt")
|
||||||
|
crnn_model_path: путь к файлу весов CRNN (напр. "best_accuracy_model_2.pth")
|
||||||
|
"""
|
||||||
|
self.yolo_model = YOLO(yolo_model_path)
|
||||||
|
|
||||||
|
self.crnn_model = CRNN(num_classes=num_classes).to(device)
|
||||||
|
self.crnn_model.load_state_dict(torch.load(crnn_model_path, map_location=device))
|
||||||
|
self.crnn_model.eval()
|
||||||
|
|
||||||
|
self.transform = transforms.Compose([
|
||||||
|
transforms.Resize((32, 128)),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize((0.5,), (0.5,))
|
||||||
|
])
|
||||||
|
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
def detect_and_recognize_frame(self, frame, padding: int = 5):
|
||||||
|
"""
|
||||||
|
Принимает кадр (BGR, np.array) напрямую,
|
||||||
|
возвращает список словарей:
|
||||||
|
{
|
||||||
|
"bbox": (x1, y1, x2, y2),
|
||||||
|
"text": распознанный_номер
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
results = self.yolo_model.predict(frame)
|
||||||
|
detections_info = []
|
||||||
|
|
||||||
|
if not results:
|
||||||
|
return detections_info
|
||||||
|
|
||||||
|
for result in results:
|
||||||
|
boxes = result.boxes
|
||||||
|
if boxes is None or len(boxes) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for box in boxes:
|
||||||
|
x1, y1, x2, y2 = map(int, box.xyxy[0].tolist())
|
||||||
|
|
||||||
|
x1_padded = max(x1 - padding, 0)
|
||||||
|
y1_padded = max(y1 - padding, 0)
|
||||||
|
x2_padded = min(x2 + padding, frame.shape[1])
|
||||||
|
y2_padded = min(y2 + padding, frame.shape[0])
|
||||||
|
|
||||||
|
plate_crop_bgr = frame[y1_padded:y2_padded, x1_padded:x2_padded]
|
||||||
|
if plate_crop_bgr.size == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Конвертация в PIL (grayscale)
|
||||||
|
plate_gray = cv2.cvtColor(plate_crop_bgr, cv2.COLOR_BGR2GRAY)
|
||||||
|
plate_pil = Image.fromarray(plate_gray)
|
||||||
|
|
||||||
|
# Подготовка к CRNN
|
||||||
|
plate_tensor = self.transform(plate_pil).unsqueeze(0).to(self.device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
logits = self.crnn_model(plate_tensor)
|
||||||
|
decoded_texts = decode_predictions(logits)
|
||||||
|
recognized_text = decoded_texts[0] if len(decoded_texts) > 0 else ""
|
||||||
|
|
||||||
|
detections_info.append({
|
||||||
|
"bbox": (x1, y1, x2, y2),
|
||||||
|
"text": recognized_text
|
||||||
|
})
|
||||||
|
|
||||||
|
return detections_info
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------
|
||||||
|
# Расшифровка результатов CRNN
|
||||||
|
# ------------------------------------------------------------
|
||||||
|
def decode_predictions(preds, blank=CTC_BLANK):
|
||||||
|
# preds: (seq_len, batch, num_classes)
|
||||||
|
preds = preds.argmax(2) # (seq_len, batch)
|
||||||
|
preds = preds.permute(1, 0) # (batch, seq_len)
|
||||||
|
|
||||||
|
decoded = []
|
||||||
|
for pred in preds:
|
||||||
|
pred = pred.tolist()
|
||||||
|
decoded_seq = []
|
||||||
|
previous = blank
|
||||||
|
for p in pred:
|
||||||
|
if p != previous and p != blank:
|
||||||
|
decoded_seq.append(IDX_TO_CHAR.get(p, ''))
|
||||||
|
previous = p
|
||||||
|
decoded.append(''.join(decoded_seq))
|
||||||
|
return decoded
|
|
@ -0,0 +1,69 @@
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from config import NUM_CLASSES, CRNN_WEIGHTS_PATH, YOLO_WEIGHTS_PATH
|
||||||
|
from license_plate_recognizer import LicensePlateRecognizer
|
||||||
|
|
||||||
|
# ------------------------------------------------------------
|
||||||
|
# Запуск в режиме реального времени (веб-камера)
|
||||||
|
# ------------------------------------------------------------
|
||||||
|
if __name__ == "__main__":
|
||||||
|
lpr = LicensePlateRecognizer(
|
||||||
|
yolo_model_path=YOLO_WEIGHTS_PATH,
|
||||||
|
crnn_model_path=CRNN_WEIGHTS_PATH,
|
||||||
|
num_classes=NUM_CLASSES,
|
||||||
|
device="cpu"
|
||||||
|
)
|
||||||
|
|
||||||
|
cap = cv2.VideoCapture(0)
|
||||||
|
|
||||||
|
# cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640)
|
||||||
|
# cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
ret, frame = cap.read()
|
||||||
|
if not ret:
|
||||||
|
print("Не удалось считать кадр с веб-камеры.")
|
||||||
|
break
|
||||||
|
|
||||||
|
detections = lpr.detect_and_recognize_frame(frame, padding=5)
|
||||||
|
|
||||||
|
for det in detections:
|
||||||
|
x1, y1, x2, y2 = det["bbox"]
|
||||||
|
text = det["text"]
|
||||||
|
|
||||||
|
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
||||||
|
|
||||||
|
cv2.putText(
|
||||||
|
frame, text, (x1, y1 - 10),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX,
|
||||||
|
0.7, (0, 255, 0), 2
|
||||||
|
)
|
||||||
|
|
||||||
|
height, width, _ = frame.shape
|
||||||
|
black_bar_width = 300
|
||||||
|
black_bar = np.zeros((height, black_bar_width, 3), dtype=np.uint8)
|
||||||
|
|
||||||
|
y_start = 40
|
||||||
|
for i, det in enumerate(detections):
|
||||||
|
txt = det["text"]
|
||||||
|
cv2.putText(
|
||||||
|
black_bar,
|
||||||
|
f"Plate #{i+1}: {txt}",
|
||||||
|
(10, y_start),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX,
|
||||||
|
0.7, (255, 255, 255), 2
|
||||||
|
)
|
||||||
|
y_start += 40
|
||||||
|
|
||||||
|
display_frame = np.hstack((frame, black_bar))
|
||||||
|
|
||||||
|
cv2.imshow("License Plate Recognition", display_frame)
|
||||||
|
|
||||||
|
key = cv2.waitKey(1) & 0xFF
|
||||||
|
if key == 27 or key == ord('q'):
|
||||||
|
break
|
||||||
|
|
||||||
|
cap.release()
|
||||||
|
cv2.destroyAllWindows()
|
|
@ -0,0 +1,56 @@
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------
|
||||||
|
# CRNN-модель
|
||||||
|
# ------------------------------------------------------------
|
||||||
|
class CRNN(nn.Module):
|
||||||
|
def __init__(self, num_classes):
|
||||||
|
super(CRNN, self).__init__()
|
||||||
|
|
||||||
|
# CNN часть
|
||||||
|
self.cnn = nn.Sequential(
|
||||||
|
nn.Conv2d(1, 64, kernel_size=3, padding=1),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.MaxPool2d(2, 2),
|
||||||
|
|
||||||
|
nn.Conv2d(64, 128, kernel_size=3, padding=1),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.MaxPool2d(2, 2),
|
||||||
|
|
||||||
|
nn.Conv2d(128, 256, kernel_size=3, padding=1),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.BatchNorm2d(256),
|
||||||
|
|
||||||
|
nn.Conv2d(256, 256, kernel_size=3, padding=1),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.MaxPool2d((2,1), (2,1)),
|
||||||
|
|
||||||
|
nn.Conv2d(256, 512, kernel_size=3, padding=1),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.BatchNorm2d(512),
|
||||||
|
|
||||||
|
nn.Conv2d(512, 512, kernel_size=3, padding=1),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.MaxPool2d((2,1), (2,1)),
|
||||||
|
)
|
||||||
|
|
||||||
|
# RNN часть
|
||||||
|
self.linear1 = nn.Linear(512 * 2, 256)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
self.lstm = nn.LSTM(256, 256, bidirectional=True, batch_first=True)
|
||||||
|
self.linear2 = nn.Linear(512, num_classes)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# x: (batch, 1, 32, 128) — после Resize
|
||||||
|
conv = self.cnn(x) # (batch, 512, 2, 32)
|
||||||
|
conv = conv.permute(0, 3, 1, 2) # (batch, width=32, channels=512, height=2)
|
||||||
|
conv = conv.view(conv.size(0), conv.size(1), -1) # (batch, 32, 512*2)
|
||||||
|
|
||||||
|
out = self.linear1(conv) # (batch, 32, 256)
|
||||||
|
out = self.relu(out) # (batch, 32, 256)
|
||||||
|
out, _ = self.lstm(out) # (batch, 32, 512) — bidirectional
|
||||||
|
out = self.linear2(out) # (batch, 32, num_classes)
|
||||||
|
|
||||||
|
out = out.permute(1, 0, 2) # (32, batch, num_classes)
|
||||||
|
return out
|
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,11 @@
|
||||||
|
--index-url https://pypi.org/simple
|
||||||
|
numpy==1.26.3
|
||||||
|
pillow==10.2.0
|
||||||
|
opencv-python==4.10.0.84
|
||||||
|
scikit-learn==1.5.2
|
||||||
|
scipy==1.13.1
|
||||||
|
matplotlib==3.9.2
|
||||||
|
tqdm
|
||||||
|
ultralytics
|
||||||
|
torch==2.0.1
|
||||||
|
torchvision==0.15.2
|
Loading…
Reference in New Issue