From 01ad80f820db828ebe68acc0555f177fbf1d4baf Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Thu, 28 Nov 2024 15:05:19 +0000 Subject: [PATCH] Improve `.from_pretrained` type annotations (#34973) * Fix from_pretrained type annotations * Better typing for image processor's `from_pretrained` --- src/transformers/image_processing_base.py | 9 ++++++--- src/transformers/modeling_utils.py | 12 ++++++++---- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/src/transformers/image_processing_base.py b/src/transformers/image_processing_base.py index bc7fd228edcc8c..e73d4a8a56f311 100644 --- a/src/transformers/image_processing_base.py +++ b/src/transformers/image_processing_base.py @@ -19,7 +19,7 @@ import os import warnings from io import BytesIO -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union import numpy as np import requests @@ -45,6 +45,9 @@ from PIL import Image +ImageProcessorType = TypeVar("ImageProcessorType", bound="ImageProcessingMixin") + + logger = logging.get_logger(__name__) @@ -95,7 +98,7 @@ def _set_processor_class(self, processor_class: str): @classmethod def from_pretrained( - cls, + cls: Type[ImageProcessorType], pretrained_model_name_or_path: Union[str, os.PathLike], cache_dir: Optional[Union[str, os.PathLike]] = None, force_download: bool = False, @@ -103,7 +106,7 @@ def from_pretrained( token: Optional[Union[str, bool]] = None, revision: str = "main", **kwargs, - ): + ) -> ImageProcessorType: r""" Instantiate a type of [`~image_processing_utils.ImageProcessingMixin`] from an image processor. diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 3f3b3d337d7119..0806c318e10152 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -30,7 +30,7 @@ from dataclasses import dataclass from functools import partial, wraps from multiprocessing import Process -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union from zipfile import is_zipfile import torch @@ -170,6 +170,10 @@ def is_local_dist_rank_0(): if is_peft_available(): from .utils import find_adapter_config_file + +SpecificPreTrainedModelType = TypeVar("SpecificPreTrainedModelType", bound="PreTrainedModel") + + TORCH_INIT_FUNCTIONS = { "uniform_": nn.init.uniform_, "normal_": nn.init.normal_, @@ -3142,7 +3146,7 @@ def float(self, *args): @classmethod def from_pretrained( - cls, + cls: Type[SpecificPreTrainedModelType], pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None, @@ -3152,10 +3156,10 @@ def from_pretrained( local_files_only: bool = False, token: Optional[Union[str, bool]] = None, revision: str = "main", - use_safetensors: bool = None, + use_safetensors: Optional[bool] = None, weights_only: bool = True, **kwargs, - ) -> "PreTrainedModel": + ) -> SpecificPreTrainedModelType: r""" Instantiate a pretrained pytorch model from a pre-trained model configuration.