683 lines
30 KiB
Plaintext
683 lines
30 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 316,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"18\n",
|
||
"{'A': 0, 'B': 1, 'C': 2, 'E': 3, 'H': 4, 'K': 5, 'M': 6, 'O': 7, 'P': 8, 'T': 9, 'X': 10, 'y': 11, '0': 12, '1': 13, '2': 14, '3': 15, '4': 16, '5': 17, '6': 18, '7': 19, '8': 20, '9': 21}\n",
|
||
"697932\n",
|
||
"433256\n",
|
||
"433256\n",
|
||
"433256\n",
|
||
"{'A': 0, 'B': 1, 'C': 2, 'E': 3, 'H': 4, 'K': 5, 'M': 6, 'O': 7, 'P': 8, 'T': 9, 'X': 10, 'y': 11, '0': 12, '1': 13, '2': 14, '3': 15, '4': 16, '5': 17, '6': 18, '7': 19, '8': 20, '9': 21}\n",
|
||
"116323\n",
|
||
"72548\n",
|
||
"72548\n",
|
||
"72548\n",
|
||
"Размер обучающего набора: (433256, 28, 28), Метки: (433256,)\n",
|
||
"Размер тестового набора: (72548, 28, 28), Метки: (72548,)\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"import numpy as np\n",
|
||
"from emnist import extract_training_samples, extract_test_samples\n",
|
||
"\n",
|
||
"# Ваш whitelist символов\n",
|
||
"whitelist = 'ABCEHKMOPTXy0123456789'\n",
|
||
"\n",
|
||
"# Список всех классов в EMNIST\n",
|
||
"emnist_classes = (\n",
|
||
" \"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ\"\n",
|
||
" \"abcdefghijklmnopqrstuvwxyz\"\n",
|
||
")\n",
|
||
"\n",
|
||
"\n",
|
||
"# Функция для фильтрации данных по whitelist\n",
|
||
"def filter_data(X, y, whitelist, emnist_classes):\n",
|
||
" \"\"\"\n",
|
||
" Фильтрует данные на основе whitelist символов.\n",
|
||
" \"\"\"\n",
|
||
" # Создаём словарь для быстрого сопоставления: символ -> индекс в whitelist\n",
|
||
" whitelist_mapping = {char: idx for idx, char in enumerate(whitelist)}\n",
|
||
" print(whitelist_mapping)\n",
|
||
" print(len(y))\n",
|
||
" # Создаём список индексов, которые соответствуют whitelist\n",
|
||
" filtered_indices = [\n",
|
||
" i for i, label in enumerate(y) if emnist_classes[label] in whitelist_mapping\n",
|
||
" ]\n",
|
||
" print(len(filtered_indices))\n",
|
||
" # Отфильтрованные данные\n",
|
||
" X_filtered = X[filtered_indices]\n",
|
||
" y_filtered = np.array([whitelist_mapping[emnist_classes[label]] for label in y[filtered_indices]])\n",
|
||
" print(len(X_filtered))\n",
|
||
" print(len(y_filtered))\n",
|
||
" return X_filtered, y_filtered\n",
|
||
"\n",
|
||
"# Загрузка данных из подмножества 'byclass'\n",
|
||
"X_train_byclass, y_train_byclass = extract_training_samples('byclass')\n",
|
||
"X_test_byclass, y_test_byclass = extract_test_samples('byclass')\n",
|
||
"print(y_test_byclass[0])\n",
|
||
"# Фильтрация данных\n",
|
||
"X_train, y_train = filter_data(X_train_byclass, y_train_byclass, whitelist, emnist_classes)\n",
|
||
"X_test, y_test = filter_data(X_test_byclass, y_test_byclass, whitelist, emnist_classes)\n",
|
||
"\n",
|
||
"# Проверка размеров наборов данных\n",
|
||
"print(f\"Размер обучающего набора: {X_train.shape}, Метки: {y_train.shape}\")\n",
|
||
"print(f\"Размер тестового набора: {X_test.shape}, Метки: {y_test.shape}\")\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 317,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"class EMNISTDataset(Dataset):\n",
|
||
" def __init__(self, X, y, transform=None):\n",
|
||
" self.X = torch.tensor(X, dtype=torch.float32).unsqueeze(1) # Добавляем канал\n",
|
||
" self.y = torch.tensor(y, dtype=torch.long)\n",
|
||
" self.transform = transform\n",
|
||
"\n",
|
||
" def __len__(self):\n",
|
||
" return len(self.X)\n",
|
||
"\n",
|
||
" def __getitem__(self, idx):\n",
|
||
" image = self.X[idx]\n",
|
||
" label = self.y[idx]\n",
|
||
" if self.transform:\n",
|
||
" image = self.transform(image)\n",
|
||
" return image, label\n",
|
||
"\n",
|
||
"# Нормализация данных\n",
|
||
"transform = transforms.Compose([\n",
|
||
" transforms.Normalize((0.5,), (0.5,)) # Нормализация в диапазон [-1, 1]\n",
|
||
"])\n",
|
||
"\n",
|
||
"# Создание датасетов\n",
|
||
"train_dataset = EMNISTDataset(X_train, y_train, transform=transform)\n",
|
||
"test_dataset = EMNISTDataset(X_test, y_test, transform=transform)\n",
|
||
"\n",
|
||
"# DataLoader для обучения и тестирования\n",
|
||
"train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)\n",
|
||
"test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)\n",
|
||
"\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 318,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import torch.nn as nn\n",
|
||
"import torch.nn.functional as F\n",
|
||
"\n",
|
||
"def compute_output_size(input_size, kernel_size, stride, padding):\n",
|
||
" return (input_size - kernel_size + 2 * padding) // stride + 1\n",
|
||
"\n",
|
||
"class CNN(nn.Module):\n",
|
||
" def __init__(self, num_classes):\n",
|
||
" super(CNN, self).__init__()\n",
|
||
" # Сверточные слои\n",
|
||
" self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)\n",
|
||
" self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)\n",
|
||
" self.pool = nn.MaxPool2d(kernel_size=2, stride=2)\n",
|
||
" \n",
|
||
" # Вычисляем размеры после сверточных слоев\n",
|
||
" height = compute_output_size(28, kernel_size=3, stride=1, padding=1) # conv1\n",
|
||
" height = compute_output_size(height, kernel_size=3, stride=1, padding=1) # conv2\n",
|
||
" height = compute_output_size(height, kernel_size=2, stride=2, padding=0) # pool\n",
|
||
" width = height # Для квадратного входа\n",
|
||
"\n",
|
||
" # Полносвязные слои\n",
|
||
" self.fc1 = nn.Linear(64 * height * width, 128)\n",
|
||
" self.fc2 = nn.Linear(128, num_classes)\n",
|
||
" self.dropout = nn.Dropout(0.5)\n",
|
||
"\n",
|
||
" def forward(self, x):\n",
|
||
" x = F.relu(self.conv1(x))\n",
|
||
" x = self.pool(F.relu(self.conv2(x)))\n",
|
||
" x = x.view(x.size(0), -1)\n",
|
||
" x = F.relu(self.fc1(x))\n",
|
||
" x = self.dropout(x)\n",
|
||
" x = self.fc2(x)\n",
|
||
" return x\n",
|
||
"\n",
|
||
" \n",
|
||
"\n",
|
||
"\n",
|
||
"\n",
|
||
"# Инициализация модели\n",
|
||
"num_classes = len(whitelist)\n",
|
||
"model = CNN(num_classes)\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 319,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import torch.optim as optim\n",
|
||
"\n",
|
||
"# Устройство (GPU или CPU)\n",
|
||
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
||
"model.to(device)\n",
|
||
"\n",
|
||
"# Определение функции потерь и оптимизатора\n",
|
||
"criterion = nn.CrossEntropyLoss()\n",
|
||
"optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
|
||
"\n",
|
||
"from tqdm import tqdm\n",
|
||
"import torch\n",
|
||
"\n",
|
||
"# Цикл обучения с tqdm\n",
|
||
"def train(model, loader, optimizer, criterion, device):\n",
|
||
" model.train()\n",
|
||
" total_loss = 0\n",
|
||
" correct = 0\n",
|
||
" loop = tqdm(loader, desc=\"Обучение\", leave=False) # Прогресс-бар для обучения\n",
|
||
" for images, labels in loop:\n",
|
||
" images, labels = images.to(device), labels.to(device)\n",
|
||
"\n",
|
||
" # Обнуление градиентов\n",
|
||
" optimizer.zero_grad()\n",
|
||
"\n",
|
||
" # Прямой проход\n",
|
||
" outputs = model(images)\n",
|
||
" loss = criterion(outputs, labels)\n",
|
||
"\n",
|
||
" # Обратный проход и обновление весов\n",
|
||
" loss.backward()\n",
|
||
" optimizer.step()\n",
|
||
"\n",
|
||
" # Логирование\n",
|
||
" total_loss += loss.item()\n",
|
||
" _, predicted = torch.max(outputs, 1)\n",
|
||
" correct += (predicted == labels).sum().item()\n",
|
||
"\n",
|
||
" # Обновление tqdm\n",
|
||
" loop.set_postfix(loss=loss.item())\n",
|
||
"\n",
|
||
" accuracy = 100 * correct / len(loader.dataset)\n",
|
||
" return total_loss / len(loader), accuracy\n",
|
||
"\n",
|
||
"# Цикл оценки с tqdm\n",
|
||
"def evaluate(model, loader, criterion, device):\n",
|
||
" model.eval()\n",
|
||
" total_loss = 0\n",
|
||
" correct = 0\n",
|
||
" loop = tqdm(loader, desc=\"Оценка\", leave=False) # Прогресс-бар для оценки\n",
|
||
" with torch.no_grad():\n",
|
||
" for images, labels in loop:\n",
|
||
" images, labels = images.to(device), labels.to(device)\n",
|
||
"\n",
|
||
" # Прямой проход\n",
|
||
" outputs = model(images)\n",
|
||
" loss = criterion(outputs, labels)\n",
|
||
"\n",
|
||
" # Логирование\n",
|
||
" total_loss += loss.item()\n",
|
||
" _, predicted = torch.max(outputs, 1)\n",
|
||
" correct += (predicted == labels).sum().item()\n",
|
||
"\n",
|
||
" # Обновление tqdm\n",
|
||
" loop.set_postfix(loss=loss.item())\n",
|
||
"\n",
|
||
" accuracy = 100 * correct / len(loader.dataset)\n",
|
||
" return total_loss / len(loader), accuracy\n",
|
||
"\n",
|
||
"\n",
|
||
"\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 320,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"device(type='cuda')"
|
||
]
|
||
},
|
||
"execution_count": 320,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
||
"device"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 321,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Эпоха 1/10\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" Train Loss: 0.4385, Train Accuracy: 86.80%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" Test Loss: 0.1677, Test Accuracy: 93.46%\n",
|
||
"Эпоха 2/10\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" Train Loss: 0.2405, Train Accuracy: 91.25%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" Test Loss: 0.1499, Test Accuracy: 93.91%\n",
|
||
"Эпоха 3/10\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" Train Loss: 0.2105, Train Accuracy: 92.21%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" Test Loss: 0.1608, Test Accuracy: 93.23%\n",
|
||
"Эпоха 4/10\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" Train Loss: 0.1978, Train Accuracy: 92.65%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" Test Loss: 0.1394, Test Accuracy: 94.31%\n",
|
||
"Эпоха 5/10\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" Train Loss: 0.1887, Train Accuracy: 92.92%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" Test Loss: 0.1417, Test Accuracy: 94.46%\n",
|
||
"Эпоха 6/10\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" Train Loss: 0.1841, Train Accuracy: 93.05%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" Test Loss: 0.1473, Test Accuracy: 94.38%\n",
|
||
"Эпоха 7/10\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" Train Loss: 0.1794, Train Accuracy: 93.18%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" Test Loss: 0.1418, Test Accuracy: 94.52%\n",
|
||
"Эпоха 8/10\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" Train Loss: 0.1759, Train Accuracy: 93.30%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" Test Loss: 0.1432, Test Accuracy: 94.42%\n",
|
||
"Эпоха 9/10\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" Train Loss: 0.1728, Train Accuracy: 93.41%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" Test Loss: 0.1406, Test Accuracy: 94.53%\n",
|
||
"Эпоха 10/10\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" Train Loss: 0.1710, Train Accuracy: 93.46%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" "
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" Test Loss: 0.1397, Test Accuracy: 94.47%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"\r"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# Цикл обучения\n",
|
||
"num_epochs = 10\n",
|
||
"for epoch in range(num_epochs):\n",
|
||
" print(f\"Эпоха {epoch + 1}/{num_epochs}\")\n",
|
||
" \n",
|
||
" # Обучение\n",
|
||
" train_loss, train_acc = train(model, train_loader, optimizer, criterion, device)\n",
|
||
" print(f\" Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.2f}%\")\n",
|
||
"\n",
|
||
" # Оценка\n",
|
||
" test_loss, test_acc = evaluate(model, test_loader, criterion, device)\n",
|
||
" print(f\" Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.2f}%\")\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 322,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"torch.save(model.state_dict(), \"letter_recognition_model1.pth\")\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 337,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZQAAAGtCAYAAAAxsILFAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAc7UlEQVR4nO3daXiU5dmH8f+ThUAgGAUDKhAUMAiiiCweqBBErYhIkMUCIrgUVESqFXDB4ta6W1vbUqEuIGhkETAQWqBUseKGVDDSeFBWg0pZEsCQAEnu94NvUmImkLlzMRQ4f8fBl5m5nvuZJHrmYWZuAuecEwAA1RR1tE8AAHB8ICgAABMEBQBggqAAAEwQFACACYICADBBUAAAJggKAMAEQQEAmCAoEdK0aVMFQaAgCDR69OhDPvaZZ54pe2xMTEyEzvD4kpmZWfY1vPzyy6t1rD59+qhWrVrKyckpd3tqaqqCINDDDz9creNX1caNGxUEgZo2bRqR9aytXLlSzz77rAYOHKizzz5bUVFRCoJA06ZNq9L8/v379bvf/U6XXHKJTjnlFNWsWVONGjVSjx499NZbb5V7bHFxsVq2bKnk5GQVFBQciaeDEAjKUTB9+nTt37+/0vtfeeWVCJ7N8Sc3N1c/+9nPFARBtY+1ZMkSzZ07V3feeacaNWpkcHYnrkcffVRjxoxRenq61q5dq3B2fcrJydEFF1yg0aNH66uvvtLFF1+stLQ0JScna9myZZo5c2a5x0dHR+vxxx/X5s2b9fTTT1s/FVSCoERY+/bttWPHDs2bNy/k/cuXL1d2drY6dOgQ4TM7fowaNUpbt27VbbfdVu1j3X333apZs6buu+8+gzM7sV100UV64IEHNGvWLK1bt05du3at0lxBQYGuuOIKrVmzRg8//LC++eYbZWRkKD09XR988IG2bdum8ePHV5jr16+f2rRpo6eeekrfffed9dNBCAQlwm6++WZJlV+FvPzyy+Ueh/DMmTNH06dP1z333KOOHTtW61iLFy9WVlaW0tLSVK9ePaMzPHHdd999+tWvfqW+ffvqrLPOqvLcE088oezsbA0fPlwTJkxQbGxsufvj4+PVtm3bkLM333yzCgoKNGnSpOqcOqqIoERYmzZt1L59ey1atEhbtmwpd9/333+vGTNmqFGjRrryyisPeZyioiL9+c9/Vmpqqk455RTFxcXpzDPP1O23366vv/663GOHDRtW9npCVf6Ueu211xQEgYYNG1Zh/cWLFys+Pl61a9fW0qVLy933ySefaOzYserYsaMaNmyoGjVqqEGDBurVq5eWLFkS5les6rZv367bbrtNKSkpevTRR6t9vN///veSFPL5H8rBX7f8/Hzdf//9at68ueLi4tSwYUMNHTq0wvf+YPPnz1fXrl2VkJCgk046SZdeemmlV7QHy83N1YQJE9S2bVslJCQoPj5ebdq00eOPP669e/eWe+xzzz2nIAh09tlna8+ePRWONXnyZAVBoMaNG2v79u1hPX9LBw4c0MSJEyVJY8aMCXt+8ODBiomJ0UsvvaSioiLr08OPOUREcnKyk+Tef/9998c//tFJco8//ni5x7z88stOknvwwQfdhg0bnCQXHR1d4Vi7d+92qampTpKrU6eO69q1q+vXr59LSUlxkly9evXcypUryx4/efJkN3To0HJ/GjRo4CS5n/zkJxXuK/Xqq686SeVuc865RYsWuVq1arn4+Hi3dOnSCufXvXt3FxUV5dq0aeOuvvpq179/f9euXTsnyUlyL7zwQsivUdeuXZ0kN2HChKp/YQ/Sr18/FxUV5f7xj3+UO//u3buHfayCggIXFxfnYmNj3d69e8M639J109LS3HnnnecSExNdr169XO/evV1SUpKT5JKTk11eXl6FYz7//PNlX6eOHTu6gQMHuvbt2ztJ7p577imb/bEvv/zSNW7c2Elyp512mrvqqqtcr169yr7Pbdu2rbDetdde6yS5n/70p+Vu//zzz13NmjVdTEyM++CDD8rdV/pzKclt2LDh8F/IQyj9+r3++uuVPuajjz5yktzpp5/unHNu9erV7uGHH3bDhw9348aNc/Pnz3fFxcWHXKf067d8+fJqnS8Oj6BEyMFBycvLc7Vq1XLNmzcv95iLL77YBUHg1q1bd8igDBo0yEly11xzjdu6dWu5+37zm984Sa5FixauqKio0vMp/Y/573//e6WPCRWUg2NS2WxmZqb75ptvKty+fPlyV7duXRcbG+tycnIqPSefoLz55ptOkhs9enSF8/cJypIlS5wk16FDh0ofc7iglAZ7165dZfft3LnTtW3b1klyv/71r8vNrVq1ykVHR7uoqCg3c+bMcvdNmzbNBUEQMih79+51zZo1c5Lc+PHj3b59+8ruy8/PdwMHDnSS3E033VRuLjc31zVt2tRJchMnTnTO/fDLSosWLZwk98wzz1R4zpEOyqRJk8riOm7cuLKvwcF/LrjgArdp06ZKj3HXXXc5Se6xxx6r1vni8AhKhBwcFOecGzx4sJPk3n33Xeecc9nZ2U6SS01Ndc65SoOyZs0aFwSBO/30093u3btDrnX11Vc7SS4jI6PS8/EJSmlMateuXXbe4br//vudJPeHP/yhwn1DhgxxKSkp7sUXXwzrmN9++6075ZRTXLNmzVx+fn6F8/cJyjPPPOMkuRtvvLHSxxwuKLVr1w4Z1vT0dCfJXXbZZeVuv/XWW50kd/3114dcr3fv3iGDMnHixLJfMELZs2ePS0pKcjExMW7nzp3l7vvkk09cjRo1XFxcnPvnP//pBgwY4CS5Xr16uZKSkgrHysnJcSkpKS4lJSXkLwXhqEpQnnjiCSfJxcbGOklu5MiR7quvvnK7du1yixcvdmeffbaT5M4991y3f//+kMd46aWXnCTXp0+fap0vDo8PORwlN998s6ZPn65XXnlFXbt2LXuR/nAvxmdmZso5px49eighISHkY1JTU5WZmanly5frmmuuMTnfxYsXq3fv3iooKNCcOXMO+w6dHTt2aMGCBcrKylJubq4OHDggSVq7dq0k6auvvqowM3XqVK9zGz58uHJzczV79mzFx8d7HePHtm7dKknVejG+ffv2Ou200yrcfs4550hShddR3n33XUnSDTfcEPJ4Q4cODflayoIFCyRJ119/fci5OnXqqH379srMzNSnn35a7vW5Dh066Nlnn9Vdd92l1NRU7dq1S8nJyZoyZUrIt12fccYZys7ODrnOkeD+/63FBw4c0MCBA8te15Kkyy+/XIsXL1ZKSoqysrKUnp6uIUOGVDhG6few9HuKI4egHCXdunXTmWeeqVmzZumFF17Q1KlTVbduXfXr1++Qc+vXr5f0w7vBSt8RVplt27aZnOuKFSs0Y8aMsg+ITZs2TWlpaZU+fvLkybr77ruVn59f6WN2795tcm5TpkxRRkaGbr/9dqWmppocU5J27dolSapbt673MZo0aRLy9tJjFhYWlru99IOTZ555Zsi5ym4v/ZkYMmRIyP+hHizUz8SoUaM0f/58LVq0SEEQKD09XSeffPIhjxMpB//SNGLEiAr3N2nSRD179tTs2bO1ZMmSkM+/9Oudm5t75E4UkgjKUVP6LqAJEyZo6NCh+u677zR8+HDVqlXrkHMlJSWSpLZt2+r8888/5GM7depkcq5ffvml4uPjlZmZqfvuu0+zZ8/Wyy+/rFtuuaXCYz/77DONGDFC0dHReuqpp9SrVy81adJE8fHxCoJAkyZN0ogRI8L6UNuhzJkzR5L06aefVghK6WcPPvvss7L70tPT1bBhw8MeNzExUVL1whcVFZk3UZb+TFx11VVq0KDBIR+bnJxc4ba1a9fqww8/lPTDFcEnn3yiiy66yP5EPRz89uLK3mpcevu3334b8v7SXw7+VyJ5PCMoR9GwYcP0yCOPKCMjQ1LVPnvSuHFjSdLFF19c7vL/SIqPj1dGRoYuu+wyJScn68ILL9To0aPVpUsXtWjRotxjZ86cKeecRo0apbFjx1Y4VulfeVlbsWJFpffl5eXpvffek1TxqqAySUlJkn74q7tIOeOMM7Ru3Tpt3LhRrVu3rnD/xo0bQ841btxY2dnZuuWWWw57hftjhYWFGjBggPbs2aPBgwdr1qxZGjNmjDp37qz27dv7PA1T7dq1UxAEcs5p+/btZT//Byt9W3OdOnVCHqP0e3i42KL6+BzKUdSkSRP17t1b9erV00UXXVSlK4oePXpIkt55550q/8+xuvr166fLLrtMktSqVSs9++yzys/P16BBg8peGym1c+dOSaF/Ey4sLNTs2bNNz23u3LlyP7y5pMKfV199VZLUvXv3stuqug9Wu3btJElr1qwxPd9DKX1davr06SHvr+w1ptKfiRkzZoS95ujRo/X555+rW7dumjp1qp577jnt379fAwYMUF5eXtjHs9awYUNdcsklkhTyM0wHDhwo+2Whsg+yZmVlSZIuvPDCI3SWKEVQjrK3335b27dvL/srh8O54IIL1LdvX3399de67rrrQv7Wmp+fr+nTp5u9CPnjF2dHjhypnj17asWKFfrlL39Z7r7SF5ynTJlS7gNzhYWFuuOOO7Rhw4ZK17nxxhvVsmXLiF15HUrnzp0VFxenVatWRWxzwVGjRik6OlozZswo+6u8Uunp6Zo7d27IueHDhys5OVkzZ87UuHHjQn5Q8bvvvtPkyZPL3fbGG29o0qRJatCggd544w1FRUVp5MiR6tevnzZs2BDyinnLli1q2bKlWrZsecgPZ1qaMGGCpB8+Mf/RRx+V3V5UVKRf/OIXWr9+vRISEnTTTTeFnF++fLkklf1ShCPoaLy17ET047cNH87hPtjYvXt3J8nVqFHDdejQwQ0YMMD179/fdejQwdWoUcNJcv/6178qPb7v51BKbd261TVo0MBFRUWVewtxbm5u2XOtV6+eS0tLc3379nVJSUkuISHBjR49utJjVveDjZWdv8/bhp377wf/MjMzQ95/uLcNh3qOzv33exvqA4pPP/102ecrOnXq5AYNGuQ6dOjgJLm777670rmsrKyyz5QkJia6Ll26uEGDBrm0tDTXqlUrFwSBa9CgQdnjs7OzXZ06dVxUVJT729/+Vu5YeXl57qyzzgr5IdTqfA5l/vz5rlOnTmV/EhISnCTXrFmzcreH8thjjzlJLiYmxnXu3Nldd911Zc+3Vq1abv78+SHn/vOf/7iYmBh3+umnuwMHDoR1vggfVyjHoISEBC1atEhvvPGGLr/8cm3evFlz5szR0qVLVVBQoMGDB2vOnDlq1qzZETuHpKQkvfbaa3LOaciQIWXvoElMTNSKFSt0xx13KDExUQsXLtSHH36oK6+8UitXrqx0z6X/RXfeeaekH7ZSiZQxY8Zo3rx5uuSSS5SVlaV33nlHsbGxmjVrlu66665K51q3bq3Vq1fr6aef1jnnnKPVq1dr5syZ+vjjj1W7dm3de++9ZVc9BQUF6t+/v77//ns99NBDFX5zP+mkkzRjxgzFxcVp7Nix+vTTT02e27Zt2/Txxx+X/Sm9klq3bl2520MZP368/vrXv+qKK65Qdna2MjIyVFxcrGHDhmnlypXq2bNnyLlp06apqKhII0aM4J+CiIDAOaO32wDHGeeczjvvPK1du1Y5OTmqX7/+0T4lhME5p/PPP1///ve/tX79+iq9uw/VwxUKUIkgCPT8889r3759evLJJ4/26SBMs2bN0hdffKFx48YRkwjhCgU4jD59+ugvf/mL1q5dyz+ydYwoLi5W69atVVBQoOzs7MN+vgs2CAoAwAR/5QUAMEFQAAAmCAoAwARBAQCYqPInfUL92wgAgBNDVd6/xRUKAMAEQQEAmCAoAAATBAUAYIKgAABMEBQAgAmCAgAwQVAAACYICgDABEEBAJggKAAAEwQFAGCCoAAATFR5t2EAsBAdHe01V1xcbHwmsMYVCgDABEEBAJggKAAAEwQFAGCCoAAATBAUAIAJggIAMEFQAAAmCAoAwARBAQCYICgAABMEBQBggs0hAXipX7++19zgwYO95ubNmxf2zKZNm7zWcs55zZ3ouEIBAJggKAAAEwQFAGCCoAAATBAUAIAJggIAMEFQAAAmCAoAwARBAQCYICgAABMEBQBggqAAAEwQFACACXYbBk5wMTF+/xsYOnSo19yTTz7pNderV6+wZ3zPccuWLV5zJzquUAAAJggKAMAEQQEAmCAoAAATBAUAYIKgAABMEBQAgAmCAgAwQVAAACYICgDABEEBAJggKAAAEwQFAGCC3YaB40gQBGHPNG3a1GutLl26eM3t37/fa+79998Pe2bHjh1ea8EPVygAABMEBQBggqAAAEwQFACACYICADBBUAAAJggKAMAEQQEAmCAoAAATBAUAYIKgAABMEBQAgAmCAgAwwW7DwHHk1FNPDXvmwQcf9Frriiuu8Jr75ptvvOYyMjLCniksLPRaC364QgEAmCAoAAATBAUAYIKgAABMEBQAgAmCAgAwQVAAACYICgDABEEBAJggKAAAEwQFAGCCoAAATLA5JEKKi4sLe6Z+/fpea23fvt1rbt++fV5zx4Lo6GivuUsvvTTsmR49enit5fMzIknLli3zmtu8ebPXHCKHKxQAgAmCAgAwQVAAACYICgDABEEBAJggKAAAEwQFAGCCoAAATBAUAIAJggIAMEFQAAAmCAoAwARBAQCYCJxzrkoPDIIjfS44ApKSkrzmnnrqqbBnrr32Wq+15s2b5zV36623es2VlJR4zUXS+eef7zU3ZcqUsGfatGnjtdYXX3zhNde5c2evub1793rNwUZVUsEVCgDABEEBAJggKAAAEwQFAGCCoAAATBAUAIAJggIAMEFQAAAmCAoAwARBAQCYICgAABMEBQBggqAAAEzEHO0TQNX47vbco0ePiM0lJiZ6rXXuued6zR0LO2DHxcV5zd17771ecykpKWHPFBcXe621ePFirzl2DT5+cYUCADBBUAAAJggKAMAEQQEAmCAoAAATBAUAYIKgAABMEBQAgAmCAgAwQVAAACYICgDABEEBAJggKAAAE+w2fIw49dRTveYefPDBiK3nu2vte++95zXnu14knXHGGV5zaWlpXnM+uxuvX7/ea61ly5Z5zeH4xRUKAMAEQQEAmCAoAAATBAUAYIKgAABMEBQAgAmCAgAwQVAAACYICgDABEEBAJggKAAAEwQFAGCCzSEjrGnTpl5zjz32mNdc8+bNveZKSkrCnlmwYIHXWhMnTvSaiyTf79sjjzziNVe7dm2vufz8/LBnfL/+Cxcu9JrD8YsrFACACYICADBBUAAAJggKAMAEQQEAmCAoAAATBAUAYIKgAABMEBQAgAmCAgAwQVAAACYICgDABEEBAJhgt+FqiIkJ/8vXp08fr7XS0tK85nzt3r077Jn58+d7rbVlyxavOV8+37e+fft6reX7ffPZ7VmS3n333bBnfL9vzjmvORy/uEIBAJggKAAAEwQFAGCCoAAATBAUAIAJggIAMEFQAAAmCAoAwARBAQCYICgAABMEBQBggqAAAEwQFACACXYbjrC6det6zcXFxRmfyaHFx8eHPeO7s25iYqLXnO8uuQkJCWHPdOnSxWutWrVqec358nluAwcO9Fpr1apVXnPvv/++19yOHTvCnmFH5MjiCgUAYIKgAABMEBQAgAmCAgAwQVAAACYICgDABEEBAJggKAAAEwQFAGCCoAAATBAUAIAJggIAMEFQAAAmAlfF7TiDIDjS53LMiYkJf7Pm+++/32utBx54wGsukrsUl5SUeM3t27fPay4nJ8drLjY2NuyZ0047zWutSO8S7fM9KC4u9lorLy/Pa+7tt9/2mhs/fnzYMz47FCO0qqSCKxQAgAmCAgAwQVAAACYICgDABEEBAJggKAAAEwQFAGCCoAAATBAUAIAJggIAMEFQAAAmCAoAwET4uxuiTFFRUdgzEydO9Fpr8+bNXnMPPfSQ15zPZog+my5K/hsotmjRwmvuWODzsyVVbQO/H/Pd5PHrr7/2mlu4cKHX3Pfff+81h8jhCgUAYIKgAABMEBQAgAmCAgAwQVAAACYICgDABEEBAJggKAAAEwQFAGCCoAAATBAUAIAJggIAMEFQAAAmAlfF7UmDIDjS54JD8N2Rt3Xr1l5zXbt2DXsmMTHRay3fc+zWrZvX3Mknn+w152Pjxo1ec6+//rrXnM8uxVlZWV5rrVq1ymtu06ZNXnPFxcVec7BRlVRwhQIAMEFQAAAmCAoAwARBAQCYICgAABMEBQBggqAAAEwQFACACYICADBBUAAAJggKAMAEQQEAmCAoAAAT7DaMkGJiYsKe8f0ZqVevntfcm2++6TXXpUuXsGd8d7r97W9/6zX3wAMPeM1V8T/nckpKSrzW8p3DsYndhgEAEUNQAAAmCAoAwARBAQCYICgAABMEBQBggqAAAEwQFACACYICADBBUAAAJggKAMAEQQEAmCAoAAAT4W8pixNCUVFRxNZKSkrymuvUqZPXnM+OvAsXLvRaa+LEiV5zBw4c8JoDjiauUAAAJggKAMAEQQEAmCAoAAATBAUAYIKgAABMEBQAgAmCAgAwQVAAACYICgDABEEBAJggKAAAE2wOCTPx8fFec9ddd53XXM2aNb3m8vPzw55ZtmyZ11qbN2/2mgOORVyhAABMEBQAgAmCAgAwQVAAACYICgDABEEBAJggKAAAEwQFAGCCoAAATBAUAIAJggIAMEFQAAAmCAoAwAS7DSOkqKjwf9fo3r2711o33HCD11xxcbHX3JIlS8KemTNnjtdaRUVFXnPAsYgrFACACYICADBBUAAAJggKAMAEQQEAmCAoAAATBAUAYIKgAABMEBQAgAmCAgAwQVAAACYICgDABEEBAJhgt2GEFB0dHfZMu3btvNZq3Lix19zOnTu95qZPnx72zKZNm7zWAk4kXKEAAEwQFACACYICADBBUAAAJggKAMAEQQEAmCAoAAATBAUAYIKgAABMEBQAgAmCAgAwQVAAACYICgDABLsNH+eCIPCaa9WqVdgzffr08VorNjbWa27Lli1ec6tXrw57pk6dOl5r5efne80VFRV5zQFHE1coAAATBAUAYIKgAABMEBQAgAmCAgAwQVAAACYICgDABEEBAJggKAAAEwQFAGCCoAAATBAUAIAJNoc8ztWrV89r7uc//3nYMykpKV5r+SopKfGau+mmm8KeKSws9Fpr0qRJXnPffvut1xxwNHGFAgAwQVAAACYICgDABEEBAJggKAAAEwQFAGCCoAAATBAUAIAJggIAMEFQAAAmCAoAwARBAQCYICgAABPsNnycS0xM9Jrr2LFj2DNxcXFeawVB4DXXtm1br7lzzjkn7JkXX3zRa628vDyvOeBYxBUKAMAEQQEAmCAoAAATBAUAYIKgAABMEBQAgAmCAgAwQVAAACYICgDABEEBAJggKAAAEwQFAGCCoAAATLDb8HEuJyfHa+6tt94Ke+bee+/1WqtmzZpeczt37vSaW7hwYdgzf/rTn7zWKigo8JoDjkVcoQAATBAUAIAJggIAMEFQAAAmCAoAwARBAQCYICgAABMEBQBggqAAAEwQFACACYICADBBUAAAJggKAMAEuw0f5woLC73mpk6dGvbMnj17vNaqW7eu19zq1au95j744IOwZ7Zt2+a1FnAi4QoFAGCCoAAATBAUAIAJggIAMEFQAAAmCAoAwARBAQCYICgAABMEBQBggqAAAEwQFACACYICADAROOdclR4YBEf6XHCMi4mJ7F6jJSUlEZ0DTmRVSQVXKAAAEwQFAGCCoAAATBAUAIAJggIAMEFQAAAmCAoAwARBAQCYICgAABMEBQBggqAAAEwQFACACYICADDBbsMAgMNit2EAQMQQFACACYICADBBUAAAJggKAMAEQQEAmCAoAAATBAUAYIKgAABMEBQAgAmCAgAwQVAAACYICgDARExVH1jFTYkBACcorlAAACYICgDABEEBAJggKAAAEwQFAGCCoAAATBAUAIAJggIAMEFQAAAm/g+EF0MJ5N2MfgAAAABJRU5ErkJggg==",
|
||
"text/plain": [
|
||
"<Figure size 500x500 with 1 Axes>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"import matplotlib.pyplot as plt\n",
|
||
"import random\n",
|
||
"\n",
|
||
"def visualize_test_sample(X_test, y_test, index, whitelist):\n",
|
||
" \"\"\"\n",
|
||
" Визуализация тестового примера с его меткой.\n",
|
||
" \n",
|
||
" :param X_test: Тензор с тестовыми изображениями.\n",
|
||
" :param y_test: Тензор с метками тестовых изображений.\n",
|
||
" :param index: Индекс примера для отображения.\n",
|
||
" :param whitelist: Список символов для сопоставления меток.\n",
|
||
" \"\"\"\n",
|
||
" image = X_test[index] # Убираем канал и переводим в numpy\n",
|
||
" label_idx = y_test[index].item() # Получаем индекс метки\n",
|
||
" label_char = whitelist[label_idx] # Соответствующий символ из whitelist\n",
|
||
"\n",
|
||
" # Отображение изображения с меткой\n",
|
||
" plt.figure(figsize=(5, 5))\n",
|
||
" plt.imshow(image, cmap='gray')\n",
|
||
" plt.title(f\"Метка: {label_char} (Index: {label_idx})\", fontsize=16)\n",
|
||
" plt.axis('off')\n",
|
||
" plt.show()\n",
|
||
"\n",
|
||
"\n",
|
||
"# Пример вызова\n",
|
||
"visualize_test_sample(X_test, y_test, index=random.choice([i for i in range(len(X_test))]), whitelist=whitelist)\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 216,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"10000"
|
||
]
|
||
},
|
||
"execution_count": 216,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"len(X_test_all[2])"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": ".venv",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.9.13"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 2
|
||
}
|