57 lines
2.0 KiB
Python
57 lines
2.0 KiB
Python
import torch.nn as nn
|
|
|
|
|
|
# ------------------------------------------------------------
|
|
# CRNN-модель
|
|
# ------------------------------------------------------------
|
|
class CRNN(nn.Module):
|
|
def __init__(self, num_classes):
|
|
super(CRNN, self).__init__()
|
|
|
|
# CNN часть
|
|
self.cnn = nn.Sequential(
|
|
nn.Conv2d(1, 64, kernel_size=3, padding=1),
|
|
nn.ReLU(inplace=True),
|
|
nn.MaxPool2d(2, 2),
|
|
|
|
nn.Conv2d(64, 128, kernel_size=3, padding=1),
|
|
nn.ReLU(inplace=True),
|
|
nn.MaxPool2d(2, 2),
|
|
|
|
nn.Conv2d(128, 256, kernel_size=3, padding=1),
|
|
nn.ReLU(inplace=True),
|
|
nn.BatchNorm2d(256),
|
|
|
|
nn.Conv2d(256, 256, kernel_size=3, padding=1),
|
|
nn.ReLU(inplace=True),
|
|
nn.MaxPool2d((2,1), (2,1)),
|
|
|
|
nn.Conv2d(256, 512, kernel_size=3, padding=1),
|
|
nn.ReLU(inplace=True),
|
|
nn.BatchNorm2d(512),
|
|
|
|
nn.Conv2d(512, 512, kernel_size=3, padding=1),
|
|
nn.ReLU(inplace=True),
|
|
nn.MaxPool2d((2,1), (2,1)),
|
|
)
|
|
|
|
# RNN часть
|
|
self.linear1 = nn.Linear(512 * 2, 256)
|
|
self.relu = nn.ReLU(inplace=True)
|
|
self.lstm = nn.LSTM(256, 256, bidirectional=True, batch_first=True)
|
|
self.linear2 = nn.Linear(512, num_classes)
|
|
|
|
def forward(self, x):
|
|
# x: (batch, 1, 32, 128) — после Resize
|
|
conv = self.cnn(x) # (batch, 512, 2, 32)
|
|
conv = conv.permute(0, 3, 1, 2) # (batch, width=32, channels=512, height=2)
|
|
conv = conv.view(conv.size(0), conv.size(1), -1) # (batch, 32, 512*2)
|
|
|
|
out = self.linear1(conv) # (batch, 32, 256)
|
|
out = self.relu(out) # (batch, 32, 256)
|
|
out, _ = self.lstm(out) # (batch, 32, 512) — bidirectional
|
|
out = self.linear2(out) # (batch, 32, num_classes)
|
|
|
|
out = out.permute(1, 0, 2) # (32, batch, num_classes)
|
|
return out
|