-
Notifications
You must be signed in to change notification settings - Fork 1
/
models.py
43 lines (32 loc) · 1.19 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch.nn as nn
import torchvision.models as models
class FlowerNet(nn.Module):
def __init__(self, num_classes=10, pretrained=False):
super().__init__()
if pretrained:
resnet = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
else:
resnet = models.resnet18()
self.conv1 = resnet.conv1
self.bn1 = resnet.bn1
self.relu = resnet.relu
self.maxpool = resnet.maxpool
self.layer1 = resnet.layer1
self.layer2 = resnet.layer2
self.layer3 = resnet.layer3
self.layer4 = resnet.layer4
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.flatten = nn.Flatten()
self.fc = nn.Linear(in_features=512, out_features=num_classes)
def forward(self, inputs):
outputs = self.conv1(inputs)
outputs = self.bn1(outputs)
outputs = self.relu(outputs)
outputs = self.maxpool(outputs)
outputs = self.layer1(outputs)
outputs = self.layer2(outputs)
outputs = self.layer3(outputs)
outputs = self.layer4(outputs)
outputs = self.avgpool(outputs)
outputs = self.flatten(outputs)
return self.fc(outputs)