{ "cells": [ { "cell_type": "code", "execution_count": 316, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "18\n", "{'A': 0, 'B': 1, 'C': 2, 'E': 3, 'H': 4, 'K': 5, 'M': 6, 'O': 7, 'P': 8, 'T': 9, 'X': 10, 'y': 11, '0': 12, '1': 13, '2': 14, '3': 15, '4': 16, '5': 17, '6': 18, '7': 19, '8': 20, '9': 21}\n", "697932\n", "433256\n", "433256\n", "433256\n", "{'A': 0, 'B': 1, 'C': 2, 'E': 3, 'H': 4, 'K': 5, 'M': 6, 'O': 7, 'P': 8, 'T': 9, 'X': 10, 'y': 11, '0': 12, '1': 13, '2': 14, '3': 15, '4': 16, '5': 17, '6': 18, '7': 19, '8': 20, '9': 21}\n", "116323\n", "72548\n", "72548\n", "72548\n", "Размер обучающего набора: (433256, 28, 28), Метки: (433256,)\n", "Размер тестового набора: (72548, 28, 28), Метки: (72548,)\n" ] } ], "source": [ "import numpy as np\n", "from emnist import extract_training_samples, extract_test_samples\n", "\n", "# Ваш whitelist символов\n", "whitelist = 'ABCEHKMOPTXy0123456789'\n", "\n", "# Список всех классов в EMNIST\n", "emnist_classes = (\n", " \"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ\"\n", " \"abcdefghijklmnopqrstuvwxyz\"\n", ")\n", "\n", "\n", "# Функция для фильтрации данных по whitelist\n", "def filter_data(X, y, whitelist, emnist_classes):\n", " \"\"\"\n", " Фильтрует данные на основе whitelist символов.\n", " \"\"\"\n", " # Создаём словарь для быстрого сопоставления: символ -> индекс в whitelist\n", " whitelist_mapping = {char: idx for idx, char in enumerate(whitelist)}\n", " print(whitelist_mapping)\n", " print(len(y))\n", " # Создаём список индексов, которые соответствуют whitelist\n", " filtered_indices = [\n", " i for i, label in enumerate(y) if emnist_classes[label] in whitelist_mapping\n", " ]\n", " print(len(filtered_indices))\n", " # Отфильтрованные данные\n", " X_filtered = X[filtered_indices]\n", " y_filtered = np.array([whitelist_mapping[emnist_classes[label]] for label in y[filtered_indices]])\n", " print(len(X_filtered))\n", " print(len(y_filtered))\n", " return X_filtered, y_filtered\n", "\n", "# Загрузка данных из подмножества 'byclass'\n", "X_train_byclass, y_train_byclass = extract_training_samples('byclass')\n", "X_test_byclass, y_test_byclass = extract_test_samples('byclass')\n", "print(y_test_byclass[0])\n", "# Фильтрация данных\n", "X_train, y_train = filter_data(X_train_byclass, y_train_byclass, whitelist, emnist_classes)\n", "X_test, y_test = filter_data(X_test_byclass, y_test_byclass, whitelist, emnist_classes)\n", "\n", "# Проверка размеров наборов данных\n", "print(f\"Размер обучающего набора: {X_train.shape}, Метки: {y_train.shape}\")\n", "print(f\"Размер тестового набора: {X_test.shape}, Метки: {y_test.shape}\")\n" ] }, { "cell_type": "code", "execution_count": 317, "metadata": {}, "outputs": [], "source": [ "class EMNISTDataset(Dataset):\n", " def __init__(self, X, y, transform=None):\n", " self.X = torch.tensor(X, dtype=torch.float32).unsqueeze(1) # Добавляем канал\n", " self.y = torch.tensor(y, dtype=torch.long)\n", " self.transform = transform\n", "\n", " def __len__(self):\n", " return len(self.X)\n", "\n", " def __getitem__(self, idx):\n", " image = self.X[idx]\n", " label = self.y[idx]\n", " if self.transform:\n", " image = self.transform(image)\n", " return image, label\n", "\n", "# Нормализация данных\n", "transform = transforms.Compose([\n", " transforms.Normalize((0.5,), (0.5,)) # Нормализация в диапазон [-1, 1]\n", "])\n", "\n", "# Создание датасетов\n", "train_dataset = EMNISTDataset(X_train, y_train, transform=transform)\n", "test_dataset = EMNISTDataset(X_test, y_test, transform=transform)\n", "\n", "# DataLoader для обучения и тестирования\n", "train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)\n", "test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)\n", "\n" ] }, { "cell_type": "code", "execution_count": 318, "metadata": {}, "outputs": [], "source": [ "import torch.nn as nn\n", "import torch.nn.functional as F\n", "\n", "def compute_output_size(input_size, kernel_size, stride, padding):\n", " return (input_size - kernel_size + 2 * padding) // stride + 1\n", "\n", "class CNN(nn.Module):\n", " def __init__(self, num_classes):\n", " super(CNN, self).__init__()\n", " # Сверточные слои\n", " self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)\n", " self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)\n", " self.pool = nn.MaxPool2d(kernel_size=2, stride=2)\n", " \n", " # Вычисляем размеры после сверточных слоев\n", " height = compute_output_size(28, kernel_size=3, stride=1, padding=1) # conv1\n", " height = compute_output_size(height, kernel_size=3, stride=1, padding=1) # conv2\n", " height = compute_output_size(height, kernel_size=2, stride=2, padding=0) # pool\n", " width = height # Для квадратного входа\n", "\n", " # Полносвязные слои\n", " self.fc1 = nn.Linear(64 * height * width, 128)\n", " self.fc2 = nn.Linear(128, num_classes)\n", " self.dropout = nn.Dropout(0.5)\n", "\n", " def forward(self, x):\n", " x = F.relu(self.conv1(x))\n", " x = self.pool(F.relu(self.conv2(x)))\n", " x = x.view(x.size(0), -1)\n", " x = F.relu(self.fc1(x))\n", " x = self.dropout(x)\n", " x = self.fc2(x)\n", " return x\n", "\n", " \n", "\n", "\n", "\n", "# Инициализация модели\n", "num_classes = len(whitelist)\n", "model = CNN(num_classes)\n" ] }, { "cell_type": "code", "execution_count": 319, "metadata": {}, "outputs": [], "source": [ "import torch.optim as optim\n", "\n", "# Устройство (GPU или CPU)\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "model.to(device)\n", "\n", "# Определение функции потерь и оптимизатора\n", "criterion = nn.CrossEntropyLoss()\n", "optimizer = optim.Adam(model.parameters(), lr=0.001)\n", "\n", "from tqdm import tqdm\n", "import torch\n", "\n", "# Цикл обучения с tqdm\n", "def train(model, loader, optimizer, criterion, device):\n", " model.train()\n", " total_loss = 0\n", " correct = 0\n", " loop = tqdm(loader, desc=\"Обучение\", leave=False) # Прогресс-бар для обучения\n", " for images, labels in loop:\n", " images, labels = images.to(device), labels.to(device)\n", "\n", " # Обнуление градиентов\n", " optimizer.zero_grad()\n", "\n", " # Прямой проход\n", " outputs = model(images)\n", " loss = criterion(outputs, labels)\n", "\n", " # Обратный проход и обновление весов\n", " loss.backward()\n", " optimizer.step()\n", "\n", " # Логирование\n", " total_loss += loss.item()\n", " _, predicted = torch.max(outputs, 1)\n", " correct += (predicted == labels).sum().item()\n", "\n", " # Обновление tqdm\n", " loop.set_postfix(loss=loss.item())\n", "\n", " accuracy = 100 * correct / len(loader.dataset)\n", " return total_loss / len(loader), accuracy\n", "\n", "# Цикл оценки с tqdm\n", "def evaluate(model, loader, criterion, device):\n", " model.eval()\n", " total_loss = 0\n", " correct = 0\n", " loop = tqdm(loader, desc=\"Оценка\", leave=False) # Прогресс-бар для оценки\n", " with torch.no_grad():\n", " for images, labels in loop:\n", " images, labels = images.to(device), labels.to(device)\n", "\n", " # Прямой проход\n", " outputs = model(images)\n", " loss = criterion(outputs, labels)\n", "\n", " # Логирование\n", " total_loss += loss.item()\n", " _, predicted = torch.max(outputs, 1)\n", " correct += (predicted == labels).sum().item()\n", "\n", " # Обновление tqdm\n", " loop.set_postfix(loss=loss.item())\n", "\n", " accuracy = 100 * correct / len(loader.dataset)\n", " return total_loss / len(loader), accuracy\n", "\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": 320, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "device(type='cuda')" ] }, "execution_count": 320, "metadata": {}, "output_type": "execute_result" } ], "source": [ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "device" ] }, { "cell_type": "code", "execution_count": 321, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Эпоха 1/10\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Train Loss: 0.4385, Train Accuracy: 86.80%\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Test Loss: 0.1677, Test Accuracy: 93.46%\n", "Эпоха 2/10\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Train Loss: 0.2405, Train Accuracy: 91.25%\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Test Loss: 0.1499, Test Accuracy: 93.91%\n", "Эпоха 3/10\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Train Loss: 0.2105, Train Accuracy: 92.21%\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Test Loss: 0.1608, Test Accuracy: 93.23%\n", "Эпоха 4/10\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Train Loss: 0.1978, Train Accuracy: 92.65%\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Test Loss: 0.1394, Test Accuracy: 94.31%\n", "Эпоха 5/10\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Train Loss: 0.1887, Train Accuracy: 92.92%\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Test Loss: 0.1417, Test Accuracy: 94.46%\n", "Эпоха 6/10\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Train Loss: 0.1841, Train Accuracy: 93.05%\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Test Loss: 0.1473, Test Accuracy: 94.38%\n", "Эпоха 7/10\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Train Loss: 0.1794, Train Accuracy: 93.18%\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Test Loss: 0.1418, Test Accuracy: 94.52%\n", "Эпоха 8/10\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Train Loss: 0.1759, Train Accuracy: 93.30%\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Test Loss: 0.1432, Test Accuracy: 94.42%\n", "Эпоха 9/10\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Train Loss: 0.1728, Train Accuracy: 93.41%\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Test Loss: 0.1406, Test Accuracy: 94.53%\n", "Эпоха 10/10\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Train Loss: 0.1710, Train Accuracy: 93.46%\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " " ] }, { "name": "stdout", "output_type": "stream", "text": [ " Test Loss: 0.1397, Test Accuracy: 94.47%\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\r" ] } ], "source": [ "# Цикл обучения\n", "num_epochs = 10\n", "for epoch in range(num_epochs):\n", " print(f\"Эпоха {epoch + 1}/{num_epochs}\")\n", " \n", " # Обучение\n", " train_loss, train_acc = train(model, train_loader, optimizer, criterion, device)\n", " print(f\" Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.2f}%\")\n", "\n", " # Оценка\n", " test_loss, test_acc = evaluate(model, test_loader, criterion, device)\n", " print(f\" Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.2f}%\")\n" ] }, { "cell_type": "code", "execution_count": 322, "metadata": {}, "outputs": [], "source": [ "torch.save(model.state_dict(), \"letter_recognition_model1.pth\")\n" ] }, { "cell_type": "code", "execution_count": 337, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZQAAAGtCAYAAAAxsILFAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAc7UlEQVR4nO3daXiU5dmH8f+ThUAgGAUDKhAUMAiiiCweqBBErYhIkMUCIrgUVESqFXDB4ta6W1vbUqEuIGhkETAQWqBUseKGVDDSeFBWg0pZEsCQAEnu94NvUmImkLlzMRQ4f8fBl5m5nvuZJHrmYWZuAuecEwAA1RR1tE8AAHB8ICgAABMEBQBggqAAAEwQFACACYICADBBUAAAJggKAMAEQQEAmCAoEdK0aVMFQaAgCDR69OhDPvaZZ54pe2xMTEyEzvD4kpmZWfY1vPzyy6t1rD59+qhWrVrKyckpd3tqaqqCINDDDz9creNX1caNGxUEgZo2bRqR9aytXLlSzz77rAYOHKizzz5bUVFRCoJA06ZNq9L8/v379bvf/U6XXHKJTjnlFNWsWVONGjVSjx499NZbb5V7bHFxsVq2bKnk5GQVFBQciaeDEAjKUTB9+nTt37+/0vtfeeWVCJ7N8Sc3N1c/+9nPFARBtY+1ZMkSzZ07V3feeacaNWpkcHYnrkcffVRjxoxRenq61q5dq3B2fcrJydEFF1yg0aNH66uvvtLFF1+stLQ0JScna9myZZo5c2a5x0dHR+vxxx/X5s2b9fTTT1s/FVSCoERY+/bttWPHDs2bNy/k/cuXL1d2drY6dOgQ4TM7fowaNUpbt27VbbfdVu1j3X333apZs6buu+8+gzM7sV100UV64IEHNGvWLK1bt05du3at0lxBQYGuuOIKrVmzRg8//LC++eYbZWRkKD09XR988IG2bdum8ePHV5jr16+f2rRpo6eeekrfffed9dNBCAQlwm6++WZJlV+FvPzyy+Ueh/DMmTNH06dP1z333KOOHTtW61iLFy9WVlaW0tLSVK9ePaMzPHHdd999+tWvfqW+ffvqrLPOqvLcE088oezsbA0fPlwTJkxQbGxsufvj4+PVtm3bkLM333yzCgoKNGnSpOqcOqqIoERYmzZt1L59ey1atEhbtmwpd9/333+vGTNmqFGjRrryyisPeZyioiL9+c9/Vmpqqk455RTFxcXpzDPP1O23366vv/663GOHDRtW9npCVf6Ueu211xQEgYYNG1Zh/cWLFys+Pl61a9fW0qVLy933ySefaOzYserYsaMaNmyoGjVqqEGDBurVq5eWLFkS5les6rZv367bbrtNKSkpevTRR6t9vN///veSFPL5H8rBX7f8/Hzdf//9at68ueLi4tSwYUMNHTq0wvf+YPPnz1fXrl2VkJCgk046SZdeemmlV7QHy83N1YQJE9S2bVslJCQoPj5ebdq00eOPP669e/eWe+xzzz2nIAh09tlna8+ePRWONXnyZAVBoMaNG2v79u1hPX9LBw4c0MSJEyVJY8aMCXt+8ODBiomJ0UsvvaSioiLr08OPOUREcnKyk+Tef/9998c//tFJco8//ni5x7z88stOknvwwQfdhg0bnCQXHR1d4Vi7d+92qampTpKrU6eO69q1q+vXr59LSUlxkly9evXcypUryx4/efJkN3To0HJ/GjRo4CS5n/zkJxXuK/Xqq686SeVuc865RYsWuVq1arn4+Hi3dOnSCufXvXt3FxUV5dq0aeOuvvpq179/f9euXTsnyUlyL7zwQsivUdeuXZ0kN2HChKp/YQ/Sr18/FxUV5f7xj3+UO//u3buHfayCggIXFxfnYmNj3d69e8M639J109LS3HnnnecSExNdr169XO/evV1SUpKT5JKTk11eXl6FYz7//PNlX6eOHTu6gQMHuvbt2ztJ7p577imb/bEvv/zSNW7c2Elyp512mrvqqqtcr169yr7Pbdu2rbDetdde6yS5n/70p+Vu//zzz13NmjVdTEyM++CDD8rdV/pzKclt2LDh8F/IQyj9+r3++uuVPuajjz5yktzpp5/unHNu9erV7uGHH3bDhw9348aNc/Pnz3fFxcWHXKf067d8+fJqnS8Oj6BEyMFBycvLc7Vq1XLNmzcv95iLL77YBUHg1q1bd8igDBo0yEly11xzjdu6dWu5+37zm984Sa5FixauqKio0vMp/Y/573//e6WPCRWUg2NS2WxmZqb75ptvKty+fPlyV7duXRcbG+tycnIqPSefoLz55ptOkhs9enSF8/cJypIlS5wk16FDh0ofc7iglAZ7165dZfft3LnTtW3b1klyv/71r8vNrVq1ykVHR7uoqCg3c+bMcvdNmzbNBUEQMih79+51zZo1c5Lc+PHj3b59+8ruy8/PdwMHDnSS3E033VRuLjc31zVt2tRJchMnTnTO/fDLSosWLZwk98wzz1R4zpEOyqRJk8riOm7cuLKvwcF/LrjgArdp06ZKj3HXXXc5Se6xxx6r1vni8AhKhBwcFOecGzx4sJPk3n33Xeecc9nZ2U6SS01Ndc65SoOyZs0aFwSBO/30093u3btDrnX11Vc7SS4jI6PS8/EJSmlMateuXXbe4br//vudJPeHP/yhwn1DhgxxKSkp7sUXXwzrmN9++6075ZRTXLNmzVx+fn6F8/cJyjPPPOMkuRtvvLHSxxwuKLVr1w4Z1vT0dCfJXXbZZeVuv/XWW50kd/3114dcr3fv3iGDMnHixLJfMELZs2ePS0pKcjExMW7nzp3l7vvkk09cjRo1XFxcnPvnP//pBgwY4CS5Xr16uZKSkgrHysnJcSkpKS4lJSXkLwXhqEpQnnjiCSfJxcbGOklu5MiR7quvvnK7du1yixcvdmeffbaT5M4991y3f//+kMd46aWXnCTXp0+fap0vDo8PORwlN998s6ZPn65XXnlFXbt2LXuR/nAvxmdmZso5px49eighISHkY1JTU5WZmanly5frmmuuMTnfxYsXq3fv3iooKNCcOXMO+w6dHTt2aMGCBcrKylJubq4OHDggSVq7dq0k6auvvqowM3XqVK9zGz58uHJzczV79mzFx8d7HePHtm7dKknVejG+ffv2Ou200yrcfs4550hShddR3n33XUnSDTfcEPJ4Q4cODflayoIFCyRJ119/fci5OnXqqH379srMzNSnn35a7vW5Dh066Nlnn9Vdd92l1NRU7dq1S8nJyZoyZUrIt12fccYZys7ODrnOkeD+/63FBw4c0MCBA8te15Kkyy+/XIsXL1ZKSoqysrKUnp6uIUOGVDhG6few9HuKI4egHCXdunXTmWeeqVmzZumFF17Q1KlTVbduXfXr1++Qc+vXr5f0w7vBSt8RVplt27aZnOuKFSs0Y8aMsg+ITZs2TWlpaZU+fvLkybr77ruVn59f6WN2795tcm5TpkxRRkaGbr/9dqWmppocU5J27dolSapbt673MZo0aRLy9tJjFhYWlru99IOTZ555Zsi5ym4v/ZkYMmRIyP+hHizUz8SoUaM0f/58LVq0SEEQKD09XSeffPIhjxMpB//SNGLEiAr3N2nSRD179tTs2bO1ZMmSkM+/9Oudm5t75E4UkgjKUVP6LqAJEyZo6NCh+u677zR8+HDVqlXrkHMlJSWSpLZt2+r8888/5GM7depkcq5ffvml4uPjlZmZqfvuu0+zZ8/Wyy+/rFtuuaXCYz/77DONGDFC0dHReuqpp9SrVy81adJE8fHxCoJAkyZN0ogRI8L6UNuhzJkzR5L06aefVghK6WcPPvvss7L70tPT1bBhw8MeNzExUVL1whcVFZk3UZb+TFx11VVq0KDBIR+bnJxc4ba1a9fqww8/lPTDFcEnn3yiiy66yP5EPRz89uLK3mpcevu3334b8v7SXw7+VyJ5PCMoR9GwYcP0yCOPKCMjQ1LVPnvSuHFjSdLFF19c7vL/SIqPj1dGRoYuu+wyJScn68ILL9To0aPVpUsXtWjRotxjZ86cKeecRo0apbFjx1Y4VulfeVlbsWJFpffl5eXpvffek1TxqqAySUlJkn74q7tIOeOMM7Ru3Tpt3LhRrVu3rnD/xo0bQ841btxY2dnZuuWWWw57hftjhYWFGjBggPbs2aPBgwdr1qxZGjNmjDp37qz27dv7PA1T7dq1UxAEcs5p+/btZT//Byt9W3OdOnVCHqP0e3i42KL6+BzKUdSkSRP17t1b9erV00UXXVSlK4oePXpIkt55550q/8+xuvr166fLLrtMktSqVSs9++yzys/P16BBg8peGym1c+dOSaF/Ey4sLNTs2bNNz23u3LlyP7y5pMKfV199VZLUvXv3stuqug9Wu3btJElr1qwxPd9DKX1davr06SHvr+w1ptKfiRkzZoS95ujRo/X555+rW7dumjp1qp577jnt379fAwYMUF5eXtjHs9awYUNdcsklkhTyM0wHDhwo+2Whsg+yZmVlSZIuvPDCI3SWKEVQjrK3335b27dvL/srh8O54IIL1LdvX3399de67rrrQv7Wmp+fr+nTp5u9CPnjF2dHjhypnj17asWKFfrlL39Z7r7SF5ynTJlS7gNzhYWFuuOOO7Rhw4ZK17nxxhvVsmXLiF15HUrnzp0VFxenVatWRWxzwVGjRik6OlozZswo+6u8Uunp6Zo7d27IueHDhys5OVkzZ87UuHHjQn5Q8bvvvtPkyZPL3fbGG29o0qRJatCggd544w1FRUVp5MiR6tevnzZs2BDyinnLli1q2bKlWrZsecgPZ1qaMGGCpB8+Mf/RRx+V3V5UVKRf/OIXWr9+vRISEnTTTTeFnF++fLkklf1ShCPoaLy17ET047cNH87hPtjYvXt3J8nVqFHDdejQwQ0YMMD179/fdejQwdWoUcNJcv/6178qPb7v51BKbd261TVo0MBFRUWVewtxbm5u2XOtV6+eS0tLc3379nVJSUkuISHBjR49utJjVveDjZWdv8/bhp377wf/MjMzQ95/uLcNh3qOzv33exvqA4pPP/102ecrOnXq5AYNGuQ6dOjgJLm777670rmsrKyyz5QkJia6Ll26uEGDBrm0tDTXqlUrFwSBa9CgQdnjs7OzXZ06dVxUVJT729/+Vu5YeXl57qyzzgr5IdTqfA5l/vz5rlOnTmV/EhISnCTXrFmzcreH8thjjzlJLiYmxnXu3Nldd911Zc+3Vq1abv78+SHn/vOf/7iYmBh3+umnuwMHDoR1vggfVyjHoISEBC1atEhvvPGGLr/8cm3evFlz5szR0qVLVVBQoMGDB2vOnDlq1qzZETuHpKQkvfbaa3LOaciQIWXvoElMTNSKFSt0xx13KDExUQsXLtSHH36oK6+8UitXrqx0z6X/RXfeeaekH7ZSiZQxY8Zo3rx5uuSSS5SVlaV33nlHsbGxmjVrlu66665K51q3bq3Vq1fr6aef1jnnnKPVq1dr5syZ+vjjj1W7dm3de++9ZVc9BQUF6t+/v77//ns99NBDFX5zP+mkkzRjxgzFxcVp7Nix+vTTT02e27Zt2/Txxx+X/Sm9klq3bl2520MZP368/vrXv+qKK65Qdna2MjIyVFxcrGHDhmnlypXq2bNnyLlp06apqKhII0aM4J+CiIDAOaO32wDHGeeczjvvPK1du1Y5OTmqX7/+0T4lhME5p/PPP1///ve/tX79+iq9uw/VwxUKUIkgCPT8889r3759evLJJ4/26SBMs2bN0hdffKFx48YRkwjhCgU4jD59+ugvf/mL1q5dyz+ydYwoLi5W69atVVBQoOzs7MN+vgs2CAoAwAR/5QUAMEFQAAAmCAoAwARBAQCYqPInfUL92wgAgBNDVd6/xRUKAMAEQQEAmCAoAAATBAUAYIKgAABMEBQAgAmCAgAwQVAAACYICgDABEEBAJggKAAAEwQFAGCCoAAATFR5t2EAsBAdHe01V1xcbHwmsMYVCgDABEEBAJggKAAAEwQFAGCCoAAATBAUAIAJggIAMEFQAAAmCAoAwARBAQCYICgAABMEBQBggs0hAXipX7++19zgwYO95ubNmxf2zKZNm7zWcs55zZ3ouEIBAJggKAAAEwQFAGCCoAAATBAUAIAJggIAMEFQAAAmCAoAwARBAQCYICgAABMEBQBggqAAAEwQFACACXYbBk5wMTF+/xsYOnSo19yTTz7pNderV6+wZ3zPccuWLV5zJzquUAAAJggKAMAEQQEAmCAoAAATBAUAYIKgAABMEBQAgAmCAgAwQVAAACYICgDABEEBAJggKAAAEwQFAGCC3YaB40gQBGHPNG3a1GutLl26eM3t37/fa+79998Pe2bHjh1ea8EPVygAABMEBQBggqAAAEwQFACACYICADBBUAAAJggKAMAEQQEAmCAoAAATBAUAYIKgAABMEBQAgAmCAgAwwW7DwHHk1FNPDXvmwQcf9Frriiuu8Jr75ptvvOYyMjLCniksLPRaC364QgEAmCAoAAATBAUAYIKgAABMEBQAgAmCAgAwQVAAACYICgDABEEBAJggKAAAEwQFAGCCoAAATLA5JEKKi4sLe6Z+/fpea23fvt1rbt++fV5zx4Lo6GivuUsvvTTsmR49enit5fMzIknLli3zmtu8ebPXHCKHKxQAgAmCAgAwQVAAACYICgDABEEBAJggKAAAEwQFAGCCoAAATBAUAIAJggIAMEFQAAAmCAoAwARBAQCYCJxzrkoPDIIjfS44ApKSkrzmnnrqqbBnrr32Wq+15s2b5zV36623es2VlJR4zUXS+eef7zU3ZcqUsGfatGnjtdYXX3zhNde5c2evub1793rNwUZVUsEVCgDABEEBAJggKAAAEwQFAGCCoAAATBAUAIAJggIAMEFQAAAmCAoAwARBAQCYICgAABMEBQBggqAAAEzEHO0TQNX47vbco0ePiM0lJiZ6rXXuued6zR0LO2DHxcV5zd17771ecykpKWHPFBcXe621ePFirzl2DT5+cYUCADBBUAAAJggKAMAEQQEAmCAoAAATBAUAYIKgAABMEBQAgAmCAgAwQVAAACYICgDABEEBAJggKAAAE+w2fIw49dRTveYefPDBiK3nu2vte++95zXnu14knXHGGV5zaWlpXnM+uxuvX7/ea61ly5Z5zeH4xRUKAMAEQQEAmCAoAAATBAUAYIKgAABMEBQAgAmCAgAwQVAAACYICgDABEEBAJggKAAAEwQFAGCCzSEjrGnTpl5zjz32mNdc8+bNveZKSkrCnlmwYIHXWhMnTvSaiyTf79sjjzziNVe7dm2vufz8/LBnfL/+Cxcu9JrD8YsrFACACYICADBBUAAAJggKAMAEQQEAmCAoAAATBAUAYIKgAABMEBQAgAmCAgAwQVAAACYICgDABEEBAJhgt+FqiIkJ/8vXp08fr7XS0tK85nzt3r077Jn58+d7rbVlyxavOV8+37e+fft6reX7ffPZ7VmS3n333bBnfL9vzjmvORy/uEIBAJggKAAAEwQFAGCCoAAATBAUAIAJggIAMEFQAAAmCAoAwARBAQCYICgAABMEBQBggqAAAEwQFACACXYbjrC6det6zcXFxRmfyaHFx8eHPeO7s25iYqLXnO8uuQkJCWHPdOnSxWutWrVqec358nluAwcO9Fpr1apVXnPvv/++19yOHTvCnmFH5MjiCgUAYIKgAABMEBQAgAmCAgAwQVAAACYICgDABEEBAJggKAAAEwQFAGCCoAAATBAUAIAJggIAMEFQAAAmAlfF7TiDIDjS53LMiYkJf7Pm+++/32utBx54wGsukrsUl5SUeM3t27fPay4nJ8drLjY2NuyZ0047zWutSO8S7fM9KC4u9lorLy/Pa+7tt9/2mhs/fnzYMz47FCO0qqSCKxQAgAmCAgAwQVAAACYICgDABEEBAJggKAAAEwQFAGCCoAAATBAUAIAJggIAMEFQAAAmCAoAwET4uxuiTFFRUdgzEydO9Fpr8+bNXnMPPfSQ15zPZog+my5K/hsotmjRwmvuWODzsyVVbQO/H/Pd5PHrr7/2mlu4cKHX3Pfff+81h8jhCgUAYIKgAABMEBQAgAmCAgAwQVAAACYICgDABEEBAJggKAAAEwQFAGCCoAAATBAUAIAJggIAMEFQAAAmAlfF7UmDIDjS54JD8N2Rt3Xr1l5zXbt2DXsmMTHRay3fc+zWrZvX3Mknn+w152Pjxo1ec6+//rrXnM8uxVlZWV5rrVq1ymtu06ZNXnPFxcVec7BRlVRwhQIAMEFQAAAmCAoAwARBAQCYICgAABMEBQBggqAAAEwQFACACYICADBBUAAAJggKAMAEQQEAmCAoAAAT7DaMkGJiYsKe8f0ZqVevntfcm2++6TXXpUuXsGd8d7r97W9/6zX3wAMPeM1V8T/nckpKSrzW8p3DsYndhgEAEUNQAAAmCAoAwARBAQCYICgAABMEBQBggqAAAEwQFACACYICADBBUAAAJggKAMAEQQEAmCAoAAAT4W8pixNCUVFRxNZKSkrymuvUqZPXnM+OvAsXLvRaa+LEiV5zBw4c8JoDjiauUAAAJggKAMAEQQEAmCAoAAATBAUAYIKgAABMEBQAgAmCAgAwQVAAACYICgDABEEBAJggKAAAE2wOCTPx8fFec9ddd53XXM2aNb3m8vPzw55ZtmyZ11qbN2/2mgOORVyhAABMEBQAgAmCAgAwQVAAACYICgDABEEBAJggKAAAEwQFAGCCoAAATBAUAIAJggIAMEFQAAAmCAoAwAS7DSOkqKjwf9fo3r2711o33HCD11xxcbHX3JIlS8KemTNnjtdaRUVFXnPAsYgrFACACYICADBBUAAAJggKAMAEQQEAmCAoAAATBAUAYIKgAABMEBQAgAmCAgAwQVAAACYICgDABEEBAJhgt2GEFB0dHfZMu3btvNZq3Lix19zOnTu95qZPnx72zKZNm7zWAk4kXKEAAEwQFACACYICADBBUAAAJggKAMAEQQEAmCAoAAATBAUAYIKgAABMEBQAgAmCAgAwQVAAACYICgDABLsNH+eCIPCaa9WqVdgzffr08VorNjbWa27Lli1ec6tXrw57pk6dOl5r5efne80VFRV5zQFHE1coAAATBAUAYIKgAABMEBQAgAmCAgAwQVAAACYICgDABEEBAJggKAAAEwQFAGCCoAAATBAUAIAJNoc8ztWrV89r7uc//3nYMykpKV5r+SopKfGau+mmm8KeKSws9Fpr0qRJXnPffvut1xxwNHGFAgAwQVAAACYICgDABEEBAJggKAAAEwQFAGCCoAAATBAUAIAJggIAMEFQAAAmCAoAwARBAQCYICgAABPsNnycS0xM9Jrr2LFj2DNxcXFeawVB4DXXtm1br7lzzjkn7JkXX3zRa628vDyvOeBYxBUKAMAEQQEAmCAoAAATBAUAYIKgAABMEBQAgAmCAgAwQVAAACYICgDABEEBAJggKAAAEwQFAGCCoAAATLDb8HEuJyfHa+6tt94Ke+bee+/1WqtmzZpeczt37vSaW7hwYdgzf/rTn7zWKigo8JoDjkVcoQAATBAUAIAJggIAMEFQAAAmCAoAwARBAQCYICgAABMEBQBggqAAAEwQFACACYICADBBUAAAJggKAMAEuw0f5woLC73mpk6dGvbMnj17vNaqW7eu19zq1au95j744IOwZ7Zt2+a1FnAi4QoFAGCCoAAATBAUAIAJggIAMEFQAAAmCAoAwARBAQCYICgAABMEBQBggqAAAEwQFACACYICADAROOdclR4YBEf6XHCMi4mJ7F6jJSUlEZ0DTmRVSQVXKAAAEwQFAGCCoAAATBAUAIAJggIAMEFQAAAmCAoAwARBAQCYICgAABMEBQBggqAAAEwQFACACYICADDBbsMAgMNit2EAQMQQFACACYICADBBUAAAJggKAMAEQQEAmCAoAAATBAUAYIKgAABMEBQAgAmCAgAwQVAAACYICgDARExVH1jFTYkBACcorlAAACYICgDABEEBAJggKAAAEwQFAGCCoAAATBAUAIAJggIAMEFQAAAm/g+EF0MJ5N2MfgAAAABJRU5ErkJggg==", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "import random\n", "\n", "def visualize_test_sample(X_test, y_test, index, whitelist):\n", " \"\"\"\n", " Визуализация тестового примера с его меткой.\n", " \n", " :param X_test: Тензор с тестовыми изображениями.\n", " :param y_test: Тензор с метками тестовых изображений.\n", " :param index: Индекс примера для отображения.\n", " :param whitelist: Список символов для сопоставления меток.\n", " \"\"\"\n", " image = X_test[index] # Убираем канал и переводим в numpy\n", " label_idx = y_test[index].item() # Получаем индекс метки\n", " label_char = whitelist[label_idx] # Соответствующий символ из whitelist\n", "\n", " # Отображение изображения с меткой\n", " plt.figure(figsize=(5, 5))\n", " plt.imshow(image, cmap='gray')\n", " plt.title(f\"Метка: {label_char} (Index: {label_idx})\", fontsize=16)\n", " plt.axis('off')\n", " plt.show()\n", "\n", "\n", "# Пример вызова\n", "visualize_test_sample(X_test, y_test, index=random.choice([i for i in range(len(X_test))]), whitelist=whitelist)\n" ] }, { "cell_type": "code", "execution_count": 216, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "10000" ] }, "execution_count": 216, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(X_test_all[2])" ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.13" } }, "nbformat": 4, "nbformat_minor": 2 }