diff --git a/models.py b/models.py index 807fd1a..9a45a2d 100644 --- a/models.py +++ b/models.py @@ -15,6 +15,7 @@ def __init__(self, pretrained=False): self.model = resnet50(pretrained) self.embedding_size = 128 + num_classes = 500 self.cnn = nn.Sequential( self.model.conv1, self.model.bn1, @@ -33,7 +34,7 @@ def __init__(self, pretrained=False): # nn.ReLU(), nn.Linear(100352, self.embedding_size)) - self.model.classifier = nn.Linear(self.embedding_size, num_classes) + self.model.classifier = nn.Linear(self.embedding_size, num_classes) def l2_norm(self, input): input_size = input.size()