import torch.nn as nn # ------------------------------------------------------------ # CRNN-модель # ------------------------------------------------------------ class CRNN(nn.Module): def __init__(self, num_classes): super(CRNN, self).__init__() # CNN часть self.cnn = nn.Sequential( nn.Conv2d(1, 64, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2), nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2), nn.Conv2d(128, 256, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.BatchNorm2d(256), nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d((2,1), (2,1)), nn.Conv2d(256, 512, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.BatchNorm2d(512), nn.Conv2d(512, 512, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d((2,1), (2,1)), ) # RNN часть self.linear1 = nn.Linear(512 * 2, 256) self.relu = nn.ReLU(inplace=True) self.lstm = nn.LSTM(256, 256, bidirectional=True, batch_first=True) self.linear2 = nn.Linear(512, num_classes) def forward(self, x): # x: (batch, 1, 32, 128) — после Resize conv = self.cnn(x) # (batch, 512, 2, 32) conv = conv.permute(0, 3, 1, 2) # (batch, width=32, channels=512, height=2) conv = conv.view(conv.size(0), conv.size(1), -1) # (batch, 32, 512*2) out = self.linear1(conv) # (batch, 32, 256) out = self.relu(out) # (batch, 32, 256) out, _ = self.lstm(out) # (batch, 32, 512) — bidirectional out = self.linear2(out) # (batch, 32, num_classes) out = out.permute(1, 0, 2) # (32, batch, num_classes) return out