-
Notifications
You must be signed in to change notification settings - Fork 10
/
model.py
executable file
·131 lines (118 loc) · 7.28 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"""
Paper: "UTRNet: High-Resolution Urdu Text Recognition In Printed Documents" presented at ICDAR 2023
Authors: Abdur Rahman, Arjun Ghosh, Chetan Arora
GitHub Repository: https://github.com/abdur75648/UTRNet-High-Resolution-Urdu-Text-Recognition
Project Website: https://abdur75648.github.io/UTRNet/
Copyright (c) 2023-present: This work is licensed under the Creative Commons Attribution-NonCommercial
4.0 International License (http://creativecommons.org/licenses/by-nc/4.0/)
"""
from modules.feature_extraction import HRNet_FeatureExtractor
from modules.sequence_modeling import BidirectionalLSTM
from modules.dropout_layer import dropout_layer
from modules.prediction import Attention
import torch.nn as nn
# Other CNN Architectures
from modules.feature_extraction import DenseNet_FeatureExtractor, InceptionUNet_FeatureExtractor
from modules.feature_extraction import RCNN_FeatureExtractor, ResNet_FeatureExtractor
from modules.feature_extraction import ResUnet_FeatureExtractor, AttnUNet_FeatureExtractor
from modules.feature_extraction import UNet_FeatureExtractor, UNetPlusPlus_FeatureExtractor
from modules.feature_extraction import VGG_FeatureExtractor
# Other sequential models
from modules.sequence_modeling import LSTM, GRU, MDLSTM
class Model(nn.Module):
def __init__(self, opt):
super(Model, self).__init__()
self.opt = opt
self.stages = {'Feat': opt.FeatureExtraction,
'Seq': opt.SequenceModeling,
'Pred': opt.Prediction}
""" FeatureExtraction """
if opt.FeatureExtraction == 'HRNet':
self.FeatureExtraction = HRNet_FeatureExtractor(opt.input_channel, opt.output_channel)
elif opt.FeatureExtraction == 'Densenet':
self.FeatureExtraction = DenseNet_FeatureExtractor(opt.input_channel, opt.output_channel)
elif opt.FeatureExtraction == 'InceptionUnet':
self.FeatureExtraction = InceptionUNet_FeatureExtractor(opt.input_channel, opt.output_channel)
elif opt.FeatureExtraction == 'RCNN':
self.FeatureExtraction = RCNN_FeatureExtractor(opt.input_channel, opt.output_channel)
elif opt.FeatureExtraction == 'ResNet':
self.FeatureExtraction = ResNet_FeatureExtractor(opt.input_channel, opt.output_channel)
elif opt.FeatureExtraction == 'ResUnet':
self.FeatureExtraction = ResUnet_FeatureExtractor(opt.input_channel, opt.output_channel)
elif opt.FeatureExtraction == 'AttnUNet':
self.FeatureExtraction = AttnUNet_FeatureExtractor(opt.input_channel, opt.output_channel)
elif opt.FeatureExtraction == 'UNet':
self.FeatureExtraction = UNet_FeatureExtractor(opt.input_channel, opt.output_channel)
elif opt.FeatureExtraction == 'UnetPlusPlus':
self.FeatureExtraction = UNetPlusPlus_FeatureExtractor(opt.input_channel, opt.output_channel)
elif opt.FeatureExtraction == 'VGG':
self.FeatureExtraction = VGG_FeatureExtractor(opt.input_channel, opt.output_channel)
else:
raise Exception('No FeatureExtraction module specified')
self.FeatureExtraction_output = opt.output_channel
self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1)) # Transform final (imgH/16-1) -> 1
"""
Temporal Dropout
"""
self.dropout1 = dropout_layer(opt.device)
self.dropout2 = dropout_layer(opt.device)
self.dropout3 = dropout_layer(opt.device)
self.dropout4 = dropout_layer(opt.device)
self.dropout5 = dropout_layer(opt.device)
""" Sequence modeling"""
if opt.SequenceModeling == 'LSTM':
self.SequenceModeling = LSTM(self.FeatureExtraction_output, opt.hidden_size, opt.hidden_size)
elif opt.SequenceModeling == 'GRU':
self.SequenceModeling = GRU(self.FeatureExtraction_output, opt.hidden_size, opt.hidden_size)
elif opt.SequenceModeling == 'MDLSTM':
self.SequenceModeling = MDLSTM(self.FeatureExtraction_output, opt.hidden_size, opt.hidden_size)
elif opt.SequenceModeling == 'BiLSTM':
self.SequenceModeling = BidirectionalLSTM(self.FeatureExtraction_output, opt.hidden_size, opt.hidden_size)
elif opt.SequenceModeling == 'DBiLSTM':
self.SequenceModeling = nn.Sequential(
BidirectionalLSTM(self.FeatureExtraction_output, opt.hidden_size, opt.hidden_size),
BidirectionalLSTM(opt.hidden_size, opt.hidden_size, opt.hidden_size))
else:
raise Exception('No Sequence Modeling module specified')
self.SequenceModeling_output = opt.hidden_size
""" Prediction """
if opt.Prediction == 'CTC':
self.Prediction = nn.Linear(self.SequenceModeling_output, opt.num_class)
elif opt.Prediction == 'Attn':
self.Prediction = Attention(self.SequenceModeling_output, opt.hidden_size, opt.num_class, opt.device)
else:
raise Exception('Prediction is neither CTC or Attn')
def forward(self, input, text=None, is_train=True):
""" Feature extraction stage """
visual_feature = self.FeatureExtraction(input)
# print(visual_feature.shape) # [32, 32, 32, 400] #HRNet, [32, 512, 32, 400] #UNet
visual_feature = self.AdaptiveAvgPool(visual_feature.permute(0, 3, 1, 2)) # [b, c, h, w] -> [b, w, c, h]
# print(visual_feature.shape) # [32, 400, 32, 1] #HRNet, [32, 400, 512, 1] #UNet
visual_feature = visual_feature.squeeze(3)
# print(visual_feature.shape) # [32, 400, 32] #HRNet, [32, 400, 512] #UNet
""" Temporal Dropout + Sequence modeling stage """
# contextual_feature = self.SequenceModeling(visual_feature) ##### Without temporal dropout
if (self.training):
visual_feature_after_dropout1 = self.dropout1(visual_feature)
contextual_feature = self.SequenceModeling(visual_feature_after_dropout1)
else :
visual_feature_after_dropout1 = self.dropout1(visual_feature)
visual_feature_after_dropout2 = self.dropout2(visual_feature)
visual_feature_after_dropout3 = self.dropout3(visual_feature)
visual_feature_after_dropout4 = self.dropout4(visual_feature)
visual_feature_after_dropout5 = self.dropout5(visual_feature)
contextual_feature1 = self.SequenceModeling(visual_feature_after_dropout1)
contextual_feature2 = self.SequenceModeling(visual_feature_after_dropout2)
contextual_feature3 = self.SequenceModeling(visual_feature_after_dropout3)
contextual_feature4 = self.SequenceModeling(visual_feature_after_dropout4)
contextual_feature5 = self.SequenceModeling(visual_feature_after_dropout5)
contextual_feature = ( (contextual_feature1).add ((contextual_feature2).add(((contextual_feature3).add(((contextual_feature4).add(contextual_feature5)))))) ) * (1/5)
""" Prediction stage """
if self.stages['Pred'] == 'CTC':
prediction = self.Prediction(contextual_feature.contiguous())
else:
if text is None:
raise Exception('Input text (for prediction) to model is None')
text = text.to(self.opt.device)
prediction = self.Prediction(contextual_feature, text, is_train, batch_max_length=self.opt.batch_max_length)
return prediction