683 lines
30 KiB
Plaintext
683 lines
30 KiB
Plaintext
|
{
|
|||
|
"cells": [
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 316,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"18\n",
|
|||
|
"{'A': 0, 'B': 1, 'C': 2, 'E': 3, 'H': 4, 'K': 5, 'M': 6, 'O': 7, 'P': 8, 'T': 9, 'X': 10, 'y': 11, '0': 12, '1': 13, '2': 14, '3': 15, '4': 16, '5': 17, '6': 18, '7': 19, '8': 20, '9': 21}\n",
|
|||
|
"697932\n",
|
|||
|
"433256\n",
|
|||
|
"433256\n",
|
|||
|
"433256\n",
|
|||
|
"{'A': 0, 'B': 1, 'C': 2, 'E': 3, 'H': 4, 'K': 5, 'M': 6, 'O': 7, 'P': 8, 'T': 9, 'X': 10, 'y': 11, '0': 12, '1': 13, '2': 14, '3': 15, '4': 16, '5': 17, '6': 18, '7': 19, '8': 20, '9': 21}\n",
|
|||
|
"116323\n",
|
|||
|
"72548\n",
|
|||
|
"72548\n",
|
|||
|
"72548\n",
|
|||
|
"Размер обучающего набора: (433256, 28, 28), Метки: (433256,)\n",
|
|||
|
"Размер тестового набора: (72548, 28, 28), Метки: (72548,)\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"import numpy as np\n",
|
|||
|
"from emnist import extract_training_samples, extract_test_samples\n",
|
|||
|
"\n",
|
|||
|
"# Ваш whitelist символов\n",
|
|||
|
"whitelist = 'ABCEHKMOPTXy0123456789'\n",
|
|||
|
"\n",
|
|||
|
"# Список всех классов в EMNIST\n",
|
|||
|
"emnist_classes = (\n",
|
|||
|
" \"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ\"\n",
|
|||
|
" \"abcdefghijklmnopqrstuvwxyz\"\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"# Функция для фильтрации данных по whitelist\n",
|
|||
|
"def filter_data(X, y, whitelist, emnist_classes):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" Фильтрует данные на основе whitelist символов.\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" # Создаём словарь для быстрого сопоставления: символ -> индекс в whitelist\n",
|
|||
|
" whitelist_mapping = {char: idx for idx, char in enumerate(whitelist)}\n",
|
|||
|
" print(whitelist_mapping)\n",
|
|||
|
" print(len(y))\n",
|
|||
|
" # Создаём список индексов, которые соответствуют whitelist\n",
|
|||
|
" filtered_indices = [\n",
|
|||
|
" i for i, label in enumerate(y) if emnist_classes[label] in whitelist_mapping\n",
|
|||
|
" ]\n",
|
|||
|
" print(len(filtered_indices))\n",
|
|||
|
" # Отфильтрованные данные\n",
|
|||
|
" X_filtered = X[filtered_indices]\n",
|
|||
|
" y_filtered = np.array([whitelist_mapping[emnist_classes[label]] for label in y[filtered_indices]])\n",
|
|||
|
" print(len(X_filtered))\n",
|
|||
|
" print(len(y_filtered))\n",
|
|||
|
" return X_filtered, y_filtered\n",
|
|||
|
"\n",
|
|||
|
"# Загрузка данных из подмножества 'byclass'\n",
|
|||
|
"X_train_byclass, y_train_byclass = extract_training_samples('byclass')\n",
|
|||
|
"X_test_byclass, y_test_byclass = extract_test_samples('byclass')\n",
|
|||
|
"print(y_test_byclass[0])\n",
|
|||
|
"# Фильтрация данных\n",
|
|||
|
"X_train, y_train = filter_data(X_train_byclass, y_train_byclass, whitelist, emnist_classes)\n",
|
|||
|
"X_test, y_test = filter_data(X_test_byclass, y_test_byclass, whitelist, emnist_classes)\n",
|
|||
|
"\n",
|
|||
|
"# Проверка размеров наборов данных\n",
|
|||
|
"print(f\"Размер обучающего набора: {X_train.shape}, Метки: {y_train.shape}\")\n",
|
|||
|
"print(f\"Размер тестового набора: {X_test.shape}, Метки: {y_test.shape}\")\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 317,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"class EMNISTDataset(Dataset):\n",
|
|||
|
" def __init__(self, X, y, transform=None):\n",
|
|||
|
" self.X = torch.tensor(X, dtype=torch.float32).unsqueeze(1) # Добавляем канал\n",
|
|||
|
" self.y = torch.tensor(y, dtype=torch.long)\n",
|
|||
|
" self.transform = transform\n",
|
|||
|
"\n",
|
|||
|
" def __len__(self):\n",
|
|||
|
" return len(self.X)\n",
|
|||
|
"\n",
|
|||
|
" def __getitem__(self, idx):\n",
|
|||
|
" image = self.X[idx]\n",
|
|||
|
" label = self.y[idx]\n",
|
|||
|
" if self.transform:\n",
|
|||
|
" image = self.transform(image)\n",
|
|||
|
" return image, label\n",
|
|||
|
"\n",
|
|||
|
"# Нормализация данных\n",
|
|||
|
"transform = transforms.Compose([\n",
|
|||
|
" transforms.Normalize((0.5,), (0.5,)) # Нормализация в диапазон [-1, 1]\n",
|
|||
|
"])\n",
|
|||
|
"\n",
|
|||
|
"# Создание датасетов\n",
|
|||
|
"train_dataset = EMNISTDataset(X_train, y_train, transform=transform)\n",
|
|||
|
"test_dataset = EMNISTDataset(X_test, y_test, transform=transform)\n",
|
|||
|
"\n",
|
|||
|
"# DataLoader для обучения и тестирования\n",
|
|||
|
"train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)\n",
|
|||
|
"test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)\n",
|
|||
|
"\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 318,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"import torch.nn as nn\n",
|
|||
|
"import torch.nn.functional as F\n",
|
|||
|
"\n",
|
|||
|
"def compute_output_size(input_size, kernel_size, stride, padding):\n",
|
|||
|
" return (input_size - kernel_size + 2 * padding) // stride + 1\n",
|
|||
|
"\n",
|
|||
|
"class CNN(nn.Module):\n",
|
|||
|
" def __init__(self, num_classes):\n",
|
|||
|
" super(CNN, self).__init__()\n",
|
|||
|
" # Сверточные слои\n",
|
|||
|
" self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)\n",
|
|||
|
" self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)\n",
|
|||
|
" self.pool = nn.MaxPool2d(kernel_size=2, stride=2)\n",
|
|||
|
" \n",
|
|||
|
" # Вычисляем размеры после сверточных слоев\n",
|
|||
|
" height = compute_output_size(28, kernel_size=3, stride=1, padding=1) # conv1\n",
|
|||
|
" height = compute_output_size(height, kernel_size=3, stride=1, padding=1) # conv2\n",
|
|||
|
" height = compute_output_size(height, kernel_size=2, stride=2, padding=0) # pool\n",
|
|||
|
" width = height # Для квадратного входа\n",
|
|||
|
"\n",
|
|||
|
" # Полносвязные слои\n",
|
|||
|
" self.fc1 = nn.Linear(64 * height * width, 128)\n",
|
|||
|
" self.fc2 = nn.Linear(128, num_classes)\n",
|
|||
|
" self.dropout = nn.Dropout(0.5)\n",
|
|||
|
"\n",
|
|||
|
" def forward(self, x):\n",
|
|||
|
" x = F.relu(self.conv1(x))\n",
|
|||
|
" x = self.pool(F.relu(self.conv2(x)))\n",
|
|||
|
" x = x.view(x.size(0), -1)\n",
|
|||
|
" x = F.relu(self.fc1(x))\n",
|
|||
|
" x = self.dropout(x)\n",
|
|||
|
" x = self.fc2(x)\n",
|
|||
|
" return x\n",
|
|||
|
"\n",
|
|||
|
" \n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"# Инициализация модели\n",
|
|||
|
"num_classes = len(whitelist)\n",
|
|||
|
"model = CNN(num_classes)\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 319,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"import torch.optim as optim\n",
|
|||
|
"\n",
|
|||
|
"# Устройство (GPU или CPU)\n",
|
|||
|
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
|||
|
"model.to(device)\n",
|
|||
|
"\n",
|
|||
|
"# Определение функции потерь и оптимизатора\n",
|
|||
|
"criterion = nn.CrossEntropyLoss()\n",
|
|||
|
"optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
|
|||
|
"\n",
|
|||
|
"from tqdm import tqdm\n",
|
|||
|
"import torch\n",
|
|||
|
"\n",
|
|||
|
"# Цикл обучения с tqdm\n",
|
|||
|
"def train(model, loader, optimizer, criterion, device):\n",
|
|||
|
" model.train()\n",
|
|||
|
" total_loss = 0\n",
|
|||
|
" correct = 0\n",
|
|||
|
" loop = tqdm(loader, desc=\"Обучение\", leave=False) # Прогресс-бар для обучения\n",
|
|||
|
" for images, labels in loop:\n",
|
|||
|
" images, labels = images.to(device), labels.to(device)\n",
|
|||
|
"\n",
|
|||
|
" # Обнуление градиентов\n",
|
|||
|
" optimizer.zero_grad()\n",
|
|||
|
"\n",
|
|||
|
" # Прямой проход\n",
|
|||
|
" outputs = model(images)\n",
|
|||
|
" loss = criterion(outputs, labels)\n",
|
|||
|
"\n",
|
|||
|
" # Обратный проход и обновление весов\n",
|
|||
|
" loss.backward()\n",
|
|||
|
" optimizer.step()\n",
|
|||
|
"\n",
|
|||
|
" # Логирование\n",
|
|||
|
" total_loss += loss.item()\n",
|
|||
|
" _, predicted = torch.max(outputs, 1)\n",
|
|||
|
" correct += (predicted == labels).sum().item()\n",
|
|||
|
"\n",
|
|||
|
" # Обновление tqdm\n",
|
|||
|
" loop.set_postfix(loss=loss.item())\n",
|
|||
|
"\n",
|
|||
|
" accuracy = 100 * correct / len(loader.dataset)\n",
|
|||
|
" return total_loss / len(loader), accuracy\n",
|
|||
|
"\n",
|
|||
|
"# Цикл оценки с tqdm\n",
|
|||
|
"def evaluate(model, loader, criterion, device):\n",
|
|||
|
" model.eval()\n",
|
|||
|
" total_loss = 0\n",
|
|||
|
" correct = 0\n",
|
|||
|
" loop = tqdm(loader, desc=\"Оценка\", leave=False) # Прогресс-бар для оценки\n",
|
|||
|
" with torch.no_grad():\n",
|
|||
|
" for images, labels in loop:\n",
|
|||
|
" images, labels = images.to(device), labels.to(device)\n",
|
|||
|
"\n",
|
|||
|
" # Прямой проход\n",
|
|||
|
" outputs = model(images)\n",
|
|||
|
" loss = criterion(outputs, labels)\n",
|
|||
|
"\n",
|
|||
|
" # Логирование\n",
|
|||
|
" total_loss += loss.item()\n",
|
|||
|
" _, predicted = torch.max(outputs, 1)\n",
|
|||
|
" correct += (predicted == labels).sum().item()\n",
|
|||
|
"\n",
|
|||
|
" # Обновление tqdm\n",
|
|||
|
" loop.set_postfix(loss=loss.item())\n",
|
|||
|
"\n",
|
|||
|
" accuracy = 100 * correct / len(loader.dataset)\n",
|
|||
|
" return total_loss / len(loader), accuracy\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 320,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"device(type='cuda')"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 320,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
|||
|
"device"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 321,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Эпоха 1/10\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
" \r"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
" Train Loss: 0.4385, Train Accuracy: 86.80%\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
" \r"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
" Test Loss: 0.1677, Test Accuracy: 93.46%\n",
|
|||
|
"Эпоха 2/10\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
" \r"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
" Train Loss: 0.2405, Train Accuracy: 91.25%\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
" \r"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
" Test Loss: 0.1499, Test Accuracy: 93.91%\n",
|
|||
|
"Эпоха 3/10\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
" \r"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
" Train Loss: 0.2105, Train Accuracy: 92.21%\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
" \r"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
" Test Loss: 0.1608, Test Accuracy: 93.23%\n",
|
|||
|
"Эпоха 4/10\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
" \r"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
" Train Loss: 0.1978, Train Accuracy: 92.65%\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
" \r"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
" Test Loss: 0.1394, Test Accuracy: 94.31%\n",
|
|||
|
"Эпоха 5/10\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
" \r"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
" Train Loss: 0.1887, Train Accuracy: 92.92%\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
" \r"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
" Test Loss: 0.1417, Test Accuracy: 94.46%\n",
|
|||
|
"Эпоха 6/10\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
" \r"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
" Train Loss: 0.1841, Train Accuracy: 93.05%\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
" \r"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
" Test Loss: 0.1473, Test Accuracy: 94.38%\n",
|
|||
|
"Эпоха 7/10\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
" \r"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
" Train Loss: 0.1794, Train Accuracy: 93.18%\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
" \r"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
" Test Loss: 0.1418, Test Accuracy: 94.52%\n",
|
|||
|
"Эпоха 8/10\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
" \r"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
" Train Loss: 0.1759, Train Accuracy: 93.30%\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
" \r"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
" Test Loss: 0.1432, Test Accuracy: 94.42%\n",
|
|||
|
"Эпоха 9/10\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
" \r"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
" Train Loss: 0.1728, Train Accuracy: 93.41%\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
" \r"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
" Test Loss: 0.1406, Test Accuracy: 94.53%\n",
|
|||
|
"Эпоха 10/10\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
" \r"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
" Train Loss: 0.1710, Train Accuracy: 93.46%\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
" "
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
" Test Loss: 0.1397, Test Accuracy: 94.47%\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"\r"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# Цикл обучения\n",
|
|||
|
"num_epochs = 10\n",
|
|||
|
"for epoch in range(num_epochs):\n",
|
|||
|
" print(f\"Эпоха {epoch + 1}/{num_epochs}\")\n",
|
|||
|
" \n",
|
|||
|
" # Обучение\n",
|
|||
|
" train_loss, train_acc = train(model, train_loader, optimizer, criterion, device)\n",
|
|||
|
" print(f\" Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.2f}%\")\n",
|
|||
|
"\n",
|
|||
|
" # Оценка\n",
|
|||
|
" test_loss, test_acc = evaluate(model, test_loader, criterion, device)\n",
|
|||
|
" print(f\" Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.2f}%\")\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 322,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"torch.save(model.state_dict(), \"letter_recognition_model1.pth\")\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 337,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZQAAAGtCAYAAAAxsILFAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAc7UlEQVR4nO3daXiU5dmH8f+ThUAgGAUDKhAUMAiiiCweqBBErYhIkMUCIrgUVESqFXDB4ta6W1vbUqEuIGhkETAQWqBUseKGVDDSeFBWg0pZEsCQAEnu94NvUmImkLlzMRQ4f8fBl5m5nvuZJHrmYWZuAuecEwAA1RR1tE8AAHB8ICgAABMEBQBggqAAAEwQFACACYICADBBUAAAJggKAMAEQQEAmCAoEdK0aVMFQaAgCDR69OhDPvaZZ54pe2xMTEyEzvD4kpmZWfY1vPzyy6t1rD59+qhWrVrKyckpd3tqaqqCINDDDz9creNX1caNGxUEgZo2bRqR9aytXLlSzz77rAYOHKizzz5bUVFRCoJA06ZNq9L8/v379bvf/U6XXHKJTjnlFNWsWVONGjVSjx499NZbb5V7bHFxsVq2bKnk5GQVFBQciaeDEAjKUTB9+nTt37+/0vtfeeWVCJ7N8Sc3N1c/+9nPFARBtY+1ZMkSzZ07V3feeacaNWpkcHYnrkcffVRjxoxRenq61q5dq3B2fcrJydEFF1yg0aNH66uvvtLFF1+stLQ0JScna9myZZo5c2a5x0dHR+vxxx/X5s2b9fTTT1s/FVSCoERY+/bttWPHDs2bNy/k/cuXL1d2drY6dOgQ4TM7fowaNUpbt27VbbfdVu1j3X333apZs6buu+8+gzM7sV100UV64IEHNGvWLK1bt05du3at0lxBQYGuuOIKrVmzRg8//LC++eYbZWRkKD09XR988IG2bdum8ePHV5jr16+f2rRpo6eeekrfffed9dNBCAQlwm6++WZJlV+FvPzyy+Ueh/DMmTNH06dP1z333KOOHTtW61iLFy9WVlaW0tLSVK9ePaMzPHHdd999+tWvfqW+ffvqrLPOqvLcE088oezsbA0fPlwTJkxQbGxsufvj4+PVtm3bkLM333yzCgoKNGnSpOqcOqqIoERYmzZt1L59ey1atEhbtmwpd9/333+vGTNmqFGjRrryyisPeZyioiL9+c9/Vmpqqk455RTFxcXpzDPP1O23366vv/663GOHDRtW9npCVf6Ueu211xQEgYYNG1Zh/cWLFys+Pl61a9fW0qVLy933ySefaOzYserYsaMaNmyoGjVqqEGDBurVq5eWLFkS5les6rZv367bbrtNKSkpevTRR6t9vN///veSFPL5H8rBX7f8/Hzdf//9at68ueLi4tSwYUMNHTq0wvf+YPPnz1fXrl2VkJCgk046SZdeemmlV7QHy83N1YQJE9S2bVslJCQoPj5ebdq00eOPP669e/eWe+xzzz2nIAh09tlna8+ePRWONXnyZAVBoMaNG2v79u1hPX9LBw4c0MSJEyVJY8aMCXt+8ODBiomJ0UsvvaSioiLr08OPOUREcnKyk+Tef/9998c//tFJco8//ni5x7z88stOknvwwQfdhg0bnCQXHR1d4Vi7d+92qampTpKrU6eO69q1q+vXr59LSUlxkly9evXcypUryx4/efJkN3To0HJ/GjRo4CS5n/zkJxXuK/Xqq686SeVuc865RYsWuVq1arn4+Hi3dOnSCufXvXt3FxUV5dq0aeOuvvpq179/f9euXTsnyUlyL7zwQsivUdeuXZ0kN2HChKp/YQ/Sr18/FxUV5f7xj3+UO//u3buHfayCggIXFxfnYmNj3d69e8M639J109LS3HnnnecSExNdr169XO/evV1SUpKT5JKTk11eXl6FYz7//PNlX6eOHTu6gQMHuvbt2ztJ7p577imb/bEvv/zSNW7c2Elyp512mrvqqqtcr169yr7Pbdu2rbDetdde6yS5n/70p+Vu//zzz13NmjVdTEyM++CDD8rdV/pzKclt2LDh8F/IQyj9+r3++uuVPuajjz5yktzpp5/unHNu9erV7uGHH3bDhw9348aNc/Pnz3fFxcWHXKf067d8+fJqnS8Oj6BEyMFBycvLc7Vq1XLNmzcv95iLL77YBUHg1q1bd8igDBo0yEly11xzjdu6dWu5+37zm984Sa5FixauqKio0vMp/Y/573//e6WPCRWUg2NS2WxmZqb75ptvKty+fPlyV7duXRcbG+tycnIqPSefoLz55ptOkhs9enSF8/cJypIlS5wk16FDh0ofc7iglAZ7165dZfft3LnTtW3b1klyv/71r8vNrVq1ykVHR7uoqCg3c+bMcvdNmzbNBUEQMih79+51zZo1c5Lc+PHj3b59+8ruy8/PdwMHDnSS3E033VRuLjc31zVt2tRJchMnTnTO/fDLSosWLZwk98wzz1R4zpEOyqRJk8riOm7cuLKvwcF/LrjgArdp06ZKj3HXXXc5Se6xxx6r1vni8AhKhBwcFOecGzx4sJPk3n33Xeecc9nZ2U6SS01Ndc65SoOyZs0aFwSBO/30093u3btDrnX11Vc7SS4jI6PS8/EJSmlMateuXXbe4br//vudJPeHP/yhwn1DhgxxKSkp7sUXXwzrmN9++6075ZRTXLNmzVx+fn6F8/cJyjPPPOMkuRtvvLHSxxwuKLVr1w4Z1vT0dCfJXXbZZeVuv/XWW50kd/3114dcr3fv3iGDMnHixLJfMELZs2ePS0pKcjExMW7nzp3l7vvkk09cjRo1XFxcnPvnP//pBgwY4CS5Xr16uZKSkgrHysnJcSkpKS4lJSXkLwXhqEpQnnjiCSfJxcbGOklu5MiR7quvvnK7du1yixcvdmeffbaT5M4991y3f//+kMd46aWXnCTXp0+fap0vDo8PORwlN998s6ZPn65XXnlFXbt2LXuR/nAvxmdmZso5px49eighISHkY1JTU5WZmanly5frmmuuMTnfxYsXq3fv3iooKNCcOXMO+w6dHTt2aMGCBcrKylJubq4OHDggSVq7dq0k6auvvqowM3XqVK9zGz58uHJzczV79mzFx8d7HePHtm7dKknVejG+ffv2Ou200yrcfs4550hShddR3n33XUnSDTfcEPJ4Q4cODflayoIFCyRJ119/fci5OnXqqH379srMzNSnn35a7vW5Dh066Nlnn9Vdd92l1NRU7dq1S8nJyZoyZUrIt12fccYZys7ODrnOkeD+/63FBw4c0MCBA8te15Kkyy+/XIsXL1ZKSoqysrKUnp6uIUOGVDhG6few9HuKI4egHCXdunXTmWeeqVmzZumFF17Q1KlTVbduXfXr1++Qc+vXr5f0w7vBSt8RVplt27aZnOuKFSs0Y8aMsg+ITZs2TWlpaZU+fvLkybr77ruVn59f6WN2795tcm5TpkxRRkaGbr/9dqWmppocU5J27dolSapbt673MZo0aRLy9tJjFhYWlru99IOTZ555Zsi5ym4v/ZkYMmRIyP+hHizUz8SoUaM0f/58LVq0SEEQKD09XSeffPIhjxMpB//SNGLEiAr3N2nSRD179tTs2bO1ZMmSkM+/9Oudm5t75E4UkgjKUVP6LqAJEyZo6NCh+u677zR8+HDVqlXrkHMlJSWSpLZt2+r8888/5GM7depkcq5ffvml4uPjlZmZqfvuu0+zZ8/Wyy+/rFtuuaXCYz/77DONGDFC0dHReuqpp9SrVy81adJE8fHxCoJAkyZN0ogRI8L6UNuhzJkzR5L06aefVghK6WcPPvvss7L70tPT1bBhw8MeNzExUVL1whcVFZk3UZb+TFx11VVq0KDBIR+bnJxc4ba1a9fqww8/lPTDFcEnn3yiiy66yP5EPRz89uLK3mpcevu3334b8v7SXw7+VyJ5PCMoR9GwYcP0yCOPKCMjQ1LVPnvSuHFjSdLFF19c7vL/SIqPj1dGRoYuu+wyJScn68ILL9To0aPVpUsXtWjRotxjZ86cKeecRo0apbFjx1Y4VulfeVlbsWJFpffl5eXpvffek1TxqqAySUlJkn74q7tIOeOMM7Ru3Tp
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 500x500 with 1 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"import matplotlib.pyplot as plt\n",
|
|||
|
"import random\n",
|
|||
|
"\n",
|
|||
|
"def visualize_test_sample(X_test, y_test, index, whitelist):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" Визуализация тестового примера с его меткой.\n",
|
|||
|
" \n",
|
|||
|
" :param X_test: Тензор с тестовыми изображениями.\n",
|
|||
|
" :param y_test: Тензор с метками тестовых изображений.\n",
|
|||
|
" :param index: Индекс примера для отображения.\n",
|
|||
|
" :param whitelist: Список символов для сопоставления меток.\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" image = X_test[index] # Убираем канал и переводим в numpy\n",
|
|||
|
" label_idx = y_test[index].item() # Получаем индекс метки\n",
|
|||
|
" label_char = whitelist[label_idx] # Соответствующий символ из whitelist\n",
|
|||
|
"\n",
|
|||
|
" # Отображение изображения с меткой\n",
|
|||
|
" plt.figure(figsize=(5, 5))\n",
|
|||
|
" plt.imshow(image, cmap='gray')\n",
|
|||
|
" plt.title(f\"Метка: {label_char} (Index: {label_idx})\", fontsize=16)\n",
|
|||
|
" plt.axis('off')\n",
|
|||
|
" plt.show()\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"# Пример вызова\n",
|
|||
|
"visualize_test_sample(X_test, y_test, index=random.choice([i for i in range(len(X_test))]), whitelist=whitelist)\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 216,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"10000"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 216,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"len(X_test_all[2])"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"kernelspec": {
|
|||
|
"display_name": ".venv",
|
|||
|
"language": "python",
|
|||
|
"name": "python3"
|
|||
|
},
|
|||
|
"language_info": {
|
|||
|
"codemirror_mode": {
|
|||
|
"name": "ipython",
|
|||
|
"version": 3
|
|||
|
},
|
|||
|
"file_extension": ".py",
|
|||
|
"mimetype": "text/x-python",
|
|||
|
"name": "python",
|
|||
|
"nbconvert_exporter": "python",
|
|||
|
"pygments_lexer": "ipython3",
|
|||
|
"version": "3.9.13"
|
|||
|
}
|
|||
|
},
|
|||
|
"nbformat": 4,
|
|||
|
"nbformat_minor": 2
|
|||
|
}
|