forked from gmberton/CosPlace
-
Notifications
You must be signed in to change notification settings - Fork 0
/
cosplace_network.py
101 lines (83 loc) · 4.09 KB
/
cosplace_network.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import torch
import logging
import torchvision
from torch import nn
from typing import Tuple
from cosplace_model.layers import Flatten, L2Norm, GeM
# The number of channels in the last convolutional layer, the one before average pooling
CHANNELS_NUM_IN_LAST_CONV = {
"ResNet18": 512,
"ResNet50": 2048,
"ResNet101": 2048,
"ResNet152": 2048,
"VGG16": 512,
"wide_resnet50_2": 2048,
"densenet121" : 1024
}
class GeoLocalizationNet(nn.Module):
def __init__(self, backbone : str, fc_output_dim : int):
"""Return a model for GeoLocalization.
Args:
backbone (str): which torchvision backbone to use. Must be VGG16 or a ResNet.
fc_output_dim (int): the output dimension of the last fc layer, equivalent to the descriptors dimension.
"""
super().__init__()
assert backbone in CHANNELS_NUM_IN_LAST_CONV, f"backbone must be one of {list(CHANNELS_NUM_IN_LAST_CONV.keys())}"
self.backbone, features_dim = get_backbone(backbone)
self.aggregation = nn.Sequential(
L2Norm(),
GeM(),
Flatten(),
nn.Linear(features_dim, fc_output_dim),
L2Norm()
)
def forward(self, x):
x = self.backbone(x)
x = self.aggregation(x)
return x
def get_pretrained_torchvision_model(backbone_name : str) -> torch.nn.Module:
"""This function takes the name of a backbone and returns the corresponding pretrained
model from torchvision. Examples of backbone_name are 'VGG16' or 'ResNet18'
"""
try: # Newer versions of pytorch require to pass weights=weights_module.DEFAULT
weights_module = getattr(__import__('torchvision.models', fromlist=[f"{backbone_name}_Weights"]), f"{backbone_name}_Weights")
model = getattr(torchvision.models, backbone_name.lower())(weights=weights_module.DEFAULT)
except (ImportError, AttributeError): # Older versions of pytorch require to pass pretrained=True
model = getattr(torchvision.models, backbone_name.lower())(pretrained=True)
return model
def get_backbone(backbone_name : str) -> Tuple[torch.nn.Module, int]:
backbone = get_pretrained_torchvision_model(backbone_name)
if backbone_name.startswith("ResNet"):
for name, child in backbone.named_children():
if name == "layer3": # Freeze layers before conv_3
break
for params in child.parameters():
params.requires_grad = False
logging.debug(f"Train only layer3 and layer4 of the {backbone_name}, freeze the previous ones")
layers = list(backbone.children())[:-2] # Remove avg pooling and FC layer
elif backbone_name == "VGG16":
layers = list(backbone.features.children())[:-2] # Remove avg pooling and FC layer
for layer in layers[:-5]:
for p in layer.parameters():
p.requires_grad = False
logging.debug("Train last layers of the VGG-16, freeze the previous ones")
elif backbone_name == "wide_resnet50_2":
layers = list(backbone.children())[:-2] # Remove avg pooling and FC layer
for layer in layers[:-2]: # Freeze layers before layer3
for p in layer.parameters():
p.requires_grad = False
logging.debug("Train only layer3 and layer4 of the wide_resnet50_2, freeze the previous ones")
elif backbone_name == "densenet121":
# Get the features block
features = list(backbone.children())[0]
# Freeze all layers before denseblock3
for name, module in features.named_children():
if name != "denseblock3" and name != "transition3" and name != "denseblock4":
for p in module.parameters():
p.requires_grad = False
logging.debug("Train only denseblock3, transition3, and denseblock4 of the densenet121, freeze the previous")
# Define layers
layers = list(features.children())
backbone = torch.nn.Sequential(*layers)
features_dim = CHANNELS_NUM_IN_LAST_CONV[backbone_name]
return backbone, features_dim