-
Notifications
You must be signed in to change notification settings - Fork 15
/
model.py
95 lines (79 loc) · 2.66 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import torch.nn as nn
from layers.scale_transfer_module import ScaleTransferModule
from layers.multibox import MultiBox
from torchvision.models import densenet169
from layers.detection import Detect
"""
different configurations of STDN
"""
stdn_in = {
'300': [800, 960, 1120, 1280, 1440, 1664],
'513': [800, 960, 1120, 1280, 1440, 1600, 1664]
}
stdn_out = {
'300': [(1, 800), (3, 960), (5, 1120), (9, 1280), (18, 360), (36, 104)],
'513': [(1, 800), (2, 960), (4, 1120), (8, 1280), (16, 1440), (32, 400),
(64, 104)]
}
class STDN(nn.Module):
"""STDN Architecture"""
def __init__(self,
mode,
stdn_config,
channels,
class_count,
anchors,
num_anchors,
new_size):
super(STDN, self).__init__()
self.mode = mode
self.stdn_in = stdn_in[stdn_config]
self.stdn_out = stdn_out[stdn_config]
self.channels = channels
self.class_count = class_count
self.anchors = anchors
self.num_anchors = num_anchors
self.new_size = new_size
# self.init_weights()
self.densenet = densenet169(pretrained=True)
self.scale_transfer_module = ScaleTransferModule(self.new_size)
self.multibox = MultiBox(num_channels=self.stdn_out,
num_anchors=self.num_anchors,
class_count=self.class_count)
if mode == 'test':
self.softmax = nn.Softmax(dim=-1)
self.detect = Detect(class_count, 200, 0.01, 0.45)
def get_out_map_sizes(self):
return [x for x, _ in self.stdn_out]
def init_weights(self, modules):
"""
initializes weights for each layer
"""
for module in self.modules():
if isinstance(module, nn.Conv2d):
nn.init.kaiming_normal_(module.weight)
elif isinstance(module, nn.BatchNorm2d):
nn.init.constant_(module.weight, 1)
nn.init.constant_(module.bias, 0)
def forward(self, x):
"""
feed forward
"""
y = self.densenet.features(x)
output = []
for stop in self.stdn_in:
output.append(y[:, :stop, :, :])
y = self.scale_transfer_module(output)
class_preds, loc_preds = self.multibox(y)
if self.mode == 'test':
output = self.detect(
self.softmax(class_preds),
loc_preds,
self.anchors
)
else:
output = (
class_preds,
loc_preds
)
return output