diff --git a/src/pytorch_ie/core/hf_hub_mixin.py b/src/pytorch_ie/core/hf_hub_mixin.py index 4dc2a316..bc49d2cb 100644 --- a/src/pytorch_ie/core/hf_hub_mixin.py +++ b/src/pytorch_ie/core/hf_hub_mixin.py @@ -4,7 +4,6 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Type, TypeVar, Union -# TODO: fix ignore import requests # type: ignore import torch from huggingface_hub.constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME @@ -19,7 +18,7 @@ MODEL_CONFIG_TYPE_KEY = "model_type" TASKMODULE_CONFIG_TYPE_KEY = "taskmodule_type" -# Generic variable that is either ModelHubMixin or a subclass thereof +# Generic variable that is either PieBaseHFHubMixin or a subclass thereof T = TypeVar("T", bound="PieBaseHFHubMixin") @@ -32,28 +31,24 @@ class PieBaseHFHubMixin: of mixin integration with the Hub. Check out our [integration guide](../guides/integrations) for more instructions. """ - config_name = MODEL_CONFIG_NAME # @ArneBinder - config_type_key = MODEL_CONFIG_TYPE_KEY # @ArneBinder + config_name = MODEL_CONFIG_NAME + config_type_key = MODEL_CONFIG_TYPE_KEY - def __init__(self, *args, is_from_pretrained: bool = False, **kwargs): # @ArneBinder - super().__init__(*args, **kwargs) # @ArneBinder - self._is_from_pretrained = is_from_pretrained # @ArneBinder + def __init__(self, *args, is_from_pretrained: bool = False, **kwargs): + super().__init__(*args, **kwargs) + self._is_from_pretrained = is_from_pretrained @property - def is_from_pretrained(self): # @ArneBinder - return self._is_from_pretrained # @ArneBinder + def is_from_pretrained(self): + return self._is_from_pretrained - # def _config_name(self) -> Optional[str]: # @ArneBinder - # return None # @ArneBinder - - def _config(self) -> Optional[Dict[str, Any]]: # @ArneBinder - return None # @ArneBinder + def _config(self) -> Optional[Dict[str, Any]]: + return None def save_pretrained( self, save_directory: Union[str, Path], *, - # config: Optional[dict] = None, # config: Optional[dict] = None @ArneBinder repo_id: Optional[str] = None, push_to_hub: bool = False, **kwargs, @@ -64,8 +59,6 @@ def save_pretrained( Args: save_directory (`str` or `Path`): Path to directory in which the model weights and configuration will be saved. - config (`dict`, *optional*): - Model configuration specified as a key/value dictionary. push_to_hub (`bool`, *optional*, defaults to `False`): Whether or not to push your model to the Huggingface Hub after saving it. repo_id (`str`, *optional*): @@ -81,7 +74,7 @@ def save_pretrained( self._save_pretrained(save_directory) # saving config - config = self._config() # @ArneBinder + config = self._config() if isinstance(config, dict): (save_directory / self.config_name).write_text(json.dumps(config, indent=2)) @@ -179,7 +172,7 @@ def from_pretrained( # The value of is_from_pretrained is set to True when the model is loaded from pretrained. # Note that the value may be already available in model_kwargs. - model_kwargs["is_from_pretrained"] = True # @ArneBinder + model_kwargs["is_from_pretrained"] = True return cls._from_pretrained( model_id=str(model_id), @@ -376,7 +369,7 @@ def _from_pretrained( ) -> "PieModelHFHubMixin": """Load Pytorch pretrained weights and return the loaded model.""" if os.path.isdir(model_id): - print("Loading weights from local directory") + logger.info("Loading weights from local directory") model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME) else: model_file = hf_hub_download(