Skip to content

Commit

Permalink
Replace check_if_xxx_greater with is_xxx_version (#2152)
Browse files Browse the repository at this point in the history
* add version and avaibility check utils

* replace check_if_transformers_greater with is_transformers_version

* fix style

* fix style

* fix
  • Loading branch information
echarlaix authored Jan 9, 2025
1 parent 600436e commit 605ed7e
Show file tree
Hide file tree
Showing 18 changed files with 148 additions and 111 deletions.
4 changes: 2 additions & 2 deletions optimum/exporters/executorch/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from transformers.utils import is_torch_available

from optimum.utils.import_utils import check_if_transformers_greater
from optimum.utils.import_utils import is_transformers_version

from ...commands.export.executorch import parse_args_executorch
from .convert import export_to_executorch
Expand Down Expand Up @@ -95,7 +95,7 @@ def main_export(
```
"""

if not check_if_transformers_greater("4.46"):
if is_transformers_version("<", "4.46"):
raise ValueError(
"The minimum Transformers version compatible with ExecuTorch is 4.46.0. Please upgrade to Transformers 4.46.0 or later."
)
Expand Down
4 changes: 2 additions & 2 deletions optimum/exporters/executorch/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@

from transformers.utils import is_torch_available

from optimum.utils.import_utils import check_if_transformers_greater
from optimum.utils.import_utils import is_transformers_version

from .recipe_registry import discover_recipes, recipe_registry


if is_torch_available():
from transformers.modeling_utils import PreTrainedModel

if check_if_transformers_greater("4.46"):
if is_transformers_version(">=", "4.46"):
from transformers.integrations.executorch import (
TorchExportableModuleWithStaticCache,
)
Expand Down
6 changes: 3 additions & 3 deletions optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from ...utils import TORCH_MINIMUM_VERSION as GLOBAL_MIN_TORCH_VERSION
from ...utils import TRANSFORMERS_MINIMUM_VERSION as GLOBAL_MIN_TRANSFORMERS_VERSION
from ...utils.doc import add_dynamic_docstring
from ...utils.import_utils import check_if_transformers_greater, is_onnx_available, is_onnxruntime_available
from ...utils.import_utils import is_onnx_available, is_onnxruntime_available, is_transformers_version
from ..base import ExportConfig
from .constants import ONNX_DECODER_MERGED_NAME, ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME
from .model_patcher import ModelPatcher, Seq2SeqModelPatcher
Expand Down Expand Up @@ -156,7 +156,7 @@ class OnnxConfig(ExportConfig, ABC):
),
"mask-generation": OrderedDict({"logits": {0: "batch_size"}}),
"masked-im": OrderedDict(
{"reconstruction" if check_if_transformers_greater("4.29.0") else "logits": {0: "batch_size"}}
{"reconstruction" if is_transformers_version(">=", "4.29.0") else "logits": {0: "batch_size"}}
),
"multiple-choice": OrderedDict({"logits": {0: "batch_size", 1: "num_choices"}}),
"object-detection": OrderedDict(
Expand Down Expand Up @@ -375,7 +375,7 @@ def is_transformers_support_available(self) -> bool:
`bool`: Whether the install version of Transformers is compatible with the model.
"""
return check_if_transformers_greater(self.MIN_TRANSFORMERS_VERSION)
return is_transformers_version(">=", self.MIN_TRANSFORMERS_VERSION.base_version)

@property
def is_torch_support_available(self) -> bool:
Expand Down
6 changes: 3 additions & 3 deletions optimum/exporters/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@
DEFAULT_DUMMY_SHAPES,
ONNX_WEIGHTS_NAME,
TORCH_MINIMUM_VERSION,
check_if_transformers_greater,
is_diffusers_available,
is_torch_onnx_support_available,
is_transformers_version,
logging,
require_numpy_strictly_lower,
)
Expand Down Expand Up @@ -512,7 +512,7 @@ def export_pytorch(

model_kwargs = model_kwargs or {}
# num_logits_to_keep was added in transformers 4.45 and isn't added as inputs when exporting the model
if check_if_transformers_greater("4.44.99") and "num_logits_to_keep" in signature(model.forward).parameters.keys():
if is_transformers_version(">=", "4.44.99") and "num_logits_to_keep" in signature(model.forward).parameters.keys():
model_kwargs["num_logits_to_keep"] = 0

with torch.no_grad():
Expand Down Expand Up @@ -1105,7 +1105,7 @@ def onnx_export_from_model(
if isinstance(atol, dict):
atol = atol[task.replace("-with-past", "")]

if check_if_transformers_greater("4.44.99"):
if is_transformers_version(">=", "4.44.99"):
misplaced_generation_parameters = model.config._get_non_default_generation_parameters()
if (
isinstance(model, GenerationMixin)
Expand Down
24 changes: 9 additions & 15 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@
NormalizedTextConfig,
NormalizedTextConfigWithGQA,
NormalizedVisionConfig,
check_if_diffusers_greater,
check_if_transformers_greater,
is_diffusers_available,
is_diffusers_version,
is_transformers_version,
logging,
)
from ...utils.normalized_config import NormalizedConfigManager
Expand Down Expand Up @@ -310,7 +310,7 @@ class GPTNeoXOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):


# OPT does not take position_ids as input for transfomers < v4.46, needs it for transformers >= v4.46
if check_if_transformers_greater("4.45.99"):
if is_transformers_version(">=", "4.45.99"):

class OPTOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
DEFAULT_ONNX_OPSET = 14 # uses SDPA in Transformers, hence opset>=14.
Expand Down Expand Up @@ -370,8 +370,7 @@ class Phi3OnnxConfig(PhiOnnxConfig):
MIN_TRANSFORMERS_VERSION = version.parse("4.41.0")

def __init__(self, *args, **kwargs):
# TODO : replace check_if_transformers_greater with is_transformers_available
if check_if_transformers_greater("4.46.0") and not check_if_transformers_greater("4.46.1"):
if is_transformers_version("==", "4.46.0"):
logger.error(
"Found transformers v4.46.0 while trying to exporting a Phi3 model, this specific version of transformers is not supported. "
"Please upgrade to v4.46.1 or higher, or downgrade your transformers version"
Expand Down Expand Up @@ -417,7 +416,7 @@ class BloomOnnxConfig(TextDecoderOnnxConfig):
DEFAULT_ONNX_OPSET = 14 # Bloom uses aten::triu that requires opset>=14, and F.scaled_dot_product_attention

def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str):
if check_if_transformers_greater("4.44"):
if is_transformers_version(">=", "4.44"):
super().add_past_key_values(inputs_or_outputs, direction)
else:
if direction not in ["inputs", "outputs"]:
Expand Down Expand Up @@ -1437,11 +1436,11 @@ def inputs(self):
common_inputs = super().inputs
common_inputs["hidden_states"] = {0: "batch_size", 1: "packed_height_width"}
common_inputs["txt_ids"] = (
{0: "sequence_length"} if check_if_diffusers_greater("0.31.0") else {0: "batch_size", 1: "sequence_length"}
{0: "sequence_length"} if is_diffusers_version(">=", "0.31.0") else {0: "batch_size", 1: "sequence_length"}
)
common_inputs["img_ids"] = (
{0: "packed_height_width"}
if check_if_diffusers_greater("0.31.0")
if is_diffusers_version(">=", "0.31.0")
else {0: "batch_size", 1: "packed_height_width"}
)

Expand Down Expand Up @@ -1774,7 +1773,7 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
common_inputs["input_features"] = {0: "batch_size"} # Remove unnecessary dynamic axis.

if self._behavior is not ConfigBehavior.ENCODER and self.use_past_in_inputs:
if check_if_transformers_greater("4.43.0"):
if is_transformers_version(">=", "4.43.0"):
# since https://github.com/huggingface/transformers/pull/31166
common_inputs["cache_position"] = {0: "decoder_sequence_length"}

Expand Down Expand Up @@ -2461,12 +2460,7 @@ class Pix2StructOnnxConfig(OnnxSeq2SeqConfigWithPast):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO : replace check_if_transformers_greater with is_transformers_available
if (
check_if_transformers_greater("4.46.0")
and not check_if_transformers_greater("4.46.1")
and self._behavior is ConfigBehavior.DECODER
):
if is_transformers_version("==", "4.46.0") and self._behavior is ConfigBehavior.DECODER:
logger.error(
"Found transformers v4.46.0 while trying to exporting a Pix2Struct model, this specific version of transformers is not supported. "
"Please upgrade to v4.46.1 or higher, or downgrade your transformers version"
Expand Down
15 changes: 7 additions & 8 deletions optimum/exporters/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,13 @@
from packaging import version
from transformers.utils import is_tf_available, is_torch_available

from ...utils import (
DIFFUSERS_MINIMUM_VERSION,
ORT_QUANTIZE_MINIMUM_VERSION,
check_if_diffusers_greater,
from ...utils import DIFFUSERS_MINIMUM_VERSION, ORT_QUANTIZE_MINIMUM_VERSION, logging
from ...utils.import_utils import (
_diffusers_version,
is_diffusers_available,
logging,
is_diffusers_version,
is_transformers_version,
)
from ...utils.import_utils import _diffusers_version, check_if_transformers_greater
from ..utils import (
_get_submodels_and_export_configs,
)
Expand All @@ -52,7 +51,7 @@


if is_diffusers_available():
if not check_if_diffusers_greater(DIFFUSERS_MINIMUM_VERSION.base_version):
if not is_diffusers_version(">=", DIFFUSERS_MINIMUM_VERSION.base_version):
raise ImportError(
f"We found an older version of diffusers {_diffusers_version} but we require diffusers to be >= {DIFFUSERS_MINIMUM_VERSION}. "
"Please update diffusers by running `pip install --upgrade diffusers`"
Expand Down Expand Up @@ -90,7 +89,7 @@
}


if check_if_transformers_greater("4.45.99"):
if is_transformers_version(">=", "4.45.99"):
MODEL_TYPES_REQUIRING_POSITION_IDS.add("opt")


Expand Down
8 changes: 4 additions & 4 deletions optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

from ..exporters.onnx import MODEL_TYPES_REQUIRING_POSITION_IDS, main_export
from ..onnx.utils import check_model_uses_external_data
from ..utils import NormalizedConfigManager, check_if_transformers_greater
from ..utils import NormalizedConfigManager, is_transformers_version
from ..utils.modeling_utils import MODEL_TO_PATCH_FOR_PAST
from ..utils.save_utils import maybe_save_preprocessors
from .constants import DECODER_MERGED_ONNX_FILE_PATTERN, DECODER_ONNX_FILE_PATTERN, DECODER_WITH_PAST_ONNX_FILE_PATTERN
Expand All @@ -43,7 +43,7 @@
if TYPE_CHECKING:
from transformers import PretrainedConfig

if check_if_transformers_greater("4.25.0"):
if is_transformers_version(">=", "4.25.0"):
from transformers.generation import GenerationMixin
else:
from transformers.generation_utils import GenerationMixin # type: ignore # noqa: F401
Expand Down Expand Up @@ -149,7 +149,7 @@ def __init__(

self.generation_config = generation_config

if check_if_transformers_greater("4.44.99"):
if is_transformers_version(">=", "4.44.99"):
misplaced_generation_parameters = self.config._get_non_default_generation_parameters()
if len(misplaced_generation_parameters) > 0:
logger.warning(
Expand Down Expand Up @@ -562,7 +562,7 @@ def _from_pretrained(
)

# Since transformers 4.44, the bloom model has been updated to use the standard cache format
use_old_bloom_modeling = not check_if_transformers_greater("4.44")
use_old_bloom_modeling = not is_transformers_version(">=", "4.44")
for input_name in input_dims.keys():
if input_dims[input_name][0] == "batch_size x num_heads":
use_old_bloom_modeling = True
Expand Down
8 changes: 4 additions & 4 deletions optimum/onnxruntime/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from transformers.modeling_outputs import ModelOutput

import onnxruntime as ort
from optimum.utils import check_if_diffusers_greater
from optimum.utils import is_diffusers_version

from ..exporters.onnx import main_export
from ..onnx.utils import _get_model_external_data_paths
Expand All @@ -75,7 +75,7 @@
)


if check_if_diffusers_greater("0.25.0"):
if is_diffusers_version(">=", "0.25.0"):
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
else:
from diffusers.models.vae import DiagonalGaussianDistribution # type: ignore
Expand Down Expand Up @@ -974,7 +974,7 @@ def __init__(self, *args, **kwargs):
)


if check_if_diffusers_greater("0.29.0"):
if is_diffusers_version(">=", "0.29.0"):
from diffusers import StableDiffusion3Img2ImgPipeline, StableDiffusion3Pipeline

@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
Expand Down Expand Up @@ -1006,7 +1006,7 @@ class ORTStableDiffusion3Img2ImgPipeline(ORTUnavailablePipeline):
MIN_VERSION = "0.29.0"


if check_if_diffusers_greater("0.30.0"):
if is_diffusers_version(">=", "0.30.0"):
from diffusers import FluxPipeline, StableDiffusion3InpaintPipeline

@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
Expand Down
8 changes: 4 additions & 4 deletions optimum/onnxruntime/modeling_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@

from ..exporters.onnx import main_export
from ..onnx.utils import _get_external_data_paths
from ..utils import check_if_transformers_greater
from ..utils import is_transformers_version
from ..utils.file_utils import validate_file_exists
from ..utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors
from .base import ORTDecoderForSeq2Seq, ORTEncoder
Expand All @@ -64,13 +64,13 @@
)


if check_if_transformers_greater("4.25.0"):
if is_transformers_version(">=", "4.25.0"):
from transformers.generation import GenerationMixin
else:
from transformers.generation_utils import GenerationMixin # type: ignore


if check_if_transformers_greater("4.43.0"):
if is_transformers_version(">=", "4.43.0"):
from transformers.cache_utils import EncoderDecoderCache
else:
EncoderDecoderCache = dict
Expand Down Expand Up @@ -705,7 +705,7 @@ def show_deprecated_argument(arg_name):
generation_config = GenerationConfig.from_model_config(config)
self.generation_config = generation_config

if check_if_transformers_greater("4.44.99"):
if is_transformers_version(">=", "4.44.99"):
misplaced_generation_parameters = self.config._get_non_default_generation_parameters()
if len(misplaced_generation_parameters) > 0:
logger.warning(
Expand Down
6 changes: 3 additions & 3 deletions optimum/onnxruntime/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@
)

from ..utils import logging
from ..utils.import_utils import check_if_transformers_greater
from ..utils.import_utils import is_transformers_version
from .training_args import ORTOptimizerNames, ORTTrainingArguments
from .utils import (
is_onnxruntime_training_available,
Expand All @@ -93,7 +93,7 @@
if is_apex_available():
from apex import amp

if check_if_transformers_greater("4.33"):
if is_transformers_version(">=", "4.33"):
from transformers.integrations.deepspeed import (
deepspeed_init,
deepspeed_load_checkpoint,
Expand All @@ -102,7 +102,7 @@
else:
from transformers.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_zero3_enabled

if check_if_transformers_greater("4.39"):
if is_transformers_version(">=", "4.39"):
from transformers.utils import is_torch_xla_available as is_torch_tpu_xla_available

if is_torch_tpu_xla_available():
Expand Down
4 changes: 2 additions & 2 deletions optimum/onnxruntime/trainer_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from transformers.trainer_utils import PredictionOutput
from transformers.utils import is_accelerate_available, logging

from ..utils.import_utils import check_if_transformers_greater
from ..utils.import_utils import is_transformers_version
from .trainer import ORTTrainer


Expand All @@ -33,7 +33,7 @@
"The package `accelerate` is required to use the ORTTrainer. Please install it following https://huggingface.co/docs/accelerate/basic_tutorials/install."
)

if check_if_transformers_greater("4.33"):
if is_transformers_version(">=", "4.33"):
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
else:
from transformers.deepspeed import is_deepspeed_zero3_enabled
Expand Down
6 changes: 3 additions & 3 deletions optimum/onnxruntime/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,13 @@
)
from transformers.utils.generic import strtobool

from ..utils.import_utils import check_if_transformers_greater
from ..utils.import_utils import is_transformers_version


if is_torch_available():
import torch

if is_accelerate_available() and check_if_transformers_greater("4.38.0"):
if is_accelerate_available() and is_transformers_version(">=", "4.38.0"):
from transformers.trainer_pt_utils import AcceleratorConfig


Expand Down Expand Up @@ -481,7 +481,7 @@ def __post_init__(self):
os.environ[f"{prefix}SYNC_MODULE_STATES"] = self.fsdp_config.get("sync_module_states", "true")
os.environ[f"{prefix}USE_ORIG_PARAMS"] = self.fsdp_config.get("use_orig_params", "false")

if is_accelerate_available() and check_if_transformers_greater("4.38.0"):
if is_accelerate_available() and is_transformers_version(">=", "4.38.0"):
if not isinstance(self.accelerator_config, (AcceleratorConfig)):
if self.accelerator_config is None:
self.accelerator_config = AcceleratorConfig()
Expand Down
Loading

0 comments on commit 605ed7e

Please sign in to comment.