2466 lines
77 KiB
Plaintext
2466 lines
77 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 28,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import os\n",
|
||
"import torch\n",
|
||
"import torch.nn as nn\n",
|
||
"import torch.optim as optim\n",
|
||
"from torch.utils.data import DataLoader, Dataset, Subset\n",
|
||
"from torchvision import transforms\n",
|
||
"import random\n",
|
||
"from PIL import Image, ImageDraw, ImageFont, ImageOps, ImageFilter\n",
|
||
"import cv2\n",
|
||
"import numpy as np\n",
|
||
"from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
|
||
"from sklearn.model_selection import train_test_split\n",
|
||
"import torch.nn as nn\n",
|
||
"import torch.nn.functional as F\n",
|
||
"from tqdm import tqdm"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 43,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"DATASET_DIR = \"dataset\"\n",
|
||
"SAVE_PATH = \"best_model_10.pth\"\n",
|
||
"BATCH_SIZE = 128\n",
|
||
"EPOCHS = 30\n",
|
||
"LEARNING_RATE = 0.01\n",
|
||
"DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
||
"\n",
|
||
"CLASSES = \"ABEKMHOPCTYX0123456789\"\n",
|
||
"NUM_CLASSES = len(CLASSES)\n",
|
||
"CLASS_TO_IDX = {char: idx for idx, char in enumerate(CLASSES)}\n",
|
||
"IDX_TO_CLASS = {idx: char for char, idx in CLASS_TO_IDX.items()}"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Датасет сохранен в папке 'dataset'.\n",
|
||
"Конфигурационный файл создан: dataset_config.txt\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"FONT_PATH = \"RoadNumbers2.0.ttf\"\n",
|
||
"CONFIG_FILE = \"dataset_config.txt\"\n",
|
||
"DATASET_DIR = \"dataset\" # Папка для итогового датасета\n",
|
||
"GOSZNAK_DIR = \"gosznak\" # Папка с изображениями для аугментаций\n",
|
||
"REAL_GOSZNAK_DIR = \"real_gosznak\" # Папка с реальными изображениями без аугментаций\n",
|
||
"\n",
|
||
"SYMBOLS = \"ABEKMHOPCTYX0123456789\"\n",
|
||
"SMALL_FONT_SIZES = [26, 32]\n",
|
||
"IMG_SIZE = (28, 28)\n",
|
||
"FONT_SIZES = [26, 32, 38, 46, 54, 58]\n",
|
||
"MAX_ROTATION = 10\n",
|
||
"AUGMENTATIONS = 240\n",
|
||
"\n",
|
||
"os.makedirs(DATASET_DIR, exist_ok=True)\n",
|
||
"\n",
|
||
"try:\n",
|
||
" font = ImageFont.truetype(FONT_PATH, FONT_SIZES[0])\n",
|
||
"except IOError:\n",
|
||
" print(f\"Шрифт {FONT_PATH} не найден. Убедитесь, что он находится в текущей директории.\")\n",
|
||
"\n",
|
||
"def process_gosznak_images(input_dir, output_dir, apply_augmentations=True):\n",
|
||
" \"\"\"\n",
|
||
" Обрабатывает изображения из папки `input_dir` и сохраняет их в папку `output_dir`.\n",
|
||
"\n",
|
||
" :param input_dir: Путь к папке с исходными изображениями.\n",
|
||
" :param output_dir: Путь к папке для сохранения обработанных изображений.\n",
|
||
" :param apply_augmentations: Если True, применяются аугментации.\n",
|
||
" \"\"\"\n",
|
||
" config_lines = []\n",
|
||
" for filename in os.listdir(input_dir):\n",
|
||
" filepath = os.path.join(input_dir, filename)\n",
|
||
" if not os.path.isfile(filepath) or not filename.endswith((\".png\", \".jpg\", \".jpeg\")):\n",
|
||
" continue\n",
|
||
"\n",
|
||
" img = Image.open(filepath).convert(\"L\") # Конвертируем в оттенки серого\n",
|
||
" img = img.resize(IMG_SIZE, resample=Image.Resampling.LANCZOS) \n",
|
||
" symbol = os.path.splitext(filename)[0][0] # Предполагаем, что имя файла начинается с символа\n",
|
||
" symbol_dir = os.path.join(output_dir, symbol)\n",
|
||
" os.makedirs(symbol_dir, exist_ok=True)\n",
|
||
"\n",
|
||
" if apply_augmentations:\n",
|
||
" for i in range(AUGMENTATIONS):\n",
|
||
" aug_img = img.copy()\n",
|
||
"\n",
|
||
" # Случайное смещение\n",
|
||
" if random.random() < 0.7:\n",
|
||
" aug_img = shift_image(aug_img, max_shift=4)\n",
|
||
"\n",
|
||
" # Случайный поворот\n",
|
||
" rotation = random.uniform(-MAX_ROTATION, MAX_ROTATION)\n",
|
||
" aug_img = aug_img.rotate(rotation, expand=False, fillcolor=255)\n",
|
||
"\n",
|
||
" # Случайное расширение\n",
|
||
" if random.random() < 0.5:\n",
|
||
" kernel_size = (random.randint(1, 2), random.randint(1, 2))\n",
|
||
" aug_img = expand_characters(aug_img, kernel_size=kernel_size, iterations=1)\n",
|
||
"\n",
|
||
" # Случайные разрывы\n",
|
||
" if random.random() < 0.15:\n",
|
||
" aug_img = add_random_gaps(aug_img, num_gaps=random.randint(1, 5), gap_size=random.randint(2, 4))\n",
|
||
"\n",
|
||
" # Блюр\n",
|
||
" #blur_min_limits = {\"box\": 5, \"gaussian\": 5, \"motion\": 6}\n",
|
||
" #blur_max_limits = {\"box\": 10, \"gaussian\": 10, \"motion\": 12}\n",
|
||
" if random.random() < 0.85:\n",
|
||
" aug_img = apply_blur(aug_img)# blur_min_limits=blur_min_limits, blur_max_limits=blur_max_limits)\n",
|
||
"\n",
|
||
" # Добавление шума\n",
|
||
" if random.random() < 0.1:\n",
|
||
" aug_img = add_noise(aug_img, intensity=random.randint(10, 20))\n",
|
||
"\n",
|
||
" file_path = os.path.join(symbol_dir, f\"{symbol}_aug_{i}.png\")\n",
|
||
" aug_img.save(file_path)\n",
|
||
" config_lines.append(f\"{file_path},{symbol}\\n\")\n",
|
||
" else:\n",
|
||
" # Сохраняем изображение без аугментаций\n",
|
||
" file_path = os.path.join(symbol_dir, filename)\n",
|
||
" img.save(file_path)\n",
|
||
" config_lines.append(f\"{file_path},{symbol}\\n\")\n",
|
||
" return config_lines\n",
|
||
"\n",
|
||
"def expand_characters(img, kernel_size=(2, 2), iterations=1):\n",
|
||
" kernel = cv2.getStructuringElement(cv2.MORPH_RECT, kernel_size)\n",
|
||
" img_array = np.array(img)\n",
|
||
" expanded = cv2.dilate(img_array, kernel, iterations=iterations)\n",
|
||
" return Image.fromarray(expanded)\n",
|
||
"\n",
|
||
"def add_random_gaps(img, num_gaps=5, gap_size=5):\n",
|
||
" draw = ImageDraw.Draw(img)\n",
|
||
" for _ in range(num_gaps):\n",
|
||
" x1 = random.randint(0, IMG_SIZE[0] - gap_size)\n",
|
||
" y1 = random.randint(0, IMG_SIZE[1] - gap_size)\n",
|
||
" x2 = x1 + gap_size\n",
|
||
" y2 = y1 + gap_size\n",
|
||
" draw.rectangle([x1, y1, x2, y2], fill=255) \n",
|
||
" return img\n",
|
||
"\n",
|
||
"def apply_blur(img, blur_min_limits=None, blur_max_limits=None):\n",
|
||
" \"\"\"\n",
|
||
" Применяет случайное размытие (BoxBlur, GaussianBlur, MotionBlur) к изображению.\n",
|
||
"\n",
|
||
" :param img: PIL Image объект.\n",
|
||
" :param blur_min_limits: Словарь с минимальными значениями размытия.\n",
|
||
" Пример: {\"box\": 1, \"gaussian\": 1, \"motion\": 4}.\n",
|
||
" :param blur_max_limits: Словарь с максимальными значениями размытия.\n",
|
||
" Пример: {\"box\": 4, \"gaussian\": 3, \"motion\": 8}.\n",
|
||
" :return: Изображение с примененным размытием.\n",
|
||
" \"\"\"\n",
|
||
" if blur_min_limits is None:\n",
|
||
" blur_min_limits = {\"box\": 1, \"gaussian\": 1, \"motion\": 4}\n",
|
||
"\n",
|
||
" if blur_max_limits is None:\n",
|
||
" blur_max_limits = {\"box\": 4, \"gaussian\": 3, \"motion\": 8}\n",
|
||
"\n",
|
||
" blur_type = random.choice([\"box\", \"gaussian\", \"motion\"])\n",
|
||
"\n",
|
||
" if blur_type == \"box\":\n",
|
||
" min_radius = blur_min_limits[\"box\"]\n",
|
||
" max_radius = blur_max_limits[\"box\"]\n",
|
||
" radius = random.randint(min_radius, max_radius)\n",
|
||
" img = img.filter(ImageFilter.BoxBlur(radius))\n",
|
||
"\n",
|
||
" elif blur_type == \"gaussian\":\n",
|
||
" min_radius = blur_min_limits[\"gaussian\"]\n",
|
||
" max_radius = blur_max_limits[\"gaussian\"]\n",
|
||
" radius = random.uniform(min_radius, max_radius)\n",
|
||
" img = img.filter(ImageFilter.GaussianBlur(radius))\n",
|
||
"\n",
|
||
" elif blur_type == \"motion\":\n",
|
||
" min_kernel = blur_min_limits[\"motion\"]\n",
|
||
" max_kernel = blur_max_limits[\"motion\"]\n",
|
||
" kernel_size = random.randint(min_kernel, max_kernel)\n",
|
||
" kernel_motion_blur = np.zeros((kernel_size, kernel_size))\n",
|
||
" kernel_motion_blur[int((kernel_size - 1) / 2), :] = 1\n",
|
||
" kernel_motion_blur /= kernel_size\n",
|
||
" img_array = cv2.filter2D(np.array(img), -1, kernel_motion_blur)\n",
|
||
" img = Image.fromarray(img_array)\n",
|
||
"\n",
|
||
" return img\n",
|
||
"\n",
|
||
"\n",
|
||
"def add_noise(img, intensity=30):\n",
|
||
" img_array = np.array(img)\n",
|
||
" noise = np.random.normal(0, intensity, img_array.shape).astype(np.int32)\n",
|
||
" noisy_img = np.clip(img_array + noise, 0, 255).astype(np.uint8)\n",
|
||
" return Image.fromarray(noisy_img)\n",
|
||
"\n",
|
||
"def shift_image(img, max_shift=4):\n",
|
||
" \"\"\"\n",
|
||
" Сдвигает изображение на случайное количество пикселей по осям X и Y.\n",
|
||
"\n",
|
||
" :param img: PIL Image объект.\n",
|
||
" :param max_shift: Максимальное смещение в пикселях по каждой оси.\n",
|
||
" :return: Сдвинутое изображение.\n",
|
||
" \"\"\"\n",
|
||
" x_shift = random.randint(-max_shift, max_shift)\n",
|
||
" y_shift = random.randint(-max_shift, max_shift)\n",
|
||
" return img.transform(\n",
|
||
" img.size,\n",
|
||
" Image.AFFINE,\n",
|
||
" (1, 0, x_shift, 0, 1, y_shift),\n",
|
||
" fillcolor=255\n",
|
||
" )\n",
|
||
"\n",
|
||
"def create_augmentations(symbol, font_sizes, max_rotation, augmentations, config_lines):\n",
|
||
" for i in range(augmentations):\n",
|
||
" for font_size in font_sizes:\n",
|
||
" current_font = ImageFont.truetype(FONT_PATH, font_size)\n",
|
||
"\n",
|
||
" # Создаём пустое изображение с белым фоном\n",
|
||
" img = Image.new(\"L\", IMG_SIZE, 255)\n",
|
||
" draw = ImageDraw.Draw(img)\n",
|
||
"\n",
|
||
" # Получаем размеры текста с использованием Font.getbbox\n",
|
||
" text_bbox = draw.textbbox((0, 0), symbol, font=current_font)\n",
|
||
" text_width, text_height = text_bbox[2] - text_bbox[0], text_bbox[3] - text_bbox[1]\n",
|
||
"\n",
|
||
" # Рассчитываем координаты для центрирования текста\n",
|
||
" x = (IMG_SIZE[0] - text_width) // 2 - text_bbox[0]\n",
|
||
" y = (IMG_SIZE[1] - text_height) // 2 - text_bbox[1]\n",
|
||
"\n",
|
||
" # Рисуем текст\n",
|
||
" draw.text((x, y), symbol, font=current_font, fill=0)\n",
|
||
"\n",
|
||
" # Для первого прогона символа с данным размером шрифта сохраняем без аугментаций\n",
|
||
" if i == 0:\n",
|
||
" symbol_dir = os.path.join(DATASET_DIR, symbol)\n",
|
||
" os.makedirs(symbol_dir, exist_ok=True)\n",
|
||
" file_path = os.path.join(symbol_dir, f\"{symbol}_{font_size}_{i}.png\")\n",
|
||
" img.save(file_path)\n",
|
||
"\n",
|
||
" config_lines.append(f\"{file_path},{symbol}\\n\")\n",
|
||
" continue\n",
|
||
"\n",
|
||
" # Случайное смещение символа\n",
|
||
" if random.random() < 0.7: \n",
|
||
" img = shift_image(img, max_shift=4)\n",
|
||
"\n",
|
||
" # Случайный поворот изображения\n",
|
||
" rotation = random.uniform(-max_rotation, max_rotation)\n",
|
||
" img = img.rotate(rotation, expand=False, fillcolor=255)\n",
|
||
"\n",
|
||
" # Случайное расширение символов\n",
|
||
" if random.random() < 0.5:\n",
|
||
" kernel_size = (random.randint(1, 2), random.randint(1, 2))\n",
|
||
" iterations = 1\n",
|
||
" img = expand_characters(img, kernel_size=kernel_size, iterations=iterations)\n",
|
||
"\n",
|
||
" # Случайное добавление разрывов\n",
|
||
" if random.random() < 0.15:\n",
|
||
" num_gaps = random.randint(1, 5)\n",
|
||
" gap_size = random.randint(2, 4)\n",
|
||
" img = add_random_gaps(img, num_gaps=num_gaps, gap_size=gap_size)\n",
|
||
"\n",
|
||
" # Инверсия цветов с вероятностью 25%\n",
|
||
" if random.random() < 0.15:\n",
|
||
" img = ImageOps.invert(img)\n",
|
||
"\n",
|
||
" # Определяем лимиты блюра в зависимости от размера шрифта\n",
|
||
" if font_size in SMALL_FONT_SIZES:\n",
|
||
" blur_limits = {\"box\": 2, \"gaussian\": 1, \"motion\": 6} # Уменьшенные лимиты\n",
|
||
" else:\n",
|
||
" blur_limits = {\"box\": 4, \"gaussian\": 3, \"motion\": 8} # Оригинальные лимиты\n",
|
||
"\n",
|
||
" # Случайное добавление блюра\n",
|
||
" if random.random() < 0.85:\n",
|
||
" img = apply_blur(img, blur_max_limits=blur_limits)\n",
|
||
"\n",
|
||
" # Случайное добавление шума\n",
|
||
" if random.random() < 0.1:\n",
|
||
" noise_intensity = random.randint(10, 20)\n",
|
||
" img = add_noise(img, intensity=noise_intensity)\n",
|
||
"\n",
|
||
" # Сохраняем изображение\n",
|
||
" symbol_dir = os.path.join(DATASET_DIR, symbol)\n",
|
||
" os.makedirs(symbol_dir, exist_ok=True)\n",
|
||
" file_path = os.path.join(symbol_dir, f\"{symbol}_{font_size}_{i}.png\")\n",
|
||
" img.save(file_path)\n",
|
||
"\n",
|
||
" config_lines.append(f\"{file_path},{symbol}\\n\")\n",
|
||
"\n",
|
||
"def main():\n",
|
||
" config_lines = []\n",
|
||
"\n",
|
||
" for symbol in SYMBOLS:\n",
|
||
" create_augmentations(symbol, FONT_SIZES, MAX_ROTATION, AUGMENTATIONS, config_lines)\n",
|
||
"\n",
|
||
" config_path = os.path.join(DATASET_DIR, CONFIG_FILE)\n",
|
||
" with open(config_path, \"w\") as config:\n",
|
||
" config.writelines(config_lines)\n",
|
||
"\n",
|
||
" print(f\"Датасет сохранен в папке '{DATASET_DIR}'.\")\n",
|
||
" print(f\"Конфигурационный файл создан: {CONFIG_FILE}\")\n",
|
||
"\n",
|
||
"def main():\n",
|
||
" config_lines = []\n",
|
||
"\n",
|
||
" for symbol in SYMBOLS:\n",
|
||
" create_augmentations(symbol, FONT_SIZES, MAX_ROTATION, AUGMENTATIONS, config_lines)\n",
|
||
"\n",
|
||
" config_lines += process_gosznak_images(GOSZNAK_DIR, DATASET_DIR, apply_augmentations=True)\n",
|
||
"\n",
|
||
" config_lines += process_gosznak_images(REAL_GOSZNAK_DIR, DATASET_DIR, apply_augmentations=False)\n",
|
||
"\n",
|
||
" config_path = os.path.join(DATASET_DIR, CONFIG_FILE)\n",
|
||
" with open(config_path, \"w\") as config:\n",
|
||
" config.writelines(config_lines)\n",
|
||
"\n",
|
||
" print(f\"Датасет сохранен в папке '{DATASET_DIR}'.\")\n",
|
||
" print(f\"Конфигурационный файл создан: {CONFIG_FILE}\")\n",
|
||
"\n",
|
||
"if __name__ == \"__main__\":\n",
|
||
" main()\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Размер тренировочного набора: 29568\n",
|
||
"Размер тестового набора: 7392\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"class LicensePlateDataset(Dataset):\n",
|
||
" def __init__(self, dataset_dir, transform=None):\n",
|
||
" self.data = []\n",
|
||
" self.transform = transform\n",
|
||
" with open(os.path.join(dataset_dir, \"dataset_config.txt\"), \"r\") as file:\n",
|
||
" for line in file:\n",
|
||
" path, label = line.strip().split(\",\")\n",
|
||
" self.data.append((path, CLASS_TO_IDX[label]))\n",
|
||
"\n",
|
||
" def __len__(self):\n",
|
||
" return len(self.data)\n",
|
||
"\n",
|
||
" def __getitem__(self, idx):\n",
|
||
" img_path, label = self.data[idx]\n",
|
||
" image = Image.open(img_path).convert(\"L\")\n",
|
||
" if self.transform:\n",
|
||
" image = self.transform(image)\n",
|
||
" return image, label\n",
|
||
"\n",
|
||
"\n",
|
||
"transform = transforms.Compose([\n",
|
||
" transforms.Resize((28, 28), antialias=True), \n",
|
||
" transforms.ToTensor(),\n",
|
||
" transforms.RandomResizedCrop(size=(32, 32), scale=(0.8, 1.0), antialias=True),\n",
|
||
" transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),\n",
|
||
" transforms.RandomAffine(degrees=5, translate=(0.1, 0.1), scale=(0.9, 1.1), shear=5),\n",
|
||
" transforms.RandomApply([transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0))], p=0.5),\n",
|
||
" transforms.RandomErasing(p=0.3, scale=(0.02, 0.1), ratio=(0.3, 3.3)),\n",
|
||
" transforms.Normalize(mean=[0.5], std=[0.5])\n",
|
||
"])\n",
|
||
"\n",
|
||
"full_dataset = LicensePlateDataset(DATASET_DIR, transform=transform)\n",
|
||
"\n",
|
||
"indices = list(range(len(full_dataset)))\n",
|
||
"train_indices, test_indices = train_test_split(indices, test_size=0.2, random_state=42)\n",
|
||
"\n",
|
||
"train_dataset = Subset(full_dataset, train_indices)\n",
|
||
"test_dataset = Subset(full_dataset, test_indices)\n",
|
||
"\n",
|
||
"train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)\n",
|
||
"test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)\n",
|
||
"\n",
|
||
"print(f\"Размер тренировочного набора: {len(train_dataset)}\")\n",
|
||
"print(f\"Размер тестового набора: {len(test_dataset)}\")\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 56,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"class ImprovedCNN(nn.Module):\n",
|
||
" def __init__(self, num_classes):\n",
|
||
" super(ImprovedCNN, self).__init__()\n",
|
||
" self.conv_layers = nn.Sequential(\n",
|
||
" nn.Conv2d(1, 32, kernel_size=2, padding=1),\n",
|
||
" nn.LeakyReLU(negative_slope=0.1),\n",
|
||
" nn.BatchNorm2d(32),\n",
|
||
" nn.MaxPool2d(kernel_size=2, stride=2),\n",
|
||
"\n",
|
||
" nn.Conv2d(32, 64, kernel_size=3, padding=1),\n",
|
||
" nn.ELU(),\n",
|
||
" nn.BatchNorm2d(64),\n",
|
||
" nn.MaxPool2d(kernel_size=2, stride=2),\n",
|
||
"\n",
|
||
" nn.Conv2d(64, 128, kernel_size=3, padding=1),\n",
|
||
" nn.SiLU(),\n",
|
||
" nn.BatchNorm2d(128),\n",
|
||
" nn.MaxPool2d(kernel_size=2, stride=2),\n",
|
||
"\n",
|
||
" nn.Conv2d(128, 256, kernel_size=4, padding=1),\n",
|
||
" nn.LeakyReLU(negative_slope=0.1),\n",
|
||
" nn.BatchNorm2d(256),\n",
|
||
" nn.MaxPool2d(kernel_size=2, stride=2)\n",
|
||
" )\n",
|
||
" self.fc_layers = nn.Sequential(\n",
|
||
" nn.Linear(256, 256),\n",
|
||
" nn.GELU(),\n",
|
||
" nn.Dropout(0.5),\n",
|
||
" nn.Linear(256, num_classes)\n",
|
||
" )\n",
|
||
"\n",
|
||
" def forward(self, x):\n",
|
||
" x = self.conv_layers(x)\n",
|
||
" x = x.view(x.size(0), -1) \n",
|
||
" x = self.fc_layers(x)\n",
|
||
" return x"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 57,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"model = ImprovedCNN(len(CLASSES)).to(DEVICE)\n",
|
||
"criterion = nn.CrossEntropyLoss(label_smoothing=0.1)\n",
|
||
"optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=0.9, weight_decay=1e-4)\n",
|
||
"#optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n",
|
||
"scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 58,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Epoch [1/30]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Train Loss: 1.4878, Train Accuracy: 70.34%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Test Loss: 0.9624, Test Accuracy: 90.64%\n",
|
||
"Best model saved with test accuracy: 90.64%\n",
|
||
"Epoch [2/30]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Train Loss: 0.9608, Train Accuracy: 91.21%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Test Loss: 0.8277, Test Accuracy: 95.06%\n",
|
||
"Best model saved with test accuracy: 95.06%\n",
|
||
"Epoch [3/30]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Train Loss: 0.8817, Train Accuracy: 93.65%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Test Loss: 0.7844, Test Accuracy: 96.32%\n",
|
||
"Best model saved with test accuracy: 96.32%\n",
|
||
"Epoch [4/30]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Train Loss: 0.8468, Train Accuracy: 94.64%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Test Loss: 0.7595, Test Accuracy: 96.78%\n",
|
||
"Best model saved with test accuracy: 96.78%\n",
|
||
"Epoch [5/30]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Train Loss: 0.8135, Train Accuracy: 95.71%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Test Loss: 0.7514, Test Accuracy: 96.74%\n",
|
||
"Epoch [6/30]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Train Loss: 0.7922, Train Accuracy: 96.37%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Test Loss: 0.7277, Test Accuracy: 97.38%\n",
|
||
"Best model saved with test accuracy: 97.38%\n",
|
||
"Epoch [7/30]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Train Loss: 0.7822, Train Accuracy: 96.46%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Test Loss: 0.7196, Test Accuracy: 97.66%\n",
|
||
"Best model saved with test accuracy: 97.66%\n",
|
||
"Epoch [8/30]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Train Loss: 0.7708, Train Accuracy: 96.68%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Test Loss: 0.7192, Test Accuracy: 97.47%\n",
|
||
"Epoch [9/30]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Train Loss: 0.7652, Train Accuracy: 96.80%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Test Loss: 0.7154, Test Accuracy: 97.46%\n",
|
||
"Epoch [10/30]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Train Loss: 0.7518, Train Accuracy: 97.23%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Test Loss: 0.7117, Test Accuracy: 97.71%\n",
|
||
"Best model saved with test accuracy: 97.71%\n",
|
||
"Epoch [11/30]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Train Loss: 0.7467, Train Accuracy: 97.33%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Test Loss: 0.7145, Test Accuracy: 97.39%\n",
|
||
"Epoch [12/30]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Train Loss: 0.7387, Train Accuracy: 97.45%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Test Loss: 0.7060, Test Accuracy: 97.58%\n",
|
||
"Epoch [13/30]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Train Loss: 0.7375, Train Accuracy: 97.44%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Test Loss: 0.6970, Test Accuracy: 97.94%\n",
|
||
"Best model saved with test accuracy: 97.94%\n",
|
||
"Epoch [14/30]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Train Loss: 0.7323, Train Accuracy: 97.72%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Test Loss: 0.6920, Test Accuracy: 98.27%\n",
|
||
"Best model saved with test accuracy: 98.27%\n",
|
||
"Epoch [15/30]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Train Loss: 0.7291, Train Accuracy: 97.62%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Test Loss: 0.6895, Test Accuracy: 98.08%\n",
|
||
"Epoch [16/30]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Train Loss: 0.7238, Train Accuracy: 97.80%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Test Loss: 0.6878, Test Accuracy: 98.04%\n",
|
||
"Epoch [17/30]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Train Loss: 0.7195, Train Accuracy: 97.96%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Test Loss: 0.6834, Test Accuracy: 98.17%\n",
|
||
"Epoch [18/30]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Train Loss: 0.7212, Train Accuracy: 97.80%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Test Loss: 0.6863, Test Accuracy: 98.21%\n",
|
||
"Epoch [19/30]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Train Loss: 0.7128, Train Accuracy: 98.19%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Test Loss: 0.6845, Test Accuracy: 98.30%\n",
|
||
"Best model saved with test accuracy: 98.30%\n",
|
||
"Epoch [20/30]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Train Loss: 0.7144, Train Accuracy: 98.06%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Test Loss: 0.6818, Test Accuracy: 98.31%\n",
|
||
"Best model saved with test accuracy: 98.31%\n",
|
||
"Epoch [21/30]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Train Loss: 0.7103, Train Accuracy: 98.17%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Test Loss: 0.6837, Test Accuracy: 98.39%\n",
|
||
"Best model saved with test accuracy: 98.39%\n",
|
||
"Epoch [22/30]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Train Loss: 0.7096, Train Accuracy: 98.12%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Test Loss: 0.6801, Test Accuracy: 98.19%\n",
|
||
"Epoch [23/30]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Train Loss: 0.7069, Train Accuracy: 98.22%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Test Loss: 0.6810, Test Accuracy: 98.15%\n",
|
||
"Epoch [24/30]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Train Loss: 0.7073, Train Accuracy: 98.20%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Test Loss: 0.6748, Test Accuracy: 98.53%\n",
|
||
"Best model saved with test accuracy: 98.53%\n",
|
||
"Epoch [25/30]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Train Loss: 0.7019, Train Accuracy: 98.35%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Test Loss: 0.6773, Test Accuracy: 98.19%\n",
|
||
"Epoch [26/30]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Train Loss: 0.7007, Train Accuracy: 98.31%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Test Loss: 0.6742, Test Accuracy: 98.38%\n",
|
||
"Epoch [27/30]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Train Loss: 0.7010, Train Accuracy: 98.31%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Test Loss: 0.6722, Test Accuracy: 98.58%\n",
|
||
"Best model saved with test accuracy: 98.58%\n",
|
||
"Epoch [28/30]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Train Loss: 0.6999, Train Accuracy: 98.32%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Test Loss: 0.6736, Test Accuracy: 98.53%\n",
|
||
"Epoch [29/30]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Train Loss: 0.6981, Train Accuracy: 98.41%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Test Loss: 0.6707, Test Accuracy: 98.62%\n",
|
||
"Best model saved with test accuracy: 98.62%\n",
|
||
"Epoch [30/30]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Train Loss: 0.6976, Train Accuracy: 98.36%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" "
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Test Loss: 0.6715, Test Accuracy: 98.48%\n",
|
||
"Обучение завершено.\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"\r"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"def train_one_epoch(model, train_loader, optimizer, criterion, device):\n",
|
||
" model.train()\n",
|
||
" running_loss = 0.0\n",
|
||
" correct = 0\n",
|
||
" total = 0\n",
|
||
"\n",
|
||
" train_bar = tqdm(train_loader, desc=\"Training\", leave=False)\n",
|
||
" for images, labels in train_bar:\n",
|
||
" images, labels = images.to(device), labels.to(device)\n",
|
||
"\n",
|
||
" optimizer.zero_grad()\n",
|
||
" outputs = model(images)\n",
|
||
" loss = criterion(outputs, labels)\n",
|
||
" loss.backward()\n",
|
||
" optimizer.step()\n",
|
||
"\n",
|
||
" running_loss += loss.item()\n",
|
||
" _, predicted = torch.max(outputs, 1)\n",
|
||
" total += labels.size(0)\n",
|
||
" correct += (predicted == labels).sum().item()\n",
|
||
"\n",
|
||
" train_bar.set_postfix(loss=f\"{running_loss / len(train_loader):.4f}\", acc=f\"{100 * correct / total:.2f}%\")\n",
|
||
"\n",
|
||
" train_loss = running_loss / len(train_loader)\n",
|
||
" train_accuracy = 100 * correct / total\n",
|
||
" return train_loss, train_accuracy\n",
|
||
"\n",
|
||
"\n",
|
||
"def evaluate(model, test_loader, criterion, device):\n",
|
||
" model.eval()\n",
|
||
" running_loss = 0.0\n",
|
||
" correct = 0\n",
|
||
" total = 0\n",
|
||
"\n",
|
||
" with torch.no_grad():\n",
|
||
" test_bar = tqdm(test_loader, desc=\"Testing\", leave=False)\n",
|
||
" for images, labels in test_bar:\n",
|
||
" images, labels = images.to(device), labels.to(device)\n",
|
||
"\n",
|
||
" outputs = model(images)\n",
|
||
" loss = criterion(outputs, labels)\n",
|
||
" running_loss += loss.item()\n",
|
||
"\n",
|
||
" _, predicted = torch.max(outputs, 1)\n",
|
||
" total += labels.size(0)\n",
|
||
" correct += (predicted == labels).sum().item()\n",
|
||
"\n",
|
||
" test_loss = running_loss / len(test_loader)\n",
|
||
" test_accuracy = 100 * correct / total\n",
|
||
" return test_loss, test_accuracy\n",
|
||
"\n",
|
||
"\n",
|
||
"best_accuracy = 0.0\n",
|
||
"for epoch in range(EPOCHS):\n",
|
||
" print(f\"Epoch [{epoch + 1}/{EPOCHS}]\")\n",
|
||
" train_loss, train_accuracy = train_one_epoch(model, train_loader, optimizer, criterion, DEVICE)\n",
|
||
" print(f\"Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%\")\n",
|
||
"\n",
|
||
" test_loss, test_accuracy = evaluate(model, test_loader, criterion, DEVICE)\n",
|
||
" print(f\"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%\")\n",
|
||
"\n",
|
||
" if test_accuracy > best_accuracy:\n",
|
||
" best_accuracy = test_accuracy\n",
|
||
" torch.save(model.state_dict(), SAVE_PATH)\n",
|
||
" print(f\"Best model saved with test accuracy: {best_accuracy:.2f}%\")\n",
|
||
"\n",
|
||
" if scheduler is not None:\n",
|
||
" scheduler.step(test_loss)\n",
|
||
"\n",
|
||
"print(\"Обучение завершено.\")\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 59,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Epoch [1/30]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Train Loss: 0.6945, Train Accuracy: 98.48%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Test Loss: 0.6733, Test Accuracy: 98.39%\n",
|
||
"Epoch [2/30]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Train Loss: 0.6948, Train Accuracy: 98.46%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Test Loss: 0.6739, Test Accuracy: 98.16%\n",
|
||
"Epoch [3/30]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Train Loss: 0.6932, Train Accuracy: 98.45%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Test Loss: 0.6719, Test Accuracy: 98.54%\n",
|
||
"Epoch [4/30]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Train Loss: 0.6927, Train Accuracy: 98.49%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Test Loss: 0.6723, Test Accuracy: 98.47%\n",
|
||
"Epoch [5/30]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Train Loss: 0.6922, Train Accuracy: 98.56%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Test Loss: 0.6768, Test Accuracy: 98.28%\n",
|
||
"Epoch [6/30]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Train Loss: 0.6823, Train Accuracy: 98.76%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Test Loss: 0.6567, Test Accuracy: 98.96%\n",
|
||
"Best model saved with test accuracy: 98.96%\n",
|
||
"Epoch [7/30]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Train Loss: 0.6790, Train Accuracy: 98.81%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Test Loss: 0.6565, Test Accuracy: 98.84%\n",
|
||
"Epoch [8/30]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Train Loss: 0.6764, Train Accuracy: 98.97%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Test Loss: 0.6542, Test Accuracy: 98.89%\n",
|
||
"Epoch [9/30]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Train Loss: 0.6734, Train Accuracy: 99.06%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Test Loss: 0.6565, Test Accuracy: 98.82%\n",
|
||
"Epoch [10/30]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Train Loss: 0.6737, Train Accuracy: 99.02%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Test Loss: 0.6530, Test Accuracy: 99.09%\n",
|
||
"Best model saved with test accuracy: 99.09%\n",
|
||
"Epoch [11/30]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Train Loss: 0.6714, Train Accuracy: 99.03%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Test Loss: 0.6546, Test Accuracy: 98.90%\n",
|
||
"Epoch [12/30]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Train Loss: 0.6727, Train Accuracy: 99.00%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Test Loss: 0.6534, Test Accuracy: 98.92%\n",
|
||
"Epoch [13/30]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Train Loss: 0.6726, Train Accuracy: 98.94%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Test Loss: 0.6549, Test Accuracy: 98.76%\n",
|
||
"Epoch [14/30]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Train Loss: 0.6707, Train Accuracy: 99.05%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Test Loss: 0.6516, Test Accuracy: 99.07%\n",
|
||
"Epoch [15/30]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Train Loss: 0.6703, Train Accuracy: 99.10%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Test Loss: 0.6541, Test Accuracy: 98.93%\n",
|
||
"Epoch [16/30]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Train Loss: 0.6719, Train Accuracy: 99.02%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Test Loss: 0.6496, Test Accuracy: 98.99%\n",
|
||
"Epoch [17/30]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Train Loss: 0.6705, Train Accuracy: 99.03%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Test Loss: 0.6534, Test Accuracy: 98.89%\n",
|
||
"Epoch [18/30]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Train Loss: 0.6704, Train Accuracy: 99.09%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Test Loss: 0.6491, Test Accuracy: 99.04%\n",
|
||
"Epoch [19/30]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Train Loss: 0.6696, Train Accuracy: 99.02%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Test Loss: 0.6553, Test Accuracy: 98.73%\n",
|
||
"Epoch [20/30]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Train Loss: 0.6693, Train Accuracy: 99.12%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Test Loss: 0.6515, Test Accuracy: 99.01%\n",
|
||
"Epoch [21/30]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Train Loss: 0.6686, Train Accuracy: 99.13%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Test Loss: 0.6479, Test Accuracy: 99.08%\n",
|
||
"Epoch [22/30]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Train Loss: 0.6675, Train Accuracy: 99.20%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Test Loss: 0.6485, Test Accuracy: 99.04%\n",
|
||
"Epoch [23/30]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Train Loss: 0.6674, Train Accuracy: 99.23%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Test Loss: 0.6466, Test Accuracy: 99.03%\n",
|
||
"Epoch [24/30]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Train Loss: 0.6687, Train Accuracy: 99.07%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Test Loss: 0.6466, Test Accuracy: 99.22%\n",
|
||
"Best model saved with test accuracy: 99.22%\n",
|
||
"Epoch [25/30]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Train Loss: 0.6672, Train Accuracy: 99.20%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Test Loss: 0.6505, Test Accuracy: 98.99%\n",
|
||
"Epoch [26/30]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Train Loss: 0.6679, Train Accuracy: 99.22%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Test Loss: 0.6492, Test Accuracy: 99.05%\n",
|
||
"Epoch [27/30]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Train Loss: 0.6690, Train Accuracy: 99.04%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Test Loss: 0.6485, Test Accuracy: 99.09%\n",
|
||
"Epoch [28/30]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Train Loss: 0.6705, Train Accuracy: 99.00%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Test Loss: 0.6498, Test Accuracy: 98.97%\n",
|
||
"Epoch [29/30]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Train Loss: 0.6664, Train Accuracy: 99.20%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Test Loss: 0.6498, Test Accuracy: 98.93%\n",
|
||
"Epoch [30/30]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Train Loss: 0.6663, Train Accuracy: 99.20%\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" "
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Test Loss: 0.6512, Test Accuracy: 98.85%\n",
|
||
"Обучение завершено.\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"\r"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"for epoch in range(EPOCHS):\n",
|
||
" print(f\"Epoch [{epoch + 1}/{EPOCHS}]\")\n",
|
||
" train_loss, train_accuracy = train_one_epoch(model, train_loader, optimizer, criterion, DEVICE)\n",
|
||
" print(f\"Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%\")\n",
|
||
"\n",
|
||
" test_loss, test_accuracy = evaluate(model, test_loader, criterion, DEVICE)\n",
|
||
" print(f\"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%\")\n",
|
||
"\n",
|
||
" if test_accuracy > best_accuracy:\n",
|
||
" best_accuracy = test_accuracy\n",
|
||
" torch.save(model.state_dict(), SAVE_PATH)\n",
|
||
" print(f\"Best model saved with test accuracy: {best_accuracy:.2f}%\")\n",
|
||
"\n",
|
||
" if scheduler is not None:\n",
|
||
" scheduler.step(test_loss)\n",
|
||
"\n",
|
||
"print(\"Обучение завершено.\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 35,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Epoch [1/30]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"ename": "KeyboardInterrupt",
|
||
"evalue": "",
|
||
"output_type": "error",
|
||
"traceback": [
|
||
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
|
||
"\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
||
"Cell \u001b[1;32mIn[35], line 13\u001b[0m\n\u001b[0;32m 10\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mEpoch [\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mepoch\u001b[38;5;241m+\u001b[39m\u001b[38;5;241m1\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m/\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mEPOCHS\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m]\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m 11\u001b[0m train_bar \u001b[38;5;241m=\u001b[39m tqdm(train_loader, desc\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTraining\u001b[39m\u001b[38;5;124m\"\u001b[39m, leave\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[1;32m---> 13\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m images, labels \u001b[38;5;129;01min\u001b[39;00m train_bar:\n\u001b[0;32m 14\u001b[0m images, labels \u001b[38;5;241m=\u001b[39m images\u001b[38;5;241m.\u001b[39mto(DEVICE), labels\u001b[38;5;241m.\u001b[39mto(DEVICE)\n\u001b[0;32m 15\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n",
|
||
"File \u001b[1;32mc:\\Users\\leonk\\Documents\\code\\number-plate-study\\.venv\\lib\\site-packages\\tqdm\\std.py:1181\u001b[0m, in \u001b[0;36mtqdm.__iter__\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 1178\u001b[0m time \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_time\n\u001b[0;32m 1180\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m-> 1181\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m obj \u001b[38;5;129;01min\u001b[39;00m iterable:\n\u001b[0;32m 1182\u001b[0m \u001b[38;5;28;01myield\u001b[39;00m obj\n\u001b[0;32m 1183\u001b[0m \u001b[38;5;66;03m# Update and possibly print the progressbar.\u001b[39;00m\n\u001b[0;32m 1184\u001b[0m \u001b[38;5;66;03m# Note: does not call self.update(1) for speed optimisation.\u001b[39;00m\n",
|
||
"File \u001b[1;32mc:\\Users\\leonk\\Documents\\code\\number-plate-study\\.venv\\lib\\site-packages\\torch\\utils\\data\\dataloader.py:633\u001b[0m, in \u001b[0;36m_BaseDataLoaderIter.__next__\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 630\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_sampler_iter \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m 631\u001b[0m \u001b[38;5;66;03m# TODO(https://github.com/pytorch/pytorch/issues/76750)\u001b[39;00m\n\u001b[0;32m 632\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reset() \u001b[38;5;66;03m# type: ignore[call-arg]\u001b[39;00m\n\u001b[1;32m--> 633\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_next_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 634\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[0;32m 635\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dataset_kind \u001b[38;5;241m==\u001b[39m _DatasetKind\u001b[38;5;241m.\u001b[39mIterable \u001b[38;5;129;01mand\u001b[39;00m \\\n\u001b[0;32m 636\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \\\n\u001b[0;32m 637\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m>\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called:\n",
|
||
"File \u001b[1;32mc:\\Users\\leonk\\Documents\\code\\number-plate-study\\.venv\\lib\\site-packages\\torch\\utils\\data\\dataloader.py:677\u001b[0m, in \u001b[0;36m_SingleProcessDataLoaderIter._next_data\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 675\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_next_data\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[0;32m 676\u001b[0m index \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_next_index() \u001b[38;5;66;03m# may raise StopIteration\u001b[39;00m\n\u001b[1;32m--> 677\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_dataset_fetcher\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfetch\u001b[49m\u001b[43m(\u001b[49m\u001b[43mindex\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# may raise StopIteration\u001b[39;00m\n\u001b[0;32m 678\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pin_memory:\n\u001b[0;32m 679\u001b[0m data \u001b[38;5;241m=\u001b[39m _utils\u001b[38;5;241m.\u001b[39mpin_memory\u001b[38;5;241m.\u001b[39mpin_memory(data, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pin_memory_device)\n",
|
||
"File \u001b[1;32mc:\\Users\\leonk\\Documents\\code\\number-plate-study\\.venv\\lib\\site-packages\\torch\\utils\\data\\_utils\\fetch.py:51\u001b[0m, in \u001b[0;36m_MapDatasetFetcher.fetch\u001b[1;34m(self, possibly_batched_index)\u001b[0m\n\u001b[0;32m 49\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset\u001b[38;5;241m.\u001b[39m__getitems__(possibly_batched_index)\n\u001b[0;32m 50\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m---> 51\u001b[0m data \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset[idx] \u001b[38;5;28;01mfor\u001b[39;00m idx \u001b[38;5;129;01min\u001b[39;00m possibly_batched_index]\n\u001b[0;32m 52\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m 53\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset[possibly_batched_index]\n",
|
||
"File \u001b[1;32mc:\\Users\\leonk\\Documents\\code\\number-plate-study\\.venv\\lib\\site-packages\\torch\\utils\\data\\_utils\\fetch.py:51\u001b[0m, in \u001b[0;36m<listcomp>\u001b[1;34m(.0)\u001b[0m\n\u001b[0;32m 49\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset\u001b[38;5;241m.\u001b[39m__getitems__(possibly_batched_index)\n\u001b[0;32m 50\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m---> 51\u001b[0m data \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdataset\u001b[49m\u001b[43m[\u001b[49m\u001b[43midx\u001b[49m\u001b[43m]\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m idx \u001b[38;5;129;01min\u001b[39;00m possibly_batched_index]\n\u001b[0;32m 52\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m 53\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset[possibly_batched_index]\n",
|
||
"File \u001b[1;32mc:\\Users\\leonk\\Documents\\code\\number-plate-study\\.venv\\lib\\site-packages\\torch\\utils\\data\\dataset.py:298\u001b[0m, in \u001b[0;36mSubset.__getitem__\u001b[1;34m(self, idx)\u001b[0m\n\u001b[0;32m 296\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(idx, \u001b[38;5;28mlist\u001b[39m):\n\u001b[0;32m 297\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset[[\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mindices[i] \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m idx]]\n\u001b[1;32m--> 298\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdataset\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mindices\u001b[49m\u001b[43m[\u001b[49m\u001b[43midx\u001b[49m\u001b[43m]\u001b[49m\u001b[43m]\u001b[49m\n",
|
||
"Cell \u001b[1;32mIn[32], line 15\u001b[0m, in \u001b[0;36mLicensePlateDataset.__getitem__\u001b[1;34m(self, idx)\u001b[0m\n\u001b[0;32m 13\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__getitem__\u001b[39m(\u001b[38;5;28mself\u001b[39m, idx):\n\u001b[0;32m 14\u001b[0m img_path, label \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdata[idx]\n\u001b[1;32m---> 15\u001b[0m image \u001b[38;5;241m=\u001b[39m \u001b[43mImage\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mopen\u001b[49m\u001b[43m(\u001b[49m\u001b[43mimg_path\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mconvert(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mL\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m 16\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtransform:\n\u001b[0;32m 17\u001b[0m image \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtransform(image)\n",
|
||
"File \u001b[1;32mc:\\Users\\leonk\\Documents\\code\\number-plate-study\\.venv\\lib\\site-packages\\PIL\\Image.py:3247\u001b[0m, in \u001b[0;36mopen\u001b[1;34m(fp, mode, formats)\u001b[0m\n\u001b[0;32m 3244\u001b[0m filename \u001b[38;5;241m=\u001b[39m fp\n\u001b[0;32m 3246\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m filename:\n\u001b[1;32m-> 3247\u001b[0m fp \u001b[38;5;241m=\u001b[39m \u001b[43mbuiltins\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mopen\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfilename\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mrb\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[0;32m 3248\u001b[0m exclusive_fp \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[0;32m 3250\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n",
|
||
"\u001b[1;31mKeyboardInterrupt\u001b[0m: "
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"from tqdm import tqdm\n",
|
||
"\n",
|
||
"best_accuracy = 0.0\n",
|
||
"for epoch in range(EPOCHS):\n",
|
||
" model.train()\n",
|
||
" running_loss = 0.0\n",
|
||
" correct_train = 0\n",
|
||
" total_train = 0\n",
|
||
"\n",
|
||
" print(f\"Epoch [{epoch+1}/{EPOCHS}]\")\n",
|
||
" train_bar = tqdm(train_loader, desc=\"Training\", leave=False)\n",
|
||
"\n",
|
||
" for images, labels in train_bar:\n",
|
||
" images, labels = images.to(DEVICE), labels.to(DEVICE)\n",
|
||
" optimizer.zero_grad()\n",
|
||
" outputs = model(images)\n",
|
||
" loss = criterion(outputs, labels)\n",
|
||
" loss.backward()\n",
|
||
" optimizer.step()\n",
|
||
"\n",
|
||
" running_loss += loss.item()\n",
|
||
" _, predicted = torch.max(outputs, 1)\n",
|
||
" total_train += labels.size(0)\n",
|
||
" correct_train += (predicted == labels).sum().item()\n",
|
||
"\n",
|
||
" train_bar.set_postfix(loss=f\"{running_loss/len(train_loader):.4f}\", acc=f\"{100 * correct_train / total_train:.2f}%\")\n",
|
||
"\n",
|
||
" train_accuracy = 100 * correct_train / total_train\n",
|
||
" print(f\"Train Loss: {running_loss/len(train_loader):.4f}, Train Accuracy: {train_accuracy:.2f}%\")\n",
|
||
"\n",
|
||
" model.eval()\n",
|
||
" correct_test = 0\n",
|
||
" total_test = 0\n",
|
||
" test_loss = 0.0\n",
|
||
"\n",
|
||
" with torch.no_grad():\n",
|
||
" test_bar = tqdm(test_loader, desc=\"Testing\", leave=False)\n",
|
||
" for images, labels in test_bar:\n",
|
||
" images, labels = images.to(DEVICE), labels.to(DEVICE)\n",
|
||
"\n",
|
||
" outputs = model(images)\n",
|
||
" loss = criterion(outputs, labels)\n",
|
||
" test_loss += loss.item()\n",
|
||
"\n",
|
||
" _, predicted = torch.max(outputs, 1)\n",
|
||
" total_test += labels.size(0)\n",
|
||
" correct_test += (predicted == labels).sum().item()\n",
|
||
"\n",
|
||
" test_accuracy = 100 * correct_test / total_test\n",
|
||
" print(f\"Test Loss: {test_loss/len(test_loader):.4f}, Test Accuracy: {test_accuracy:.2f}%\")\n",
|
||
"\n",
|
||
" if test_accuracy > best_accuracy:\n",
|
||
" best_accuracy = test_accuracy\n",
|
||
" torch.save(model.state_dict(), SAVE_PATH)\n",
|
||
" print(f\"Best model saved with test accuracy: {best_accuracy:.2f}%\")\n",
|
||
"\n",
|
||
"print(\"Обучение завершено.\")\n"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": ".venv",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.9.13"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 2
|
||
}
|