From a21460ef165543627394acceac78e5d9b7ceef65 Mon Sep 17 00:00:00 2001 From: itqop Date: Wed, 15 Jan 2025 05:45:32 +0300 Subject: [PATCH] feat: add train_crnn.ipynb for CRNN model training Added a Jupyter Notebook `train_crnn.ipynb` to the repository for training the CRNN model. This notebook includes: - Steps for preparing the dataset. - Training pipeline for the CRNN architecture. - Evaluation of the trained model. This addition enhances the project's flexibility for users who want to retrain the CRNN model on custom datasets. --- train_crnn.ipynb | 647 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 647 insertions(+) create mode 100644 train_crnn.ipynb diff --git a/train_crnn.ipynb b/train_crnn.ipynb new file mode 100644 index 0000000..438db5a --- /dev/null +++ b/train_crnn.ipynb @@ -0,0 +1,647 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Используемое устройство: cuda\n" + ] + } + ], + "source": [ + "import os\n", + "import string\n", + "import torch\n", + "import torch.nn as nn\n", + "from torch.utils.data import Dataset, DataLoader\n", + "from torchvision import transforms\n", + "from PIL import Image\n", + "from tqdm import tqdm\n", + "\n", + "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", + "print(f\"Используемое устройство: {device}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "ALPHABET = '-ABEKMHOPCTYX0123456789'\n", + "CHAR_TO_IDX = {char: idx + 1 for idx, char in enumerate(ALPHABET)} # 0 будет использоваться для CTC blank\n", + "IDX_TO_CHAR = {idx + 1: char for idx, char in enumerate(ALPHABET)}\n", + "\n", + "NUM_CLASSES = len(ALPHABET) + 1\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "class LicensePlateDataset(Dataset):\n", + " def __init__(self, root_dir, transform=None):\n", + " \"\"\"\n", + " Args:\n", + " root_dir (string): Путь к директории с данными (train, val, test)\n", + " transform (callable, optional): Трансформации для изображений\n", + " \"\"\"\n", + " self.root_dir = root_dir\n", + " self.img_dir = os.path.join(root_dir, 'img')\n", + " self.transform = transform\n", + " self.images = [img for img in os.listdir(self.img_dir) if img.endswith(('.png', '.jpg', '.jpeg'))]\n", + " \n", + " def __len__(self):\n", + " return len(self.images)\n", + " \n", + " def __getitem__(self, idx):\n", + " img_name = self.images[idx]\n", + " img_path = os.path.join(self.img_dir, img_name)\n", + " image = Image.open(img_path).convert('L') \n", + " if self.transform:\n", + " image = self.transform(image)\n", + " \n", + " label_str = os.path.splitext(img_name)[0].upper()\n", + " label = [CHAR_TO_IDX[char] for char in label_str if char in CHAR_TO_IDX]\n", + " label = torch.tensor(label, dtype=torch.long)\n", + " \n", + " return image, label\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "transform = transforms.Compose([\n", + " transforms.Resize((32, 128)),\n", + " transforms.ToTensor(),\n", + " transforms.Normalize((0.5,), (0.5,))\n", + "])\n", + "\n", + "import random\n", + "import torchvision.transforms.functional as F\n", + "from torchvision import transforms\n", + "\n", + "def random_padding(img):\n", + " pad_left = random.randint(5, 15)\n", + " pad_right = random.randint(5, 15)\n", + " pad_top = random.randint(5, 15)\n", + " pad_bottom = random.randint(5, 15)\n", + " return F.pad(img, (pad_left, pad_top, pad_right, pad_bottom), fill=0)\n", + "\n", + "# Обновленный transform\n", + "transform_train = transforms.Compose([\n", + " transforms.Lambda(random_padding),\n", + " transforms.Resize((32, 128)), # Изменяем размер изображения\n", + " transforms.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)), # Сдвиг, масштабирование, поворот\n", + " transforms.ColorJitter(brightness=0.2, contrast=0.2), # Изменение яркости и контраста\n", + " transforms.RandomPerspective(distortion_scale=0.2, p=0.5), # Перспективные искажения\n", + " transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)), # Размытие\n", + " transforms.ToTensor(), # Преобразуем в тензор\n", + " transforms.Normalize((0.5,), (0.5,)), # Нормализация\n", + "])\n", + "\n", + "def collate_fn(batch):\n", + " images, labels = zip(*batch)\n", + " images = torch.stack(images, 0)\n", + " \n", + " # Соединяем все метки в один тензор\n", + " label_lengths = torch.tensor([len(label) for label in labels], dtype=torch.long)\n", + " labels_concat = torch.cat(labels)\n", + " \n", + " return images, labels_concat, label_lengths\n" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [], + "source": [ + "train_dataset = LicensePlateDataset(root_dir='dataset-ocr-new/train', transform=transform_train)\n", + "val_dataset = LicensePlateDataset(root_dir='dataset-ocr-new/val', transform=transform)\n", + "test_dataset = LicensePlateDataset(root_dir='dataset-ocr-new/test', transform=transform)\n", + "test_10 = LicensePlateDataset(root_dir=r'dataset-ocr\\fine-tune-val', transform=transform)\n", + "\n", + "batch_size = 64\n", + "\n", + "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, collate_fn=lambda x: collate_fn(x))\n", + "val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0, collate_fn=lambda x: collate_fn(x))\n", + "test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0, collate_fn=lambda x: collate_fn(x))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "class CRNN(nn.Module):\n", + " def __init__(self, num_classes):\n", + " super(CRNN, self).__init__()\n", + " \n", + " # CNN часть\n", + " self.cnn = nn.Sequential(\n", + " nn.Conv2d(1, 64, kernel_size=3, padding=1), # (batch, 64, 32, 128)\n", + " nn.ReLU(inplace=True),\n", + " nn.MaxPool2d(2, 2), # (batch, 64, 16, 64)\n", + " \n", + " nn.Conv2d(64, 128, kernel_size=3, padding=1), # (batch, 128, 16, 64)\n", + " nn.ReLU(inplace=True),\n", + " nn.MaxPool2d(2, 2), # (batch, 128, 8, 32)\n", + " \n", + " nn.Conv2d(128, 256, kernel_size=3, padding=1), # (batch, 256, 8, 32)\n", + " nn.ReLU(inplace=True),\n", + " nn.BatchNorm2d(256),\n", + " \n", + " nn.Conv2d(256, 256, kernel_size=3, padding=1), # (batch, 256, 8, 32)\n", + " nn.ReLU(inplace=True),\n", + " nn.MaxPool2d((2,1), (2,1)), # (batch, 256, 4, 32)\n", + " \n", + " nn.Conv2d(256, 512, kernel_size=3, padding=1), # (batch, 512, 4, 32)\n", + " nn.ReLU(inplace=True),\n", + " nn.BatchNorm2d(512),\n", + " \n", + " nn.Conv2d(512, 512, kernel_size=3, padding=1), # (batch, 512, 4, 32)\n", + " nn.ReLU(inplace=True),\n", + " nn.MaxPool2d((2,1), (2,1)), # (batch, 512, 2, 32)\n", + " )\n", + " \n", + " # RNN часть\n", + " self.linear1 = nn.Linear(512 * 2, 256)\n", + " self.relu = nn.ReLU(inplace=True)\n", + " self.lstm = nn.LSTM(256, 256, bidirectional=True, batch_first=True)\n", + " self.linear2 = nn.Linear(512, num_classes)\n", + " \n", + " def forward(self, x):\n", + " # CNN часть\n", + " conv = self.cnn(x) # (batch, 512, 2, 32)\n", + " \n", + " # Перестановка и изменение формы для RNN\n", + " conv = conv.permute(0, 3, 1, 2) # (batch, width=32, channels=512, height=2)\n", + " conv = conv.view(conv.size(0), conv.size(1), -1) # (batch, 32, 512*2)\n", + " \n", + " # RNN часть\n", + " out = self.linear1(conv) # (batch, 32, 256)\n", + " out = self.relu(out) # (batch, 32, 256)\n", + " out, _ = self.lstm(out) # (batch, 32, 512)\n", + " out = self.linear2(out) # (batch, 32, num_classes)\n", + " \n", + " # Перестановка для CTC loss\n", + " out = out.permute(1, 0, 2) # (32, batch, num_classes)\n", + " return out\n" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "def train(model, loader, optimizer, criterion, device):\n", + " model.train()\n", + " epoch_loss = 0\n", + " for images, labels, label_lengths in tqdm(loader, desc='Training'):\n", + " images = images.to(device)\n", + " labels = labels.to(device)\n", + " label_lengths = label_lengths.to(device)\n", + " \n", + " optimizer.zero_grad()\n", + " outputs = model(images) # (seq_len, batch, num_classes)\n", + " \n", + " # Определяем длину входных последовательностей (последний слой)\n", + " input_lengths = torch.full(size=(outputs.size(1),), fill_value=outputs.size(0), dtype=torch.long).to(device)\n", + " \n", + " loss = criterion(outputs.log_softmax(2), labels, input_lengths, label_lengths)\n", + " loss.backward()\n", + " optimizer.step()\n", + " \n", + " epoch_loss += loss.item()\n", + " return epoch_loss / len(loader)\n", + "\n", + "def validate(model, loader, criterion, device):\n", + " model.eval()\n", + " epoch_loss = 0\n", + " with torch.no_grad():\n", + " for images, labels, label_lengths in tqdm(loader, desc='Validation'):\n", + " images = images.to(device)\n", + " labels = labels.to(device)\n", + " label_lengths = label_lengths.to(device)\n", + " \n", + " outputs = model(images)\n", + " \n", + " input_lengths = torch.full(size=(outputs.size(1),), fill_value=outputs.size(0), dtype=torch.long).to(device)\n", + " \n", + " loss = criterion(outputs.log_softmax(2), labels, input_lengths, label_lengths)\n", + " epoch_loss += loss.item()\n", + " return epoch_loss / len(loader)\n", + "\n", + "def decode_predictions(preds, blank=0):\n", + " preds = preds.argmax(2) # (seq_len, batch)\n", + " preds = preds.permute(1, 0) # (batch, seq_len)\n", + " decoded = []\n", + " for pred in preds:\n", + " pred = pred.tolist()\n", + " decoded_seq = []\n", + " previous = blank\n", + " for p in pred:\n", + " if p != previous and p != blank:\n", + " decoded_seq.append(IDX_TO_CHAR.get(p, ''))\n", + " previous = p\n", + " decoded.append(''.join(decoded_seq))\n", + " return decoded\n", + "\n", + "def evaluate(model, loader, device):\n", + " model.eval()\n", + " correct = 0\n", + " total = 0\n", + " with torch.no_grad():\n", + " for images, labels, label_lengths in tqdm(loader, desc='Testing'):\n", + " images = images.to(device)\n", + " outputs = model(images)\n", + " preds = decode_predictions(outputs)\n", + " \n", + " batch_size = images.size(0)\n", + " start = 0\n", + " for i in range(batch_size):\n", + " length = label_lengths[i]\n", + " true_label = ''.join([IDX_TO_CHAR.get(idx.item(), '') for idx in labels[start:start+length]])\n", + " start += length\n", + " pred_label = preds[i]\n", + " if pred_label == true_label:\n", + " correct += 1\n", + " total += 1\n", + " accuracy = correct / total * 100\n", + " return accuracy\n" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [], + "source": [ + "model = CRNN(num_classes=NUM_CLASSES).to(device)\n", + "\n", + "criterion = nn.CTCLoss(blank=0, zero_infinity=True)\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1/10\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Training: 100%|██████████| 569/569 [01:20<00:00, 7.03it/s]\n", + "Validation: 100%|██████████| 72/72 [00:04<00:00, 16.47it/s]\n", + "Testing: 100%|██████████| 72/72 [00:05<00:00, 14.37it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train Loss: 0.0850 | Val Loss: 0.0767 | Accuracy: 96.73%\n", + "Модель с лучшей потерей сохранена!\n", + "Модель с лучшей точностью сохранена!\n", + "Epoch 2/10\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Training: 100%|██████████| 569/569 [01:20<00:00, 7.07it/s]\n", + "Validation: 100%|██████████| 72/72 [00:04<00:00, 17.92it/s]\n", + "Testing: 100%|██████████| 72/72 [00:04<00:00, 14.78it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train Loss: 0.0829 | Val Loss: 0.0706 | Accuracy: 97.63%\n", + "Модель с лучшей потерей сохранена!\n", + "Модель с лучшей точностью сохранена!\n", + "Epoch 3/10\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Training: 100%|██████████| 569/569 [01:20<00:00, 7.09it/s]\n", + "Validation: 100%|██████████| 72/72 [00:04<00:00, 17.20it/s]\n", + "Testing: 100%|██████████| 72/72 [00:04<00:00, 14.51it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train Loss: 0.0796 | Val Loss: 0.0763 | Accuracy: 96.81%\n", + "Epoch 4/10\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Training: 100%|██████████| 569/569 [01:19<00:00, 7.11it/s]\n", + "Validation: 100%|██████████| 72/72 [00:04<00:00, 17.50it/s]\n", + "Testing: 100%|██████████| 72/72 [00:04<00:00, 15.43it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train Loss: 0.0805 | Val Loss: 0.0732 | Accuracy: 97.72%\n", + "Модель с лучшей точностью сохранена!\n", + "Epoch 5/10\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Training: 100%|██████████| 569/569 [01:19<00:00, 7.17it/s]\n", + "Validation: 100%|██████████| 72/72 [00:04<00:00, 17.76it/s]\n", + "Testing: 100%|██████████| 72/72 [00:04<00:00, 14.62it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train Loss: 0.0787 | Val Loss: 0.0716 | Accuracy: 97.76%\n", + "Модель с лучшей точностью сохранена!\n", + "Epoch 6/10\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Training: 100%|██████████| 569/569 [01:19<00:00, 7.13it/s]\n", + "Validation: 100%|██████████| 72/72 [00:03<00:00, 18.01it/s]\n", + "Testing: 100%|██████████| 72/72 [00:04<00:00, 14.58it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train Loss: 0.0775 | Val Loss: 0.0732 | Accuracy: 97.63%\n", + "Epoch 7/10\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Training: 100%|██████████| 569/569 [01:18<00:00, 7.21it/s]\n", + "Validation: 100%|██████████| 72/72 [00:04<00:00, 17.52it/s]\n", + "Testing: 100%|██████████| 72/72 [00:04<00:00, 14.53it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train Loss: 0.0731 | Val Loss: 0.0746 | Accuracy: 97.58%\n", + "Epoch 8/10\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Training: 100%|██████████| 569/569 [01:19<00:00, 7.15it/s]\n", + "Validation: 100%|██████████| 72/72 [00:04<00:00, 17.64it/s]\n", + "Testing: 100%|██████████| 72/72 [00:04<00:00, 14.43it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train Loss: 0.0745 | Val Loss: 0.0753 | Accuracy: 96.77%\n", + "Epoch 9/10\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Training: 100%|██████████| 569/569 [01:19<00:00, 7.19it/s]\n", + "Validation: 100%|██████████| 72/72 [00:04<00:00, 17.87it/s]\n", + "Testing: 100%|██████████| 72/72 [00:05<00:00, 14.38it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train Loss: 0.0734 | Val Loss: 0.0742 | Accuracy: 97.80%\n", + "Модель с лучшей точностью сохранена!\n", + "Epoch 10/10\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Training: 100%|██████████| 569/569 [01:19<00:00, 7.15it/s]\n", + "Validation: 100%|██████████| 72/72 [00:04<00:00, 17.24it/s]\n", + "Testing: 100%|██████████| 72/72 [00:04<00:00, 14.55it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train Loss: 0.0767 | Val Loss: 0.0776 | Accuracy: 96.95%\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "num_epochs = 10\n", + "best_val_loss = float('inf')\n", + "best_accuracy = 0.0\n", + "\n", + "for epoch in range(1, num_epochs + 1):\n", + " print(f'Epoch {epoch}/{num_epochs}')\n", + " \n", + " train_loss = train(model, train_loader, optimizer, criterion, device)\n", + " val_loss = validate(model, val_loader, criterion, device)\n", + " accuracy = evaluate(model, val_loader, device)\n", + " \n", + " print(f'Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Accuracy: {accuracy:.2f}%')\n", + " \n", + " if val_loss < best_val_loss:\n", + " best_val_loss = val_loss\n", + " torch.save(model.state_dict(), 'best_loss_model_3.pth')\n", + " print('Модель с лучшей потерей сохранена!')\n", + " \n", + " if accuracy > best_accuracy:\n", + " best_accuracy = accuracy\n", + " torch.save(model.state_dict(), 'best_accuracy_model_3.pth')\n", + " print('Модель с лучшей точностью сохранена!')\n" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Testing: 100%|██████████| 72/72 [00:04<00:00, 14.55it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Точность на тестовом наборе: 96.68%\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "#model.load_state_dict(torch.load('models/best_accuracy_model_3.pth'))\n", + "test_accuracy = evaluate(model, test_loader, device)\n", + "print(f'Точность на тестовом наборе: {test_accuracy:.2f}%')" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Processing Images: 100%|██████████| 8/8 [00:00<00:00, 54.74it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "A023TY97.png: A023TY97\n", + "A413YE97.png: A413YE97\n", + "B642OT97.png: B642OT97\n", + "H702TH97.png: H702TH97\n", + "K263CO97.png: K263CO97\n", + "O571KT99.png: O571KT99\n", + "T829MK97.png: T829MK97\n", + "Y726PA97.png: Y726PA97\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "import os\n", + "from PIL import Image\n", + "import torch\n", + "from torchvision import transforms\n", + "from tqdm import tqdm \n", + "\n", + "def recognize_license_plates(model, folder_path, transform, device):\n", + " model.eval() \n", + " images = [img for img in os.listdir(folder_path) if img.endswith(('.png', '.jpg', '.jpeg'))]\n", + " \n", + " results = {}\n", + " \n", + " for img_name in tqdm(images, desc=\"Processing Images\"):\n", + " img_path = os.path.join(folder_path, img_name)\n", + " image = Image.open(img_path).convert('L') \n", + " \n", + " image_tensor = transform(image).unsqueeze(0).to(device)\n", + " \n", + " with torch.no_grad():\n", + " output = model(image_tensor) # (seq_len, batch, num_classes)\n", + " \n", + " decoded_text = decode_predictions(output)\n", + " \n", + " results[img_name] = decoded_text[0]\n", + " \n", + " return results\n", + "\n", + "folder_path = 'dataset-ocr/fine-tune-train/img_plate_image' # Путь к папке с изображениями\n", + "#model.load_state_dict(torch.load('best_accuracy_model_2.pth'))\n", + "#model.to(device)\n", + "\n", + "results = recognize_license_plates(model, folder_path, transform, device)\n", + "\n", + "for img_name, text in results.items():\n", + " print(f\"{img_name}: {text}\")\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}