Skip to content

Commit

Permalink
Check hash
Browse files Browse the repository at this point in the history
  • Loading branch information
MatthijsBurgh committed Jun 28, 2024
1 parent 94a2f2f commit 76ba85d
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions facenet_pytorch/models/inception_resnet_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x


def load_weights(mdl, name):
def load_weights(mdl: torch.nn.Module, name: str) -> None:
"""Download pretrained state_dict and load into model.
Arguments:
Expand All @@ -304,8 +304,10 @@ def load_weights(mdl, name):
"""
if name == "vggface2":
path = "https://github.com/timesler/facenet-pytorch/releases/download/v2.2.9/20180402-114759-vggface2.pt"
hash_prefix = "281cebca8662831adb987a874bdcb36e73f5b1c6dc5ee5878f305e985625d99b"
elif name == "casia-webface":
path = "https://github.com/timesler/facenet-pytorch/releases/download/v2.2.9/20180408-102900-casia-webface.pt"
hash_prefix = "7a67afdbbc995fce5e10128675e318799a70698c2f433ba75dd7eb9a2f096e7d"
else:
msg = "Pretrained models only exist for 'vggface2' and 'casia-webface'"
raise ValueError(msg)

Check warning on line 313 in facenet_pytorch/models/inception_resnet_v1.py

View check run for this annotation

Codecov / codecov/patch

facenet_pytorch/models/inception_resnet_v1.py#L312-L313

Added lines #L312 - L313 were not covered by tests
Expand All @@ -315,7 +317,7 @@ def load_weights(mdl, name):

cached_file = model_dir / Path(path).name
if not cached_file.exists():
download_url_to_file(path, cached_file)
download_url_to_file(path, cached_file, hash_prefix)

state_dict = torch.load(cached_file)
mdl.load_state_dict(state_dict)
Expand Down

0 comments on commit 76ba85d

Please sign in to comment.