Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
ArneBinder committed Sep 13, 2023
1 parent 4b04a00 commit 1be180f
Showing 1 changed file with 13 additions and 20 deletions.
33 changes: 13 additions & 20 deletions src/pytorch_ie/core/hf_hub_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from pathlib import Path
from typing import Any, Dict, List, Optional, Type, TypeVar, Union

# TODO: fix ignore
import requests # type: ignore
import torch
from huggingface_hub.constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME
Expand All @@ -19,7 +18,7 @@
MODEL_CONFIG_TYPE_KEY = "model_type"
TASKMODULE_CONFIG_TYPE_KEY = "taskmodule_type"

# Generic variable that is either ModelHubMixin or a subclass thereof
# Generic variable that is either PieBaseHFHubMixin or a subclass thereof
T = TypeVar("T", bound="PieBaseHFHubMixin")


Expand All @@ -32,28 +31,24 @@ class PieBaseHFHubMixin:
of mixin integration with the Hub. Check out our [integration guide](../guides/integrations) for more instructions.
"""

config_name = MODEL_CONFIG_NAME # @ArneBinder
config_type_key = MODEL_CONFIG_TYPE_KEY # @ArneBinder
config_name = MODEL_CONFIG_NAME
config_type_key = MODEL_CONFIG_TYPE_KEY

def __init__(self, *args, is_from_pretrained: bool = False, **kwargs): # @ArneBinder
super().__init__(*args, **kwargs) # @ArneBinder
self._is_from_pretrained = is_from_pretrained # @ArneBinder
def __init__(self, *args, is_from_pretrained: bool = False, **kwargs):
super().__init__(*args, **kwargs)
self._is_from_pretrained = is_from_pretrained

@property
def is_from_pretrained(self): # @ArneBinder
return self._is_from_pretrained # @ArneBinder
def is_from_pretrained(self):
return self._is_from_pretrained

# def _config_name(self) -> Optional[str]: # @ArneBinder
# return None # @ArneBinder

def _config(self) -> Optional[Dict[str, Any]]: # @ArneBinder
return None # @ArneBinder
def _config(self) -> Optional[Dict[str, Any]]:
return None

def save_pretrained(
self,
save_directory: Union[str, Path],
*,
# config: Optional[dict] = None, # config: Optional[dict] = None @ArneBinder
repo_id: Optional[str] = None,
push_to_hub: bool = False,
**kwargs,
Expand All @@ -64,8 +59,6 @@ def save_pretrained(
Args:
save_directory (`str` or `Path`):
Path to directory in which the model weights and configuration will be saved.
config (`dict`, *optional*):
Model configuration specified as a key/value dictionary.
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether or not to push your model to the Huggingface Hub after saving it.
repo_id (`str`, *optional*):
Expand All @@ -81,7 +74,7 @@ def save_pretrained(
self._save_pretrained(save_directory)

# saving config
config = self._config() # @ArneBinder
config = self._config()
if isinstance(config, dict):
(save_directory / self.config_name).write_text(json.dumps(config, indent=2))

Expand Down Expand Up @@ -179,7 +172,7 @@ def from_pretrained(

# The value of is_from_pretrained is set to True when the model is loaded from pretrained.
# Note that the value may be already available in model_kwargs.
model_kwargs["is_from_pretrained"] = True # @ArneBinder
model_kwargs["is_from_pretrained"] = True

return cls._from_pretrained(
model_id=str(model_id),
Expand Down Expand Up @@ -376,7 +369,7 @@ def _from_pretrained(
) -> "PieModelHFHubMixin":
"""Load Pytorch pretrained weights and return the loaded model."""
if os.path.isdir(model_id):
print("Loading weights from local directory")
logger.info("Loading weights from local directory")
model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME)
else:
model_file = hf_hub_download(
Expand Down

0 comments on commit 1be180f

Please sign in to comment.