Initial commit: Add complete project files to the repository
This commit is contained in:
parent
74a8baefc9
commit
a4e6544db1
|
@ -0,0 +1,9 @@
|
|||
ALPHABET = '-ABEKMHOPCTYX0123456789'
|
||||
CHAR_TO_IDX = {char: idx + 1 for idx, char in enumerate(ALPHABET)}
|
||||
IDX_TO_CHAR = {idx + 1: char for idx, char in enumerate(ALPHABET)}
|
||||
|
||||
NUM_CLASSES = len(ALPHABET) + 1
|
||||
CTC_BLANK = 0
|
||||
|
||||
YOLO_WEIGHTS_PATH = "models/yolo_plate.pt"
|
||||
CRNN_WEIGHTS_PATH = "models/best_accuracy_model_3.pth"
|
|
@ -0,0 +1,107 @@
|
|||
import torch
|
||||
from PIL import Image
|
||||
from ultralytics import YOLO
|
||||
from torchvision import transforms
|
||||
from model import CRNN
|
||||
import cv2
|
||||
|
||||
from config import CTC_BLANK, IDX_TO_CHAR
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Основной класс, объединяющий YOLO + CRNN
|
||||
# ------------------------------------------------------------
|
||||
class LicensePlateRecognizer:
|
||||
def __init__(self,
|
||||
yolo_model_path: str,
|
||||
crnn_model_path: str,
|
||||
num_classes: int,
|
||||
device: str = "cpu"):
|
||||
"""
|
||||
yolo_model_path: путь к файлу весов YOLO (напр. "yolo_plate.pt")
|
||||
crnn_model_path: путь к файлу весов CRNN (напр. "best_accuracy_model_2.pth")
|
||||
"""
|
||||
self.yolo_model = YOLO(yolo_model_path)
|
||||
|
||||
self.crnn_model = CRNN(num_classes=num_classes).to(device)
|
||||
self.crnn_model.load_state_dict(torch.load(crnn_model_path, map_location=device))
|
||||
self.crnn_model.eval()
|
||||
|
||||
self.transform = transforms.Compose([
|
||||
transforms.Resize((32, 128)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5,), (0.5,))
|
||||
])
|
||||
|
||||
self.device = device
|
||||
|
||||
def detect_and_recognize_frame(self, frame, padding: int = 5):
|
||||
"""
|
||||
Принимает кадр (BGR, np.array) напрямую,
|
||||
возвращает список словарей:
|
||||
{
|
||||
"bbox": (x1, y1, x2, y2),
|
||||
"text": распознанный_номер
|
||||
}
|
||||
"""
|
||||
results = self.yolo_model.predict(frame)
|
||||
detections_info = []
|
||||
|
||||
if not results:
|
||||
return detections_info
|
||||
|
||||
for result in results:
|
||||
boxes = result.boxes
|
||||
if boxes is None or len(boxes) == 0:
|
||||
continue
|
||||
|
||||
for box in boxes:
|
||||
x1, y1, x2, y2 = map(int, box.xyxy[0].tolist())
|
||||
|
||||
x1_padded = max(x1 - padding, 0)
|
||||
y1_padded = max(y1 - padding, 0)
|
||||
x2_padded = min(x2 + padding, frame.shape[1])
|
||||
y2_padded = min(y2 + padding, frame.shape[0])
|
||||
|
||||
plate_crop_bgr = frame[y1_padded:y2_padded, x1_padded:x2_padded]
|
||||
if plate_crop_bgr.size == 0:
|
||||
continue
|
||||
|
||||
# Конвертация в PIL (grayscale)
|
||||
plate_gray = cv2.cvtColor(plate_crop_bgr, cv2.COLOR_BGR2GRAY)
|
||||
plate_pil = Image.fromarray(plate_gray)
|
||||
|
||||
# Подготовка к CRNN
|
||||
plate_tensor = self.transform(plate_pil).unsqueeze(0).to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
logits = self.crnn_model(plate_tensor)
|
||||
decoded_texts = decode_predictions(logits)
|
||||
recognized_text = decoded_texts[0] if len(decoded_texts) > 0 else ""
|
||||
|
||||
detections_info.append({
|
||||
"bbox": (x1, y1, x2, y2),
|
||||
"text": recognized_text
|
||||
})
|
||||
|
||||
return detections_info
|
||||
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Расшифровка результатов CRNN
|
||||
# ------------------------------------------------------------
|
||||
def decode_predictions(preds, blank=CTC_BLANK):
|
||||
# preds: (seq_len, batch, num_classes)
|
||||
preds = preds.argmax(2) # (seq_len, batch)
|
||||
preds = preds.permute(1, 0) # (batch, seq_len)
|
||||
|
||||
decoded = []
|
||||
for pred in preds:
|
||||
pred = pred.tolist()
|
||||
decoded_seq = []
|
||||
previous = blank
|
||||
for p in pred:
|
||||
if p != previous and p != blank:
|
||||
decoded_seq.append(IDX_TO_CHAR.get(p, ''))
|
||||
previous = p
|
||||
decoded.append(''.join(decoded_seq))
|
||||
return decoded
|
|
@ -0,0 +1,69 @@
|
|||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from config import NUM_CLASSES, CRNN_WEIGHTS_PATH, YOLO_WEIGHTS_PATH
|
||||
from license_plate_recognizer import LicensePlateRecognizer
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Запуск в режиме реального времени (веб-камера)
|
||||
# ------------------------------------------------------------
|
||||
if __name__ == "__main__":
|
||||
lpr = LicensePlateRecognizer(
|
||||
yolo_model_path=YOLO_WEIGHTS_PATH,
|
||||
crnn_model_path=CRNN_WEIGHTS_PATH,
|
||||
num_classes=NUM_CLASSES,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
cap = cv2.VideoCapture(0)
|
||||
|
||||
# cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640)
|
||||
# cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480)
|
||||
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
print("Не удалось считать кадр с веб-камеры.")
|
||||
break
|
||||
|
||||
detections = lpr.detect_and_recognize_frame(frame, padding=5)
|
||||
|
||||
for det in detections:
|
||||
x1, y1, x2, y2 = det["bbox"]
|
||||
text = det["text"]
|
||||
|
||||
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
||||
|
||||
cv2.putText(
|
||||
frame, text, (x1, y1 - 10),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.7, (0, 255, 0), 2
|
||||
)
|
||||
|
||||
height, width, _ = frame.shape
|
||||
black_bar_width = 300
|
||||
black_bar = np.zeros((height, black_bar_width, 3), dtype=np.uint8)
|
||||
|
||||
y_start = 40
|
||||
for i, det in enumerate(detections):
|
||||
txt = det["text"]
|
||||
cv2.putText(
|
||||
black_bar,
|
||||
f"Plate #{i+1}: {txt}",
|
||||
(10, y_start),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.7, (255, 255, 255), 2
|
||||
)
|
||||
y_start += 40
|
||||
|
||||
display_frame = np.hstack((frame, black_bar))
|
||||
|
||||
cv2.imshow("License Plate Recognition", display_frame)
|
||||
|
||||
key = cv2.waitKey(1) & 0xFF
|
||||
if key == 27 or key == ord('q'):
|
||||
break
|
||||
|
||||
cap.release()
|
||||
cv2.destroyAllWindows()
|
|
@ -0,0 +1,56 @@
|
|||
import torch.nn as nn
|
||||
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# CRNN-модель
|
||||
# ------------------------------------------------------------
|
||||
class CRNN(nn.Module):
|
||||
def __init__(self, num_classes):
|
||||
super(CRNN, self).__init__()
|
||||
|
||||
# CNN часть
|
||||
self.cnn = nn.Sequential(
|
||||
nn.Conv2d(1, 64, kernel_size=3, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(2, 2),
|
||||
|
||||
nn.Conv2d(64, 128, kernel_size=3, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(2, 2),
|
||||
|
||||
nn.Conv2d(128, 256, kernel_size=3, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.BatchNorm2d(256),
|
||||
|
||||
nn.Conv2d(256, 256, kernel_size=3, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d((2,1), (2,1)),
|
||||
|
||||
nn.Conv2d(256, 512, kernel_size=3, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.BatchNorm2d(512),
|
||||
|
||||
nn.Conv2d(512, 512, kernel_size=3, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d((2,1), (2,1)),
|
||||
)
|
||||
|
||||
# RNN часть
|
||||
self.linear1 = nn.Linear(512 * 2, 256)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.lstm = nn.LSTM(256, 256, bidirectional=True, batch_first=True)
|
||||
self.linear2 = nn.Linear(512, num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
# x: (batch, 1, 32, 128) — после Resize
|
||||
conv = self.cnn(x) # (batch, 512, 2, 32)
|
||||
conv = conv.permute(0, 3, 1, 2) # (batch, width=32, channels=512, height=2)
|
||||
conv = conv.view(conv.size(0), conv.size(1), -1) # (batch, 32, 512*2)
|
||||
|
||||
out = self.linear1(conv) # (batch, 32, 256)
|
||||
out = self.relu(out) # (batch, 32, 256)
|
||||
out, _ = self.lstm(out) # (batch, 32, 512) — bidirectional
|
||||
out = self.linear2(out) # (batch, 32, num_classes)
|
||||
|
||||
out = out.permute(1, 0, 2) # (32, batch, num_classes)
|
||||
return out
|
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,11 @@
|
|||
--index-url https://pypi.org/simple
|
||||
numpy==1.26.3
|
||||
pillow==10.2.0
|
||||
opencv-python==4.10.0.84
|
||||
scikit-learn==1.5.2
|
||||
scipy==1.13.1
|
||||
matplotlib==3.9.2
|
||||
tqdm
|
||||
ultralytics
|
||||
torch==2.0.1
|
||||
torchvision==0.15.2
|
Loading…
Reference in New Issue