feat: add train_crnn.ipynb for CRNN model training
Added a Jupyter Notebook `train_crnn.ipynb` to the repository for training the CRNN model. This notebook includes: - Steps for preparing the dataset. - Training pipeline for the CRNN architecture. - Evaluation of the trained model. This addition enhances the project's flexibility for users who want to retrain the CRNN model on custom datasets.
This commit is contained in:
parent
a4e6544db1
commit
a21460ef16
|
@ -0,0 +1,647 @@
|
||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Используемое устройство: cuda\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"import os\n",
|
||||||
|
"import string\n",
|
||||||
|
"import torch\n",
|
||||||
|
"import torch.nn as nn\n",
|
||||||
|
"from torch.utils.data import Dataset, DataLoader\n",
|
||||||
|
"from torchvision import transforms\n",
|
||||||
|
"from PIL import Image\n",
|
||||||
|
"from tqdm import tqdm\n",
|
||||||
|
"\n",
|
||||||
|
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
||||||
|
"print(f\"Используемое устройство: {device}\")\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"ALPHABET = '-ABEKMHOPCTYX0123456789'\n",
|
||||||
|
"CHAR_TO_IDX = {char: idx + 1 for idx, char in enumerate(ALPHABET)} # 0 будет использоваться для CTC blank\n",
|
||||||
|
"IDX_TO_CHAR = {idx + 1: char for idx, char in enumerate(ALPHABET)}\n",
|
||||||
|
"\n",
|
||||||
|
"NUM_CLASSES = len(ALPHABET) + 1\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"class LicensePlateDataset(Dataset):\n",
|
||||||
|
" def __init__(self, root_dir, transform=None):\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" Args:\n",
|
||||||
|
" root_dir (string): Путь к директории с данными (train, val, test)\n",
|
||||||
|
" transform (callable, optional): Трансформации для изображений\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" self.root_dir = root_dir\n",
|
||||||
|
" self.img_dir = os.path.join(root_dir, 'img')\n",
|
||||||
|
" self.transform = transform\n",
|
||||||
|
" self.images = [img for img in os.listdir(self.img_dir) if img.endswith(('.png', '.jpg', '.jpeg'))]\n",
|
||||||
|
" \n",
|
||||||
|
" def __len__(self):\n",
|
||||||
|
" return len(self.images)\n",
|
||||||
|
" \n",
|
||||||
|
" def __getitem__(self, idx):\n",
|
||||||
|
" img_name = self.images[idx]\n",
|
||||||
|
" img_path = os.path.join(self.img_dir, img_name)\n",
|
||||||
|
" image = Image.open(img_path).convert('L') \n",
|
||||||
|
" if self.transform:\n",
|
||||||
|
" image = self.transform(image)\n",
|
||||||
|
" \n",
|
||||||
|
" label_str = os.path.splitext(img_name)[0].upper()\n",
|
||||||
|
" label = [CHAR_TO_IDX[char] for char in label_str if char in CHAR_TO_IDX]\n",
|
||||||
|
" label = torch.tensor(label, dtype=torch.long)\n",
|
||||||
|
" \n",
|
||||||
|
" return image, label\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"transform = transforms.Compose([\n",
|
||||||
|
" transforms.Resize((32, 128)),\n",
|
||||||
|
" transforms.ToTensor(),\n",
|
||||||
|
" transforms.Normalize((0.5,), (0.5,))\n",
|
||||||
|
"])\n",
|
||||||
|
"\n",
|
||||||
|
"import random\n",
|
||||||
|
"import torchvision.transforms.functional as F\n",
|
||||||
|
"from torchvision import transforms\n",
|
||||||
|
"\n",
|
||||||
|
"def random_padding(img):\n",
|
||||||
|
" pad_left = random.randint(5, 15)\n",
|
||||||
|
" pad_right = random.randint(5, 15)\n",
|
||||||
|
" pad_top = random.randint(5, 15)\n",
|
||||||
|
" pad_bottom = random.randint(5, 15)\n",
|
||||||
|
" return F.pad(img, (pad_left, pad_top, pad_right, pad_bottom), fill=0)\n",
|
||||||
|
"\n",
|
||||||
|
"# Обновленный transform\n",
|
||||||
|
"transform_train = transforms.Compose([\n",
|
||||||
|
" transforms.Lambda(random_padding),\n",
|
||||||
|
" transforms.Resize((32, 128)), # Изменяем размер изображения\n",
|
||||||
|
" transforms.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)), # Сдвиг, масштабирование, поворот\n",
|
||||||
|
" transforms.ColorJitter(brightness=0.2, contrast=0.2), # Изменение яркости и контраста\n",
|
||||||
|
" transforms.RandomPerspective(distortion_scale=0.2, p=0.5), # Перспективные искажения\n",
|
||||||
|
" transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)), # Размытие\n",
|
||||||
|
" transforms.ToTensor(), # Преобразуем в тензор\n",
|
||||||
|
" transforms.Normalize((0.5,), (0.5,)), # Нормализация\n",
|
||||||
|
"])\n",
|
||||||
|
"\n",
|
||||||
|
"def collate_fn(batch):\n",
|
||||||
|
" images, labels = zip(*batch)\n",
|
||||||
|
" images = torch.stack(images, 0)\n",
|
||||||
|
" \n",
|
||||||
|
" # Соединяем все метки в один тензор\n",
|
||||||
|
" label_lengths = torch.tensor([len(label) for label in labels], dtype=torch.long)\n",
|
||||||
|
" labels_concat = torch.cat(labels)\n",
|
||||||
|
" \n",
|
||||||
|
" return images, labels_concat, label_lengths\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 26,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"train_dataset = LicensePlateDataset(root_dir='dataset-ocr-new/train', transform=transform_train)\n",
|
||||||
|
"val_dataset = LicensePlateDataset(root_dir='dataset-ocr-new/val', transform=transform)\n",
|
||||||
|
"test_dataset = LicensePlateDataset(root_dir='dataset-ocr-new/test', transform=transform)\n",
|
||||||
|
"test_10 = LicensePlateDataset(root_dir=r'dataset-ocr\\fine-tune-val', transform=transform)\n",
|
||||||
|
"\n",
|
||||||
|
"batch_size = 64\n",
|
||||||
|
"\n",
|
||||||
|
"train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, collate_fn=lambda x: collate_fn(x))\n",
|
||||||
|
"val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0, collate_fn=lambda x: collate_fn(x))\n",
|
||||||
|
"test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0, collate_fn=lambda x: collate_fn(x))\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 23,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"class CRNN(nn.Module):\n",
|
||||||
|
" def __init__(self, num_classes):\n",
|
||||||
|
" super(CRNN, self).__init__()\n",
|
||||||
|
" \n",
|
||||||
|
" # CNN часть\n",
|
||||||
|
" self.cnn = nn.Sequential(\n",
|
||||||
|
" nn.Conv2d(1, 64, kernel_size=3, padding=1), # (batch, 64, 32, 128)\n",
|
||||||
|
" nn.ReLU(inplace=True),\n",
|
||||||
|
" nn.MaxPool2d(2, 2), # (batch, 64, 16, 64)\n",
|
||||||
|
" \n",
|
||||||
|
" nn.Conv2d(64, 128, kernel_size=3, padding=1), # (batch, 128, 16, 64)\n",
|
||||||
|
" nn.ReLU(inplace=True),\n",
|
||||||
|
" nn.MaxPool2d(2, 2), # (batch, 128, 8, 32)\n",
|
||||||
|
" \n",
|
||||||
|
" nn.Conv2d(128, 256, kernel_size=3, padding=1), # (batch, 256, 8, 32)\n",
|
||||||
|
" nn.ReLU(inplace=True),\n",
|
||||||
|
" nn.BatchNorm2d(256),\n",
|
||||||
|
" \n",
|
||||||
|
" nn.Conv2d(256, 256, kernel_size=3, padding=1), # (batch, 256, 8, 32)\n",
|
||||||
|
" nn.ReLU(inplace=True),\n",
|
||||||
|
" nn.MaxPool2d((2,1), (2,1)), # (batch, 256, 4, 32)\n",
|
||||||
|
" \n",
|
||||||
|
" nn.Conv2d(256, 512, kernel_size=3, padding=1), # (batch, 512, 4, 32)\n",
|
||||||
|
" nn.ReLU(inplace=True),\n",
|
||||||
|
" nn.BatchNorm2d(512),\n",
|
||||||
|
" \n",
|
||||||
|
" nn.Conv2d(512, 512, kernel_size=3, padding=1), # (batch, 512, 4, 32)\n",
|
||||||
|
" nn.ReLU(inplace=True),\n",
|
||||||
|
" nn.MaxPool2d((2,1), (2,1)), # (batch, 512, 2, 32)\n",
|
||||||
|
" )\n",
|
||||||
|
" \n",
|
||||||
|
" # RNN часть\n",
|
||||||
|
" self.linear1 = nn.Linear(512 * 2, 256)\n",
|
||||||
|
" self.relu = nn.ReLU(inplace=True)\n",
|
||||||
|
" self.lstm = nn.LSTM(256, 256, bidirectional=True, batch_first=True)\n",
|
||||||
|
" self.linear2 = nn.Linear(512, num_classes)\n",
|
||||||
|
" \n",
|
||||||
|
" def forward(self, x):\n",
|
||||||
|
" # CNN часть\n",
|
||||||
|
" conv = self.cnn(x) # (batch, 512, 2, 32)\n",
|
||||||
|
" \n",
|
||||||
|
" # Перестановка и изменение формы для RNN\n",
|
||||||
|
" conv = conv.permute(0, 3, 1, 2) # (batch, width=32, channels=512, height=2)\n",
|
||||||
|
" conv = conv.view(conv.size(0), conv.size(1), -1) # (batch, 32, 512*2)\n",
|
||||||
|
" \n",
|
||||||
|
" # RNN часть\n",
|
||||||
|
" out = self.linear1(conv) # (batch, 32, 256)\n",
|
||||||
|
" out = self.relu(out) # (batch, 32, 256)\n",
|
||||||
|
" out, _ = self.lstm(out) # (batch, 32, 512)\n",
|
||||||
|
" out = self.linear2(out) # (batch, 32, num_classes)\n",
|
||||||
|
" \n",
|
||||||
|
" # Перестановка для CTC loss\n",
|
||||||
|
" out = out.permute(1, 0, 2) # (32, batch, num_classes)\n",
|
||||||
|
" return out\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 10,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def train(model, loader, optimizer, criterion, device):\n",
|
||||||
|
" model.train()\n",
|
||||||
|
" epoch_loss = 0\n",
|
||||||
|
" for images, labels, label_lengths in tqdm(loader, desc='Training'):\n",
|
||||||
|
" images = images.to(device)\n",
|
||||||
|
" labels = labels.to(device)\n",
|
||||||
|
" label_lengths = label_lengths.to(device)\n",
|
||||||
|
" \n",
|
||||||
|
" optimizer.zero_grad()\n",
|
||||||
|
" outputs = model(images) # (seq_len, batch, num_classes)\n",
|
||||||
|
" \n",
|
||||||
|
" # Определяем длину входных последовательностей (последний слой)\n",
|
||||||
|
" input_lengths = torch.full(size=(outputs.size(1),), fill_value=outputs.size(0), dtype=torch.long).to(device)\n",
|
||||||
|
" \n",
|
||||||
|
" loss = criterion(outputs.log_softmax(2), labels, input_lengths, label_lengths)\n",
|
||||||
|
" loss.backward()\n",
|
||||||
|
" optimizer.step()\n",
|
||||||
|
" \n",
|
||||||
|
" epoch_loss += loss.item()\n",
|
||||||
|
" return epoch_loss / len(loader)\n",
|
||||||
|
"\n",
|
||||||
|
"def validate(model, loader, criterion, device):\n",
|
||||||
|
" model.eval()\n",
|
||||||
|
" epoch_loss = 0\n",
|
||||||
|
" with torch.no_grad():\n",
|
||||||
|
" for images, labels, label_lengths in tqdm(loader, desc='Validation'):\n",
|
||||||
|
" images = images.to(device)\n",
|
||||||
|
" labels = labels.to(device)\n",
|
||||||
|
" label_lengths = label_lengths.to(device)\n",
|
||||||
|
" \n",
|
||||||
|
" outputs = model(images)\n",
|
||||||
|
" \n",
|
||||||
|
" input_lengths = torch.full(size=(outputs.size(1),), fill_value=outputs.size(0), dtype=torch.long).to(device)\n",
|
||||||
|
" \n",
|
||||||
|
" loss = criterion(outputs.log_softmax(2), labels, input_lengths, label_lengths)\n",
|
||||||
|
" epoch_loss += loss.item()\n",
|
||||||
|
" return epoch_loss / len(loader)\n",
|
||||||
|
"\n",
|
||||||
|
"def decode_predictions(preds, blank=0):\n",
|
||||||
|
" preds = preds.argmax(2) # (seq_len, batch)\n",
|
||||||
|
" preds = preds.permute(1, 0) # (batch, seq_len)\n",
|
||||||
|
" decoded = []\n",
|
||||||
|
" for pred in preds:\n",
|
||||||
|
" pred = pred.tolist()\n",
|
||||||
|
" decoded_seq = []\n",
|
||||||
|
" previous = blank\n",
|
||||||
|
" for p in pred:\n",
|
||||||
|
" if p != previous and p != blank:\n",
|
||||||
|
" decoded_seq.append(IDX_TO_CHAR.get(p, ''))\n",
|
||||||
|
" previous = p\n",
|
||||||
|
" decoded.append(''.join(decoded_seq))\n",
|
||||||
|
" return decoded\n",
|
||||||
|
"\n",
|
||||||
|
"def evaluate(model, loader, device):\n",
|
||||||
|
" model.eval()\n",
|
||||||
|
" correct = 0\n",
|
||||||
|
" total = 0\n",
|
||||||
|
" with torch.no_grad():\n",
|
||||||
|
" for images, labels, label_lengths in tqdm(loader, desc='Testing'):\n",
|
||||||
|
" images = images.to(device)\n",
|
||||||
|
" outputs = model(images)\n",
|
||||||
|
" preds = decode_predictions(outputs)\n",
|
||||||
|
" \n",
|
||||||
|
" batch_size = images.size(0)\n",
|
||||||
|
" start = 0\n",
|
||||||
|
" for i in range(batch_size):\n",
|
||||||
|
" length = label_lengths[i]\n",
|
||||||
|
" true_label = ''.join([IDX_TO_CHAR.get(idx.item(), '') for idx in labels[start:start+length]])\n",
|
||||||
|
" start += length\n",
|
||||||
|
" pred_label = preds[i]\n",
|
||||||
|
" if pred_label == true_label:\n",
|
||||||
|
" correct += 1\n",
|
||||||
|
" total += 1\n",
|
||||||
|
" accuracy = correct / total * 100\n",
|
||||||
|
" return accuracy\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 27,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"model = CRNN(num_classes=NUM_CLASSES).to(device)\n",
|
||||||
|
"\n",
|
||||||
|
"criterion = nn.CTCLoss(blank=0, zero_infinity=True)\n",
|
||||||
|
"optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 29,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 1/10\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Training: 100%|██████████| 569/569 [01:20<00:00, 7.03it/s]\n",
|
||||||
|
"Validation: 100%|██████████| 72/72 [00:04<00:00, 16.47it/s]\n",
|
||||||
|
"Testing: 100%|██████████| 72/72 [00:05<00:00, 14.37it/s]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Train Loss: 0.0850 | Val Loss: 0.0767 | Accuracy: 96.73%\n",
|
||||||
|
"Модель с лучшей потерей сохранена!\n",
|
||||||
|
"Модель с лучшей точностью сохранена!\n",
|
||||||
|
"Epoch 2/10\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Training: 100%|██████████| 569/569 [01:20<00:00, 7.07it/s]\n",
|
||||||
|
"Validation: 100%|██████████| 72/72 [00:04<00:00, 17.92it/s]\n",
|
||||||
|
"Testing: 100%|██████████| 72/72 [00:04<00:00, 14.78it/s]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Train Loss: 0.0829 | Val Loss: 0.0706 | Accuracy: 97.63%\n",
|
||||||
|
"Модель с лучшей потерей сохранена!\n",
|
||||||
|
"Модель с лучшей точностью сохранена!\n",
|
||||||
|
"Epoch 3/10\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Training: 100%|██████████| 569/569 [01:20<00:00, 7.09it/s]\n",
|
||||||
|
"Validation: 100%|██████████| 72/72 [00:04<00:00, 17.20it/s]\n",
|
||||||
|
"Testing: 100%|██████████| 72/72 [00:04<00:00, 14.51it/s]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Train Loss: 0.0796 | Val Loss: 0.0763 | Accuracy: 96.81%\n",
|
||||||
|
"Epoch 4/10\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Training: 100%|██████████| 569/569 [01:19<00:00, 7.11it/s]\n",
|
||||||
|
"Validation: 100%|██████████| 72/72 [00:04<00:00, 17.50it/s]\n",
|
||||||
|
"Testing: 100%|██████████| 72/72 [00:04<00:00, 15.43it/s]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Train Loss: 0.0805 | Val Loss: 0.0732 | Accuracy: 97.72%\n",
|
||||||
|
"Модель с лучшей точностью сохранена!\n",
|
||||||
|
"Epoch 5/10\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Training: 100%|██████████| 569/569 [01:19<00:00, 7.17it/s]\n",
|
||||||
|
"Validation: 100%|██████████| 72/72 [00:04<00:00, 17.76it/s]\n",
|
||||||
|
"Testing: 100%|██████████| 72/72 [00:04<00:00, 14.62it/s]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Train Loss: 0.0787 | Val Loss: 0.0716 | Accuracy: 97.76%\n",
|
||||||
|
"Модель с лучшей точностью сохранена!\n",
|
||||||
|
"Epoch 6/10\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Training: 100%|██████████| 569/569 [01:19<00:00, 7.13it/s]\n",
|
||||||
|
"Validation: 100%|██████████| 72/72 [00:03<00:00, 18.01it/s]\n",
|
||||||
|
"Testing: 100%|██████████| 72/72 [00:04<00:00, 14.58it/s]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Train Loss: 0.0775 | Val Loss: 0.0732 | Accuracy: 97.63%\n",
|
||||||
|
"Epoch 7/10\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Training: 100%|██████████| 569/569 [01:18<00:00, 7.21it/s]\n",
|
||||||
|
"Validation: 100%|██████████| 72/72 [00:04<00:00, 17.52it/s]\n",
|
||||||
|
"Testing: 100%|██████████| 72/72 [00:04<00:00, 14.53it/s]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Train Loss: 0.0731 | Val Loss: 0.0746 | Accuracy: 97.58%\n",
|
||||||
|
"Epoch 8/10\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Training: 100%|██████████| 569/569 [01:19<00:00, 7.15it/s]\n",
|
||||||
|
"Validation: 100%|██████████| 72/72 [00:04<00:00, 17.64it/s]\n",
|
||||||
|
"Testing: 100%|██████████| 72/72 [00:04<00:00, 14.43it/s]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Train Loss: 0.0745 | Val Loss: 0.0753 | Accuracy: 96.77%\n",
|
||||||
|
"Epoch 9/10\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Training: 100%|██████████| 569/569 [01:19<00:00, 7.19it/s]\n",
|
||||||
|
"Validation: 100%|██████████| 72/72 [00:04<00:00, 17.87it/s]\n",
|
||||||
|
"Testing: 100%|██████████| 72/72 [00:05<00:00, 14.38it/s]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Train Loss: 0.0734 | Val Loss: 0.0742 | Accuracy: 97.80%\n",
|
||||||
|
"Модель с лучшей точностью сохранена!\n",
|
||||||
|
"Epoch 10/10\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Training: 100%|██████████| 569/569 [01:19<00:00, 7.15it/s]\n",
|
||||||
|
"Validation: 100%|██████████| 72/72 [00:04<00:00, 17.24it/s]\n",
|
||||||
|
"Testing: 100%|██████████| 72/72 [00:04<00:00, 14.55it/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Train Loss: 0.0767 | Val Loss: 0.0776 | Accuracy: 96.95%\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"num_epochs = 10\n",
|
||||||
|
"best_val_loss = float('inf')\n",
|
||||||
|
"best_accuracy = 0.0\n",
|
||||||
|
"\n",
|
||||||
|
"for epoch in range(1, num_epochs + 1):\n",
|
||||||
|
" print(f'Epoch {epoch}/{num_epochs}')\n",
|
||||||
|
" \n",
|
||||||
|
" train_loss = train(model, train_loader, optimizer, criterion, device)\n",
|
||||||
|
" val_loss = validate(model, val_loader, criterion, device)\n",
|
||||||
|
" accuracy = evaluate(model, val_loader, device)\n",
|
||||||
|
" \n",
|
||||||
|
" print(f'Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Accuracy: {accuracy:.2f}%')\n",
|
||||||
|
" \n",
|
||||||
|
" if val_loss < best_val_loss:\n",
|
||||||
|
" best_val_loss = val_loss\n",
|
||||||
|
" torch.save(model.state_dict(), 'best_loss_model_3.pth')\n",
|
||||||
|
" print('Модель с лучшей потерей сохранена!')\n",
|
||||||
|
" \n",
|
||||||
|
" if accuracy > best_accuracy:\n",
|
||||||
|
" best_accuracy = accuracy\n",
|
||||||
|
" torch.save(model.state_dict(), 'best_accuracy_model_3.pth')\n",
|
||||||
|
" print('Модель с лучшей точностью сохранена!')\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 30,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Testing: 100%|██████████| 72/72 [00:04<00:00, 14.55it/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Точность на тестовом наборе: 96.68%\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"#model.load_state_dict(torch.load('models/best_accuracy_model_3.pth'))\n",
|
||||||
|
"test_accuracy = evaluate(model, test_loader, device)\n",
|
||||||
|
"print(f'Точность на тестовом наборе: {test_accuracy:.2f}%')"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 31,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Processing Images: 100%|██████████| 8/8 [00:00<00:00, 54.74it/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"A023TY97.png: A023TY97\n",
|
||||||
|
"A413YE97.png: A413YE97\n",
|
||||||
|
"B642OT97.png: B642OT97\n",
|
||||||
|
"H702TH97.png: H702TH97\n",
|
||||||
|
"K263CO97.png: K263CO97\n",
|
||||||
|
"O571KT99.png: O571KT99\n",
|
||||||
|
"T829MK97.png: T829MK97\n",
|
||||||
|
"Y726PA97.png: Y726PA97\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"import os\n",
|
||||||
|
"from PIL import Image\n",
|
||||||
|
"import torch\n",
|
||||||
|
"from torchvision import transforms\n",
|
||||||
|
"from tqdm import tqdm \n",
|
||||||
|
"\n",
|
||||||
|
"def recognize_license_plates(model, folder_path, transform, device):\n",
|
||||||
|
" model.eval() \n",
|
||||||
|
" images = [img for img in os.listdir(folder_path) if img.endswith(('.png', '.jpg', '.jpeg'))]\n",
|
||||||
|
" \n",
|
||||||
|
" results = {}\n",
|
||||||
|
" \n",
|
||||||
|
" for img_name in tqdm(images, desc=\"Processing Images\"):\n",
|
||||||
|
" img_path = os.path.join(folder_path, img_name)\n",
|
||||||
|
" image = Image.open(img_path).convert('L') \n",
|
||||||
|
" \n",
|
||||||
|
" image_tensor = transform(image).unsqueeze(0).to(device)\n",
|
||||||
|
" \n",
|
||||||
|
" with torch.no_grad():\n",
|
||||||
|
" output = model(image_tensor) # (seq_len, batch, num_classes)\n",
|
||||||
|
" \n",
|
||||||
|
" decoded_text = decode_predictions(output)\n",
|
||||||
|
" \n",
|
||||||
|
" results[img_name] = decoded_text[0]\n",
|
||||||
|
" \n",
|
||||||
|
" return results\n",
|
||||||
|
"\n",
|
||||||
|
"folder_path = 'dataset-ocr/fine-tune-train/img_plate_image' # Путь к папке с изображениями\n",
|
||||||
|
"#model.load_state_dict(torch.load('best_accuracy_model_2.pth'))\n",
|
||||||
|
"#model.to(device)\n",
|
||||||
|
"\n",
|
||||||
|
"results = recognize_license_plates(model, folder_path, transform, device)\n",
|
||||||
|
"\n",
|
||||||
|
"for img_name, text in results.items():\n",
|
||||||
|
" print(f\"{img_name}: {text}\")\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": ".venv",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.9.13"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
Loading…
Reference in New Issue