Compare commits

...

3 Commits

Author SHA1 Message Date
itqop 5923b469bf docs: update README.md with project details
Updated the `README.md` to include:
- Comprehensive project overview.
- Detailed setup and usage instructions.
- Information on system architecture.
- Guidance for real-time recognition, static image testing, and CRNN training.
- License information.

This update ensures clarity for users and contributors, improving the overall documentation quality.
2025-01-15 05:54:44 +03:00
itqop 13898eda38 feat: add static-test.ipynb and img folder with test images
- Added `static-test.ipynb`, a Jupyter Notebook for testing the license plate recognition system on static images. The notebook demonstrates:
  - Loading and preprocessing of test images.
  - Inference using the YOLO and CRNN models.
  - Visualization of detection results.

- Added the `img` folder containing test images of vehicles with license plates for validation and demonstration purposes.

These additions provide a comprehensive framework for testing and showcasing the system's capabilities on pre-captured images.
2025-01-15 05:46:33 +03:00
itqop a21460ef16 feat: add train_crnn.ipynb for CRNN model training
Added a Jupyter Notebook `train_crnn.ipynb` to the repository for training the CRNN model. This notebook includes:
- Steps for preparing the dataset.
- Training pipeline for the CRNN architecture.
- Evaluation of the trained model.

This addition enhances the project's flexibility for users who want to retrain the CRNN model on custom datasets.
2025-01-15 05:45:32 +03:00
11 changed files with 1061 additions and 2 deletions

157
README.md
View File

@ -1,3 +1,156 @@
# number-plate
# License Plate Recognition System
A robust and efficient real-time license plate recognition system using YOLO for plate detection and CRNN for text recognition. The system is specifically designed for Russian license plates and provides high accuracy in both detection and character recognition.
## 🚀 Features
- Real-time license plate detection and recognition
- Support for both real-time webcam input and static image batch processing
- Support for Russian license plate format
- High-accuracy YOLO-based plate detection
- Advanced CRNN-based text recognition with CTC loss
- Real-time visualization with OpenCV
- Multi-stage pipeline with separate detection and recognition components
- Built-in data augmentation for robust training
- Support for both real-time webcam input and static image processing
## 🛠️ Technical Architecture
The system consists of two main components:
1. **License Plate Detection (YOLO)**
- Uses YOLO (You Only Look Once) for real-time object detection
- Trained specifically for license plate detection
- Provides bounding box coordinates for detected plates
2. **Text Recognition (CRNN)**
- Convolutional Recurrent Neural Network architecture
- CNN backbone for feature extraction
- Bidirectional LSTM for sequence modeling
- CTC loss for end-to-end training
- Supports the specific character set used in Russian license plates
## 📊 Performance
- CRNN Text Recognition Accuracy: ~97% on validation set
- Real-time processing capability on CPU
- Support for various lighting conditions and angles through data augmentation
## 🔧 Requirements
```
numpy==1.26.3
pillow==10.2.0
opencv-python==4.10.0.84
scikit-learn==1.5.2
scipy==1.13.1
matplotlib==3.9.2
tqdm
ultralytics
torch==2.0.1
torchvision==0.15.2
```
## 💻 Installation
1. Clone the repository:
```bash
git clone https://git.itqop.pw/itqop/number-plate.git
cd number-plate
```
2. Create and activate a virtual environment (optional but recommended):
```bash
python -m venv venv
source venv/bin/activate # On Windows: venv\Scripts\activate
```
3. Install the required packages:
```bash
pip install -r requirements.txt
```
## 🚗 Usage
### Real-time Recognition (Webcam)
To run the system with webcam input:
```bash
python main.py
```
The application will open a window showing the webcam feed with detected license plates and their recognized text. Press 'q' or 'ESC' to exit.
### Static Image Processing
To process a batch of images:
1. Place your images in the `img` folder
2. Open `static-test.ipynb`
3. Run the static test script
The script will:
- Process all images in the specified folder
- Display results with bounding boxes and recognized text
- Show results in a matplotlib window for each image
### Training the CRNN Model
The CRNN model can be trained using the provided Jupyter notebook:
1. Open `train_crnn.ipynb`
2. Configure the training parameters if needed
3. Run all cells to start training
4. The best models will be saved based on validation loss and accuracy
## 📁 Project Structure
```
├── img/ # Directory to store images for static tests
├── models/ # Directory to store YOLO and CRNN weights
├── config.py # Configuration for character mappings and model paths
├── license_plate_recognizer.py # Main class for license plate detection and recognition
├── main.py # Script for real-time license plate recognition
├── model.py # CRNN model definition
├── requirements.txt # Project dependencies
├── static-test.ipynb # Notebook for testing inference on static images
├── train_crnn.ipynb # Notebook for training the CRNN model
└── README.md # Project documentation
```
## 🧬 Model Architecture
### CRNN Architecture
- **CNN Layers**: Multiple convolutional layers with max pooling and batch normalization
- **RNN Layers**: Bidirectional LSTM for sequence modeling
- **Dense Layers**: Final classification layers
- **Input Size**: 32x128 grayscale images
- **Output**: Character sequence predictions
## 📝 License Plate Format
The system is designed to recognize Russian license plates with the following character set:
- Letters: A, B, E, K, M, H, O, P, C, T, Y, X
- Numbers: 0-9
- Special characters: - (hyphen)
## ⚙️ Configuration
Key parameters can be modified in `config.py`:
- `ALPHABET`: Supported characters
- `NUM_CLASSES`: Number of character classes
- `YOLO_WEIGHTS_PATH`: Path to YOLO model weights
- `CRNN_WEIGHTS_PATH`: Path to CRNN model weights
## 🤝 Contributing
Contributions are welcome! Please feel free to submit a Pull Request.
## 📄 License
This project is licensed under the MIT License.
You are free to use, modify, and distribute this software as per the terms of the MIT License. For detailed terms, see the `LICENSE` file included in this repository.
.

