245 lines
10 KiB
Plaintext
245 lines
10 KiB
Plaintext
|
{
|
|||
|
"cells": [
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 2,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"import os\n",
|
|||
|
"import torch\n",
|
|||
|
"import torch.nn as nn\n",
|
|||
|
"import torch.optim as optim\n",
|
|||
|
"from torch.utils.data import DataLoader, Dataset\n",
|
|||
|
"from torchvision import transforms\n",
|
|||
|
"import random\n",
|
|||
|
"from PIL import Image, ImageDraw, ImageFont, ImageOps, ImageFilter\n",
|
|||
|
"import cv2\n",
|
|||
|
"import numpy as np\n",
|
|||
|
"\n",
|
|||
|
"DATASET_DIR = \"dataset\"\n",
|
|||
|
"SAVE_PATH = \"best_model_8.pth\"\n",
|
|||
|
"BATCH_SIZE = 32\n",
|
|||
|
"EPOCHS = 30\n",
|
|||
|
"LEARNING_RATE = 0.001\n",
|
|||
|
"DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
|||
|
"\n",
|
|||
|
"CLASSES = \"ABEKMHOPCTYX0123456789\"\n",
|
|||
|
"NUM_CLASSES = len(CLASSES)\n",
|
|||
|
"CLASS_TO_IDX = {char: idx for idx, char in enumerate(CLASSES)}\n",
|
|||
|
"IDX_TO_CLASS = {idx: char for char, idx in CLASS_TO_IDX.items()}"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Датасет сохранен в папке 'dataset'.\n",
|
|||
|
"Конфигурационный файл создан: dataset_config.txt\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"\n",
|
|||
|
"FONT_PATH = \"RoadNumbers2.0.ttf\"\n",
|
|||
|
"CONFIG_FILE = \"dataset_config.txt\"\n",
|
|||
|
"\n",
|
|||
|
"# Буквы и цифры, используемые в российских номерах\n",
|
|||
|
"SYMBOLS = \"ABEKMHOPCTYX0123456789\"\n",
|
|||
|
"SMALL_FONT_SIZES = [26, 32] # Размеры шрифта с уменьшенным блюром\n",
|
|||
|
"IMG_SIZE = (28, 28)\n",
|
|||
|
"FONT_SIZES = [26, 32, 38, 46, 54, 58]\n",
|
|||
|
"MAX_ROTATION = 15\n",
|
|||
|
"AUGMENTATIONS = 240\n",
|
|||
|
"\n",
|
|||
|
"os.makedirs(DATASET_DIR, exist_ok=True)\n",
|
|||
|
"\n",
|
|||
|
"try:\n",
|
|||
|
" font = ImageFont.truetype(FONT_PATH, FONT_SIZES[0])\n",
|
|||
|
"except IOError:\n",
|
|||
|
" print(f\"Шрифт {FONT_PATH} не найден. Убедитесь, что он находится в текущей директории.\")\n",
|
|||
|
" exit()\n",
|
|||
|
"\n",
|
|||
|
"def expand_characters(img, kernel_size=(2, 2), iterations=1):\n",
|
|||
|
" kernel = cv2.getStructuringElement(cv2.MORPH_RECT, kernel_size)\n",
|
|||
|
" img_array = np.array(img)\n",
|
|||
|
" expanded = cv2.dilate(img_array, kernel, iterations=iterations)\n",
|
|||
|
" return Image.fromarray(expanded)\n",
|
|||
|
"\n",
|
|||
|
"def add_random_gaps(img, num_gaps=5, gap_size=5):\n",
|
|||
|
" draw = ImageDraw.Draw(img)\n",
|
|||
|
" for _ in range(num_gaps):\n",
|
|||
|
" x1 = random.randint(0, IMG_SIZE[0] - gap_size)\n",
|
|||
|
" y1 = random.randint(0, IMG_SIZE[1] - gap_size)\n",
|
|||
|
" x2 = x1 + gap_size\n",
|
|||
|
" y2 = y1 + gap_size\n",
|
|||
|
" draw.rectangle([x1, y1, x2, y2], fill=255) \n",
|
|||
|
" return img\n",
|
|||
|
"\n",
|
|||
|
"def apply_blur(img, blur_limits=None):\n",
|
|||
|
" if blur_limits is None:\n",
|
|||
|
" blur_limits = {\"box\": 4, \"gaussian\": 3, \"motion\": 8}\n",
|
|||
|
"\n",
|
|||
|
" blur_type = random.choice([\"box\", \"gaussian\", \"motion\"])\n",
|
|||
|
"\n",
|
|||
|
" if blur_type == \"box\":\n",
|
|||
|
" radius = random.randint(1, blur_limits[\"box\"])\n",
|
|||
|
" img = img.filter(ImageFilter.BoxBlur(radius))\n",
|
|||
|
" elif blur_type == \"gaussian\":\n",
|
|||
|
" radius = random.randint(1, blur_limits[\"gaussian\"])\n",
|
|||
|
" img = img.filter(ImageFilter.GaussianBlur(radius))\n",
|
|||
|
" elif blur_type == \"motion\":\n",
|
|||
|
" kernel_size = random.randint(4, blur_limits[\"motion\"])\n",
|
|||
|
" kernel_motion_blur = np.zeros((kernel_size, kernel_size))\n",
|
|||
|
" kernel_motion_blur[int((kernel_size - 1) / 2), :] = 1\n",
|
|||
|
" kernel_motion_blur /= kernel_size\n",
|
|||
|
" img_array = cv2.filter2D(np.array(img), -1, kernel_motion_blur)\n",
|
|||
|
" img = Image.fromarray(img_array)\n",
|
|||
|
" return img\n",
|
|||
|
"\n",
|
|||
|
"def add_noise(img, intensity=30):\n",
|
|||
|
" img_array = np.array(img)\n",
|
|||
|
" noise = np.random.normal(0, intensity, img_array.shape).astype(np.int32)\n",
|
|||
|
" noisy_img = np.clip(img_array + noise, 0, 255).astype(np.uint8)\n",
|
|||
|
" return Image.fromarray(noisy_img)\n",
|
|||
|
"\n",
|
|||
|
"def shift_image(img, max_shift=4):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" Сдвигает изображение на случайное количество пикселей по осям X и Y.\n",
|
|||
|
"\n",
|
|||
|
" :param img: PIL Image объект.\n",
|
|||
|
" :param max_shift: Максимальное смещение в пикселях по каждой оси.\n",
|
|||
|
" :return: Сдвинутое изображение.\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" x_shift = random.randint(-max_shift, max_shift)\n",
|
|||
|
" y_shift = random.randint(-max_shift, max_shift)\n",
|
|||
|
" return img.transform(\n",
|
|||
|
" img.size,\n",
|
|||
|
" Image.AFFINE,\n",
|
|||
|
" (1, 0, x_shift, 0, 1, y_shift),\n",
|
|||
|
" fillcolor=255\n",
|
|||
|
" )\n",
|
|||
|
"\n",
|
|||
|
"def create_augmentations(symbol, font_sizes, max_rotation, augmentations, config_lines):\n",
|
|||
|
" for i in range(augmentations):\n",
|
|||
|
" for font_size in font_sizes:\n",
|
|||
|
" current_font = ImageFont.truetype(FONT_PATH, font_size)\n",
|
|||
|
"\n",
|
|||
|
" # Создаём пустое изображение с белым фоном\n",
|
|||
|
" img = Image.new(\"L\", IMG_SIZE, 255)\n",
|
|||
|
" draw = ImageDraw.Draw(img)\n",
|
|||
|
"\n",
|
|||
|
" # Получаем размеры текста с использованием Font.getbbox\n",
|
|||
|
" text_bbox = draw.textbbox((0, 0), symbol, font=current_font)\n",
|
|||
|
" text_width, text_height = text_bbox[2] - text_bbox[0], text_bbox[3] - text_bbox[1]\n",
|
|||
|
"\n",
|
|||
|
" # Рассчитываем координаты для центрирования текста\n",
|
|||
|
" x = (IMG_SIZE[0] - text_width) // 2 - text_bbox[0]\n",
|
|||
|
" y = (IMG_SIZE[1] - text_height) // 2 - text_bbox[1]\n",
|
|||
|
"\n",
|
|||
|
" # Рисуем текст\n",
|
|||
|
" draw.text((x, y), symbol, font=current_font, fill=0)\n",
|
|||
|
"\n",
|
|||
|
" # Для первого прогона символа с данным размером шрифта сохраняем без аугментаций\n",
|
|||
|
" if i == 0:\n",
|
|||
|
" symbol_dir = os.path.join(DATASET_DIR, symbol)\n",
|
|||
|
" os.makedirs(symbol_dir, exist_ok=True)\n",
|
|||
|
" file_path = os.path.join(symbol_dir, f\"{symbol}_{font_size}_{i}.png\")\n",
|
|||
|
" img.save(file_path)\n",
|
|||
|
"\n",
|
|||
|
" config_lines.append(f\"{file_path},{symbol}\\n\")\n",
|
|||
|
" continue\n",
|
|||
|
"\n",
|
|||
|
" # Случайное смещение символа\n",
|
|||
|
" if random.random() < 0.7: \n",
|
|||
|
" img = shift_image(img, max_shift=4)\n",
|
|||
|
"\n",
|
|||
|
" # Случайный поворот изображения\n",
|
|||
|
" rotation = random.uniform(-max_rotation, max_rotation)\n",
|
|||
|
" img = img.rotate(rotation, expand=False, fillcolor=255)\n",
|
|||
|
"\n",
|
|||
|
" # Случайное расширение символов\n",
|
|||
|
" if random.random() < 0.5:\n",
|
|||
|
" kernel_size = (random.randint(1, 2), random.randint(1, 2))\n",
|
|||
|
" iterations = 1\n",
|
|||
|
" img = expand_characters(img, kernel_size=kernel_size, iterations=iterations)\n",
|
|||
|
"\n",
|
|||
|
" # Случайное добавление разрывов\n",
|
|||
|
" if random.random() < 0.15:\n",
|
|||
|
" num_gaps = random.randint(1, 5)\n",
|
|||
|
" gap_size = random.randint(2, 4)\n",
|
|||
|
" img = add_random_gaps(img, num_gaps=num_gaps, gap_size=gap_size)\n",
|
|||
|
"\n",
|
|||
|
" # Инверсия цветов с вероятностью 25%\n",
|
|||
|
" if random.random() < 0.15:\n",
|
|||
|
" img = ImageOps.invert(img)\n",
|
|||
|
"\n",
|
|||
|
" # Определяем лимиты блюра в зависимости от размера шрифта\n",
|
|||
|
" if font_size in SMALL_FONT_SIZES:\n",
|
|||
|
" blur_limits = {\"box\": 2, \"gaussian\": 1, \"motion\": 6} # Уменьшенные лимиты\n",
|
|||
|
" else:\n",
|
|||
|
" blur_limits = {\"box\": 4, \"gaussian\": 3, \"motion\": 8} # Оригинальные лимиты\n",
|
|||
|
"\n",
|
|||
|
" # Случайное добавление блюра\n",
|
|||
|
" if random.random() < 0.85:\n",
|
|||
|
" img = apply_blur(img, blur_limits=blur_limits)\n",
|
|||
|
"\n",
|
|||
|
" # Случайное добавление шума\n",
|
|||
|
" if random.random() < 0.1:\n",
|
|||
|
" noise_intensity = random.randint(10, 20)\n",
|
|||
|
" img = add_noise(img, intensity=noise_intensity)\n",
|
|||
|
"\n",
|
|||
|
" # Сохраняем изображение\n",
|
|||
|
" symbol_dir = os.path.join(DATASET_DIR, symbol)\n",
|
|||
|
" os.makedirs(symbol_dir, exist_ok=True)\n",
|
|||
|
" file_path = os.path.join(symbol_dir, f\"{symbol}_{font_size}_{i}.png\")\n",
|
|||
|
" img.save(file_path)\n",
|
|||
|
"\n",
|
|||
|
" config_lines.append(f\"{file_path},{symbol}\\n\")\n",
|
|||
|
"\n",
|
|||
|
"def main():\n",
|
|||
|
" config_lines = []\n",
|
|||
|
"\n",
|
|||
|
" for symbol in SYMBOLS:\n",
|
|||
|
" create_augmentations(symbol, FONT_SIZES, MAX_ROTATION, AUGMENTATIONS, config_lines)\n",
|
|||
|
"\n",
|
|||
|
" config_path = os.path.join(DATASET_DIR, CONFIG_FILE)\n",
|
|||
|
" with open(config_path, \"w\") as config:\n",
|
|||
|
" config.writelines(config_lines)\n",
|
|||
|
"\n",
|
|||
|
" print(f\"Датасет сохранен в папке '{DATASET_DIR}'.\")\n",
|
|||
|
" print(f\"Конфигурационный файл создан: {CONFIG_FILE}\")\n",
|
|||
|
"\n",
|
|||
|
"if __name__ == \"__main__\":\n",
|
|||
|
" main()\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"kernelspec": {
|
|||
|
"display_name": ".venv",
|
|||
|
"language": "python",
|
|||
|
"name": "python3"
|
|||
|
},
|
|||
|
"language_info": {
|
|||
|
"codemirror_mode": {
|
|||
|
"name": "ipython",
|
|||
|
"version": 3
|
|||
|
},
|
|||
|
"file_extension": ".py",
|
|||
|
"mimetype": "text/x-python",
|
|||
|
"name": "python",
|
|||
|
"nbconvert_exporter": "python",
|
|||
|
"pygments_lexer": "ipython3",
|
|||
|
"version": "3.9.13"
|
|||
|
}
|
|||
|
},
|
|||
|
"nbformat": 4,
|
|||
|
"nbformat_minor": 2
|
|||
|
}
|