number-plate/model.py

57 lines
2.0 KiB
Python
Raw Normal View History

import torch.nn as nn
# ------------------------------------------------------------
# CRNN-модель
# ------------------------------------------------------------
class CRNN(nn.Module):
def __init__(self, num_classes):
super(CRNN, self).__init__()
# CNN часть
self.cnn = nn.Sequential(
nn.Conv2d(1, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Conv2d(128, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.BatchNorm2d(256),
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d((2,1), (2,1)),
nn.Conv2d(256, 512, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.BatchNorm2d(512),
nn.Conv2d(512, 512, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d((2,1), (2,1)),
)
# RNN часть
self.linear1 = nn.Linear(512 * 2, 256)
self.relu = nn.ReLU(inplace=True)
self.lstm = nn.LSTM(256, 256, bidirectional=True, batch_first=True)
self.linear2 = nn.Linear(512, num_classes)
def forward(self, x):
# x: (batch, 1, 32, 128) — после Resize
conv = self.cnn(x) # (batch, 512, 2, 32)
conv = conv.permute(0, 3, 1, 2) # (batch, width=32, channels=512, height=2)
conv = conv.view(conv.size(0), conv.size(1), -1) # (batch, 32, 512*2)
out = self.linear1(conv) # (batch, 32, 256)
out = self.relu(out) # (batch, 32, 256)
out, _ = self.lstm(out) # (batch, 32, 512) — bidirectional
out = self.linear2(out) # (batch, 32, num_classes)
out = out.permute(1, 0, 2) # (32, batch, num_classes)
return out