{ "cells": [ { "cell_type": "code", "execution_count": 316, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "18\n", "{'A': 0, 'B': 1, 'C': 2, 'E': 3, 'H': 4, 'K': 5, 'M': 6, 'O': 7, 'P': 8, 'T': 9, 'X': 10, 'y': 11, '0': 12, '1': 13, '2': 14, '3': 15, '4': 16, '5': 17, '6': 18, '7': 19, '8': 20, '9': 21}\n", "697932\n", "433256\n", "433256\n", "433256\n", "{'A': 0, 'B': 1, 'C': 2, 'E': 3, 'H': 4, 'K': 5, 'M': 6, 'O': 7, 'P': 8, 'T': 9, 'X': 10, 'y': 11, '0': 12, '1': 13, '2': 14, '3': 15, '4': 16, '5': 17, '6': 18, '7': 19, '8': 20, '9': 21}\n", "116323\n", "72548\n", "72548\n", "72548\n", "Размер обучающего набора: (433256, 28, 28), Метки: (433256,)\n", "Размер тестового набора: (72548, 28, 28), Метки: (72548,)\n" ] } ], "source": [ "import numpy as np\n", "from emnist import extract_training_samples, extract_test_samples\n", "\n", "# Ваш whitelist символов\n", "whitelist = 'ABCEHKMOPTXy0123456789'\n", "\n", "# Список всех классов в EMNIST\n", "emnist_classes = (\n", " \"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ\"\n", " \"abcdefghijklmnopqrstuvwxyz\"\n", ")\n", "\n", "\n", "# Функция для фильтрации данных по whitelist\n", "def filter_data(X, y, whitelist, emnist_classes):\n", " \"\"\"\n", " Фильтрует данные на основе whitelist символов.\n", " \"\"\"\n", " # Создаём словарь для быстрого сопоставления: символ -> индекс в whitelist\n", " whitelist_mapping = {char: idx for idx, char in enumerate(whitelist)}\n", " print(whitelist_mapping)\n", " print(len(y))\n", " # Создаём список индексов, которые соответствуют whitelist\n", " filtered_indices = [\n", " i for i, label in enumerate(y) if emnist_classes[label] in whitelist_mapping\n", " ]\n", " print(len(filtered_indices))\n", " # Отфильтрованные данные\n", " X_filtered = X[filtered_indices]\n", " y_filtered = np.array([whitelist_mapping[emnist_classes[label]] for label in y[filtered_indices]])\n", " print(len(X_filtered))\n", " print(len(y_filtered))\n", " return X_filtered, y_filtered\n", "\n", "# Загрузка данных из подмножества 'byclass'\n", "X_train_byclass, y_train_byclass = extract_training_samples('byclass')\n", "X_test_byclass, y_test_byclass = extract_test_samples('byclass')\n", "print(y_test_byclass[0])\n", "# Фильтрация данных\n", "X_train, y_train = filter_data(X_train_byclass, y_train_byclass, whitelist, emnist_classes)\n", "X_test, y_test = filter_data(X_test_byclass, y_test_byclass, whitelist, emnist_classes)\n", "\n", "# Проверка размеров наборов данных\n", "print(f\"Размер обучающего набора: {X_train.shape}, Метки: {y_train.shape}\")\n", "print(f\"Размер тестового набора: {X_test.shape}, Метки: {y_test.shape}\")\n" ] }, { "cell_type": "code", "execution_count": 317, "metadata": {}, "outputs": [], "source": [ "class EMNISTDataset(Dataset):\n", " def __init__(self, X, y, transform=None):\n", " self.X = torch.tensor(X, dtype=torch.float32).unsqueeze(1) # Добавляем канал\n", " self.y = torch.tensor(y, dtype=torch.long)\n", " self.transform = transform\n", "\n", " def __len__(self):\n", " return len(self.X)\n", "\n", " def __getitem__(self, idx):\n", " image = self.X[idx]\n", " label = self.y[idx]\n", " if self.transform:\n", " image = self.transform(image)\n", " return image, label\n", "\n", "# Нормализация данных\n", "transform = transforms.Compose([\n", " transforms.Normalize((0.5,), (0.5,)) # Нормализация в диапазон [-1, 1]\n", "])\n", "\n", "# Создание датасетов\n", "train_dataset = EMNISTDataset(X_train, y_train, transform=transform)\n", "test_dataset = EMNISTDataset(X_test, y_test, transform=transform)\n", "\n", "# DataLoader для обучения и тестирования\n", "train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)\n", "test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)\n", "\n" ] }, { "cell_type": "code", "execution_count": 318, "metadata": {}, "outputs": [], "source": [ "import torch.nn as nn\n", "import torch.nn.functional as F\n", "\n", "def compute_output_size(input_size, kernel_size, stride, padding):\n", " return (input_size - kernel_size + 2 * padding) // stride + 1\n", "\n", "class CNN(nn.Module):\n", " def __init__(self, num_classes):\n", " super(CNN, self).__init__()\n", " # Сверточные слои\n", " self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)\n", " self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)\n", " self.pool = nn.MaxPool2d(kernel_size=2, stride=2)\n", " \n", " # Вычисляем размеры после сверточных слоев\n", " height = compute_output_size(28, kernel_size=3, stride=1, padding=1) # conv1\n", " height = compute_output_size(height, kernel_size=3, stride=1, padding=1) # conv2\n", " height = compute_output_size(height, kernel_size=2, stride=2, padding=0) # pool\n", " width = height # Для квадратного входа\n", "\n", " # Полносвязные слои\n", " self.fc1 = nn.Linear(64 * height * width, 128)\n", " self.fc2 = nn.Linear(128, num_classes)\n", " self.dropout = nn.Dropout(0.5)\n", "\n", " def forward(self, x):\n", " x = F.relu(self.conv1(x))\n", " x = self.pool(F.relu(self.conv2(x)))\n", " x = x.view(x.size(0), -1)\n", " x = F.relu(self.fc1(x))\n", " x = self.dropout(x)\n", " x = self.fc2(x)\n", " return x\n", "\n", " \n", "\n", "\n", "\n", "# Инициализация модели\n", "num_classes = len(whitelist)\n", "model = CNN(num_classes)\n" ] }, { "cell_type": "code", "execution_count": 319, "metadata": {}, "outputs": [], "source": [ "import torch.optim as optim\n", "\n", "# Устройство (GPU или CPU)\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "model.to(device)\n", "\n", "# Определение функции потерь и оптимизатора\n", "criterion = nn.CrossEntropyLoss()\n", "optimizer = optim.Adam(model.parameters(), lr=0.001)\n", "\n", "from tqdm import tqdm\n", "import torch\n", "\n", "# Цикл обучения с tqdm\n", "def train(model, loader, optimizer, criterion, device):\n", " model.train()\n", " total_loss = 0\n", " correct = 0\n", " loop = tqdm(loader, desc=\"Обучение\", leave=False) # Прогресс-бар для обучения\n", " for images, labels in loop:\n", " images, labels = images.to(device), labels.to(device)\n", "\n", " # Обнуление градиентов\n", " optimizer.zero_grad()\n", "\n", " # Прямой проход\n", " outputs = model(images)\n", " loss = criterion(outputs, labels)\n", "\n", " # Обратный проход и обновление весов\n", " loss.backward()\n", " optimizer.step()\n", "\n", " # Логирование\n", " total_loss += loss.item()\n", " _, predicted = torch.max(outputs, 1)\n", " correct += (predicted == labels).sum().item()\n", "\n", " # Обновление tqdm\n", " loop.set_postfix(loss=loss.item())\n", "\n", " accuracy = 100 * correct / len(loader.dataset)\n", " return total_loss / len(loader), accuracy\n", "\n", "# Цикл оценки с tqdm\n", "def evaluate(model, loader, criterion, device):\n", " model.eval()\n", " total_loss = 0\n", " correct = 0\n", " loop = tqdm(loader, desc=\"Оценка\", leave=False) # Прогресс-бар для оценки\n", " with torch.no_grad():\n", " for images, labels in loop:\n", " images, labels = images.to(device), labels.to(device)\n", "\n", " # Прямой проход\n", " outputs = model(images)\n", " loss = criterion(outputs, labels)\n", "\n", " # Логирование\n", " total_loss += loss.item()\n", " _, predicted = torch.max(outputs, 1)\n", " correct += (predicted == labels).sum().item()\n", "\n", " # Обновление tqdm\n", " loop.set_postfix(loss=loss.item())\n", "\n", " accuracy = 100 * correct / len(loader.dataset)\n", " return total_loss / len(loader), accuracy\n", "\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": 320, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "device(type='cuda')" ] }, "execution_count": 320, "metadata": {}, "output_type": "execute_result" } ], "source": [ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "device" ] }, { "cell_type": "code", "execution_count": 321, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Эпоха 1/10\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Train Loss: 0.4385, Train Accuracy: 86.80%\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Test Loss: 0.1677, Test Accuracy: 93.46%\n", "Эпоха 2/10\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Train Loss: 0.2405, Train Accuracy: 91.25%\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Test Loss: 0.1499, Test Accuracy: 93.91%\n", "Эпоха 3/10\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Train Loss: 0.2105, Train Accuracy: 92.21%\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Test Loss: 0.1608, Test Accuracy: 93.23%\n", "Эпоха 4/10\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Train Loss: 0.1978, Train Accuracy: 92.65%\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Test Loss: 0.1394, Test Accuracy: 94.31%\n", "Эпоха 5/10\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Train Loss: 0.1887, Train Accuracy: 92.92%\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Test Loss: 0.1417, Test Accuracy: 94.46%\n", "Эпоха 6/10\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Train Loss: 0.1841, Train Accuracy: 93.05%\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Test Loss: 0.1473, Test Accuracy: 94.38%\n", "Эпоха 7/10\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Train Loss: 0.1794, Train Accuracy: 93.18%\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Test Loss: 0.1418, Test Accuracy: 94.52%\n", "Эпоха 8/10\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Train Loss: 0.1759, Train Accuracy: 93.30%\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Test Loss: 0.1432, Test Accuracy: 94.42%\n", "Эпоха 9/10\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Train Loss: 0.1728, Train Accuracy: 93.41%\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Test Loss: 0.1406, Test Accuracy: 94.53%\n", "Эпоха 10/10\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Train Loss: 0.1710, Train Accuracy: 93.46%\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " " ] }, { "name": "stdout", "output_type": "stream", "text": [ " Test Loss: 0.1397, Test Accuracy: 94.47%\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\r" ] } ], "source": [ "# Цикл обучения\n", "num_epochs = 10\n", "for epoch in range(num_epochs):\n", " print(f\"Эпоха {epoch + 1}/{num_epochs}\")\n", " \n", " # Обучение\n", " train_loss, train_acc = train(model, train_loader, optimizer, criterion, device)\n", " print(f\" Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.2f}%\")\n", "\n", " # Оценка\n", " test_loss, test_acc = evaluate(model, test_loader, criterion, device)\n", " print(f\" Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.2f}%\")\n" ] }, { "cell_type": "code", "execution_count": 322, "metadata": {}, "outputs": [], "source": [ "torch.save(model.state_dict(), \"letter_recognition_model1.pth\")\n" ] }, { "cell_type": "code", "execution_count": 337, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "import random\n", "\n", "def visualize_test_sample(X_test, y_test, index, whitelist):\n", " \"\"\"\n", " Визуализация тестового примера с его меткой.\n", " \n", " :param X_test: Тензор с тестовыми изображениями.\n", " :param y_test: Тензор с метками тестовых изображений.\n", " :param index: Индекс примера для отображения.\n", " :param whitelist: Список символов для сопоставления меток.\n", " \"\"\"\n", " image = X_test[index] # Убираем канал и переводим в numpy\n", " label_idx = y_test[index].item() # Получаем индекс метки\n", " label_char = whitelist[label_idx] # Соответствующий символ из whitelist\n", "\n", " # Отображение изображения с меткой\n", " plt.figure(figsize=(5, 5))\n", " plt.imshow(image, cmap='gray')\n", " plt.title(f\"Метка: {label_char} (Index: {label_idx})\", fontsize=16)\n", " plt.axis('off')\n", " plt.show()\n", "\n", "\n", "# Пример вызова\n", "visualize_test_sample(X_test, y_test, index=random.choice([i for i in range(len(X_test))]), whitelist=whitelist)\n" ] }, { "cell_type": "code", "execution_count": 216, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "10000" ] }, "execution_count": 216, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(X_test_all[2])" ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.13" } }, "nbformat": 4, "nbformat_minor": 2 }