Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tf available and version #2154

Merged
merged 3 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,12 @@
from ...utils import TORCH_MINIMUM_VERSION as GLOBAL_MIN_TORCH_VERSION
from ...utils import TRANSFORMERS_MINIMUM_VERSION as GLOBAL_MIN_TRANSFORMERS_VERSION
from ...utils.doc import add_dynamic_docstring
from ...utils.import_utils import is_onnx_available, is_onnxruntime_available, is_transformers_version
from ...utils.import_utils import (
is_onnx_available,
is_onnxruntime_available,
is_torch_version,
is_transformers_version,
)
from ..base import ExportConfig
from .constants import ONNX_DECODER_MERGED_NAME, ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME
from .model_patcher import ModelPatcher, Seq2SeqModelPatcher
Expand Down Expand Up @@ -386,9 +391,8 @@ def is_torch_support_available(self) -> bool:
`bool`: Whether the installed version of PyTorch is compatible with the model.
"""
if is_torch_available():
from ...utils import torch_version
return is_torch_version(">=", self.MIN_TORCH_VERSION.base_version)

return torch_version >= self.MIN_TORCH_VERSION
return False

@property
Expand Down
7 changes: 3 additions & 4 deletions optimum/exporters/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,17 +851,16 @@ def export(
)

if is_torch_available() and isinstance(model, nn.Module):
from ...utils import torch_version
from ...utils.import_utils import _torch_version

if not is_torch_onnx_support_available():
raise MinimumVersionError(
f"Unsupported PyTorch version, minimum required is {TORCH_MINIMUM_VERSION}, got: {torch_version}"
f"Unsupported PyTorch version, minimum required is {TORCH_MINIMUM_VERSION}, got: {_torch_version}"
)

if not config.is_torch_support_available:
raise MinimumVersionError(
f"Unsupported PyTorch version for this model. Minimum required is {config.MIN_TORCH_VERSION},"
f" got: {torch.__version__}"
f"Unsupported PyTorch version for this model. Minimum required is {config.MIN_TORCH_VERSION}, got: {_torch_version}"
)

export_output = export_pytorch(
Expand Down
2 changes: 2 additions & 0 deletions optimum/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@
is_onnxruntime_available,
is_pydantic_available,
is_sentence_transformers_available,
is_tf_available,
is_timm_available,
is_torch_available,
is_torch_onnx_support_available,
is_torch_version,
is_transformers_available,
Expand Down
65 changes: 52 additions & 13 deletions optimum/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
"""Import utilities."""

import importlib.metadata as importlib_metadata
import importlib.metadata
import importlib.util
import inspect
import operator as op
Expand All @@ -23,7 +23,6 @@

import numpy as np
from packaging import version
from transformers.utils import is_torch_available


TORCH_MINIMUM_VERSION = version.parse("1.11.0")
Expand Down Expand Up @@ -64,14 +63,46 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
_datasets_available = _is_package_available("datasets")
_diffusers_available, _diffusers_version = _is_package_available("diffusers", return_version=True)
_transformers_available, _transformers_version = _is_package_available("transformers", return_version=True)
_torch_available, _torch_version = _is_package_available("torch", return_version=True)

# importlib.metadata.version seem to not be robust with the ONNX Runtime extensions (`onnxruntime-gpu`, etc.)
_onnxruntime_available = _is_package_available("onnxruntime", return_version=False)


# TODO : Remove
torch_version = None
if is_torch_available():
torch_version = version.parse(importlib_metadata.version("torch"))
torch_version = version.parse(importlib.metadata.version("torch")) if _torch_available else None


# Note: _is_package_available("tensorflow") fails for tensorflow-cpu. Please test any changes to the line below
# with tensorflow-cpu to make sure it still works!
_tf_available = importlib.util.find_spec("tensorflow") is not None
_tf_version = None
if _tf_available:
candidates = (
"tensorflow",
"tensorflow-cpu",
"tensorflow-gpu",
"tf-nightly",
"tf-nightly-cpu",
"tf-nightly-gpu",
"tf-nightly-rocm",
"intel-tensorflow",
"intel-tensorflow-avx512",
"tensorflow-rocm",
"tensorflow-macos",
"tensorflow-aarch64",
)
# For the metadata, we have to look for both tensorflow and tensorflow-cpu
for pkg in candidates:
try:
_tf_version = importlib.metadata.version(pkg)
break
except importlib.metadata.PackageNotFoundError:
pass
_tf_available = _tf_version is not None
if _tf_available:
if version.parse(_tf_version) < version.parse("2"):
_tf_available = False
_tf_version = _tf_version or "N/A"


# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L319
Expand All @@ -91,7 +122,7 @@ def compare_versions(library_or_version: Union[str, version.Version], operation:
raise ValueError(f"`operation` must be one of {list(STR_OPERATION_TO_FUNC.keys())}, received {operation}")
operation = STR_OPERATION_TO_FUNC[operation]
if isinstance(library_or_version, str):
library_or_version = version.parse(importlib_metadata.version(library_or_version))
library_or_version = version.parse(importlib.metadata.version(library_or_version))
return operation(library_or_version, version.parse(requirement_version))


Expand All @@ -117,15 +148,15 @@ def is_torch_version(operation: str, reference_version: str):
"""
Compare the current torch version to a given reference with an operation.
"""
if not is_torch_available():
if not _torch_available:
return False

import torch

return compare_versions(version.parse(version.parse(torch.__version__).base_version), operation, reference_version)


_is_torch_onnx_support_available = is_torch_available() and is_torch_version(">=", TORCH_MINIMUM_VERSION.base_version)
_is_torch_onnx_support_available = _torch_available and is_torch_version(">=", TORCH_MINIMUM_VERSION.base_version)


def is_torch_onnx_support_available():
Expand Down Expand Up @@ -176,9 +207,17 @@ def is_transformers_available():
return _transformers_available


def is_torch_available():
return _torch_available


def is_tf_available():
return _tf_available


def is_auto_gptq_available():
if _auto_gptq_available:
v = version.parse(importlib_metadata.version("auto_gptq"))
v = version.parse(importlib.metadata.version("auto_gptq"))
if v >= AUTOGPTQ_MINIMUM_VERSION:
return True
else:
Expand All @@ -189,7 +228,7 @@ def is_auto_gptq_available():

def is_gptqmodel_available():
if _gptqmodel_available:
v = version.parse(importlib_metadata.version("gptqmodel"))
v = version.parse(importlib.metadata.version("gptqmodel"))
if v >= GPTQMODEL_MINIMUM_VERSION:
return True
else:
Expand Down Expand Up @@ -260,10 +299,10 @@ def check_if_torch_greater(target_version: str) -> bool:
Returns:
bool: whether the check is True or not.
"""
if not is_torch_available():
if not _torch_available:
return False

