Skip to content

Commit

Permalink
implement PieModelHFHubMixin.save_model_file() for completion
Browse files Browse the repository at this point in the history
  • Loading branch information
ArneBinder committed Nov 3, 2024
1 parent 4ae3722 commit 86eddae
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions src/pytorch_ie/core/hf_hub_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,10 +387,10 @@ class PieModelHFHubMixin(PieBaseHFHubMixin):
```
"""

def _save_pretrained(self, save_directory: Path) -> None:
def save_model_file(self, model_file: str) -> None:
"""Save weights from a Pytorch model to a local directory."""
model_to_save = self.module if hasattr(self, "module") else self # type: ignore
torch.save(model_to_save.state_dict(), save_directory / self.weights_file_name)
torch.save(model_to_save.state_dict(), model_file)

def load_model_file(
self, model_file: str, map_location: str = "cpu", strict: bool = False
Expand All @@ -399,6 +399,10 @@ def load_model_file(
self.load_state_dict(state_dict, strict=strict) # type: ignore
self.eval() # type: ignore

def _save_pretrained(self, save_directory: Path) -> None:
"""Save weights from a Pytorch model to a local directory."""
self.save_model_file(str(save_directory / self.weights_file_name))

@classmethod
def _from_pretrained(
cls: Type[TModel],
Expand Down

0 comments on commit 86eddae

Please sign in to comment.