feat: add train_crnn.ipynb for CRNN model training
Added a Jupyter Notebook `train_crnn.ipynb` to the repository for training the CRNN model. This notebook includes: - Steps for preparing the dataset. - Training pipeline for the CRNN architecture. - Evaluation of the trained model. This addition enhances the project's flexibility for users who want to retrain the CRNN model on custom datasets.
This commit is contained in:
		
							parent
							
								
									a4e6544db1
								
							
						
					
					
						commit
						a21460ef16
					
				| 
						 | 
				
			
			@ -0,0 +1,647 @@
 | 
			
		|||
{
 | 
			
		||||
 "cells": [
 | 
			
		||||
  {
 | 
			
		||||
   "cell_type": "code",
 | 
			
		||||
   "execution_count": 1,
 | 
			
		||||
   "metadata": {},
 | 
			
		||||
   "outputs": [
 | 
			
		||||
    {
 | 
			
		||||
     "name": "stdout",
 | 
			
		||||
     "output_type": "stream",
 | 
			
		||||
     "text": [
 | 
			
		||||
      "Используемое устройство: cuda\n"
 | 
			
		||||
     ]
 | 
			
		||||
    }
 | 
			
		||||
   ],
 | 
			
		||||
   "source": [
 | 
			
		||||
    "import os\n",
 | 
			
		||||
    "import string\n",
 | 
			
		||||
    "import torch\n",
 | 
			
		||||
    "import torch.nn as nn\n",
 | 
			
		||||
    "from torch.utils.data import Dataset, DataLoader\n",
 | 
			
		||||
    "from torchvision import transforms\n",
 | 
			
		||||
    "from PIL import Image\n",
 | 
			
		||||
    "from tqdm import tqdm\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
 | 
			
		||||
    "print(f\"Используемое устройство: {device}\")\n"
 | 
			
		||||
   ]
 | 
			
		||||
  },
 | 
			
		||||
  {
 | 
			
		||||
   "cell_type": "code",
 | 
			
		||||
   "execution_count": 2,
 | 
			
		||||
   "metadata": {},
 | 
			
		||||
   "outputs": [],
 | 
			
		||||
   "source": [
 | 
			
		||||
    "ALPHABET = '-ABEKMHOPCTYX0123456789'\n",
 | 
			
		||||
    "CHAR_TO_IDX = {char: idx + 1 for idx, char in enumerate(ALPHABET)}  # 0 будет использоваться для CTC blank\n",
 | 
			
		||||
    "IDX_TO_CHAR = {idx + 1: char for idx, char in enumerate(ALPHABET)}\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "NUM_CLASSES = len(ALPHABET) + 1\n"
 | 
			
		||||
   ]
 | 
			
		||||
  },
 | 
			
		||||
  {
 | 
			
		||||
   "cell_type": "code",
 | 
			
		||||
   "execution_count": 3,
 | 
			
		||||
   "metadata": {},
 | 
			
		||||
   "outputs": [],
 | 
			
		||||
   "source": [
 | 
			
		||||
    "class LicensePlateDataset(Dataset):\n",
 | 
			
		||||
    "    def __init__(self, root_dir, transform=None):\n",
 | 
			
		||||
    "        \"\"\"\n",
 | 
			
		||||
    "        Args:\n",
 | 
			
		||||
    "            root_dir (string): Путь к директории с данными (train, val, test)\n",
 | 
			
		||||
    "            transform (callable, optional): Трансформации для изображений\n",
 | 
			
		||||
    "        \"\"\"\n",
 | 
			
		||||
    "        self.root_dir = root_dir\n",
 | 
			
		||||
    "        self.img_dir = os.path.join(root_dir, 'img')\n",
 | 
			
		||||
    "        self.transform = transform\n",
 | 
			
		||||
    "        self.images = [img for img in os.listdir(self.img_dir) if img.endswith(('.png', '.jpg', '.jpeg'))]\n",
 | 
			
		||||
    "    \n",
 | 
			
		||||
    "    def __len__(self):\n",
 | 
			
		||||
    "        return len(self.images)\n",
 | 
			
		||||
    "    \n",
 | 
			
		||||
    "    def __getitem__(self, idx):\n",
 | 
			
		||||
    "        img_name = self.images[idx]\n",
 | 
			
		||||
    "        img_path = os.path.join(self.img_dir, img_name)\n",
 | 
			
		||||
    "        image = Image.open(img_path).convert('L') \n",
 | 
			
		||||
    "        if self.transform:\n",
 | 
			
		||||
    "            image = self.transform(image)\n",
 | 
			
		||||
    "        \n",
 | 
			
		||||
    "        label_str = os.path.splitext(img_name)[0].upper()\n",
 | 
			
		||||
    "        label = [CHAR_TO_IDX[char] for char in label_str if char in CHAR_TO_IDX]\n",
 | 
			
		||||
    "        label = torch.tensor(label, dtype=torch.long)\n",
 | 
			
		||||
    "        \n",
 | 
			
		||||
    "        return image, label\n"
 | 
			
		||||
   ]
 | 
			
		||||
  },
 | 
			
		||||
  {
 | 
			
		||||
   "cell_type": "code",
 | 
			
		||||
   "execution_count": 6,
 | 
			
		||||
   "metadata": {},
 | 
			
		||||
   "outputs": [],
 | 
			
		||||
   "source": [
 | 
			
		||||
    "transform = transforms.Compose([\n",
 | 
			
		||||
    "    transforms.Resize((32, 128)),\n",
 | 
			
		||||
    "    transforms.ToTensor(),\n",
 | 
			
		||||
    "    transforms.Normalize((0.5,), (0.5,))\n",
 | 
			
		||||
    "])\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "import random\n",
 | 
			
		||||
    "import torchvision.transforms.functional as F\n",
 | 
			
		||||
    "from torchvision import transforms\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "def random_padding(img):\n",
 | 
			
		||||
    "    pad_left = random.randint(5, 15)\n",
 | 
			
		||||
    "    pad_right = random.randint(5, 15)\n",
 | 
			
		||||
    "    pad_top = random.randint(5, 15)\n",
 | 
			
		||||
    "    pad_bottom = random.randint(5, 15)\n",
 | 
			
		||||
    "    return F.pad(img, (pad_left, pad_top, pad_right, pad_bottom), fill=0)\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "# Обновленный transform\n",
 | 
			
		||||
    "transform_train = transforms.Compose([\n",
 | 
			
		||||
    "    transforms.Lambda(random_padding),\n",
 | 
			
		||||
    "    transforms.Resize((32, 128)),  # Изменяем размер изображения\n",
 | 
			
		||||
    "    transforms.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)),  # Сдвиг, масштабирование, поворот\n",
 | 
			
		||||
    "    transforms.ColorJitter(brightness=0.2, contrast=0.2),  # Изменение яркости и контраста\n",
 | 
			
		||||
    "    transforms.RandomPerspective(distortion_scale=0.2, p=0.5),  # Перспективные искажения\n",
 | 
			
		||||
    "    transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),  # Размытие\n",
 | 
			
		||||
    "    transforms.ToTensor(),  # Преобразуем в тензор\n",
 | 
			
		||||
    "    transforms.Normalize((0.5,), (0.5,)),  # Нормализация\n",
 | 
			
		||||
    "])\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "def collate_fn(batch):\n",
 | 
			
		||||
    "    images, labels = zip(*batch)\n",
 | 
			
		||||
    "    images = torch.stack(images, 0)\n",
 | 
			
		||||
    "    \n",
 | 
			
		||||
    "    # Соединяем все метки в один тензор\n",
 | 
			
		||||
    "    label_lengths = torch.tensor([len(label) for label in labels], dtype=torch.long)\n",
 | 
			
		||||
    "    labels_concat = torch.cat(labels)\n",
 | 
			
		||||
    "    \n",
 | 
			
		||||
    "    return images, labels_concat, label_lengths\n"
 | 
			
		||||
   ]
 | 
			
		||||
  },
 | 
			
		||||
  {
 | 
			
		||||
   "cell_type": "code",
 | 
			
		||||
   "execution_count": 26,
 | 
			
		||||
   "metadata": {},
 | 
			
		||||
   "outputs": [],
 | 
			
		||||
   "source": [
 | 
			
		||||
    "train_dataset = LicensePlateDataset(root_dir='dataset-ocr-new/train', transform=transform_train)\n",
 | 
			
		||||
    "val_dataset = LicensePlateDataset(root_dir='dataset-ocr-new/val', transform=transform)\n",
 | 
			
		||||
    "test_dataset = LicensePlateDataset(root_dir='dataset-ocr-new/test', transform=transform)\n",
 | 
			
		||||
    "test_10 = LicensePlateDataset(root_dir=r'dataset-ocr\\fine-tune-val', transform=transform)\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "batch_size = 64\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, collate_fn=lambda x: collate_fn(x))\n",
 | 
			
		||||
    "val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0, collate_fn=lambda x: collate_fn(x))\n",
 | 
			
		||||
    "test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0, collate_fn=lambda x: collate_fn(x))\n"
 | 
			
		||||
   ]
 | 
			
		||||
  },
 | 
			
		||||
  {
 | 
			
		||||
   "cell_type": "code",
 | 
			
		||||
   "execution_count": 23,
 | 
			
		||||
   "metadata": {},
 | 
			
		||||
   "outputs": [],
 | 
			
		||||
   "source": [
 | 
			
		||||
    "class CRNN(nn.Module):\n",
 | 
			
		||||
    "    def __init__(self, num_classes):\n",
 | 
			
		||||
    "        super(CRNN, self).__init__()\n",
 | 
			
		||||
    "        \n",
 | 
			
		||||
    "        # CNN часть\n",
 | 
			
		||||
    "        self.cnn = nn.Sequential(\n",
 | 
			
		||||
    "            nn.Conv2d(1, 64, kernel_size=3, padding=1),  # (batch, 64, 32, 128)\n",
 | 
			
		||||
    "            nn.ReLU(inplace=True),\n",
 | 
			
		||||
    "            nn.MaxPool2d(2, 2),  # (batch, 64, 16, 64)\n",
 | 
			
		||||
    "            \n",
 | 
			
		||||
    "            nn.Conv2d(64, 128, kernel_size=3, padding=1),  # (batch, 128, 16, 64)\n",
 | 
			
		||||
    "            nn.ReLU(inplace=True),\n",
 | 
			
		||||
    "            nn.MaxPool2d(2, 2),  # (batch, 128, 8, 32)\n",
 | 
			
		||||
    "            \n",
 | 
			
		||||
    "            nn.Conv2d(128, 256, kernel_size=3, padding=1),  # (batch, 256, 8, 32)\n",
 | 
			
		||||
    "            nn.ReLU(inplace=True),\n",
 | 
			
		||||
    "            nn.BatchNorm2d(256),\n",
 | 
			
		||||
    "            \n",
 | 
			
		||||
    "            nn.Conv2d(256, 256, kernel_size=3, padding=1),  # (batch, 256, 8, 32)\n",
 | 
			
		||||
    "            nn.ReLU(inplace=True),\n",
 | 
			
		||||
    "            nn.MaxPool2d((2,1), (2,1)),  # (batch, 256, 4, 32)\n",
 | 
			
		||||
    "            \n",
 | 
			
		||||
    "            nn.Conv2d(256, 512, kernel_size=3, padding=1),  # (batch, 512, 4, 32)\n",
 | 
			
		||||
    "            nn.ReLU(inplace=True),\n",
 | 
			
		||||
    "            nn.BatchNorm2d(512),\n",
 | 
			
		||||
    "            \n",
 | 
			
		||||
    "            nn.Conv2d(512, 512, kernel_size=3, padding=1),  # (batch, 512, 4, 32)\n",
 | 
			
		||||
    "            nn.ReLU(inplace=True),\n",
 | 
			
		||||
    "            nn.MaxPool2d((2,1), (2,1)),  # (batch, 512, 2, 32)\n",
 | 
			
		||||
    "        )\n",
 | 
			
		||||
    "        \n",
 | 
			
		||||
    "        # RNN часть\n",
 | 
			
		||||
    "        self.linear1 = nn.Linear(512 * 2, 256)\n",
 | 
			
		||||
    "        self.relu = nn.ReLU(inplace=True)\n",
 | 
			
		||||
    "        self.lstm = nn.LSTM(256, 256, bidirectional=True, batch_first=True)\n",
 | 
			
		||||
    "        self.linear2 = nn.Linear(512, num_classes)\n",
 | 
			
		||||
    "        \n",
 | 
			
		||||
    "    def forward(self, x):\n",
 | 
			
		||||
    "        # CNN часть\n",
 | 
			
		||||
    "        conv = self.cnn(x)  # (batch, 512, 2, 32)\n",
 | 
			
		||||
    "        \n",
 | 
			
		||||
    "        # Перестановка и изменение формы для RNN\n",
 | 
			
		||||
    "        conv = conv.permute(0, 3, 1, 2)  # (batch, width=32, channels=512, height=2)\n",
 | 
			
		||||
    "        conv = conv.view(conv.size(0), conv.size(1), -1)  # (batch, 32, 512*2)\n",
 | 
			
		||||
    "        \n",
 | 
			
		||||
    "        # RNN часть\n",
 | 
			
		||||
    "        out = self.linear1(conv)  # (batch, 32, 256)\n",
 | 
			
		||||
    "        out = self.relu(out)      # (batch, 32, 256)\n",
 | 
			
		||||
    "        out, _ = self.lstm(out)   # (batch, 32, 512)\n",
 | 
			
		||||
    "        out = self.linear2(out)   # (batch, 32, num_classes)\n",
 | 
			
		||||
    "        \n",
 | 
			
		||||
    "        # Перестановка для CTC loss\n",
 | 
			
		||||
    "        out = out.permute(1, 0, 2)  # (32, batch, num_classes)\n",
 | 
			
		||||
    "        return out\n"
 | 
			
		||||
   ]
 | 
			
		||||
  },
 | 
			
		||||
  {
 | 
			
		||||
   "cell_type": "code",
 | 
			
		||||
   "execution_count": 10,
 | 
			
		||||
   "metadata": {},
 | 
			
		||||
   "outputs": [],
 | 
			
		||||
   "source": [
 | 
			
		||||
    "def train(model, loader, optimizer, criterion, device):\n",
 | 
			
		||||
    "    model.train()\n",
 | 
			
		||||
    "    epoch_loss = 0\n",
 | 
			
		||||
    "    for images, labels, label_lengths in tqdm(loader, desc='Training'):\n",
 | 
			
		||||
    "        images = images.to(device)\n",
 | 
			
		||||
    "        labels = labels.to(device)\n",
 | 
			
		||||
    "        label_lengths = label_lengths.to(device)\n",
 | 
			
		||||
    "        \n",
 | 
			
		||||
    "        optimizer.zero_grad()\n",
 | 
			
		||||
    "        outputs = model(images)  # (seq_len, batch, num_classes)\n",
 | 
			
		||||
    "        \n",
 | 
			
		||||
    "        # Определяем длину входных последовательностей (последний слой)\n",
 | 
			
		||||
    "        input_lengths = torch.full(size=(outputs.size(1),), fill_value=outputs.size(0), dtype=torch.long).to(device)\n",
 | 
			
		||||
    "        \n",
 | 
			
		||||
    "        loss = criterion(outputs.log_softmax(2), labels, input_lengths, label_lengths)\n",
 | 
			
		||||
    "        loss.backward()\n",
 | 
			
		||||
    "        optimizer.step()\n",
 | 
			
		||||
    "        \n",
 | 
			
		||||
    "        epoch_loss += loss.item()\n",
 | 
			
		||||
    "    return epoch_loss / len(loader)\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "def validate(model, loader, criterion, device):\n",
 | 
			
		||||
    "    model.eval()\n",
 | 
			
		||||
    "    epoch_loss = 0\n",
 | 
			
		||||
    "    with torch.no_grad():\n",
 | 
			
		||||
    "        for images, labels, label_lengths in tqdm(loader, desc='Validation'):\n",
 | 
			
		||||
    "            images = images.to(device)\n",
 | 
			
		||||
    "            labels = labels.to(device)\n",
 | 
			
		||||
    "            label_lengths = label_lengths.to(device)\n",
 | 
			
		||||
    "            \n",
 | 
			
		||||
    "            outputs = model(images)\n",
 | 
			
		||||
    "            \n",
 | 
			
		||||
    "            input_lengths = torch.full(size=(outputs.size(1),), fill_value=outputs.size(0), dtype=torch.long).to(device)\n",
 | 
			
		||||
    "            \n",
 | 
			
		||||
    "            loss = criterion(outputs.log_softmax(2), labels, input_lengths, label_lengths)\n",
 | 
			
		||||
    "            epoch_loss += loss.item()\n",
 | 
			
		||||
    "    return epoch_loss / len(loader)\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "def decode_predictions(preds, blank=0):\n",
 | 
			
		||||
    "    preds = preds.argmax(2)  # (seq_len, batch)\n",
 | 
			
		||||
    "    preds = preds.permute(1, 0)  # (batch, seq_len)\n",
 | 
			
		||||
    "    decoded = []\n",
 | 
			
		||||
    "    for pred in preds:\n",
 | 
			
		||||
    "        pred = pred.tolist()\n",
 | 
			
		||||
    "        decoded_seq = []\n",
 | 
			
		||||
    "        previous = blank\n",
 | 
			
		||||
    "        for p in pred:\n",
 | 
			
		||||
    "            if p != previous and p != blank:\n",
 | 
			
		||||
    "                decoded_seq.append(IDX_TO_CHAR.get(p, ''))\n",
 | 
			
		||||
    "            previous = p\n",
 | 
			
		||||
    "        decoded.append(''.join(decoded_seq))\n",
 | 
			
		||||
    "    return decoded\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "def evaluate(model, loader, device):\n",
 | 
			
		||||
    "    model.eval()\n",
 | 
			
		||||
    "    correct = 0\n",
 | 
			
		||||
    "    total = 0\n",
 | 
			
		||||
    "    with torch.no_grad():\n",
 | 
			
		||||
    "        for images, labels, label_lengths in tqdm(loader, desc='Testing'):\n",
 | 
			
		||||
    "            images = images.to(device)\n",
 | 
			
		||||
    "            outputs = model(images)\n",
 | 
			
		||||
    "            preds = decode_predictions(outputs)\n",
 | 
			
		||||
    "            \n",
 | 
			
		||||
    "            batch_size = images.size(0)\n",
 | 
			
		||||
    "            start = 0\n",
 | 
			
		||||
    "            for i in range(batch_size):\n",
 | 
			
		||||
    "                length = label_lengths[i]\n",
 | 
			
		||||
    "                true_label = ''.join([IDX_TO_CHAR.get(idx.item(), '') for idx in labels[start:start+length]])\n",
 | 
			
		||||
    "                start += length\n",
 | 
			
		||||
    "                pred_label = preds[i]\n",
 | 
			
		||||
    "                if pred_label == true_label:\n",
 | 
			
		||||
    "                    correct += 1\n",
 | 
			
		||||
    "                total += 1\n",
 | 
			
		||||
    "    accuracy = correct / total * 100\n",
 | 
			
		||||
    "    return accuracy\n"
 | 
			
		||||
   ]
 | 
			
		||||
  },
 | 
			
		||||
  {
 | 
			
		||||
   "cell_type": "code",
 | 
			
		||||
   "execution_count": 27,
 | 
			
		||||
   "metadata": {},
 | 
			
		||||
   "outputs": [],
 | 
			
		||||
   "source": [
 | 
			
		||||
    "model = CRNN(num_classes=NUM_CLASSES).to(device)\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "criterion = nn.CTCLoss(blank=0, zero_infinity=True)\n",
 | 
			
		||||
    "optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n"
 | 
			
		||||
   ]
 | 
			
		||||
  },
 | 
			
		||||
  {
 | 
			
		||||
   "cell_type": "code",
 | 
			
		||||
   "execution_count": 29,
 | 
			
		||||
   "metadata": {},
 | 
			
		||||
   "outputs": [
 | 
			
		||||
    {
 | 
			
		||||
     "name": "stdout",
 | 
			
		||||
     "output_type": "stream",
 | 
			
		||||
     "text": [
 | 
			
		||||
      "Epoch 1/10\n"
 | 
			
		||||
     ]
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
     "name": "stderr",
 | 
			
		||||
     "output_type": "stream",
 | 
			
		||||
     "text": [
 | 
			
		||||
      "Training: 100%|██████████| 569/569 [01:20<00:00,  7.03it/s]\n",
 | 
			
		||||
      "Validation: 100%|██████████| 72/72 [00:04<00:00, 16.47it/s]\n",
 | 
			
		||||
      "Testing: 100%|██████████| 72/72 [00:05<00:00, 14.37it/s]\n"
 | 
			
		||||
     ]
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
     "name": "stdout",
 | 
			
		||||
     "output_type": "stream",
 | 
			
		||||
     "text": [
 | 
			
		||||
      "Train Loss: 0.0850 | Val Loss: 0.0767 | Accuracy: 96.73%\n",
 | 
			
		||||
      "Модель с лучшей потерей сохранена!\n",
 | 
			
		||||
      "Модель с лучшей точностью сохранена!\n",
 | 
			
		||||
      "Epoch 2/10\n"
 | 
			
		||||
     ]
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
     "name": "stderr",
 | 
			
		||||
     "output_type": "stream",
 | 
			
		||||
     "text": [
 | 
			
		||||
      "Training: 100%|██████████| 569/569 [01:20<00:00,  7.07it/s]\n",
 | 
			
		||||
      "Validation: 100%|██████████| 72/72 [00:04<00:00, 17.92it/s]\n",
 | 
			
		||||
      "Testing: 100%|██████████| 72/72 [00:04<00:00, 14.78it/s]\n"
 | 
			
		||||
     ]
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
     "name": "stdout",
 | 
			
		||||
     "output_type": "stream",
 | 
			
		||||
     "text": [
 | 
			
		||||
      "Train Loss: 0.0829 | Val Loss: 0.0706 | Accuracy: 97.63%\n",
 | 
			
		||||
      "Модель с лучшей потерей сохранена!\n",
 | 
			
		||||
      "Модель с лучшей точностью сохранена!\n",
 | 
			
		||||
      "Epoch 3/10\n"
 | 
			
		||||
     ]
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
     "name": "stderr",
 | 
			
		||||
     "output_type": "stream",
 | 
			
		||||
     "text": [
 | 
			
		||||
      "Training: 100%|██████████| 569/569 [01:20<00:00,  7.09it/s]\n",
 | 
			
		||||
      "Validation: 100%|██████████| 72/72 [00:04<00:00, 17.20it/s]\n",
 | 
			
		||||
      "Testing: 100%|██████████| 72/72 [00:04<00:00, 14.51it/s]\n"
 | 
			
		||||
     ]
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
     "name": "stdout",
 | 
			
		||||
     "output_type": "stream",
 | 
			
		||||
     "text": [
 | 
			
		||||
      "Train Loss: 0.0796 | Val Loss: 0.0763 | Accuracy: 96.81%\n",
 | 
			
		||||
      "Epoch 4/10\n"
 | 
			
		||||
     ]
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
     "name": "stderr",
 | 
			
		||||
     "output_type": "stream",
 | 
			
		||||
     "text": [
 | 
			
		||||
      "Training: 100%|██████████| 569/569 [01:19<00:00,  7.11it/s]\n",
 | 
			
		||||
      "Validation: 100%|██████████| 72/72 [00:04<00:00, 17.50it/s]\n",
 | 
			
		||||
      "Testing: 100%|██████████| 72/72 [00:04<00:00, 15.43it/s]\n"
 | 
			
		||||
     ]
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
     "name": "stdout",
 | 
			
		||||
     "output_type": "stream",
 | 
			
		||||
     "text": [
 | 
			
		||||
      "Train Loss: 0.0805 | Val Loss: 0.0732 | Accuracy: 97.72%\n",
 | 
			
		||||
      "Модель с лучшей точностью сохранена!\n",
 | 
			
		||||
      "Epoch 5/10\n"
 | 
			
		||||
     ]
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
     "name": "stderr",
 | 
			
		||||
     "output_type": "stream",
 | 
			
		||||
     "text": [
 | 
			
		||||
      "Training: 100%|██████████| 569/569 [01:19<00:00,  7.17it/s]\n",
 | 
			
		||||
      "Validation: 100%|██████████| 72/72 [00:04<00:00, 17.76it/s]\n",
 | 
			
		||||
      "Testing: 100%|██████████| 72/72 [00:04<00:00, 14.62it/s]\n"
 | 
			
		||||
     ]
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
     "name": "stdout",
 | 
			
		||||
     "output_type": "stream",
 | 
			
		||||
     "text": [
 | 
			
		||||
      "Train Loss: 0.0787 | Val Loss: 0.0716 | Accuracy: 97.76%\n",
 | 
			
		||||
      "Модель с лучшей точностью сохранена!\n",
 | 
			
		||||
      "Epoch 6/10\n"
 | 
			
		||||
     ]
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
     "name": "stderr",
 | 
			
		||||
     "output_type": "stream",
 | 
			
		||||
     "text": [
 | 
			
		||||
      "Training: 100%|██████████| 569/569 [01:19<00:00,  7.13it/s]\n",
 | 
			
		||||
      "Validation: 100%|██████████| 72/72 [00:03<00:00, 18.01it/s]\n",
 | 
			
		||||
      "Testing: 100%|██████████| 72/72 [00:04<00:00, 14.58it/s]\n"
 | 
			
		||||
     ]
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
     "name": "stdout",
 | 
			
		||||
     "output_type": "stream",
 | 
			
		||||
     "text": [
 | 
			
		||||
      "Train Loss: 0.0775 | Val Loss: 0.0732 | Accuracy: 97.63%\n",
 | 
			
		||||
      "Epoch 7/10\n"
 | 
			
		||||
     ]
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
     "name": "stderr",
 | 
			
		||||
     "output_type": "stream",
 | 
			
		||||
     "text": [
 | 
			
		||||
      "Training: 100%|██████████| 569/569 [01:18<00:00,  7.21it/s]\n",
 | 
			
		||||
      "Validation: 100%|██████████| 72/72 [00:04<00:00, 17.52it/s]\n",
 | 
			
		||||
      "Testing: 100%|██████████| 72/72 [00:04<00:00, 14.53it/s]\n"
 | 
			
		||||
     ]
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
     "name": "stdout",
 | 
			
		||||
     "output_type": "stream",
 | 
			
		||||
     "text": [
 | 
			
		||||
      "Train Loss: 0.0731 | Val Loss: 0.0746 | Accuracy: 97.58%\n",
 | 
			
		||||
      "Epoch 8/10\n"
 | 
			
		||||
     ]
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
     "name": "stderr",
 | 
			
		||||
     "output_type": "stream",
 | 
			
		||||
     "text": [
 | 
			
		||||
      "Training: 100%|██████████| 569/569 [01:19<00:00,  7.15it/s]\n",
 | 
			
		||||
      "Validation: 100%|██████████| 72/72 [00:04<00:00, 17.64it/s]\n",
 | 
			
		||||
      "Testing: 100%|██████████| 72/72 [00:04<00:00, 14.43it/s]\n"
 | 
			
		||||
     ]
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
     "name": "stdout",
 | 
			
		||||
     "output_type": "stream",
 | 
			
		||||
     "text": [
 | 
			
		||||
      "Train Loss: 0.0745 | Val Loss: 0.0753 | Accuracy: 96.77%\n",
 | 
			
		||||
      "Epoch 9/10\n"
 | 
			
		||||
     ]
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
     "name": "stderr",
 | 
			
		||||
     "output_type": "stream",
 | 
			
		||||
     "text": [
 | 
			
		||||
      "Training: 100%|██████████| 569/569 [01:19<00:00,  7.19it/s]\n",
 | 
			
		||||
      "Validation: 100%|██████████| 72/72 [00:04<00:00, 17.87it/s]\n",
 | 
			
		||||
      "Testing: 100%|██████████| 72/72 [00:05<00:00, 14.38it/s]\n"
 | 
			
		||||
     ]
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
     "name": "stdout",
 | 
			
		||||
     "output_type": "stream",
 | 
			
		||||
     "text": [
 | 
			
		||||
      "Train Loss: 0.0734 | Val Loss: 0.0742 | Accuracy: 97.80%\n",
 | 
			
		||||
      "Модель с лучшей точностью сохранена!\n",
 | 
			
		||||
      "Epoch 10/10\n"
 | 
			
		||||
     ]
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
     "name": "stderr",
 | 
			
		||||
     "output_type": "stream",
 | 
			
		||||
     "text": [
 | 
			
		||||
      "Training: 100%|██████████| 569/569 [01:19<00:00,  7.15it/s]\n",
 | 
			
		||||
      "Validation: 100%|██████████| 72/72 [00:04<00:00, 17.24it/s]\n",
 | 
			
		||||
      "Testing: 100%|██████████| 72/72 [00:04<00:00, 14.55it/s]"
 | 
			
		||||
     ]
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
     "name": "stdout",
 | 
			
		||||
     "output_type": "stream",
 | 
			
		||||
     "text": [
 | 
			
		||||
      "Train Loss: 0.0767 | Val Loss: 0.0776 | Accuracy: 96.95%\n"
 | 
			
		||||
     ]
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
     "name": "stderr",
 | 
			
		||||
     "output_type": "stream",
 | 
			
		||||
     "text": [
 | 
			
		||||
      "\n"
 | 
			
		||||
     ]
 | 
			
		||||
    }
 | 
			
		||||
   ],
 | 
			
		||||
   "source": [
 | 
			
		||||
    "num_epochs = 10\n",
 | 
			
		||||
    "best_val_loss = float('inf')\n",
 | 
			
		||||
    "best_accuracy = 0.0\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "for epoch in range(1, num_epochs + 1):\n",
 | 
			
		||||
    "    print(f'Epoch {epoch}/{num_epochs}')\n",
 | 
			
		||||
    "    \n",
 | 
			
		||||
    "    train_loss = train(model, train_loader, optimizer, criterion, device)\n",
 | 
			
		||||
    "    val_loss = validate(model, val_loader, criterion, device)\n",
 | 
			
		||||
    "    accuracy = evaluate(model, val_loader, device)\n",
 | 
			
		||||
    "    \n",
 | 
			
		||||
    "    print(f'Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Accuracy: {accuracy:.2f}%')\n",
 | 
			
		||||
    "    \n",
 | 
			
		||||
    "    if val_loss < best_val_loss:\n",
 | 
			
		||||
    "        best_val_loss = val_loss\n",
 | 
			
		||||
    "        torch.save(model.state_dict(), 'best_loss_model_3.pth')\n",
 | 
			
		||||
    "        print('Модель с лучшей потерей сохранена!')\n",
 | 
			
		||||
    "    \n",
 | 
			
		||||
    "    if accuracy > best_accuracy:\n",
 | 
			
		||||
    "        best_accuracy = accuracy\n",
 | 
			
		||||
    "        torch.save(model.state_dict(), 'best_accuracy_model_3.pth')\n",
 | 
			
		||||
    "        print('Модель с лучшей точностью сохранена!')\n"
 | 
			
		||||
   ]
 | 
			
		||||
  },
 | 
			
		||||
  {
 | 
			
		||||
   "cell_type": "code",
 | 
			
		||||
   "execution_count": 30,
 | 
			
		||||
   "metadata": {},
 | 
			
		||||
   "outputs": [
 | 
			
		||||
    {
 | 
			
		||||
     "name": "stderr",
 | 
			
		||||
     "output_type": "stream",
 | 
			
		||||
     "text": [
 | 
			
		||||
      "Testing: 100%|██████████| 72/72 [00:04<00:00, 14.55it/s]"
 | 
			
		||||
     ]
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
     "name": "stdout",
 | 
			
		||||
     "output_type": "stream",
 | 
			
		||||
     "text": [
 | 
			
		||||
      "Точность на тестовом наборе: 96.68%\n"
 | 
			
		||||
     ]
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
     "name": "stderr",
 | 
			
		||||
     "output_type": "stream",
 | 
			
		||||
     "text": [
 | 
			
		||||
      "\n"
 | 
			
		||||
     ]
 | 
			
		||||
    }
 | 
			
		||||
   ],
 | 
			
		||||
   "source": [
 | 
			
		||||
    "#model.load_state_dict(torch.load('models/best_accuracy_model_3.pth'))\n",
 | 
			
		||||
    "test_accuracy = evaluate(model, test_loader, device)\n",
 | 
			
		||||
    "print(f'Точность на тестовом наборе: {test_accuracy:.2f}%')"
 | 
			
		||||
   ]
 | 
			
		||||
  },
 | 
			
		||||
  {
 | 
			
		||||
   "cell_type": "code",
 | 
			
		||||
   "execution_count": 31,
 | 
			
		||||
   "metadata": {},
 | 
			
		||||
   "outputs": [
 | 
			
		||||
    {
 | 
			
		||||
     "name": "stderr",
 | 
			
		||||
     "output_type": "stream",
 | 
			
		||||
     "text": [
 | 
			
		||||
      "Processing Images: 100%|██████████| 8/8 [00:00<00:00, 54.74it/s]"
 | 
			
		||||
     ]
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
     "name": "stdout",
 | 
			
		||||
     "output_type": "stream",
 | 
			
		||||
     "text": [
 | 
			
		||||
      "A023TY97.png: A023TY97\n",
 | 
			
		||||
      "A413YE97.png: A413YE97\n",
 | 
			
		||||
      "B642OT97.png: B642OT97\n",
 | 
			
		||||
      "H702TH97.png: H702TH97\n",
 | 
			
		||||
      "K263CO97.png: K263CO97\n",
 | 
			
		||||
      "O571KT99.png: O571KT99\n",
 | 
			
		||||
      "T829MK97.png: T829MK97\n",
 | 
			
		||||
      "Y726PA97.png: Y726PA97\n"
 | 
			
		||||
     ]
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
     "name": "stderr",
 | 
			
		||||
     "output_type": "stream",
 | 
			
		||||
     "text": [
 | 
			
		||||
      "\n"
 | 
			
		||||
     ]
 | 
			
		||||
    }
 | 
			
		||||
   ],
 | 
			
		||||
   "source": [
 | 
			
		||||
    "import os\n",
 | 
			
		||||
    "from PIL import Image\n",
 | 
			
		||||
    "import torch\n",
 | 
			
		||||
    "from torchvision import transforms\n",
 | 
			
		||||
    "from tqdm import tqdm \n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "def recognize_license_plates(model, folder_path, transform, device):\n",
 | 
			
		||||
    "    model.eval() \n",
 | 
			
		||||
    "    images = [img for img in os.listdir(folder_path) if img.endswith(('.png', '.jpg', '.jpeg'))]\n",
 | 
			
		||||
    "    \n",
 | 
			
		||||
    "    results = {}\n",
 | 
			
		||||
    "    \n",
 | 
			
		||||
    "    for img_name in tqdm(images, desc=\"Processing Images\"):\n",
 | 
			
		||||
    "        img_path = os.path.join(folder_path, img_name)\n",
 | 
			
		||||
    "        image = Image.open(img_path).convert('L') \n",
 | 
			
		||||
    "        \n",
 | 
			
		||||
    "        image_tensor = transform(image).unsqueeze(0).to(device)\n",
 | 
			
		||||
    "        \n",
 | 
			
		||||
    "        with torch.no_grad():\n",
 | 
			
		||||
    "            output = model(image_tensor)  # (seq_len, batch, num_classes)\n",
 | 
			
		||||
    "        \n",
 | 
			
		||||
    "        decoded_text = decode_predictions(output)\n",
 | 
			
		||||
    "        \n",
 | 
			
		||||
    "        results[img_name] = decoded_text[0]\n",
 | 
			
		||||
    "    \n",
 | 
			
		||||
    "    return results\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "folder_path = 'dataset-ocr/fine-tune-train/img_plate_image'  # Путь к папке с изображениями\n",
 | 
			
		||||
    "#model.load_state_dict(torch.load('best_accuracy_model_2.pth'))\n",
 | 
			
		||||
    "#model.to(device)\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "results = recognize_license_plates(model, folder_path, transform, device)\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "for img_name, text in results.items():\n",
 | 
			
		||||
    "    print(f\"{img_name}: {text}\")\n"
 | 
			
		||||
   ]
 | 
			
		||||
  }
 | 
			
		||||
 ],
 | 
			
		||||
 "metadata": {
 | 
			
		||||
  "kernelspec": {
 | 
			
		||||
   "display_name": ".venv",
 | 
			
		||||
   "language": "python",
 | 
			
		||||
   "name": "python3"
 | 
			
		||||
  },
 | 
			
		||||
  "language_info": {
 | 
			
		||||
   "codemirror_mode": {
 | 
			
		||||
    "name": "ipython",
 | 
			
		||||
    "version": 3
 | 
			
		||||
   },
 | 
			
		||||
   "file_extension": ".py",
 | 
			
		||||
   "mimetype": "text/x-python",
 | 
			
		||||
   "name": "python",
 | 
			
		||||
   "nbconvert_exporter": "python",
 | 
			
		||||
   "pygments_lexer": "ipython3",
 | 
			
		||||
   "version": "3.9.13"
 | 
			
		||||
  }
 | 
			
		||||
 },
 | 
			
		||||
 "nbformat": 4,
 | 
			
		||||
 "nbformat_minor": 2
 | 
			
		||||
}
 | 
			
		||||
		Loading…
	
		Reference in New Issue