return torch_version >= version.parse(target_version)
return version.parse(_torch_version) >= version.parse(target_version)


@contextmanager
Expand Down
3 changes: 1 addition & 2 deletions optimum/utils/input_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@
from typing import Any, List, Optional, Tuple, Union

import numpy as np
from transformers.utils import is_tf_available, is_torch_available

from ..utils import is_diffusers_version, is_transformers_version
from ..utils import is_diffusers_version, is_tf_available, is_torch_available, is_transformers_version
from .normalized_config import (
NormalizedConfig,
NormalizedEncoderDecoderConfig,
Expand Down
8 changes: 3 additions & 5 deletions tests/exporters/onnx/test_onnx_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,20 +207,18 @@ def _onnx_export(
model.config.pad_token_id = 0

if is_torch_available():
from optimum.utils import torch_version
from optimum.utils.import_utils import _torch_version, _transformers_version

if not onnx_config.is_transformers_support_available:
import transformers

pytest.skip(
"Skipping due to incompatible Transformers version. Minimum required is"
f" {onnx_config.MIN_TRANSFORMERS_VERSION}, got: {transformers.__version__}"
f" {onnx_config.MIN_TRANSFORMERS_VERSION}, got: {_transformers_version}"
)

if not onnx_config.is_torch_support_available:
pytest.skip(
"Skipping due to incompatible PyTorch version. Minimum required is"
f" {onnx_config.MIN_TORCH_VERSION}, got: {torch_version}"
f" {onnx_config.MIN_TORCH_VERSION}, got: {_torch_version}"
)

atol = onnx_config.ATOL_FOR_VALIDATION
Expand Down
Loading