Skip to content

Commit

Permalink
Avoid DeprecationWarning when loading torchvision pretrained model
Browse files Browse the repository at this point in the history
  • Loading branch information
gmberton committed Jan 3, 2023
1 parent db401f1 commit eab6e11
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 17 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ By default training is performed with a ResNet-18 with descriptors dimensionalit

To change the backbone or the output descriptors dimensionality simply run

`$ python3 train.py --dataset_folder path/to/sf-xl/processed --backbone resnet50 --fc_output_dim 128`
`$ python3 train.py --dataset_folder path/to/sf-xl/processed --backbone ResNet50 --fc_output_dim 128`

You can also speed up your training with Automatic Mixed Precision (note that all results/statistics from the paper did not use AMP)

Expand All @@ -62,7 +62,7 @@ If you are a researcher comparing your work against ours, please make sure to fo
## Test
You can test a trained model as such

`$ python3 eval.py --dataset_folder path/to/sf-xl/processed --backbone resnet50 --fc_output_dim 128 --resume_model path/to/best_model.pth`
`$ python3 eval.py --dataset_folder path/to/sf-xl/processed --backbone ResNet50 --fc_output_dim 128 --resume_model path/to/best_model.pth`

You can download plenty of trained models below.

Expand Down
29 changes: 16 additions & 13 deletions model/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,21 @@ def forward(self, x):
return x


def get_backbone(backbone_name):
if backbone_name.startswith("resnet"):
if backbone_name == "resnet18":
backbone = torchvision.models.resnet18(pretrained=True)
elif backbone_name == "resnet50":
backbone = torchvision.models.resnet50(pretrained=True)
elif backbone_name == "resnet101":
backbone = torchvision.models.resnet101(pretrained=True)
elif backbone_name == "resnet152":
backbone = torchvision.models.resnet152(pretrained=True)

def get_pretrained_torchvision_model(backbone_name : str) -> torch.nn.Module:
"""This function takes the name of a backbone and returns the corresponding pretrained
model from torchvision. Examples of backbone_name are 'VGG16' or 'ResNet18'
"""
try: # Newer versions of pytorch require to pass weights=weights_module.DEFAULT
weights_module = getattr(__import__('torchvision.models', fromlist=[f"{backbone_name}_Weights"]), f"{backbone_name}_Weights")
model = getattr(torchvision.models, backbone_name.lower())(weights=weights_module.DEFAULT)
except (ImportError, AttributeError): # Older versions of pytorch require to pass pretrained=True
model = getattr(torchvision.models, backbone_name.lower())(pretrained=True)
return model


def get_backbone(backbone_name : str) -> Tuple[torch.nn.Module, int]:
backbone = get_pretrained_torchvision_model(backbone_name)
if backbone_name.startswith("ResNet"):
for name, child in backbone.named_children():
if name == "layer3": # Freeze layers before conv_3
break
Expand All @@ -61,8 +65,7 @@ def get_backbone(backbone_name):
logging.debug(f"Train only layer3 and layer4 of the {backbone_name}, freeze the previous ones")
layers = list(backbone.children())[:-2] # Remove avg pooling and FC layer

elif backbone_name == "vgg16":
backbone = torchvision.models.vgg16(pretrained=True)
elif backbone_name == "VGG16":
layers = list(backbone.features.children())[:-2] # Remove avg pooling and FC layer
for layer in layers[:-5]:
for p in layer.parameters():
Expand Down
4 changes: 2 additions & 2 deletions parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ def parse_arguments(is_training: bool = True):
parser.add_argument("--groups_num", type=int, default=8, help="_")
parser.add_argument("--min_images_per_class", type=int, default=10, help="_")
# Model parameters
parser.add_argument("--backbone", type=str, default="resnet18",
choices=["vgg16", "resnet18", "resnet50", "resnet101", "resnet152"], help="_")
parser.add_argument("--backbone", type=str, default="ResNet18",
choices=["VGG16", "ResNet18", "ResNet50", "ResNet101", "ResNet152"], help="_")
parser.add_argument("--fc_output_dim", type=int, default=512,
help="Output dimension of final fully connected layer")
# Training parameters
Expand Down

0 comments on commit eab6e11

Please sign in to comment.