Skip to content

Commit

Permalink
outsource PieModelHFHubMixin.load_model_file (#433)
Browse files Browse the repository at this point in the history
* outsource PieModelHFHubMixin.load_weights

* fiy typing

* parametrize weights_file_name

* use weights_file_name from loaded model

* adjust AutoModel._from_pretrained

* fix PieModelHFHubMixin.load_weights

* rename PieModelHFHubMixin.load_weights() to load_model_file()

* implement PieModelHFHubMixin.save_model_file() for completion
  • Loading branch information
ArneBinder authored Nov 4, 2024
1 parent b62b931 commit 00cd802
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 29 deletions.
21 changes: 9 additions & 12 deletions src/pytorch_ie/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,19 @@ def _from_pretrained(
Overwrite this method in case you wish to initialize your model in a different way.
"""

config = (config or {}).copy()
config.update(model_kwargs)
class_name = config.pop(cls.config_type_key)
clazz = PyTorchIEModel.by_name(class_name)
model = clazz(**config)

"""Load Pytorch pretrained weights and return the loaded model."""
if os.path.isdir(model_id):
print("Loading weights from local directory")
model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME)
model_file = os.path.join(model_id, model.weights_file_name)
else:
model_file = hf_hub_download(
repo_id=model_id,
filename=PYTORCH_WEIGHTS_NAME,
filename=model.weights_file_name,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
Expand All @@ -50,15 +55,7 @@ def _from_pretrained(
local_files_only=local_files_only,
)

config = (config or {}).copy()
config.update(model_kwargs)
class_name = config.pop(cls.config_type_key)
clazz = PyTorchIEModel.by_name(class_name)
model = clazz(**config)

state_dict = torch.load(model_file, map_location=torch.device(map_location))
model.load_state_dict(state_dict, strict=strict) # type: ignore
model.eval() # type: ignore
model.load_model_file(model_file, map_location=map_location, strict=strict)

return model

Expand Down
51 changes: 34 additions & 17 deletions src/pytorch_ie/core/hf_hub_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,9 +345,13 @@ def _from_config(cls: Type[T], config: dict, **kwargs) -> T:
return cls(**config)


TModel = TypeVar("TModel", bound="PieModelHFHubMixin")


class PieModelHFHubMixin(PieBaseHFHubMixin):
config_name = MODEL_CONFIG_NAME
config_type_key = MODEL_CONFIG_TYPE_KEY
weights_file_name = PYTORCH_WEIGHTS_NAME

"""
Implementation of [`ModelHubMixin`] to provide model Hub upload/download capabilities to PyTorch models. The model
Expand Down Expand Up @@ -383,14 +387,25 @@ class PieModelHFHubMixin(PieBaseHFHubMixin):
```
"""

def _save_pretrained(self, save_directory: Path) -> None:
def save_model_file(self, model_file: str) -> None:
"""Save weights from a Pytorch model to a local directory."""
model_to_save = self.module if hasattr(self, "module") else self # type: ignore
torch.save(model_to_save.state_dict(), save_directory / PYTORCH_WEIGHTS_NAME)
torch.save(model_to_save.state_dict(), model_file)

def load_model_file(
self, model_file: str, map_location: str = "cpu", strict: bool = False
) -> None:
state_dict = torch.load(model_file, map_location=torch.device(map_location))
self.load_state_dict(state_dict, strict=strict) # type: ignore
self.eval() # type: ignore

def _save_pretrained(self, save_directory: Path) -> None:
"""Save weights from a Pytorch model to a local directory."""
self.save_model_file(str(save_directory / self.weights_file_name))

@classmethod
def _from_pretrained(
cls: Type[T],
cls: Type[TModel],
*,
model_id: str,
revision: Optional[str],
Expand All @@ -404,15 +419,22 @@ def _from_pretrained(
strict: bool = False,
config: Optional[dict] = None,
**model_kwargs,
) -> T:
) -> TModel:

config = (config or {}).copy()
config.update(model_kwargs)
if cls.config_type_key is not None:
config.pop(cls.config_type_key)
model = cls(**config)

"""Load Pytorch pretrained weights and return the loaded model."""
if os.path.isdir(model_id):
logger.info("Loading weights from local directory")
model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME)
model_file = os.path.join(model_id, model.weights_file_name)
else:
model_file = hf_hub_download(
repo_id=model_id,
filename=PYTORCH_WEIGHTS_NAME,
filename=model.weights_file_name,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
Expand All @@ -422,19 +444,14 @@ def _from_pretrained(
local_files_only=local_files_only,
)

config = (config or {}).copy()
config.update(model_kwargs)
if cls.config_type_key is not None:
config.pop(cls.config_type_key)
model = cls(**config)

state_dict = torch.load(model_file, map_location=torch.device(map_location))
model.load_state_dict(state_dict, strict=strict) # type: ignore
model.eval() # type: ignore
model.load_model_file(model_file, map_location=map_location, strict=strict)

return model


TTaskModule = TypeVar("TTaskModule", bound="PieTaskModuleHFHubMixin")


class PieTaskModuleHFHubMixin(PieBaseHFHubMixin):
config_name = TASKMODULE_CONFIG_NAME
config_type_key = TASKMODULE_CONFIG_TYPE_KEY
Expand All @@ -447,7 +464,7 @@ def _save_pretrained(self, save_directory):

@classmethod
def _from_pretrained(
cls: Type[T],
cls: Type[TTaskModule],
*,
model_id: str,
revision: Optional[str],
Expand All @@ -461,7 +478,7 @@ def _from_pretrained(
strict: bool = False,
config: Optional[dict] = None,
**taskmodule_kwargs,
) -> T:
) -> TTaskModule:
config = (config or {}).copy()
config.update(taskmodule_kwargs)
if cls.config_type_key is not None:
Expand Down

0 comments on commit 00cd802

Please sign in to comment.