From 00cd8021de09b30f093a637b417d6fa35915b4f8 Mon Sep 17 00:00:00 2001 From: ArneBinder Date: Mon, 4 Nov 2024 16:28:32 +0100 Subject: [PATCH] outsource `PieModelHFHubMixin.load_model_file` (#433) * outsource PieModelHFHubMixin.load_weights * fiy typing * parametrize weights_file_name * use weights_file_name from loaded model * adjust AutoModel._from_pretrained * fix PieModelHFHubMixin.load_weights * rename PieModelHFHubMixin.load_weights() to load_model_file() * implement PieModelHFHubMixin.save_model_file() for completion --- src/pytorch_ie/auto.py | 21 +++++------- src/pytorch_ie/core/hf_hub_mixin.py | 51 +++++++++++++++++++---------- 2 files changed, 43 insertions(+), 29 deletions(-) diff --git a/src/pytorch_ie/auto.py b/src/pytorch_ie/auto.py index a23d8228..8485128b 100644 --- a/src/pytorch_ie/auto.py +++ b/src/pytorch_ie/auto.py @@ -33,14 +33,19 @@ def _from_pretrained( Overwrite this method in case you wish to initialize your model in a different way. """ + config = (config or {}).copy() + config.update(model_kwargs) + class_name = config.pop(cls.config_type_key) + clazz = PyTorchIEModel.by_name(class_name) + model = clazz(**config) + """Load Pytorch pretrained weights and return the loaded model.""" if os.path.isdir(model_id): - print("Loading weights from local directory") - model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME) + model_file = os.path.join(model_id, model.weights_file_name) else: model_file = hf_hub_download( repo_id=model_id, - filename=PYTORCH_WEIGHTS_NAME, + filename=model.weights_file_name, revision=revision, cache_dir=cache_dir, force_download=force_download, @@ -50,15 +55,7 @@ def _from_pretrained( local_files_only=local_files_only, ) - config = (config or {}).copy() - config.update(model_kwargs) - class_name = config.pop(cls.config_type_key) - clazz = PyTorchIEModel.by_name(class_name) - model = clazz(**config) - - state_dict = torch.load(model_file, map_location=torch.device(map_location)) - model.load_state_dict(state_dict, strict=strict) # type: ignore - model.eval() # type: ignore + model.load_model_file(model_file, map_location=map_location, strict=strict) return model diff --git a/src/pytorch_ie/core/hf_hub_mixin.py b/src/pytorch_ie/core/hf_hub_mixin.py index 18c58fd0..5a1aabbb 100644 --- a/src/pytorch_ie/core/hf_hub_mixin.py +++ b/src/pytorch_ie/core/hf_hub_mixin.py @@ -345,9 +345,13 @@ def _from_config(cls: Type[T], config: dict, **kwargs) -> T: return cls(**config) +TModel = TypeVar("TModel", bound="PieModelHFHubMixin") + + class PieModelHFHubMixin(PieBaseHFHubMixin): config_name = MODEL_CONFIG_NAME config_type_key = MODEL_CONFIG_TYPE_KEY + weights_file_name = PYTORCH_WEIGHTS_NAME """ Implementation of [`ModelHubMixin`] to provide model Hub upload/download capabilities to PyTorch models. The model @@ -383,14 +387,25 @@ class PieModelHFHubMixin(PieBaseHFHubMixin): ``` """ - def _save_pretrained(self, save_directory: Path) -> None: + def save_model_file(self, model_file: str) -> None: """Save weights from a Pytorch model to a local directory.""" model_to_save = self.module if hasattr(self, "module") else self # type: ignore - torch.save(model_to_save.state_dict(), save_directory / PYTORCH_WEIGHTS_NAME) + torch.save(model_to_save.state_dict(), model_file) + + def load_model_file( + self, model_file: str, map_location: str = "cpu", strict: bool = False + ) -> None: + state_dict = torch.load(model_file, map_location=torch.device(map_location)) + self.load_state_dict(state_dict, strict=strict) # type: ignore + self.eval() # type: ignore + + def _save_pretrained(self, save_directory: Path) -> None: + """Save weights from a Pytorch model to a local directory.""" + self.save_model_file(str(save_directory / self.weights_file_name)) @classmethod def _from_pretrained( - cls: Type[T], + cls: Type[TModel], *, model_id: str, revision: Optional[str], @@ -404,15 +419,22 @@ def _from_pretrained( strict: bool = False, config: Optional[dict] = None, **model_kwargs, - ) -> T: + ) -> TModel: + + config = (config or {}).copy() + config.update(model_kwargs) + if cls.config_type_key is not None: + config.pop(cls.config_type_key) + model = cls(**config) + """Load Pytorch pretrained weights and return the loaded model.""" if os.path.isdir(model_id): logger.info("Loading weights from local directory") - model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME) + model_file = os.path.join(model_id, model.weights_file_name) else: model_file = hf_hub_download( repo_id=model_id, - filename=PYTORCH_WEIGHTS_NAME, + filename=model.weights_file_name, revision=revision, cache_dir=cache_dir, force_download=force_download, @@ -422,19 +444,14 @@ def _from_pretrained( local_files_only=local_files_only, ) - config = (config or {}).copy() - config.update(model_kwargs) - if cls.config_type_key is not None: - config.pop(cls.config_type_key) - model = cls(**config) - - state_dict = torch.load(model_file, map_location=torch.device(map_location)) - model.load_state_dict(state_dict, strict=strict) # type: ignore - model.eval() # type: ignore + model.load_model_file(model_file, map_location=map_location, strict=strict) return model +TTaskModule = TypeVar("TTaskModule", bound="PieTaskModuleHFHubMixin") + + class PieTaskModuleHFHubMixin(PieBaseHFHubMixin): config_name = TASKMODULE_CONFIG_NAME config_type_key = TASKMODULE_CONFIG_TYPE_KEY @@ -447,7 +464,7 @@ def _save_pretrained(self, save_directory): @classmethod def _from_pretrained( - cls: Type[T], + cls: Type[TTaskModule], *, model_id: str, revision: Optional[str], @@ -461,7 +478,7 @@ def _from_pretrained( strict: bool = False, config: Optional[dict] = None, **taskmodule_kwargs, - ) -> T: + ) -> TTaskModule: config = (config or {}).copy() config.update(taskmodule_kwargs) if cls.config_type_key is not None: