forked from ai4luc/CerraData-code-data
-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
88 lines (64 loc) · 3.12 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import torch
import math
from torch import nn
from torch.nn import functional as F
import torchvision
class Flatten(nn.Module):
def __init__(self):
super(Flatten, self).__init__()
def forward(self, x):
return x.view(x.size(0), -1)
class Model(nn.Module):
__constants__ = ['model_name', 'num_classes', 'num_domains']
def __init__(self, backbone, num_classes, get_features=False):
super(Model, self).__init__()
if not isinstance(backbone, nn.Module):
raise ValueError('A model must be provided.')
self.num_classes = num_classes
self.model_name = backbone.__class__.__name__.lower()
if any(prefix in self.model_name for prefix in ['alexnet', 'mnasnet', 'mobilenet', 'vgg', 'convnext', 'efficient']):
feature_dim = backbone.classifier[-1].in_features
layer = nn.Linear(feature_dim, num_classes)
if 'efficient' in self.model_name:
init_range = 1.0 / math.sqrt(layer.out_features)
nn.init.uniform_(layer.weight, -init_range, init_range)
nn.init.zeros_(layer.bias)
if 'convnext' in self.model_name:
nn.init.trunc_normal_(layer.weight, std=0.02)
if layer.bias is not None:
nn.init.zeros_(layer.bias)
backbone.classifier[-1] = layer
if get_features: self.add_features_hook(backbone.classifier[-1])
elif 'densenet' in self.model_name:
feature_dim = backbone.classifier.in_features
backbone.classifier = nn.Linear(feature_dim, num_classes)
if get_features: self.add_features_hook(backbone.classifier)
elif any(prefix in self.model_name for prefix in ['googlenet', 'inception', 'resnet', 'shufflenet', 'resnext']):
feature_dim = backbone.fc.in_features
backbone.fc = nn.Linear(feature_dim, num_classes)
if get_features: self.add_features_hook(backbone.fc)
elif 'squeezenet' in self.model_name:
in_channels = backbone.classifier[1].in_channels
kernel_size = backbone.classifier[1].kernel_size
backbone.classifier[1] = nn.Conv2d(in_channels, num_classes, kernel_size=kernel_size )
if get_features: self.add_features_hook(backbone.classifier[1])
elif 'cerranet' in self.model_name:
feature_dim = backbone.classifier[-1].in_features
backbone.classifier[-1] = nn.Linear(feature_dim, num_classes)
if get_features: self.add_features_hook(backbone.classifier[-1])
self.backbone = backbone
def add_features_hook(self, layer):
self.features = []
def fn(model, incoming_data, output):
incoming_data = incoming_data.cpu().numpy()
self.features.append(incoming_data)
layer.register_forward_hook(fn)
def reset_features(self):
del self.features
self.features = []
def extra_repr(self):
s = ('backbone={model_name}, num_classes={num_classes}')
return s.format(**self.__dict__)
def forward(self, x):
x = self.backbone(x)
return x