{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Используемое устройство: cuda\n" ] } ], "source": [ "import os\n", "import string\n", "import torch\n", "import torch.nn as nn\n", "from torch.utils.data import Dataset, DataLoader\n", "from torchvision import transforms\n", "from PIL import Image\n", "from tqdm import tqdm\n", "\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "print(f\"Используемое устройство: {device}\")\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "ALPHABET = '-ABEKMHOPCTYX0123456789'\n", "CHAR_TO_IDX = {char: idx + 1 for idx, char in enumerate(ALPHABET)} # 0 будет использоваться для CTC blank\n", "IDX_TO_CHAR = {idx + 1: char for idx, char in enumerate(ALPHABET)}\n", "\n", "NUM_CLASSES = len(ALPHABET) + 1\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "class LicensePlateDataset(Dataset):\n", " def __init__(self, root_dir, transform=None):\n", " \"\"\"\n", " Args:\n", " root_dir (string): Путь к директории с данными (train, val, test)\n", " transform (callable, optional): Трансформации для изображений\n", " \"\"\"\n", " self.root_dir = root_dir\n", " self.img_dir = os.path.join(root_dir, 'img')\n", " self.transform = transform\n", " self.images = [img for img in os.listdir(self.img_dir) if img.endswith(('.png', '.jpg', '.jpeg'))]\n", " \n", " def __len__(self):\n", " return len(self.images)\n", " \n", " def __getitem__(self, idx):\n", " img_name = self.images[idx]\n", " img_path = os.path.join(self.img_dir, img_name)\n", " image = Image.open(img_path).convert('L') \n", " if self.transform:\n", " image = self.transform(image)\n", " \n", " label_str = os.path.splitext(img_name)[0].upper()\n", " label = [CHAR_TO_IDX[char] for char in label_str if char in CHAR_TO_IDX]\n", " label = torch.tensor(label, dtype=torch.long)\n", " \n", " return image, label\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "transform = transforms.Compose([\n", " transforms.Resize((32, 128)),\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.5,), (0.5,))\n", "])\n", "\n", "import random\n", "import torchvision.transforms.functional as F\n", "from torchvision import transforms\n", "\n", "def random_padding(img):\n", " pad_left = random.randint(5, 15)\n", " pad_right = random.randint(5, 15)\n", " pad_top = random.randint(5, 15)\n", " pad_bottom = random.randint(5, 15)\n", " return F.pad(img, (pad_left, pad_top, pad_right, pad_bottom), fill=0)\n", "\n", "# Обновленный transform\n", "transform_train = transforms.Compose([\n", " transforms.Lambda(random_padding),\n", " transforms.Resize((32, 128)), # Изменяем размер изображения\n", " transforms.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)), # Сдвиг, масштабирование, поворот\n", " transforms.ColorJitter(brightness=0.2, contrast=0.2), # Изменение яркости и контраста\n", " transforms.RandomPerspective(distortion_scale=0.2, p=0.5), # Перспективные искажения\n", " transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)), # Размытие\n", " transforms.ToTensor(), # Преобразуем в тензор\n", " transforms.Normalize((0.5,), (0.5,)), # Нормализация\n", "])\n", "\n", "def collate_fn(batch):\n", " images, labels = zip(*batch)\n", " images = torch.stack(images, 0)\n", " \n", " # Соединяем все метки в один тензор\n", " label_lengths = torch.tensor([len(label) for label in labels], dtype=torch.long)\n", " labels_concat = torch.cat(labels)\n", " \n", " return images, labels_concat, label_lengths\n" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "train_dataset = LicensePlateDataset(root_dir='dataset-ocr-new/train', transform=transform_train)\n", "val_dataset = LicensePlateDataset(root_dir='dataset-ocr-new/val', transform=transform)\n", "test_dataset = LicensePlateDataset(root_dir='dataset-ocr-new/test', transform=transform)\n", "test_10 = LicensePlateDataset(root_dir=r'dataset-ocr\\fine-tune-val', transform=transform)\n", "\n", "batch_size = 64\n", "\n", "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, collate_fn=lambda x: collate_fn(x))\n", "val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0, collate_fn=lambda x: collate_fn(x))\n", "test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0, collate_fn=lambda x: collate_fn(x))\n" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "class CRNN(nn.Module):\n", " def __init__(self, num_classes):\n", " super(CRNN, self).__init__()\n", " \n", " # CNN часть\n", " self.cnn = nn.Sequential(\n", " nn.Conv2d(1, 64, kernel_size=3, padding=1), # (batch, 64, 32, 128)\n", " nn.ReLU(inplace=True),\n", " nn.MaxPool2d(2, 2), # (batch, 64, 16, 64)\n", " \n", " nn.Conv2d(64, 128, kernel_size=3, padding=1), # (batch, 128, 16, 64)\n", " nn.ReLU(inplace=True),\n", " nn.MaxPool2d(2, 2), # (batch, 128, 8, 32)\n", " \n", " nn.Conv2d(128, 256, kernel_size=3, padding=1), # (batch, 256, 8, 32)\n", " nn.ReLU(inplace=True),\n", " nn.BatchNorm2d(256),\n", " \n", " nn.Conv2d(256, 256, kernel_size=3, padding=1), # (batch, 256, 8, 32)\n", " nn.ReLU(inplace=True),\n", " nn.MaxPool2d((2,1), (2,1)), # (batch, 256, 4, 32)\n", " \n", " nn.Conv2d(256, 512, kernel_size=3, padding=1), # (batch, 512, 4, 32)\n", " nn.ReLU(inplace=True),\n", " nn.BatchNorm2d(512),\n", " \n", " nn.Conv2d(512, 512, kernel_size=3, padding=1), # (batch, 512, 4, 32)\n", " nn.ReLU(inplace=True),\n", " nn.MaxPool2d((2,1), (2,1)), # (batch, 512, 2, 32)\n", " )\n", " \n", " # RNN часть\n", " self.linear1 = nn.Linear(512 * 2, 256)\n", " self.relu = nn.ReLU(inplace=True)\n", " self.lstm = nn.LSTM(256, 256, bidirectional=True, batch_first=True)\n", " self.linear2 = nn.Linear(512, num_classes)\n", " \n", " def forward(self, x):\n", " # CNN часть\n", " conv = self.cnn(x) # (batch, 512, 2, 32)\n", " \n", " # Перестановка и изменение формы для RNN\n", " conv = conv.permute(0, 3, 1, 2) # (batch, width=32, channels=512, height=2)\n", " conv = conv.view(conv.size(0), conv.size(1), -1) # (batch, 32, 512*2)\n", " \n", " # RNN часть\n", " out = self.linear1(conv) # (batch, 32, 256)\n", " out = self.relu(out) # (batch, 32, 256)\n", " out, _ = self.lstm(out) # (batch, 32, 512)\n", " out = self.linear2(out) # (batch, 32, num_classes)\n", " \n", " # Перестановка для CTC loss\n", " out = out.permute(1, 0, 2) # (32, batch, num_classes)\n", " return out\n" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "def train(model, loader, optimizer, criterion, device):\n", " model.train()\n", " epoch_loss = 0\n", " for images, labels, label_lengths in tqdm(loader, desc='Training'):\n", " images = images.to(device)\n", " labels = labels.to(device)\n", " label_lengths = label_lengths.to(device)\n", " \n", " optimizer.zero_grad()\n", " outputs = model(images) # (seq_len, batch, num_classes)\n", " \n", " # Определяем длину входных последовательностей (последний слой)\n", " input_lengths = torch.full(size=(outputs.size(1),), fill_value=outputs.size(0), dtype=torch.long).to(device)\n", " \n", " loss = criterion(outputs.log_softmax(2), labels, input_lengths, label_lengths)\n", " loss.backward()\n", " optimizer.step()\n", " \n", " epoch_loss += loss.item()\n", " return epoch_loss / len(loader)\n", "\n", "def validate(model, loader, criterion, device):\n", " model.eval()\n", " epoch_loss = 0\n", " with torch.no_grad():\n", " for images, labels, label_lengths in tqdm(loader, desc='Validation'):\n", " images = images.to(device)\n", " labels = labels.to(device)\n", " label_lengths = label_lengths.to(device)\n", " \n", " outputs = model(images)\n", " \n", " input_lengths = torch.full(size=(outputs.size(1),), fill_value=outputs.size(0), dtype=torch.long).to(device)\n", " \n", " loss = criterion(outputs.log_softmax(2), labels, input_lengths, label_lengths)\n", " epoch_loss += loss.item()\n", " return epoch_loss / len(loader)\n", "\n", "def decode_predictions(preds, blank=0):\n", " preds = preds.argmax(2) # (seq_len, batch)\n", " preds = preds.permute(1, 0) # (batch, seq_len)\n", " decoded = []\n", " for pred in preds:\n", " pred = pred.tolist()\n", " decoded_seq = []\n", " previous = blank\n", " for p in pred:\n", " if p != previous and p != blank:\n", " decoded_seq.append(IDX_TO_CHAR.get(p, ''))\n", " previous = p\n", " decoded.append(''.join(decoded_seq))\n", " return decoded\n", "\n", "def evaluate(model, loader, device):\n", " model.eval()\n", " correct = 0\n", " total = 0\n", " with torch.no_grad():\n", " for images, labels, label_lengths in tqdm(loader, desc='Testing'):\n", " images = images.to(device)\n", " outputs = model(images)\n", " preds = decode_predictions(outputs)\n", " \n", " batch_size = images.size(0)\n", " start = 0\n", " for i in range(batch_size):\n", " length = label_lengths[i]\n", " true_label = ''.join([IDX_TO_CHAR.get(idx.item(), '') for idx in labels[start:start+length]])\n", " start += length\n", " pred_label = preds[i]\n", " if pred_label == true_label:\n", " correct += 1\n", " total += 1\n", " accuracy = correct / total * 100\n", " return accuracy\n" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "model = CRNN(num_classes=NUM_CLASSES).to(device)\n", "\n", "criterion = nn.CTCLoss(blank=0, zero_infinity=True)\n", "optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/10\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Training: 100%|██████████| 569/569 [01:20<00:00, 7.03it/s]\n", "Validation: 100%|██████████| 72/72 [00:04<00:00, 16.47it/s]\n", "Testing: 100%|██████████| 72/72 [00:05<00:00, 14.37it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train Loss: 0.0850 | Val Loss: 0.0767 | Accuracy: 96.73%\n", "Модель с лучшей потерей сохранена!\n", "Модель с лучшей точностью сохранена!\n", "Epoch 2/10\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Training: 100%|██████████| 569/569 [01:20<00:00, 7.07it/s]\n", "Validation: 100%|██████████| 72/72 [00:04<00:00, 17.92it/s]\n", "Testing: 100%|██████████| 72/72 [00:04<00:00, 14.78it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train Loss: 0.0829 | Val Loss: 0.0706 | Accuracy: 97.63%\n", "Модель с лучшей потерей сохранена!\n", "Модель с лучшей точностью сохранена!\n", "Epoch 3/10\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Training: 100%|██████████| 569/569 [01:20<00:00, 7.09it/s]\n", "Validation: 100%|██████████| 72/72 [00:04<00:00, 17.20it/s]\n", "Testing: 100%|██████████| 72/72 [00:04<00:00, 14.51it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train Loss: 0.0796 | Val Loss: 0.0763 | Accuracy: 96.81%\n", "Epoch 4/10\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Training: 100%|██████████| 569/569 [01:19<00:00, 7.11it/s]\n", "Validation: 100%|██████████| 72/72 [00:04<00:00, 17.50it/s]\n", "Testing: 100%|██████████| 72/72 [00:04<00:00, 15.43it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train Loss: 0.0805 | Val Loss: 0.0732 | Accuracy: 97.72%\n", "Модель с лучшей точностью сохранена!\n", "Epoch 5/10\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Training: 100%|██████████| 569/569 [01:19<00:00, 7.17it/s]\n", "Validation: 100%|██████████| 72/72 [00:04<00:00, 17.76it/s]\n", "Testing: 100%|██████████| 72/72 [00:04<00:00, 14.62it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train Loss: 0.0787 | Val Loss: 0.0716 | Accuracy: 97.76%\n", "Модель с лучшей точностью сохранена!\n", "Epoch 6/10\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Training: 100%|██████████| 569/569 [01:19<00:00, 7.13it/s]\n", "Validation: 100%|██████████| 72/72 [00:03<00:00, 18.01it/s]\n", "Testing: 100%|██████████| 72/72 [00:04<00:00, 14.58it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train Loss: 0.0775 | Val Loss: 0.0732 | Accuracy: 97.63%\n", "Epoch 7/10\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Training: 100%|██████████| 569/569 [01:18<00:00, 7.21it/s]\n", "Validation: 100%|██████████| 72/72 [00:04<00:00, 17.52it/s]\n", "Testing: 100%|██████████| 72/72 [00:04<00:00, 14.53it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train Loss: 0.0731 | Val Loss: 0.0746 | Accuracy: 97.58%\n", "Epoch 8/10\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Training: 100%|██████████| 569/569 [01:19<00:00, 7.15it/s]\n", "Validation: 100%|██████████| 72/72 [00:04<00:00, 17.64it/s]\n", "Testing: 100%|██████████| 72/72 [00:04<00:00, 14.43it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train Loss: 0.0745 | Val Loss: 0.0753 | Accuracy: 96.77%\n", "Epoch 9/10\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Training: 100%|██████████| 569/569 [01:19<00:00, 7.19it/s]\n", "Validation: 100%|██████████| 72/72 [00:04<00:00, 17.87it/s]\n", "Testing: 100%|██████████| 72/72 [00:05<00:00, 14.38it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train Loss: 0.0734 | Val Loss: 0.0742 | Accuracy: 97.80%\n", "Модель с лучшей точностью сохранена!\n", "Epoch 10/10\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Training: 100%|██████████| 569/569 [01:19<00:00, 7.15it/s]\n", "Validation: 100%|██████████| 72/72 [00:04<00:00, 17.24it/s]\n", "Testing: 100%|██████████| 72/72 [00:04<00:00, 14.55it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train Loss: 0.0767 | Val Loss: 0.0776 | Accuracy: 96.95%\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "num_epochs = 10\n", "best_val_loss = float('inf')\n", "best_accuracy = 0.0\n", "\n", "for epoch in range(1, num_epochs + 1):\n", " print(f'Epoch {epoch}/{num_epochs}')\n", " \n", " train_loss = train(model, train_loader, optimizer, criterion, device)\n", " val_loss = validate(model, val_loader, criterion, device)\n", " accuracy = evaluate(model, val_loader, device)\n", " \n", " print(f'Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Accuracy: {accuracy:.2f}%')\n", " \n", " if val_loss < best_val_loss:\n", " best_val_loss = val_loss\n", " torch.save(model.state_dict(), 'best_loss_model_3.pth')\n", " print('Модель с лучшей потерей сохранена!')\n", " \n", " if accuracy > best_accuracy:\n", " best_accuracy = accuracy\n", " torch.save(model.state_dict(), 'best_accuracy_model_3.pth')\n", " print('Модель с лучшей точностью сохранена!')\n" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Testing: 100%|██████████| 72/72 [00:04<00:00, 14.55it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Точность на тестовом наборе: 96.68%\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "#model.load_state_dict(torch.load('models/best_accuracy_model_3.pth'))\n", "test_accuracy = evaluate(model, test_loader, device)\n", "print(f'Точность на тестовом наборе: {test_accuracy:.2f}%')" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Processing Images: 100%|██████████| 8/8 [00:00<00:00, 54.74it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "A023TY97.png: A023TY97\n", "A413YE97.png: A413YE97\n", "B642OT97.png: B642OT97\n", "H702TH97.png: H702TH97\n", "K263CO97.png: K263CO97\n", "O571KT99.png: O571KT99\n", "T829MK97.png: T829MK97\n", "Y726PA97.png: Y726PA97\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "import os\n", "from PIL import Image\n", "import torch\n", "from torchvision import transforms\n", "from tqdm import tqdm \n", "\n", "def recognize_license_plates(model, folder_path, transform, device):\n", " model.eval() \n", " images = [img for img in os.listdir(folder_path) if img.endswith(('.png', '.jpg', '.jpeg'))]\n", " \n", " results = {}\n", " \n", " for img_name in tqdm(images, desc=\"Processing Images\"):\n", " img_path = os.path.join(folder_path, img_name)\n", " image = Image.open(img_path).convert('L') \n", " \n", " image_tensor = transform(image).unsqueeze(0).to(device)\n", " \n", " with torch.no_grad():\n", " output = model(image_tensor) # (seq_len, batch, num_classes)\n", " \n", " decoded_text = decode_predictions(output)\n", " \n", " results[img_name] = decoded_text[0]\n", " \n", " return results\n", "\n", "folder_path = 'dataset-ocr/fine-tune-train/img_plate_image' # Путь к папке с изображениями\n", "#model.load_state_dict(torch.load('best_accuracy_model_2.pth'))\n", "#model.to(device)\n", "\n", "results = recognize_license_plates(model, folder_path, transform, device)\n", "\n", "for img_name, text in results.items():\n", " print(f\"{img_name}: {text}\")\n" ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.13" } }, "nbformat": 4, "nbformat_minor": 2 }