Skip to content

Commit

Permalink
fiy typing
Browse files Browse the repository at this point in the history
  • Loading branch information
ArneBinder committed Nov 3, 2024
1 parent 8e71ee5 commit 3b9f10e
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions src/pytorch_ie/core/hf_hub_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,9 @@ def _from_config(cls: Type[T], config: dict, **kwargs) -> T:
return cls(**config)


TModel = TypeVar("TModel", bound="PieModelHFHubMixin")


class PieModelHFHubMixin(PieBaseHFHubMixin):
config_name = MODEL_CONFIG_NAME
config_type_key = MODEL_CONFIG_TYPE_KEY
Expand Down Expand Up @@ -397,7 +400,7 @@ def load_weights(

@classmethod
def _from_pretrained(
cls: Type[T],
cls: Type[TModel],
*,
model_id: str,
revision: Optional[str],
Expand All @@ -411,7 +414,7 @@ def _from_pretrained(
strict: bool = False,
config: Optional[dict] = None,
**model_kwargs,
) -> T:
) -> TModel:
"""Load Pytorch pretrained weights and return the loaded model."""
if os.path.isdir(model_id):
logger.info("Loading weights from local directory")
Expand Down Expand Up @@ -440,6 +443,9 @@ def _from_pretrained(
return model


TTaskModule = TypeVar("TTaskModule", bound="PieTaskModuleHFHubMixin")


class PieTaskModuleHFHubMixin(PieBaseHFHubMixin):
config_name = TASKMODULE_CONFIG_NAME
config_type_key = TASKMODULE_CONFIG_TYPE_KEY
Expand All @@ -452,7 +458,7 @@ def _save_pretrained(self, save_directory):

@classmethod
def _from_pretrained(
cls: Type[T],
cls: Type[TTaskModule],
*,
model_id: str,
revision: Optional[str],
Expand All @@ -466,7 +472,7 @@ def _from_pretrained(
strict: bool = False,
config: Optional[dict] = None,
**taskmodule_kwargs,
) -> T:
) -> TTaskModule:
config = (config or {}).copy()
config.update(taskmodule_kwargs)
if cls.config_type_key is not None:
Expand Down

0 comments on commit 3b9f10e

Please sign in to comment.