diff --git a/src/pytorch_ie/auto.py b/src/pytorch_ie/auto.py index 99052a02..bf35fd03 100644 --- a/src/pytorch_ie/auto.py +++ b/src/pytorch_ie/auto.py @@ -1,57 +1,39 @@ import os -from typing import Any, Dict, Optional +from pathlib import Path +from typing import Any, Dict, Optional, Type, Union import torch from huggingface_hub.constants import PYTORCH_WEIGHTS_NAME from huggingface_hub.file_download import hf_hub_download from pytorch_ie.core import PyTorchIEModel, TaskModule -from pytorch_ie.core.hf_hub_mixin import PyTorchIEModelHubMixin, PyTorchIETaskmoduleModelHubMixin +from pytorch_ie.core.hf_hub_mixin import PieModelHFHubMixin, PieTaskModuleHFHubMixin from pytorch_ie.pipeline import Pipeline -class AutoTaskModule(PyTorchIETaskmoduleModelHubMixin): +class AutoModel(PieModelHFHubMixin): @classmethod def _from_pretrained( cls, - model_id, - revision, - cache_dir, - force_download, - proxies, - resume_download, - local_files_only, - use_auth_token, - **module_kwargs, - ) -> TaskModule: - class_name = module_kwargs.pop(cls.config_type_key) - clazz = TaskModule.by_name(class_name) # type: ignore - taskmodule: TaskModule = clazz(**module_kwargs) - taskmodule._post_prepare() - return taskmodule - - -class AutoModel(PyTorchIEModelHubMixin): - @classmethod - def _from_pretrained( - cls, - model_id, - revision, - cache_dir, - force_download, - proxies, - resume_download, - local_files_only, - use_auth_token, - map_location="cpu", - strict=False, + *, + model_id: str, + revision: Optional[str], + cache_dir: Optional[Union[str, Path]], + force_download: bool, + proxies: Optional[Dict], + resume_download: bool, + local_files_only: bool, + token: Union[str, bool, None], + map_location: str = "cpu", + strict: bool = False, + config: Optional[dict] = None, **model_kwargs, ) -> PyTorchIEModel: """ Overwrite this method in case you wish to initialize your model in a different way. """ - map_location = torch.device(map_location) + """Load Pytorch pretrained weights and return the loaded model.""" if os.path.isdir(model_id): print("Loading weights from local directory") model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME) @@ -64,21 +46,50 @@ def _from_pretrained( force_download=force_download, proxies=proxies, resume_download=resume_download, - use_auth_token=use_auth_token, + token=token, local_files_only=local_files_only, ) - class_name = model_kwargs.pop(cls.config_type_key) + config = (config or {}).copy() + config.update(model_kwargs) + class_name = config.pop(cls.config_type_key) clazz = PyTorchIEModel.by_name(class_name) - model = clazz(**model_kwargs) + model = clazz(**config) - state_dict = torch.load(model_file, map_location=map_location) - model.load_state_dict(state_dict, strict=strict) - model.eval() + state_dict = torch.load(model_file, map_location=torch.device(map_location)) + model.load_state_dict(state_dict, strict=strict) # type: ignore + model.eval() # type: ignore return model +class AutoTaskModule(PieTaskModuleHFHubMixin): + @classmethod + def _from_pretrained( # type: ignore + cls, + *, + model_id: str, + revision: Optional[str], + cache_dir: Optional[Union[str, Path]], + force_download: bool, + proxies: Optional[Dict], + resume_download: bool, + local_files_only: bool, + token: Union[str, bool, None], + map_location: str = "cpu", + strict: bool = False, + config: Optional[dict] = None, + **taskmodule_kwargs, + ) -> TaskModule: + config = (config or {}).copy() + config.update(taskmodule_kwargs) + class_name = config.pop(cls.config_type_key) + clazz: Type[TaskModule] = TaskModule.by_name(class_name) + taskmodule = clazz(**config) + taskmodule.post_prepare() + return taskmodule + + class AutoPipeline: @staticmethod def from_pretrained( diff --git a/src/pytorch_ie/core/hf_hub_mixin.py b/src/pytorch_ie/core/hf_hub_mixin.py index f981b7d7..bc49d2cb 100644 --- a/src/pytorch_ie/core/hf_hub_mixin.py +++ b/src/pytorch_ie/core/hf_hub_mixin.py @@ -2,15 +2,14 @@ import logging import os from pathlib import Path -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, List, Optional, Type, TypeVar, Union -# TODO: fix ignore import requests # type: ignore import torch from huggingface_hub.constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME from huggingface_hub.file_download import hf_hub_download -from huggingface_hub.hf_api import HfApi, HfFolder -from huggingface_hub.repository import Repository +from huggingface_hub.hf_api import HfApi +from huggingface_hub.utils import SoftTemporaryDirectory, validate_hf_hub_args logger = logging.getLogger(__name__) @@ -19,12 +18,17 @@ MODEL_CONFIG_TYPE_KEY = "model_type" TASKMODULE_CONFIG_TYPE_KEY = "taskmodule_type" +# Generic variable that is either PieBaseHFHubMixin or a subclass thereof +T = TypeVar("T", bound="PieBaseHFHubMixin") -class PyTorchIEBaseModelHubMixin: + +class PieBaseHFHubMixin: """ - A Generic Base Model Hub Mixin. Define your own mixin for anything by inheriting from this class - and overwriting _from_pretrained and _save_pretrained to define custom logic for saving/loading - your classes. See ``huggingface_hub.PyTorchModelHubMixin`` for an example. + A generic mixin to integrate ANY machine learning framework with the Hub. + + To integrate your framework, your model class must inherit from this class. Custom logic for saving/loading models + have to be overwritten in [`_from_pretrained`] and [`_save_pretrained`]. [`PyTorchModelHubMixin`] is a good example + of mixin integration with the Hub. Check out our [integration guide](../guides/integrations) for more instructions. """ config_name = MODEL_CONFIG_NAME @@ -38,334 +42,334 @@ def __init__(self, *args, is_from_pretrained: bool = False, **kwargs): def is_from_pretrained(self): return self._is_from_pretrained + def _config(self) -> Optional[Dict[str, Any]]: + return None + def save_pretrained( self, - save_directory: str, + save_directory: Union[str, Path], + *, + repo_id: Optional[str] = None, push_to_hub: bool = False, **kwargs, - ): + ) -> Optional[str]: """ - Saving weights in local directory. - - Parameters: - save_directory (:obj:`str`): - Specify directory in which you want to save weights. - config (:obj:`dict`, `optional`): - specify config (must be dict) incase you want to save it. - push_to_hub (:obj:`bool`, `optional`, defaults to :obj:`False`): - Set it to `True` in case you want to push your weights to huggingface_hub - model_id (:obj:`str`, `optional`, defaults to :obj:`save_directory`): - Repo name in huggingface_hub. If not specified, repo name will be same as `save_directory` - kwargs (:obj:`Dict`, `optional`): - kwargs will be passed to `push_to_hub` + Save weights in local directory. + + Args: + save_directory (`str` or `Path`): + Path to directory in which the model weights and configuration will be saved. + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether or not to push your model to the Huggingface Hub after saving it. + repo_id (`str`, *optional*): + ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to the folder name if + not provided. + kwargs: + Additional key word arguments passed along to the [`~ModelHubMixin._from_pretrained`] method. """ - - os.makedirs(save_directory, exist_ok=True) - - config_name = self.config_name - config = self._config() - - # saving config - if config is not None: - path = os.path.join(save_directory, config_name) - with open(path, "w") as f: - json.dump(config, f) + save_directory = Path(save_directory) + save_directory.mkdir(parents=True, exist_ok=True) # saving model weights/files self._save_pretrained(save_directory) - if push_to_hub: - return self.push_to_hub(save_directory, **kwargs) - - def _config_name(self) -> Optional[str]: - return None + # saving config + config = self._config() + if isinstance(config, dict): + (save_directory / self.config_name).write_text(json.dumps(config, indent=2)) - def _config(self) -> Optional[Dict[str, Any]]: + if push_to_hub: + kwargs = kwargs.copy() # soft-copy to avoid mutating input + if config is not None: # kwarg for `push_to_hub` + kwargs["config"] = config + if repo_id is None: + repo_id = save_directory.name # Defaults to `save_directory` name + return self.push_to_hub(repo_id=repo_id, **kwargs) return None - def _save_pretrained(self, save_directory): + def _save_pretrained(self, save_directory: Path) -> None: """ Overwrite this method in subclass to define how to save your model. + Check out our [integration guide](../guides/integrations) for instructions. + + Args: + save_directory (`str` or `Path`): + Path to directory in which the model weights and configuration will be saved. """ raise NotImplementedError @classmethod + @validate_hf_hub_args def from_pretrained( - cls, - pretrained_model_name_or_path: str, + cls: Type[T], + pretrained_model_name_or_path: Union[str, Path], + *, force_download: bool = False, resume_download: bool = False, proxies: Optional[Dict] = None, - use_auth_token: Optional[str] = None, - cache_dir: Optional[str] = None, + token: Optional[Union[str, bool]] = None, + cache_dir: Optional[Union[str, Path]] = None, local_files_only: bool = False, + revision: Optional[str] = None, **model_kwargs, - ): - r""" - Instantiate a pretrained pytorch model from a pre-trained model configuration from huggingface-hub. - The model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated). To - train the model, you should first set it back in training mode with ``model.train()``. - - Parameters: - pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`, `optional`): - Can be either: - - A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co. - Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under - a user or organization name, like ``dbmdz/bert-base-german-cased``. - - You can add `revision` by appending `@` at the end of model_id simply like this: ``dbmdz/bert-base-german-cased@main`` - Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id, - since we use a git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any identifier allowed by git. - - A path to a `directory` containing model weights saved using - :func:`~transformers.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``. - - :obj:`None` if you are both providing the configuration and state dictionary (resp. with keyword - arguments ``config`` and ``state_dict``). - cache_dir (:obj:`Union[str, os.PathLike]`, `optional`): - Path to a directory in which a downloaded pretrained model configuration should be cached if the - standard cache should not be used. - force_download (:obj:`bool`, `optional`, defaults to :obj:`False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`): - Whether or not to delete incompletely received files. Will attempt to resume the download if such a - file exists. - proxies (:obj:`Dict[str, str], `optional`): - A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`): - Whether or not to only look at local files (i.e., do not try to download the model). - use_auth_token (:obj:`str` or `bool`, `optional`): - The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token - generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`). - model_kwargs (:obj:`Dict`, `optional`):: - model_kwargs will be passed to the model during initialization - .. note:: - Passing :obj:`use_auth_token=True` is required when you want to use a private model. + ) -> T: + """ + Download a model from the Huggingface Hub and instantiate it. + + Args: + pretrained_model_name_or_path (`str`, `Path`): + - Either the `model_id` (string) of a model hosted on the Hub, e.g. `bigscience/bloom`. + - Or a path to a `directory` containing model weights saved using + [`~transformers.PreTrainedModel.save_pretrained`], e.g., `../path/to/my_model_directory/`. + revision (`str`, *optional*): + Revision of the model on the Hub. Can be a branch name, a git tag or any commit id. + Defaults to the latest commit on `main` branch. + force_download (`bool`, *optional*, defaults to `False`): + Whether to force (re-)downloading the model weights and configuration files from the Hub, overriding + the existing cache. + resume_download (`bool`, *optional*, defaults to `False`): + Whether to delete incompletely received files. Will attempt to resume the download if such a file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on every request. + token (`str` or `bool`, *optional*): + The token to use as HTTP bearer authorization for remote files. By default, it will use the token + cached when running `huggingface-cli login`. + cache_dir (`str`, `Path`, *optional*): + Path to the folder where cached files are stored. + local_files_only (`bool`, *optional*, defaults to `False`): + If `True`, avoid downloading the file and return the path to the local cached file if it exists. + model_kwargs (`Dict`, *optional*): + Additional kwargs to pass to the model during initialization. """ - model_id = pretrained_model_name_or_path - - config_name = cls.config_name - - revision = None - if len(model_id.split("@")) == 2: - model_id, revision = model_id.split("@") - - if os.path.isdir(model_id) and config_name in os.listdir(model_id): - config_file = os.path.join(model_id, config_name) - else: + config_file: Optional[str] = None + if os.path.isdir(model_id): + if cls.config_name in os.listdir(model_id): + config_file = os.path.join(model_id, cls.config_name) + else: + logger.warning(f"{cls.config_name} not found in {Path(model_id).resolve()}") + elif isinstance(model_id, str): try: config_file = hf_hub_download( - repo_id=model_id, - filename=config_name, + repo_id=str(model_id), + filename=cls.config_name, revision=revision, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, - use_auth_token=use_auth_token, + token=token, local_files_only=local_files_only, ) except requests.exceptions.RequestException: - logger.warning(f"{config_name} not found in HuggingFace Hub") - config_file = None + logger.warning(f"{cls.config_name} not found in HuggingFace Hub.") - config = {} if config_file is not None: with open(config_file, encoding="utf-8") as f: config = json.load(f) + model_kwargs.update({"config": config}) - config.update(model_kwargs) - - # the entry may be already in the config, so we overwrite it - config["is_from_pretrained"] = True + # The value of is_from_pretrained is set to True when the model is loaded from pretrained. + # Note that the value may be already available in model_kwargs. + model_kwargs["is_from_pretrained"] = True return cls._from_pretrained( - model_id, - revision, - cache_dir, - force_download, - proxies, - resume_download, - local_files_only, - use_auth_token, - **config, + model_id=str(model_id), + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + token=token, + **model_kwargs, ) @classmethod def _from_pretrained( - cls, - model_id, - revision, - cache_dir, - force_download, - proxies, - resume_download, - local_files_only, - use_auth_token, + cls: Type[T], + *, + model_id: str, + revision: Optional[str], + cache_dir: Optional[Union[str, Path]], + force_download: bool, + proxies: Optional[Dict], + resume_download: bool, + local_files_only: bool, + token: Optional[Union[str, bool]], **model_kwargs, - ): - """Overwrite this method in subclass to define how to load your model from pretrained""" + ) -> T: + """Overwrite this method in subclass to define how to load your model from pretrained. + + Use [`hf_hub_download`] or [`snapshot_download`] to download files from the Hub before loading them. Most + args taken as input can be directly passed to those 2 methods. If needed, you can add more arguments to this + method using "model_kwargs". For example [`PyTorchModelHubMixin._from_pretrained`] takes as input a `map_location` + parameter to set on which device the model should be loaded. + + Check out our [integration guide](../guides/integrations) for more instructions. + + Args: + model_id (`str`): + ID of the model to load from the Huggingface Hub (e.g. `bigscience/bloom`). + revision (`str`, *optional*): + Revision of the model on the Hub. Can be a branch name, a git tag or any commit id. Defaults to the + latest commit on `main` branch. + force_download (`bool`, *optional*, defaults to `False`): + Whether to force (re-)downloading the model weights and configuration files from the Hub, overriding + the existing cache. + resume_download (`bool`, *optional*, defaults to `False`): + Whether to delete incompletely received files. Will attempt to resume the download if such a file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint (e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`). + token (`str` or `bool`, *optional*): + The token to use as HTTP bearer authorization for remote files. By default, it will use the token + cached when running `huggingface-cli login`. + cache_dir (`str`, `Path`, *optional*): + Path to the folder where cached files are stored. + local_files_only (`bool`, *optional*, defaults to `False`): + If `True`, avoid downloading the file and return the path to the local cached file if it exists. + model_kwargs: + Additional keyword arguments passed along to the [`~ModelHubMixin._from_pretrained`] method. + """ raise NotImplementedError + @validate_hf_hub_args def push_to_hub( self, - repo_path_or_name: Optional[str] = None, - repo_url: Optional[str] = None, - commit_message: Optional[str] = "Add model", - organization: Optional[str] = None, - private: Optional[bool] = None, - api_endpoint: Optional[str] = None, - use_auth_token: Optional[Union[bool, str]] = None, - git_user: Optional[str] = None, - git_email: Optional[str] = None, + repo_id: str, + *, config: Optional[dict] = None, + commit_message: str = "Push model using huggingface_hub.", + private: bool = False, + api_endpoint: Optional[str] = None, + token: Optional[str] = None, + branch: Optional[str] = None, + create_pr: Optional[bool] = None, + allow_patterns: Optional[Union[List[str], str]] = None, + ignore_patterns: Optional[Union[List[str], str]] = None, + delete_patterns: Optional[Union[List[str], str]] = None, ) -> str: """ - Upload model checkpoint or tokenizer files to the 🤗 Model Hub while synchronizing a local clone of the repo in - :obj:`repo_path_or_name`. - - Parameters: - repo_path_or_name (:obj:`str`, `optional`): - Can either be a repository name for your model or tokenizer in the Hub or a path to a local folder (in - which case the repository will have the name of that local folder). If not specified, will default to - the name given by :obj:`repo_url` and a local directory with that name will be created. - repo_url (:obj:`str`, `optional`): - Specify this in case you want to push to an existing repository in the hub. If unspecified, a new - repository will be created in your namespace (unless you specify an :obj:`organization`) with - :obj:`repo_name`. - commit_message (:obj:`str`, `optional`): - Message to commit while pushing. Will default to :obj:`"add config"`, :obj:`"add tokenizer"` or - :obj:`"add model"` depending on the type of the class. - organization (:obj:`str`, `optional`): - Organization in which you want to push your model or tokenizer (you must be a member of this - organization). - private (:obj:`bool`, `optional`): - Whether or not the repository created should be private (requires a paying subscription). - api_endpoint (:obj:`str`, `optional`): - The API endpoint to use when pushing the model to the hub. - use_auth_token (:obj:`bool` or :obj:`str`, `optional`): - The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token - generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`). Will default to - :obj:`True` if :obj:`repo_url` is not specified. - git_user (``str``, `optional`): - will override the ``git config user.name`` for committing and pushing files to the hub. - git_email (``str``, `optional`): - will override the ``git config user.email`` for committing and pushing files to the hub. - config (:obj:`dict`, `optional`): - Configuration object to be saved alongside the model weights. + Upload model checkpoint to the Hub. + + Use `allow_patterns` and `ignore_patterns` to precisely filter which files should be pushed to the hub. Use + `delete_patterns` to delete existing remote files in the same commit. See [`upload_folder`] reference for more + details. + Args: + repo_id (`str`): + ID of the repository to push to (example: `"username/my-model"`). + config (`dict`, *optional*): + Configuration object to be saved alongside the model weights. + commit_message (`str`, *optional*): + Message to commit while pushing. + private (`bool`, *optional*, defaults to `False`): + Whether the repository created should be private. + api_endpoint (`str`, *optional*): + The API endpoint to use when pushing the model to the hub. + token (`str`, *optional*): + The token to use as HTTP bearer authorization for remote files. By default, it will use the token + cached when running `huggingface-cli login`. + branch (`str`, *optional*): + The git branch on which to push the model. This defaults to `"main"`. + create_pr (`boolean`, *optional*): + Whether or not to create a Pull Request from `branch` with that commit. Defaults to `False`. + allow_patterns (`List[str]` or `str`, *optional*): + If provided, only files matching at least one pattern are pushed. + ignore_patterns (`List[str]` or `str`, *optional*): + If provided, files matching any of the patterns are not pushed. + delete_patterns (`List[str]` or `str`, *optional*): + If provided, remote files matching any of the patterns will be deleted from the repo. + Returns: The url of the commit of your model in the given repository. """ - - if use_auth_token is None and repo_url is None: - token = HfFolder.get_token() - if token is None: - raise ValueError( - "You must login to the Hugging Face hub on this computer by typing `transformers-cli login` and " - "entering your credentials to use `use_auth_token=True`. Alternatively, you can pass your own " - "token as the `use_auth_token` argument." - ) - elif isinstance(use_auth_token, str): - token = use_auth_token - else: - token = None - - if repo_path_or_name is None: - if repo_url is None: - raise ValueError("You need to specify a `repo_path_or_name` or a `repo_url`.") - repo_path_or_name = repo_url.split("/")[-1] - - # If no URL is passed and there's no path to a directory containing files, create a repo - if repo_url is None and not os.path.exists(repo_path_or_name): - repo_name = Path(repo_path_or_name).name - repo_url = HfApi(endpoint=api_endpoint).create_repo( - repo_name, - token=token, - organization=organization, - private=private, - repo_type=None, - exist_ok=True, + api = HfApi(endpoint=api_endpoint, token=token) + repo_id = api.create_repo(repo_id=repo_id, private=private, exist_ok=True).repo_id + + # Push the files to the repo in a single commit + with SoftTemporaryDirectory() as tmp: + saved_path = Path(tmp) / repo_id + self.save_pretrained(saved_path, config=config) + return api.upload_folder( + repo_id=repo_id, + repo_type="model", + folder_path=saved_path, + commit_message=commit_message, + revision=branch, + create_pr=create_pr, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + delete_patterns=delete_patterns, ) - repo = Repository( - repo_path_or_name, - clone_from=repo_url, - use_auth_token=use_auth_token, - git_user=git_user, - git_email=git_email, - ) - repo.git_pull(rebase=True) - # Save the files in the cloned repo - self.save_pretrained(repo_path_or_name, config=config) +class PieModelHFHubMixin(PieBaseHFHubMixin): + config_name = MODEL_CONFIG_NAME + config_type_key = MODEL_CONFIG_TYPE_KEY - # Commit and push! - repo.git_add() - repo.git_commit(commit_message) - return repo.git_push() + """ + Implementation of [`ModelHubMixin`] to provide model Hub upload/download capabilities to PyTorch models. The model + is set in evaluation mode by default using `model.eval()` (dropout modules are deactivated). To train the model, + you should first set it back in training mode with `model.train()`. + Example: -class PyTorchIEModelHubMixin(PyTorchIEBaseModelHubMixin): - config_name = MODEL_CONFIG_NAME - config_type_key = MODEL_CONFIG_TYPE_KEY + ```python + >>> import torch + >>> import torch.nn as nn + >>> from huggingface_hub import PyTorchModelHubMixin - def __init__(self, *args, **kwargs): - """ - Mix this class with your torch-model class for ease process of saving & loading from huggingface-hub - Example:: + >>> class MyModel(nn.Module, PyTorchModelHubMixin): + ... def __init__(self): + ... super().__init__() + ... self.param = nn.Parameter(torch.rand(3, 4)) + ... self.linear = nn.Linear(4, 5) - >>> from huggingface_hub import PyTorchModelHubMixin + ... def forward(self, x): + ... return self.linear(x + self.param) + >>> model = MyModel() - >>> class MyModel(nn.Module, PyTorchModelHubMixin): - ... def __init__(self, **kwargs): - ... super().__init__() - ... self.config = kwargs.pop("config", None) - ... self.layer = ... - ... def forward(self, ...) - ... return ... + # Save model weights to local directory + >>> model.save_pretrained("my-awesome-model") - >>> model = MyModel() - >>> model.save_pretrained("mymodel", push_to_hub=False) # Saving model weights in the directory - >>> model.push_to_hub("mymodel", "model-1") # Pushing model-weights to hf-hub + # Push model weights to the Hub + >>> model.push_to_hub("my-awesome-model") - >>> # Downloading weights from hf-hub & model will be initialized from those weights - >>> model = MyModel.from_pretrained("username/mymodel@main") - """ - super().__init__(*args, **kwargs) + # Download and initialize weights from the Hub + >>> model = MyModel.from_pretrained("username/my-awesome-model") + ``` + """ - def _save_pretrained(self, save_directory): - """ - Overwrite this method in case you don't want to save complete model, rather some specific layers - """ - path = os.path.join(save_directory, PYTORCH_WEIGHTS_NAME) - model_to_save = self.module if hasattr(self, "module") else self - torch.save(model_to_save.state_dict(), path) + def _save_pretrained(self, save_directory: Path) -> None: + """Save weights from a Pytorch model to a local directory.""" + model_to_save = self.module if hasattr(self, "module") else self # type: ignore + torch.save(model_to_save.state_dict(), save_directory / PYTORCH_WEIGHTS_NAME) @classmethod def _from_pretrained( cls, - model_id, - revision, - cache_dir, - force_download, - proxies, - resume_download, - local_files_only, - use_auth_token, - map_location="cpu", - strict=False, + *, + model_id: str, + revision: Optional[str], + cache_dir: Optional[Union[str, Path]], + force_download: bool, + proxies: Optional[Dict], + resume_download: bool, + local_files_only: bool, + token: Union[str, bool, None], + map_location: str = "cpu", + strict: bool = False, + config: Optional[dict] = None, **model_kwargs, - ): - """ - Overwrite this method in case you wish to initialize your model in a different way. - """ - map_location = torch.device(map_location) - + ) -> "PieModelHFHubMixin": + """Load Pytorch pretrained weights and return the loaded model.""" if os.path.isdir(model_id): - print("Loading weights from local directory") + logger.info("Loading weights from local directory") model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME) else: model_file = hf_hub_download( @@ -376,66 +380,59 @@ def _from_pretrained( force_download=force_download, proxies=proxies, resume_download=resume_download, - use_auth_token=use_auth_token, + token=token, local_files_only=local_files_only, ) + + config = (config or {}).copy() + config.update(model_kwargs) if cls.config_type_key is not None: - model_kwargs.pop(cls.config_type_key) - model = cls(**model_kwargs) + config.pop(cls.config_type_key) + model = cls(**config) - state_dict = torch.load(model_file, map_location=map_location) - model.load_state_dict(state_dict, strict=strict) - model.eval() + state_dict = torch.load(model_file, map_location=torch.device(map_location)) + model.load_state_dict(state_dict, strict=strict) # type: ignore + model.eval() # type: ignore return model -class PyTorchIETaskmoduleModelHubMixin(PyTorchIEBaseModelHubMixin): +class PieTaskModuleHFHubMixin(PieBaseHFHubMixin): config_name = TASKMODULE_CONFIG_NAME config_type_key = TASKMODULE_CONFIG_TYPE_KEY def __init__(self, *args, **kwargs): - """ - Mix this class with your torch-model class for ease process of saving & loading from huggingface-hub - - Example:: - - >>> from huggingface_hub import PyTorchModelHubMixin - - >>> class MyModel(nn.Module, PyTorchModelHubMixin): - ... def __init__(self, **kwargs): - ... super().__init__() - ... self.config = kwargs.pop("config", None) - ... self.layer = ... - ... def forward(self, ...) - ... return ... - - >>> model = MyModel() - >>> model.save_pretrained("mymodel", push_to_hub=False) # Saving model weights in the directory - >>> model.push_to_hub("mymodel", "model-1") # Pushing model-weights to hf-hub - - >>> # Downloading weights from hf-hub & model will be initialized from those weights - >>> model = MyModel.from_pretrained("username/mymodel@main") - """ super().__init__(*args, **kwargs) def _save_pretrained(self, save_directory): return None + def post_prepare(self) -> None: + pass + @classmethod def _from_pretrained( cls, - model_id, - revision, - cache_dir, - force_download, - proxies, - resume_download, - local_files_only, - use_auth_token, - **module_kwargs, - ): + *, + model_id: str, + revision: Optional[str], + cache_dir: Optional[Union[str, Path]], + force_download: bool, + proxies: Optional[Dict], + resume_download: bool, + local_files_only: bool, + token: Union[str, bool, None], + map_location: str = "cpu", + strict: bool = False, + config: Optional[dict] = None, + **taskmodule_kwargs, + ) -> "PieTaskModuleHFHubMixin": + config = (config or {}).copy() + config.update(taskmodule_kwargs) if cls.config_type_key is not None: - module_kwargs.pop(cls.config_type_key) - taskmodule = cls(**module_kwargs) + config.pop(cls.config_type_key) + + taskmodule = cls(**config) + taskmodule.post_prepare() + return taskmodule diff --git a/src/pytorch_ie/core/model.py b/src/pytorch_ie/core/model.py index 1dad278a..8536893c 100644 --- a/src/pytorch_ie/core/model.py +++ b/src/pytorch_ie/core/model.py @@ -2,11 +2,11 @@ from pytorch_lightning import LightningModule -from pytorch_ie.core.hf_hub_mixin import PyTorchIEModelHubMixin +from pytorch_ie.core.hf_hub_mixin import PieModelHFHubMixin from pytorch_ie.core.registrable import Registrable -class PyTorchIEModel(PyTorchIEModelHubMixin, LightningModule, Registrable): +class PyTorchIEModel(PieModelHFHubMixin, LightningModule, Registrable): def _config(self) -> Dict[str, Any]: config = super()._config() or {} config[self.config_type_key] = PyTorchIEModel.name_for_object_class(self) diff --git a/src/pytorch_ie/core/taskmodule.py b/src/pytorch_ie/core/taskmodule.py index 38d46f9d..3f055352 100644 --- a/src/pytorch_ie/core/taskmodule.py +++ b/src/pytorch_ie/core/taskmodule.py @@ -23,7 +23,7 @@ from tqdm import tqdm from pytorch_ie.core.document import Annotation, Document -from pytorch_ie.core.hf_hub_mixin import PyTorchIETaskmoduleModelHubMixin +from pytorch_ie.core.hf_hub_mixin import PieTaskModuleHFHubMixin from pytorch_ie.core.registrable import Registrable from pytorch_ie.data import Dataset, IterableDataset @@ -146,7 +146,7 @@ def __len__(self) -> int: class TaskModule( ABC, - PyTorchIETaskmoduleModelHubMixin, + PieTaskModuleHFHubMixin, HyperparametersMixin, Registrable, Generic[ @@ -206,6 +206,10 @@ def _assert_is_prepared(self, msg: Optional[str] = None): f"{msg or ''} Required attributes that are not set: {str(attributes_not_prepared)}" ) + def post_prepare(self): + self._assert_is_prepared() + self._post_prepare() + def prepare(self, documents: Sequence[DocumentType]) -> None: if self.is_prepared: if len(self.PREPARED_ATTRIBUTES) > 0: