number-plate-study/emnist.ipynb

683 lines
30 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "code",
"execution_count": 316,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"18\n",
"{'A': 0, 'B': 1, 'C': 2, 'E': 3, 'H': 4, 'K': 5, 'M': 6, 'O': 7, 'P': 8, 'T': 9, 'X': 10, 'y': 11, '0': 12, '1': 13, '2': 14, '3': 15, '4': 16, '5': 17, '6': 18, '7': 19, '8': 20, '9': 21}\n",
"697932\n",
"433256\n",
"433256\n",
"433256\n",
"{'A': 0, 'B': 1, 'C': 2, 'E': 3, 'H': 4, 'K': 5, 'M': 6, 'O': 7, 'P': 8, 'T': 9, 'X': 10, 'y': 11, '0': 12, '1': 13, '2': 14, '3': 15, '4': 16, '5': 17, '6': 18, '7': 19, '8': 20, '9': 21}\n",
"116323\n",
"72548\n",
"72548\n",
"72548\n",
"Размер обучающего набора: (433256, 28, 28), Метки: (433256,)\n",
"Размер тестового набора: (72548, 28, 28), Метки: (72548,)\n"
]
}
],
"source": [
"import numpy as np\n",
"from emnist import extract_training_samples, extract_test_samples\n",
"\n",
"# Ваш whitelist символов\n",
"whitelist = 'ABCEHKMOPTXy0123456789'\n",
"\n",
"# Список всех классов в EMNIST\n",
"emnist_classes = (\n",
" \"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ\"\n",
" \"abcdefghijklmnopqrstuvwxyz\"\n",
")\n",
"\n",
"\n",
"# Функция для фильтрации данных по whitelist\n",
"def filter_data(X, y, whitelist, emnist_classes):\n",
" \"\"\"\n",
" Фильтрует данные на основе whitelist символов.\n",
" \"\"\"\n",
" # Создаём словарь для быстрого сопоставления: символ -> индекс в whitelist\n",
" whitelist_mapping = {char: idx for idx, char in enumerate(whitelist)}\n",
" print(whitelist_mapping)\n",
" print(len(y))\n",
" # Создаём список индексов, которые соответствуют whitelist\n",
" filtered_indices = [\n",
" i for i, label in enumerate(y) if emnist_classes[label] in whitelist_mapping\n",
" ]\n",
" print(len(filtered_indices))\n",
" # Отфильтрованные данные\n",
" X_filtered = X[filtered_indices]\n",
" y_filtered = np.array([whitelist_mapping[emnist_classes[label]] for label in y[filtered_indices]])\n",
" print(len(X_filtered))\n",
" print(len(y_filtered))\n",
" return X_filtered, y_filtered\n",
"\n",
"# Загрузка данных из подмножества 'byclass'\n",
"X_train_byclass, y_train_byclass = extract_training_samples('byclass')\n",
"X_test_byclass, y_test_byclass = extract_test_samples('byclass')\n",
"print(y_test_byclass[0])\n",
"# Фильтрация данных\n",
"X_train, y_train = filter_data(X_train_byclass, y_train_byclass, whitelist, emnist_classes)\n",
"X_test, y_test = filter_data(X_test_byclass, y_test_byclass, whitelist, emnist_classes)\n",
"\n",
"# Проверка размеров наборов данных\n",
"print(f\"Размер обучающего набора: {X_train.shape}, Метки: {y_train.shape}\")\n",
"print(f\"Размер тестового набора: {X_test.shape}, Метки: {y_test.shape}\")\n"
]
},
{
"cell_type": "code",
"execution_count": 317,
"metadata": {},
"outputs": [],
"source": [
"class EMNISTDataset(Dataset):\n",
" def __init__(self, X, y, transform=None):\n",
" self.X = torch.tensor(X, dtype=torch.float32).unsqueeze(1) # Добавляем канал\n",
" self.y = torch.tensor(y, dtype=torch.long)\n",
" self.transform = transform\n",
"\n",
" def __len__(self):\n",
" return len(self.X)\n",
"\n",
" def __getitem__(self, idx):\n",
" image = self.X[idx]\n",
" label = self.y[idx]\n",
" if self.transform:\n",
" image = self.transform(image)\n",
" return image, label\n",
"\n",
"# Нормализация данных\n",
"transform = transforms.Compose([\n",
" transforms.Normalize((0.5,), (0.5,)) # Нормализация в диапазон [-1, 1]\n",
"])\n",
"\n",
"# Создание датасетов\n",
"train_dataset = EMNISTDataset(X_train, y_train, transform=transform)\n",
"test_dataset = EMNISTDataset(X_test, y_test, transform=transform)\n",
"\n",
"# DataLoader для обучения и тестирования\n",
"train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)\n",
"test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 318,
"metadata": {},
"outputs": [],
"source": [
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"\n",
"def compute_output_size(input_size, kernel_size, stride, padding):\n",
" return (input_size - kernel_size + 2 * padding) // stride + 1\n",
"\n",
"class CNN(nn.Module):\n",
" def __init__(self, num_classes):\n",
" super(CNN, self).__init__()\n",
" # Сверточные слои\n",
" self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)\n",
" self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)\n",
" self.pool = nn.MaxPool2d(kernel_size=2, stride=2)\n",
" \n",
" # Вычисляем размеры после сверточных слоев\n",
" height = compute_output_size(28, kernel_size=3, stride=1, padding=1) # conv1\n",
" height = compute_output_size(height, kernel_size=3, stride=1, padding=1) # conv2\n",
" height = compute_output_size(height, kernel_size=2, stride=2, padding=0) # pool\n",
" width = height # Для квадратного входа\n",
"\n",
" # Полносвязные слои\n",
" self.fc1 = nn.Linear(64 * height * width, 128)\n",
" self.fc2 = nn.Linear(128, num_classes)\n",
" self.dropout = nn.Dropout(0.5)\n",
"\n",
" def forward(self, x):\n",
" x = F.relu(self.conv1(x))\n",
" x = self.pool(F.relu(self.conv2(x)))\n",
" x = x.view(x.size(0), -1)\n",
" x = F.relu(self.fc1(x))\n",
" x = self.dropout(x)\n",
" x = self.fc2(x)\n",
" return x\n",
"\n",
" \n",
"\n",
"\n",
"\n",
"# Инициализация модели\n",
"num_classes = len(whitelist)\n",
"model = CNN(num_classes)\n"
]
},
{
"cell_type": "code",
"execution_count": 319,
"metadata": {},
"outputs": [],
"source": [
"import torch.optim as optim\n",
"\n",
"# Устройство (GPU или CPU)\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"model.to(device)\n",
"\n",
"# Определение функции потерь и оптимизатора\n",
"criterion = nn.CrossEntropyLoss()\n",
"optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
"\n",
"from tqdm import tqdm\n",
"import torch\n",
"\n",
"# Цикл обучения с tqdm\n",
"def train(model, loader, optimizer, criterion, device):\n",
" model.train()\n",
" total_loss = 0\n",
" correct = 0\n",
" loop = tqdm(loader, desc=\"Обучение\", leave=False) # Прогресс-бар для обучения\n",
" for images, labels in loop:\n",
" images, labels = images.to(device), labels.to(device)\n",
"\n",
" # Обнуление градиентов\n",
" optimizer.zero_grad()\n",
"\n",
" # Прямой проход\n",
" outputs = model(images)\n",
" loss = criterion(outputs, labels)\n",
"\n",
" # Обратный проход и обновление весов\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" # Логирование\n",
" total_loss += loss.item()\n",
" _, predicted = torch.max(outputs, 1)\n",
" correct += (predicted == labels).sum().item()\n",
"\n",
" # Обновление tqdm\n",
" loop.set_postfix(loss=loss.item())\n",
"\n",
" accuracy = 100 * correct / len(loader.dataset)\n",
" return total_loss / len(loader), accuracy\n",
"\n",
"# Цикл оценки с tqdm\n",
"def evaluate(model, loader, criterion, device):\n",
" model.eval()\n",
" total_loss = 0\n",
" correct = 0\n",
" loop = tqdm(loader, desc=\"Оценка\", leave=False) # Прогресс-бар для оценки\n",
" with torch.no_grad():\n",
" for images, labels in loop:\n",
" images, labels = images.to(device), labels.to(device)\n",
"\n",
" # Прямой проход\n",
" outputs = model(images)\n",
" loss = criterion(outputs, labels)\n",
"\n",
" # Логирование\n",
" total_loss += loss.item()\n",
" _, predicted = torch.max(outputs, 1)\n",
" correct += (predicted == labels).sum().item()\n",
"\n",
" # Обновление tqdm\n",
" loop.set_postfix(loss=loss.item())\n",
"\n",
" accuracy = 100 * correct / len(loader.dataset)\n",
" return total_loss / len(loader), accuracy\n",
"\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 320,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"device(type='cuda')"
]
},
"execution_count": 320,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"device"
]
},
{
"cell_type": "code",
"execution_count": 321,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Эпоха 1/10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" Train Loss: 0.4385, Train Accuracy: 86.80%\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" Test Loss: 0.1677, Test Accuracy: 93.46%\n",
"Эпоха 2/10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" Train Loss: 0.2405, Train Accuracy: 91.25%\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" Test Loss: 0.1499, Test Accuracy: 93.91%\n",
"Эпоха 3/10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" Train Loss: 0.2105, Train Accuracy: 92.21%\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" Test Loss: 0.1608, Test Accuracy: 93.23%\n",
"Эпоха 4/10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" Train Loss: 0.1978, Train Accuracy: 92.65%\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" Test Loss: 0.1394, Test Accuracy: 94.31%\n",
"Эпоха 5/10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" Train Loss: 0.1887, Train Accuracy: 92.92%\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" Test Loss: 0.1417, Test Accuracy: 94.46%\n",
"Эпоха 6/10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" Train Loss: 0.1841, Train Accuracy: 93.05%\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" Test Loss: 0.1473, Test Accuracy: 94.38%\n",
"Эпоха 7/10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" Train Loss: 0.1794, Train Accuracy: 93.18%\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" Test Loss: 0.1418, Test Accuracy: 94.52%\n",
"Эпоха 8/10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" Train Loss: 0.1759, Train Accuracy: 93.30%\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" Test Loss: 0.1432, Test Accuracy: 94.42%\n",
"Эпоха 9/10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" Train Loss: 0.1728, Train Accuracy: 93.41%\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" Test Loss: 0.1406, Test Accuracy: 94.53%\n",
"Эпоха 10/10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" Train Loss: 0.1710, Train Accuracy: 93.46%\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" "
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" Test Loss: 0.1397, Test Accuracy: 94.47%\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r"
]
}
],
"source": [
"# Цикл обучения\n",
"num_epochs = 10\n",
"for epoch in range(num_epochs):\n",
" print(f\"Эпоха {epoch + 1}/{num_epochs}\")\n",
" \n",
" # Обучение\n",
" train_loss, train_acc = train(model, train_loader, optimizer, criterion, device)\n",
" print(f\" Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.2f}%\")\n",
"\n",
" # Оценка\n",
" test_loss, test_acc = evaluate(model, test_loader, criterion, device)\n",
" print(f\" Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.2f}%\")\n"
]
},
{
"cell_type": "code",
"execution_count": 322,
"metadata": {},
"outputs": [],
"source": [
"torch.save(model.state_dict(), \"letter_recognition_model1.pth\")\n"
]
},
{
"cell_type": "code",
"execution_count": 337,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 500x500 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"import random\n",
"\n",
"def visualize_test_sample(X_test, y_test, index, whitelist):\n",
" \"\"\"\n",
" Визуализация тестового примера с его меткой.\n",
" \n",
" :param X_test: Тензор с тестовыми изображениями.\n",
" :param y_test: Тензор с метками тестовых изображений.\n",
" :param index: Индекс примера для отображения.\n",
" :param whitelist: Список символов для сопоставления меток.\n",
" \"\"\"\n",
" image = X_test[index] # Убираем канал и переводим в numpy\n",
" label_idx = y_test[index].item() # Получаем индекс метки\n",
" label_char = whitelist[label_idx] # Соответствующий символ из whitelist\n",
"\n",
" # Отображение изображения с меткой\n",
" plt.figure(figsize=(5, 5))\n",
" plt.imshow(image, cmap='gray')\n",
" plt.title(f\"Метка: {label_char} (Index: {label_idx})\", fontsize=16)\n",
" plt.axis('off')\n",
" plt.show()\n",
"\n",
"\n",
"# Пример вызова\n",
"visualize_test_sample(X_test, y_test, index=random.choice([i for i in range(len(X_test))]), whitelist=whitelist)\n"
]
},
{
"cell_type": "code",
"execution_count": 216,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"10000"
]
},
"execution_count": 216,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(X_test_all[2])"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}