Skip to content

Commit

Permalink
outsource PieModelHFHubMixin.load_weights
Browse files Browse the repository at this point in the history
  • Loading branch information
ArneBinder committed Nov 3, 2024
1 parent b62b931 commit 8e71ee5
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions src/pytorch_ie/core/hf_hub_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,13 @@ def _save_pretrained(self, save_directory: Path) -> None:
model_to_save = self.module if hasattr(self, "module") else self # type: ignore
torch.save(model_to_save.state_dict(), save_directory / PYTORCH_WEIGHTS_NAME)

def load_weights(
self, model_file: str, map_location: str = "cpu", strict: bool = False
) -> None:
state_dict = torch.load(model_file, map_location=torch.device(map_location))
model.load_state_dict(state_dict, strict=strict) # type: ignore
model.eval() # type: ignore

@classmethod
def _from_pretrained(
cls: Type[T],
Expand Down Expand Up @@ -428,9 +435,7 @@ def _from_pretrained(
config.pop(cls.config_type_key)
model = cls(**config)

state_dict = torch.load(model_file, map_location=torch.device(map_location))
model.load_state_dict(state_dict, strict=strict) # type: ignore
model.eval() # type: ignore
model.load_weights(model_file, map_location=map_location, strict=strict)

return model

Expand Down

0 comments on commit 8e71ee5

Please sign in to comment.