number-plate/train_crnn.ipynb

648 lines
23 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Используемое устройство: cuda\n"
]
}
],
"source": [
"import os\n",
"import string\n",
"import torch\n",
"import torch.nn as nn\n",
"from torch.utils.data import Dataset, DataLoader\n",
"from torchvision import transforms\n",
"from PIL import Image\n",
"from tqdm import tqdm\n",
"\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"print(f\"Используемое устройство: {device}\")\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"ALPHABET = '-ABEKMHOPCTYX0123456789'\n",
"CHAR_TO_IDX = {char: idx + 1 for idx, char in enumerate(ALPHABET)} # 0 будет использоваться для CTC blank\n",
"IDX_TO_CHAR = {idx + 1: char for idx, char in enumerate(ALPHABET)}\n",
"\n",
"NUM_CLASSES = len(ALPHABET) + 1\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"class LicensePlateDataset(Dataset):\n",
" def __init__(self, root_dir, transform=None):\n",
" \"\"\"\n",
" Args:\n",
" root_dir (string): Путь к директории с данными (train, val, test)\n",
" transform (callable, optional): Трансформации для изображений\n",
" \"\"\"\n",
" self.root_dir = root_dir\n",
" self.img_dir = os.path.join(root_dir, 'img')\n",
" self.transform = transform\n",
" self.images = [img for img in os.listdir(self.img_dir) if img.endswith(('.png', '.jpg', '.jpeg'))]\n",
" \n",
" def __len__(self):\n",
" return len(self.images)\n",
" \n",
" def __getitem__(self, idx):\n",
" img_name = self.images[idx]\n",
" img_path = os.path.join(self.img_dir, img_name)\n",
" image = Image.open(img_path).convert('L') \n",
" if self.transform:\n",
" image = self.transform(image)\n",
" \n",
" label_str = os.path.splitext(img_name)[0].upper()\n",
" label = [CHAR_TO_IDX[char] for char in label_str if char in CHAR_TO_IDX]\n",
" label = torch.tensor(label, dtype=torch.long)\n",
" \n",
" return image, label\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"transform = transforms.Compose([\n",
" transforms.Resize((32, 128)),\n",
" transforms.ToTensor(),\n",
" transforms.Normalize((0.5,), (0.5,))\n",
"])\n",
"\n",
"import random\n",
"import torchvision.transforms.functional as F\n",
"from torchvision import transforms\n",
"\n",
"def random_padding(img):\n",
" pad_left = random.randint(5, 15)\n",
" pad_right = random.randint(5, 15)\n",
" pad_top = random.randint(5, 15)\n",
" pad_bottom = random.randint(5, 15)\n",
" return F.pad(img, (pad_left, pad_top, pad_right, pad_bottom), fill=0)\n",
"\n",
"# Обновленный transform\n",
"transform_train = transforms.Compose([\n",
" transforms.Lambda(random_padding),\n",
" transforms.Resize((32, 128)), # Изменяем размер изображения\n",
" transforms.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)), # Сдвиг, масштабирование, поворот\n",
" transforms.ColorJitter(brightness=0.2, contrast=0.2), # Изменение яркости и контраста\n",
" transforms.RandomPerspective(distortion_scale=0.2, p=0.5), # Перспективные искажения\n",
" transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)), # Размытие\n",
" transforms.ToTensor(), # Преобразуем в тензор\n",
" transforms.Normalize((0.5,), (0.5,)), # Нормализация\n",
"])\n",
"\n",
"def collate_fn(batch):\n",
" images, labels = zip(*batch)\n",
" images = torch.stack(images, 0)\n",
" \n",
" # Соединяем все метки в один тензор\n",
" label_lengths = torch.tensor([len(label) for label in labels], dtype=torch.long)\n",
" labels_concat = torch.cat(labels)\n",
" \n",
" return images, labels_concat, label_lengths\n"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"train_dataset = LicensePlateDataset(root_dir='dataset-ocr-new/train', transform=transform_train)\n",
"val_dataset = LicensePlateDataset(root_dir='dataset-ocr-new/val', transform=transform)\n",
"test_dataset = LicensePlateDataset(root_dir='dataset-ocr-new/test', transform=transform)\n",
"test_10 = LicensePlateDataset(root_dir=r'dataset-ocr\\fine-tune-val', transform=transform)\n",
"\n",
"batch_size = 64\n",
"\n",
"train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, collate_fn=lambda x: collate_fn(x))\n",
"val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0, collate_fn=lambda x: collate_fn(x))\n",
"test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0, collate_fn=lambda x: collate_fn(x))\n"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"class CRNN(nn.Module):\n",
" def __init__(self, num_classes):\n",
" super(CRNN, self).__init__()\n",
" \n",
" # CNN часть\n",
" self.cnn = nn.Sequential(\n",
" nn.Conv2d(1, 64, kernel_size=3, padding=1), # (batch, 64, 32, 128)\n",
" nn.ReLU(inplace=True),\n",
" nn.MaxPool2d(2, 2), # (batch, 64, 16, 64)\n",
" \n",
" nn.Conv2d(64, 128, kernel_size=3, padding=1), # (batch, 128, 16, 64)\n",
" nn.ReLU(inplace=True),\n",
" nn.MaxPool2d(2, 2), # (batch, 128, 8, 32)\n",
" \n",
" nn.Conv2d(128, 256, kernel_size=3, padding=1), # (batch, 256, 8, 32)\n",
" nn.ReLU(inplace=True),\n",
" nn.BatchNorm2d(256),\n",
" \n",
" nn.Conv2d(256, 256, kernel_size=3, padding=1), # (batch, 256, 8, 32)\n",
" nn.ReLU(inplace=True),\n",
" nn.MaxPool2d((2,1), (2,1)), # (batch, 256, 4, 32)\n",
" \n",
" nn.Conv2d(256, 512, kernel_size=3, padding=1), # (batch, 512, 4, 32)\n",
" nn.ReLU(inplace=True),\n",
" nn.BatchNorm2d(512),\n",
" \n",
" nn.Conv2d(512, 512, kernel_size=3, padding=1), # (batch, 512, 4, 32)\n",
" nn.ReLU(inplace=True),\n",
" nn.MaxPool2d((2,1), (2,1)), # (batch, 512, 2, 32)\n",
" )\n",
" \n",
" # RNN часть\n",
" self.linear1 = nn.Linear(512 * 2, 256)\n",
" self.relu = nn.ReLU(inplace=True)\n",
" self.lstm = nn.LSTM(256, 256, bidirectional=True, batch_first=True)\n",
" self.linear2 = nn.Linear(512, num_classes)\n",
" \n",
" def forward(self, x):\n",
" # CNN часть\n",
" conv = self.cnn(x) # (batch, 512, 2, 32)\n",
" \n",
" # Перестановка и изменение формы для RNN\n",
" conv = conv.permute(0, 3, 1, 2) # (batch, width=32, channels=512, height=2)\n",
" conv = conv.view(conv.size(0), conv.size(1), -1) # (batch, 32, 512*2)\n",
" \n",
" # RNN часть\n",
" out = self.linear1(conv) # (batch, 32, 256)\n",
" out = self.relu(out) # (batch, 32, 256)\n",
" out, _ = self.lstm(out) # (batch, 32, 512)\n",
" out = self.linear2(out) # (batch, 32, num_classes)\n",
" \n",
" # Перестановка для CTC loss\n",
" out = out.permute(1, 0, 2) # (32, batch, num_classes)\n",
" return out\n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"def train(model, loader, optimizer, criterion, device):\n",
" model.train()\n",
" epoch_loss = 0\n",
" for images, labels, label_lengths in tqdm(loader, desc='Training'):\n",
" images = images.to(device)\n",
" labels = labels.to(device)\n",
" label_lengths = label_lengths.to(device)\n",
" \n",
" optimizer.zero_grad()\n",
" outputs = model(images) # (seq_len, batch, num_classes)\n",
" \n",
" # Определяем длину входных последовательностей (последний слой)\n",
" input_lengths = torch.full(size=(outputs.size(1),), fill_value=outputs.size(0), dtype=torch.long).to(device)\n",
" \n",
" loss = criterion(outputs.log_softmax(2), labels, input_lengths, label_lengths)\n",
" loss.backward()\n",
" optimizer.step()\n",
" \n",
" epoch_loss += loss.item()\n",
" return epoch_loss / len(loader)\n",
"\n",
"def validate(model, loader, criterion, device):\n",
" model.eval()\n",
" epoch_loss = 0\n",
" with torch.no_grad():\n",
" for images, labels, label_lengths in tqdm(loader, desc='Validation'):\n",
" images = images.to(device)\n",
" labels = labels.to(device)\n",
" label_lengths = label_lengths.to(device)\n",
" \n",
" outputs = model(images)\n",
" \n",
" input_lengths = torch.full(size=(outputs.size(1),), fill_value=outputs.size(0), dtype=torch.long).to(device)\n",
" \n",
" loss = criterion(outputs.log_softmax(2), labels, input_lengths, label_lengths)\n",
" epoch_loss += loss.item()\n",
" return epoch_loss / len(loader)\n",
"\n",
"def decode_predictions(preds, blank=0):\n",
" preds = preds.argmax(2) # (seq_len, batch)\n",
" preds = preds.permute(1, 0) # (batch, seq_len)\n",
" decoded = []\n",
" for pred in preds:\n",
" pred = pred.tolist()\n",
" decoded_seq = []\n",
" previous = blank\n",
" for p in pred:\n",
" if p != previous and p != blank:\n",
" decoded_seq.append(IDX_TO_CHAR.get(p, ''))\n",
" previous = p\n",
" decoded.append(''.join(decoded_seq))\n",
" return decoded\n",
"\n",
"def evaluate(model, loader, device):\n",
" model.eval()\n",
" correct = 0\n",
" total = 0\n",
" with torch.no_grad():\n",
" for images, labels, label_lengths in tqdm(loader, desc='Testing'):\n",
" images = images.to(device)\n",
" outputs = model(images)\n",
" preds = decode_predictions(outputs)\n",
" \n",
" batch_size = images.size(0)\n",
" start = 0\n",
" for i in range(batch_size):\n",
" length = label_lengths[i]\n",
" true_label = ''.join([IDX_TO_CHAR.get(idx.item(), '') for idx in labels[start:start+length]])\n",
" start += length\n",
" pred_label = preds[i]\n",
" if pred_label == true_label:\n",
" correct += 1\n",
" total += 1\n",
" accuracy = correct / total * 100\n",
" return accuracy\n"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"model = CRNN(num_classes=NUM_CLASSES).to(device)\n",
"\n",
"criterion = nn.CTCLoss(blank=0, zero_infinity=True)\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Training: 100%|██████████| 569/569 [01:20<00:00, 7.03it/s]\n",
"Validation: 100%|██████████| 72/72 [00:04<00:00, 16.47it/s]\n",
"Testing: 100%|██████████| 72/72 [00:05<00:00, 14.37it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train Loss: 0.0850 | Val Loss: 0.0767 | Accuracy: 96.73%\n",
"Модель с лучшей потерей сохранена!\n",
"Модель с лучшей точностью сохранена!\n",
"Epoch 2/10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Training: 100%|██████████| 569/569 [01:20<00:00, 7.07it/s]\n",
"Validation: 100%|██████████| 72/72 [00:04<00:00, 17.92it/s]\n",
"Testing: 100%|██████████| 72/72 [00:04<00:00, 14.78it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train Loss: 0.0829 | Val Loss: 0.0706 | Accuracy: 97.63%\n",
"Модель с лучшей потерей сохранена!\n",
"Модель с лучшей точностью сохранена!\n",
"Epoch 3/10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Training: 100%|██████████| 569/569 [01:20<00:00, 7.09it/s]\n",
"Validation: 100%|██████████| 72/72 [00:04<00:00, 17.20it/s]\n",
"Testing: 100%|██████████| 72/72 [00:04<00:00, 14.51it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train Loss: 0.0796 | Val Loss: 0.0763 | Accuracy: 96.81%\n",
"Epoch 4/10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Training: 100%|██████████| 569/569 [01:19<00:00, 7.11it/s]\n",
"Validation: 100%|██████████| 72/72 [00:04<00:00, 17.50it/s]\n",
"Testing: 100%|██████████| 72/72 [00:04<00:00, 15.43it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train Loss: 0.0805 | Val Loss: 0.0732 | Accuracy: 97.72%\n",
"Модель с лучшей точностью сохранена!\n",
"Epoch 5/10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Training: 100%|██████████| 569/569 [01:19<00:00, 7.17it/s]\n",
"Validation: 100%|██████████| 72/72 [00:04<00:00, 17.76it/s]\n",
"Testing: 100%|██████████| 72/72 [00:04<00:00, 14.62it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train Loss: 0.0787 | Val Loss: 0.0716 | Accuracy: 97.76%\n",
"Модель с лучшей точностью сохранена!\n",
"Epoch 6/10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Training: 100%|██████████| 569/569 [01:19<00:00, 7.13it/s]\n",
"Validation: 100%|██████████| 72/72 [00:03<00:00, 18.01it/s]\n",
"Testing: 100%|██████████| 72/72 [00:04<00:00, 14.58it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train Loss: 0.0775 | Val Loss: 0.0732 | Accuracy: 97.63%\n",
"Epoch 7/10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Training: 100%|██████████| 569/569 [01:18<00:00, 7.21it/s]\n",
"Validation: 100%|██████████| 72/72 [00:04<00:00, 17.52it/s]\n",
"Testing: 100%|██████████| 72/72 [00:04<00:00, 14.53it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train Loss: 0.0731 | Val Loss: 0.0746 | Accuracy: 97.58%\n",
"Epoch 8/10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Training: 100%|██████████| 569/569 [01:19<00:00, 7.15it/s]\n",
"Validation: 100%|██████████| 72/72 [00:04<00:00, 17.64it/s]\n",
"Testing: 100%|██████████| 72/72 [00:04<00:00, 14.43it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train Loss: 0.0745 | Val Loss: 0.0753 | Accuracy: 96.77%\n",
"Epoch 9/10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Training: 100%|██████████| 569/569 [01:19<00:00, 7.19it/s]\n",
"Validation: 100%|██████████| 72/72 [00:04<00:00, 17.87it/s]\n",
"Testing: 100%|██████████| 72/72 [00:05<00:00, 14.38it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train Loss: 0.0734 | Val Loss: 0.0742 | Accuracy: 97.80%\n",
"Модель с лучшей точностью сохранена!\n",
"Epoch 10/10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Training: 100%|██████████| 569/569 [01:19<00:00, 7.15it/s]\n",
"Validation: 100%|██████████| 72/72 [00:04<00:00, 17.24it/s]\n",
"Testing: 100%|██████████| 72/72 [00:04<00:00, 14.55it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train Loss: 0.0767 | Val Loss: 0.0776 | Accuracy: 96.95%\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"num_epochs = 10\n",
"best_val_loss = float('inf')\n",
"best_accuracy = 0.0\n",
"\n",
"for epoch in range(1, num_epochs + 1):\n",
" print(f'Epoch {epoch}/{num_epochs}')\n",
" \n",
" train_loss = train(model, train_loader, optimizer, criterion, device)\n",
" val_loss = validate(model, val_loader, criterion, device)\n",
" accuracy = evaluate(model, val_loader, device)\n",
" \n",
" print(f'Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Accuracy: {accuracy:.2f}%')\n",
" \n",
" if val_loss < best_val_loss:\n",
" best_val_loss = val_loss\n",
" torch.save(model.state_dict(), 'best_loss_model_3.pth')\n",
" print('Модель с лучшей потерей сохранена!')\n",
" \n",
" if accuracy > best_accuracy:\n",
" best_accuracy = accuracy\n",
" torch.save(model.state_dict(), 'best_accuracy_model_3.pth')\n",
" print('Модель с лучшей точностью сохранена!')\n"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Testing: 100%|██████████| 72/72 [00:04<00:00, 14.55it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Точность на тестовом наборе: 96.68%\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"#model.load_state_dict(torch.load('models/best_accuracy_model_3.pth'))\n",
"test_accuracy = evaluate(model, test_loader, device)\n",
"print(f'Точность на тестовом наборе: {test_accuracy:.2f}%')"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processing Images: 100%|██████████| 8/8 [00:00<00:00, 54.74it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"A023TY97.png: A023TY97\n",
"A413YE97.png: A413YE97\n",
"B642OT97.png: B642OT97\n",
"H702TH97.png: H702TH97\n",
"K263CO97.png: K263CO97\n",
"O571KT99.png: O571KT99\n",
"T829MK97.png: T829MK97\n",
"Y726PA97.png: Y726PA97\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"import os\n",
"from PIL import Image\n",
"import torch\n",
"from torchvision import transforms\n",
"from tqdm import tqdm \n",
"\n",
"def recognize_license_plates(model, folder_path, transform, device):\n",
" model.eval() \n",
" images = [img for img in os.listdir(folder_path) if img.endswith(('.png', '.jpg', '.jpeg'))]\n",
" \n",
" results = {}\n",
" \n",
" for img_name in tqdm(images, desc=\"Processing Images\"):\n",
" img_path = os.path.join(folder_path, img_name)\n",
" image = Image.open(img_path).convert('L') \n",
" \n",
" image_tensor = transform(image).unsqueeze(0).to(device)\n",
" \n",
" with torch.no_grad():\n",
" output = model(image_tensor) # (seq_len, batch, num_classes)\n",
" \n",
" decoded_text = decode_predictions(output)\n",
" \n",
" results[img_name] = decoded_text[0]\n",
" \n",
" return results\n",
"\n",
"folder_path = 'dataset-ocr/fine-tune-train/img_plate_image' # Путь к папке с изображениями\n",
"#model.load_state_dict(torch.load('best_accuracy_model_2.pth'))\n",
"#model.to(device)\n",
"\n",
"results = recognize_license_plates(model, folder_path, transform, device)\n",
"\n",
"for img_name, text in results.items():\n",
" print(f\"{img_name}: {text}\")\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}