Skip to content

Commit

Permalink
Improve .from_pretrained type annotations (#34973)
Browse files Browse the repository at this point in the history
* Fix from_pretrained type annotations

* Better typing for image processor's `from_pretrained`
  • Loading branch information
qubvel authored Nov 28, 2024
1 parent 9d6f0dd commit 01ad80f
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 7 deletions.
9 changes: 6 additions & 3 deletions src/transformers/image_processing_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import os
import warnings
from io import BytesIO
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union

import numpy as np
import requests
Expand All @@ -45,6 +45,9 @@
from PIL import Image


ImageProcessorType = TypeVar("ImageProcessorType", bound="ImageProcessingMixin")


logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -95,15 +98,15 @@ def _set_processor_class(self, processor_class: str):

@classmethod
def from_pretrained(
cls,
cls: Type[ImageProcessorType],
pretrained_model_name_or_path: Union[str, os.PathLike],
cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False,
local_files_only: bool = False,
token: Optional[Union[str, bool]] = None,
revision: str = "main",
**kwargs,
):
) -> ImageProcessorType:
r"""
Instantiate a type of [`~image_processing_utils.ImageProcessingMixin`] from an image processor.
Expand Down
12 changes: 8 additions & 4 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from dataclasses import dataclass
from functools import partial, wraps
from multiprocessing import Process
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union
from zipfile import is_zipfile

import torch
Expand Down Expand Up @@ -170,6 +170,10 @@ def is_local_dist_rank_0():
if is_peft_available():
from .utils import find_adapter_config_file


SpecificPreTrainedModelType = TypeVar("SpecificPreTrainedModelType", bound="PreTrainedModel")


TORCH_INIT_FUNCTIONS = {
"uniform_": nn.init.uniform_,
"normal_": nn.init.normal_,
Expand Down Expand Up @@ -3142,7 +3146,7 @@ def float(self, *args):

@classmethod
def from_pretrained(
cls,
cls: Type[SpecificPreTrainedModelType],
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
*model_args,
config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
Expand All @@ -3152,10 +3156,10 @@ def from_pretrained(
local_files_only: bool = False,
token: Optional[Union[str, bool]] = None,
revision: str = "main",
use_safetensors: bool = None,
use_safetensors: Optional[bool] = None,
weights_only: bool = True,
**kwargs,
) -> "PreTrainedModel":
) -> SpecificPreTrainedModelType:
r"""
Instantiate a pretrained pytorch model from a pre-trained model configuration.
Expand Down

0 comments on commit 01ad80f

Please sign in to comment.