number-plate-study/why28.ipynb

245 lines
10 KiB
Plaintext
Raw Normal View History

2024-11-28 21:24:59 +01:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"from torch.utils.data import DataLoader, Dataset\n",
"from torchvision import transforms\n",
"import random\n",
"from PIL import Image, ImageDraw, ImageFont, ImageOps, ImageFilter\n",
"import cv2\n",
"import numpy as np\n",
"\n",
"DATASET_DIR = \"dataset\"\n",
"SAVE_PATH = \"best_model_8.pth\"\n",
"BATCH_SIZE = 32\n",
"EPOCHS = 30\n",
"LEARNING_RATE = 0.001\n",
"DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
"CLASSES = \"ABEKMHOPCTYX0123456789\"\n",
"NUM_CLASSES = len(CLASSES)\n",
"CLASS_TO_IDX = {char: idx for idx, char in enumerate(CLASSES)}\n",
"IDX_TO_CLASS = {idx: char for char, idx in CLASS_TO_IDX.items()}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Датасет сохранен в папке 'dataset'.\n",
"Конфигурационный файл создан: dataset_config.txt\n"
]
}
],
"source": [
"\n",
"FONT_PATH = \"RoadNumbers2.0.ttf\"\n",
"CONFIG_FILE = \"dataset_config.txt\"\n",
"\n",
"# Буквы и цифры, используемые в российских номерах\n",
"SYMBOLS = \"ABEKMHOPCTYX0123456789\"\n",
"SMALL_FONT_SIZES = [26, 32] # Размеры шрифта с уменьшенным блюром\n",
"IMG_SIZE = (28, 28)\n",
"FONT_SIZES = [26, 32, 38, 46, 54, 58]\n",
"MAX_ROTATION = 15\n",
"AUGMENTATIONS = 240\n",
"\n",
"os.makedirs(DATASET_DIR, exist_ok=True)\n",
"\n",
"try:\n",
" font = ImageFont.truetype(FONT_PATH, FONT_SIZES[0])\n",
"except IOError:\n",
" print(f\"Шрифт {FONT_PATH} не найден. Убедитесь, что он находится в текущей директории.\")\n",
" exit()\n",
"\n",
"def expand_characters(img, kernel_size=(2, 2), iterations=1):\n",
" kernel = cv2.getStructuringElement(cv2.MORPH_RECT, kernel_size)\n",
" img_array = np.array(img)\n",
" expanded = cv2.dilate(img_array, kernel, iterations=iterations)\n",
" return Image.fromarray(expanded)\n",
"\n",
"def add_random_gaps(img, num_gaps=5, gap_size=5):\n",
" draw = ImageDraw.Draw(img)\n",
" for _ in range(num_gaps):\n",
" x1 = random.randint(0, IMG_SIZE[0] - gap_size)\n",
" y1 = random.randint(0, IMG_SIZE[1] - gap_size)\n",
" x2 = x1 + gap_size\n",
" y2 = y1 + gap_size\n",
" draw.rectangle([x1, y1, x2, y2], fill=255) \n",
" return img\n",
"\n",
"def apply_blur(img, blur_limits=None):\n",
" if blur_limits is None:\n",
" blur_limits = {\"box\": 4, \"gaussian\": 3, \"motion\": 8}\n",
"\n",
" blur_type = random.choice([\"box\", \"gaussian\", \"motion\"])\n",
"\n",
" if blur_type == \"box\":\n",
" radius = random.randint(1, blur_limits[\"box\"])\n",
" img = img.filter(ImageFilter.BoxBlur(radius))\n",
" elif blur_type == \"gaussian\":\n",
" radius = random.randint(1, blur_limits[\"gaussian\"])\n",
" img = img.filter(ImageFilter.GaussianBlur(radius))\n",
" elif blur_type == \"motion\":\n",
" kernel_size = random.randint(4, blur_limits[\"motion\"])\n",
" kernel_motion_blur = np.zeros((kernel_size, kernel_size))\n",
" kernel_motion_blur[int((kernel_size - 1) / 2), :] = 1\n",
" kernel_motion_blur /= kernel_size\n",
" img_array = cv2.filter2D(np.array(img), -1, kernel_motion_blur)\n",
" img = Image.fromarray(img_array)\n",
" return img\n",
"\n",
"def add_noise(img, intensity=30):\n",
" img_array = np.array(img)\n",
" noise = np.random.normal(0, intensity, img_array.shape).astype(np.int32)\n",
" noisy_img = np.clip(img_array + noise, 0, 255).astype(np.uint8)\n",
" return Image.fromarray(noisy_img)\n",
"\n",
"def shift_image(img, max_shift=4):\n",
" \"\"\"\n",
" Сдвигает изображение на случайное количество пикселей по осям X и Y.\n",
"\n",
" :param img: PIL Image объект.\n",
" :param max_shift: Максимальное смещение в пикселях по каждой оси.\n",
" :return: Сдвинутое изображение.\n",
" \"\"\"\n",
" x_shift = random.randint(-max_shift, max_shift)\n",
" y_shift = random.randint(-max_shift, max_shift)\n",
" return img.transform(\n",
" img.size,\n",
" Image.AFFINE,\n",
" (1, 0, x_shift, 0, 1, y_shift),\n",
" fillcolor=255\n",
" )\n",
"\n",
"def create_augmentations(symbol, font_sizes, max_rotation, augmentations, config_lines):\n",
" for i in range(augmentations):\n",
" for font_size in font_sizes:\n",
" current_font = ImageFont.truetype(FONT_PATH, font_size)\n",
"\n",
" # Создаём пустое изображение с белым фоном\n",
" img = Image.new(\"L\", IMG_SIZE, 255)\n",
" draw = ImageDraw.Draw(img)\n",
"\n",
" # Получаем размеры текста с использованием Font.getbbox\n",
" text_bbox = draw.textbbox((0, 0), symbol, font=current_font)\n",
" text_width, text_height = text_bbox[2] - text_bbox[0], text_bbox[3] - text_bbox[1]\n",
"\n",
" # Рассчитываем координаты для центрирования текста\n",
" x = (IMG_SIZE[0] - text_width) // 2 - text_bbox[0]\n",
" y = (IMG_SIZE[1] - text_height) // 2 - text_bbox[1]\n",
"\n",
" # Рисуем текст\n",
" draw.text((x, y), symbol, font=current_font, fill=0)\n",
"\n",
" # Для первого прогона символа с данным размером шрифта сохраняем без аугментаций\n",
" if i == 0:\n",
" symbol_dir = os.path.join(DATASET_DIR, symbol)\n",
" os.makedirs(symbol_dir, exist_ok=True)\n",
" file_path = os.path.join(symbol_dir, f\"{symbol}_{font_size}_{i}.png\")\n",
" img.save(file_path)\n",
"\n",
" config_lines.append(f\"{file_path},{symbol}\\n\")\n",
" continue\n",
"\n",
" # Случайное смещение символа\n",
" if random.random() < 0.7: \n",
" img = shift_image(img, max_shift=4)\n",
"\n",
" # Случайный поворот изображения\n",
" rotation = random.uniform(-max_rotation, max_rotation)\n",
" img = img.rotate(rotation, expand=False, fillcolor=255)\n",
"\n",
" # Случайное расширение символов\n",
" if random.random() < 0.5:\n",
" kernel_size = (random.randint(1, 2), random.randint(1, 2))\n",
" iterations = 1\n",
" img = expand_characters(img, kernel_size=kernel_size, iterations=iterations)\n",
"\n",
" # Случайное добавление разрывов\n",
" if random.random() < 0.15:\n",
" num_gaps = random.randint(1, 5)\n",
" gap_size = random.randint(2, 4)\n",
" img = add_random_gaps(img, num_gaps=num_gaps, gap_size=gap_size)\n",
"\n",
" # Инверсия цветов с вероятностью 25%\n",
" if random.random() < 0.15:\n",
" img = ImageOps.invert(img)\n",
"\n",
" # Определяем лимиты блюра в зависимости от размера шрифта\n",
" if font_size in SMALL_FONT_SIZES:\n",
" blur_limits = {\"box\": 2, \"gaussian\": 1, \"motion\": 6} # Уменьшенные лимиты\n",
" else:\n",
" blur_limits = {\"box\": 4, \"gaussian\": 3, \"motion\": 8} # Оригинальные лимиты\n",
"\n",
" # Случайное добавление блюра\n",
" if random.random() < 0.85:\n",
" img = apply_blur(img, blur_limits=blur_limits)\n",
"\n",
" # Случайное добавление шума\n",
" if random.random() < 0.1:\n",
" noise_intensity = random.randint(10, 20)\n",
" img = add_noise(img, intensity=noise_intensity)\n",
"\n",
" # Сохраняем изображение\n",
" symbol_dir = os.path.join(DATASET_DIR, symbol)\n",
" os.makedirs(symbol_dir, exist_ok=True)\n",
" file_path = os.path.join(symbol_dir, f\"{symbol}_{font_size}_{i}.png\")\n",
" img.save(file_path)\n",
"\n",
" config_lines.append(f\"{file_path},{symbol}\\n\")\n",
"\n",
"def main():\n",
" config_lines = []\n",
"\n",
" for symbol in SYMBOLS:\n",
" create_augmentations(symbol, FONT_SIZES, MAX_ROTATION, AUGMENTATIONS, config_lines)\n",
"\n",
" config_path = os.path.join(DATASET_DIR, CONFIG_FILE)\n",
" with open(config_path, \"w\") as config:\n",
" config.writelines(config_lines)\n",
"\n",
" print(f\"Датасет сохранен в папке '{DATASET_DIR}'.\")\n",
" print(f\"Конфигурационный файл создан: {CONFIG_FILE}\")\n",
"\n",
"if __name__ == \"__main__\":\n",
" main()\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}