BIN
img/01-1721.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 73 KiB

BIN
img/01-2060.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 63 KiB

BIN
img/01-2193.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 63 KiB

BIN
img/01-2489.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 71 KiB

BIN
img/01-393.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 75 KiB

BIN
img/01-486.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 64 KiB

BIN
img/01-541.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 62 KiB

BIN
img/01-715.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 66 KiB

259
static-test.ipynb Normal file

File diff suppressed because one or more lines are too long

647
train_crnn.ipynb Normal file
View File

@ -0,0 +1,647 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Используемое устройство: cuda\n"
]
}
],
"source": [
"import os\n",
"import string\n",
"import torch\n",
"import torch.nn as nn\n",
"from torch.utils.data import Dataset, DataLoader\n",
"from torchvision import transforms\n",
"from PIL import Image\n",
"from tqdm import tqdm\n",
"\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"print(f\"Используемое устройство: {device}\")\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"ALPHABET = '-ABEKMHOPCTYX0123456789'\n",
"CHAR_TO_IDX = {char: idx + 1 for idx, char in enumerate(ALPHABET)} # 0 будет использоваться для CTC blank\n",
"IDX_TO_CHAR = {idx + 1: char for idx, char in enumerate(ALPHABET)}\n",
"\n",
"NUM_CLASSES = len(ALPHABET) + 1\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"class LicensePlateDataset(Dataset):\n",
" def __init__(self, root_dir, transform=None):\n",
" \"\"\"\n",
" Args:\n",
" root_dir (string): Путь к директории с данными (train, val, test)\n",
" transform (callable, optional): Трансформации для изображений\n",
" \"\"\"\n",
" self.root_dir = root_dir\n",
" self.img_dir = os.path.join(root_dir, 'img')\n",
" self.transform = transform\n",
" self.images = [img for img in os.listdir(self.img_dir) if img.endswith(('.png', '.jpg', '.jpeg'))]\n",
" \n",
" def __len__(self):\n",
" return len(self.images)\n",
" \n",
" def __getitem__(self, idx):\n",
" img_name = self.images[idx]\n",
" img_path = os.path.join(self.img_dir, img_name)\n",
" image = Image.open(img_path).convert('L') \n",
" if self.transform:\n",
" image = self.transform(image)\n",
" \n",
" label_str = os.path.splitext(img_name)[0].upper()\n",
" label = [CHAR_TO_IDX[char] for char in label_str if char in CHAR_TO_IDX]\n",
" label = torch.tensor(label, dtype=torch.long)\n",
" \n",
" return image, label\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"transform = transforms.Compose([\n",
" transforms.Resize((32, 128)),\n",
" transforms.ToTensor(),\n",
" transforms.Normalize((0.5,), (0.5,))\n",
"])\n",
"\n",
"import random\n",
"import torchvision.transforms.functional as F\n",
"from torchvision import transforms\n",
"\n",
"def random_padding(img):\n",
" pad_left = random.randint(5, 15)\n",
" pad_right = random.randint(5, 15)\n",
" pad_top = random.randint(5, 15)\n",
" pad_bottom = random.randint(5, 15)\n",
" return F.pad(img, (pad_left, pad_top, pad_right, pad_bottom), fill=0)\n",
"\n",
"# Обновленный transform\n",
"transform_train = transforms.Compose([\n",
" transforms.Lambda(random_padding),\n",
" transforms.Resize((32, 128)), # Изменяем размер изображения\n",
" transforms.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)), # Сдвиг, масштабирование, поворот\n",
" transforms.ColorJitter(brightness=0.2, contrast=0.2), # Изменение яркости и контраста\n",
" transforms.RandomPerspective(distortion_scale=0.2, p=0.5), # Перспективные искажения\n",
" transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)), # Размытие\n",
" transforms.ToTensor(), # Преобразуем в тензор\n",
" transforms.Normalize((0.5,), (0.5,)), # Нормализация\n",
"])\n",
"\n",
"def collate_fn(batch):\n",
" images, labels = zip(*batch)\n",
" images = torch.stack(images, 0)\n",
" \n",
" # Соединяем все метки в один тензор\n",
" label_lengths = torch.tensor([len(label) for label in labels], dtype=torch.long)\n",
" labels_concat = torch.cat(labels)\n",
" \n",
" return images, labels_concat, label_lengths\n"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"train_dataset = LicensePlateDataset(root_dir='dataset-ocr-new/train', transform=transform_train)\n",
"val_dataset = LicensePlateDataset(root_dir='dataset-ocr-new/val', transform=transform)\n",
"test_dataset = LicensePlateDataset(root_dir='dataset-ocr-new/test', transform=transform)\n",
"test_10 = LicensePlateDataset(root_dir=r'dataset-ocr\\fine-tune-val', transform=transform)\n",
"\n",
"batch_size = 64\n",
"\n",
"train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, collate_fn=lambda x: collate_fn(x))\n",
"val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0, collate_fn=lambda x: collate_fn(x))\n",
"test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0, collate_fn=lambda x: collate_fn(x))\n"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"class CRNN(nn.Module):\n",
" def __init__(self, num_classes):\n",
" super(CRNN, self).__init__()\n",
" \n",
" # CNN часть\n",
" self.cnn = nn.Sequential(\n",
" nn.Conv2d(1, 64, kernel_size=3, padding=1), # (batch, 64, 32, 128)\n",
" nn.ReLU(inplace=True),\n",
" nn.MaxPool2d(2, 2), # (batch, 64, 16, 64)\n",
" \n",
" nn.Conv2d(64, 128, kernel_size=3, padding=1), # (batch, 128, 16, 64)\n",
" nn.ReLU(inplace=True),\n",
" nn.MaxPool2d(2, 2), # (batch, 128, 8, 32)\n",
" \n",
" nn.Conv2d(128, 256, kernel_size=3, padding=1), # (batch, 256, 8, 32)\n",
" nn.ReLU(inplace=True),\n",
" nn.BatchNorm2d(256),\n",
" \n",
" nn.Conv2d(256, 256, kernel_size=3, padding=1), # (batch, 256, 8, 32)\n",
" nn.ReLU(inplace=True),\n",
" nn.MaxPool2d((2,1), (2,1)), # (batch, 256, 4, 32)\n",
" \n",
" nn.Conv2d(256, 512, kernel_size=3, padding=1), # (batch, 512, 4, 32)\n",
" nn.ReLU(inplace=True),\n",
" nn.BatchNorm2d(512),\n",
" \n",
" nn.Conv2d(512, 512, kernel_size=3, padding=1), # (batch, 512, 4, 32)\n",
" nn.ReLU(inplace=True),\n",
" nn.MaxPool2d((2,1), (2,1)), # (batch, 512, 2, 32)\n",
" )\n",
" \n",
" # RNN часть\n",
" self.linear1 = nn.Linear(512 * 2, 256)\n",
" self.relu = nn.ReLU(inplace=True)\n",
" self.lstm = nn.LSTM(256, 256, bidirectional=True, batch_first=True)\n",
" self.linear2 = nn.Linear(512, num_classes)\n",
" \n",
" def forward(self, x):\n",
" # CNN часть\n",
" conv = self.cnn(x) # (batch, 512, 2, 32)\n",
" \n",
" # Перестановка и изменение формы для RNN\n",
" conv = conv.permute(0, 3, 1, 2) # (batch, width=32, channels=512, height=2)\n",
" conv = conv.view(conv.size(0), conv.size(1), -1) # (batch, 32, 512*2)\n",
" \n",
" # RNN часть\n",
" out = self.linear1(conv) # (batch, 32, 256)\n",
" out = self.relu(out) # (batch, 32, 256)\n",
" out, _ = self.lstm(out) # (batch, 32, 512)\n",
" out = self.linear2(out) # (batch, 32, num_classes)\n",
" \n",
" # Перестановка для CTC loss\n",
" out = out.permute(1, 0, 2) # (32, batch, num_classes)\n",
" return out\n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"def train(model, loader, optimizer, criterion, device):\n",
" model.train()\n",
" epoch_loss = 0\n",
" for images, labels, label_lengths in tqdm(loader, desc='Training'):\n",
" images = images.to(device)\n",
" labels = labels.to(device)\n",
" label_lengths = label_lengths.to(device)\n",
" \n",
" optimizer.zero_grad()\n",
" outputs = model(images) # (seq_len, batch, num_classes)\n",
" \n",
" # Определяем длину входных последовательностей (последний слой)\n",
" input_lengths = torch.full(size=(outputs.size(1),), fill_value=outputs.size(0), dtype=torch.long).to(device)\n",
" \n",
" loss = criterion(outputs.log_softmax(2), labels, input_lengths, label_lengths)\n",
" loss.backward()\n",
" optimizer.step()\n",
" \n",
" epoch_loss += loss.item()\n",
" return epoch_loss / len(loader)\n",
"\n",
"def validate(model, loader, criterion, device):\n",
" model.eval()\n",
" epoch_loss = 0\n",
" with torch.no_grad():\n",
" for images, labels, label_lengths in tqdm(loader, desc='Validation'):\n",
" images = images.to(device)\n",
" labels = labels.to(device)\n",
" label_lengths = label_lengths.to(device)\n",
" \n",
" outputs = model(images)\n",
" \n",
" input_lengths = torch.full(size=(outputs.size(1),), fill_value=outputs.size(0), dtype=torch.long).to(device)\n",
" \n",
" loss = criterion(outputs.log_softmax(2), labels, input_lengths, label_lengths)\n",
" epoch_loss += loss.item()\n",
" return epoch_loss / len(loader)\n",
"\n",
"def decode_predictions(preds, blank=0):\n",
" preds = preds.argmax(2) # (seq_len, batch)\n",
" preds = preds.permute(1, 0) # (batch, seq_len)\n",
" decoded = []\n",
" for pred in preds:\n",
" pred = pred.tolist()\n",
" decoded_seq = []\n",
" previous = blank\n",
" for p in pred:\n",
" if p != previous and p != blank:\n",
" decoded_seq.append(IDX_TO_CHAR.get(p, ''))\n",
" previous = p\n",
" decoded.append(''.join(decoded_seq))\n",
" return decoded\n",
"\n",
"def evaluate(model, loader, device):\n",
" model.eval()\n",
" correct = 0\n",
" total = 0\n",
" with torch.no_grad():\n",
" for images, labels, label_lengths in tqdm(loader, desc='Testing'):\n",
" images = images.to(device)\n",
" outputs = model(images)\n",
" preds = decode_predictions(outputs)\n",
" \n",
" batch_size = images.size(0)\n",
" start = 0\n",
" for i in range(batch_size):\n",
" length = label_lengths[i]\n",
" true_label = ''.join([IDX_TO_CHAR.get(idx.item(), '') for idx in labels[start:start+length]])\n",
" start += length\n",
" pred_label = preds[i]\n",
" if pred_label == true_label:\n",
" correct += 1\n",
" total += 1\n",
" accuracy = correct / total * 100\n",
" return accuracy\n"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"model = CRNN(num_classes=NUM_CLASSES).to(device)\n",
"\n",
"criterion = nn.CTCLoss(blank=0, zero_infinity=True)\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Training: 100%|██████████| 569/569 [01:20<00:00, 7.03it/s]\n",
"Validation: 100%|██████████| 72/72 [00:04<00:00, 16.47it/s]\n",
"Testing: 100%|██████████| 72/72 [00:05<00:00, 14.37it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train Loss: 0.0850 | Val Loss: 0.0767 | Accuracy: 96.73%\n",
"Модель с лучшей потерей сохранена!\n",
"Модель с лучшей точностью сохранена!\n",
"Epoch 2/10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Training: 100%|██████████| 569/569 [01:20<00:00, 7.07it/s]\n",
"Validation: 100%|██████████| 72/72 [00:04<00:00, 17.92it/s]\n",
"Testing: 100%|██████████| 72/72 [00:04<00:00, 14.78it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train Loss: 0.0829 | Val Loss: 0.0706 | Accuracy: 97.63%\n",
"Модель с лучшей потерей сохранена!\n",
"Модель с лучшей точностью сохранена!\n",
"Epoch 3/10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Training: 100%|██████████| 569/569 [01:20<00:00, 7.09it/s]\n",
"Validation: 100%|██████████| 72/72 [00:04<00:00, 17.20it/s]\n",
"Testing: 100%|██████████| 72/72 [00:04<00:00, 14.51it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train Loss: 0.0796 | Val Loss: 0.0763 | Accuracy: 96.81%\n",
"Epoch 4/10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Training: 100%|██████████| 569/569 [01:19<00:00, 7.11it/s]\n",
"Validation: 100%|██████████| 72/72 [00:04<00:00, 17.50it/s]\n",
"Testing: 100%|██████████| 72/72 [00:04<00:00, 15.43it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train Loss: 0.0805 | Val Loss: 0.0732 | Accuracy: 97.72%\n",
"Модель с лучшей точностью сохранена!\n",
"Epoch 5/10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Training: 100%|██████████| 569/569 [01:19<00:00, 7.17it/s]\n",
"Validation: 100%|██████████| 72/72 [00:04<00:00, 17.76it/s]\n",
"Testing: 100%|██████████| 72/72 [00:04<00:00, 14.62it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train Loss: 0.0787 | Val Loss: 0.0716 | Accuracy: 97.76%\n",
"Модель с лучшей точностью сохранена!\n",
"Epoch 6/10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Training: 100%|██████████| 569/569 [01:19<00:00, 7.13it/s]\n",
"Validation: 100%|██████████| 72/72 [00:03<00:00, 18.01it/s]\n",
"Testing: 100%|██████████| 72/72 [00:04<00:00, 14.58it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train Loss: 0.0775 | Val Loss: 0.0732 | Accuracy: 97.63%\n",
"Epoch 7/10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Training: 100%|██████████| 569/569 [01:18<00:00, 7.21it/s]\n",
"Validation: 100%|██████████| 72/72 [00:04<00:00, 17.52it/s]\n",
"Testing: 100%|██████████| 72/72 [00:04<00:00, 14.53it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train Loss: 0.0731 | Val Loss: 0.0746 | Accuracy: 97.58%\n",
"Epoch 8/10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Training: 100%|██████████| 569/569 [01:19<00:00, 7.15it/s]\n",
"Validation: 100%|██████████| 72/72 [00:04<00:00, 17.64it/s]\n",
"Testing: 100%|██████████| 72/72 [00:04<00:00, 14.43it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train Loss: 0.0745 | Val Loss: 0.0753 | Accuracy: 96.77%\n",
"Epoch 9/10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Training: 100%|██████████| 569/569 [01:19<00:00, 7.19it/s]\n",
"Validation: 100%|██████████| 72/72 [00:04<00:00, 17.87it/s]\n",
"Testing: 100%|██████████| 72/72 [00:05<00:00, 14.38it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train Loss: 0.0734 | Val Loss: 0.0742 | Accuracy: 97.80%\n",
"Модель с лучшей точностью сохранена!\n",
"Epoch 10/10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Training: 100%|██████████| 569/569 [01:19<00:00, 7.15it/s]\n",
"Validation: 100%|██████████| 72/72 [00:04<00:00, 17.24it/s]\n",
"Testing: 100%|██████████| 72/72 [00:04<00:00, 14.55it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train Loss: 0.0767 | Val Loss: 0.0776 | Accuracy: 96.95%\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"num_epochs = 10\n",
"best_val_loss = float('inf')\n",
"best_accuracy = 0.0\n",
"\n",
"for epoch in range(1, num_epochs + 1):\n",
" print(f'Epoch {epoch}/{num_epochs}')\n",
" \n",
" train_loss = train(model, train_loader, optimizer, criterion, device)\n",
" val_loss = validate(model, val_loader, criterion, device)\n",
" accuracy = evaluate(model, val_loader, device)\n",
" \n",
" print(f'Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Accuracy: {accuracy:.2f}%')\n",
" \n",
" if val_loss < best_val_loss:\n",
" best_val_loss = val_loss\n",
" torch.save(model.state_dict(), 'best_loss_model_3.pth')\n",
" print('Модель с лучшей потерей сохранена!')\n",
" \n",
" if accuracy > best_accuracy:\n",
" best_accuracy = accuracy\n",
" torch.save(model.state_dict(), 'best_accuracy_model_3.pth')\n",
" print('Модель с лучшей точностью сохранена!')\n"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Testing: 100%|██████████| 72/72 [00:04<00:00, 14.55it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Точность на тестовом наборе: 96.68%\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"#model.load_state_dict(torch.load('models/best_accuracy_model_3.pth'))\n",
"test_accuracy = evaluate(model, test_loader, device)\n",
"print(f'Точность на тестовом наборе: {test_accuracy:.2f}%')"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processing Images: 100%|██████████| 8/8 [00:00<00:00, 54.74it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"A023TY97.png: A023TY97\n",
"A413YE97.png: A413YE97\n",
"B642OT97.png: B642OT97\n",
"H702TH97.png: H702TH97\n",
"K263CO97.png: K263CO97\n",
"O571KT99.png: O571KT99\n",
"T829MK97.png: T829MK97\n",
"Y726PA97.png: Y726PA97\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"import os\n",
"from PIL import Image\n",
"import torch\n",
"from torchvision import transforms\n",
"from tqdm import tqdm \n",
"\n",
"def recognize_license_plates(model, folder_path, transform, device):\n",
" model.eval() \n",
" images = [img for img in os.listdir(folder_path) if img.endswith(('.png', '.jpg', '.jpeg'))]\n",
" \n",
" results = {}\n",
" \n",
" for img_name in tqdm(images, desc=\"Processing Images\"):\n",
" img_path = os.path.join(folder_path, img_name)\n",
" image = Image.open(img_path).convert('L') \n",
" \n",
" image_tensor = transform(image).unsqueeze(0).to(device)\n",
" \n",
" with torch.no_grad():\n",
" output = model(image_tensor) # (seq_len, batch, num_classes)\n",
" \n",
" decoded_text = decode_predictions(output)\n",
" \n",
" results[img_name] = decoded_text[0]\n",
" \n",
" return results\n",
"\n",
"folder_path = 'dataset-ocr/fine-tune-train/img_plate_image' # Путь к папке с изображениями\n",
"#model.load_state_dict(torch.load('best_accuracy_model_2.pth'))\n",
"#model.to(device)\n",
"\n",
"results = recognize_license_plates(model, folder_path, transform, device)\n",
"\n",
"for img_name, text in results.items():\n",
" print(f\"{img_name}: {text}\")\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}