Skip to content

Commit

Permalink
[CI] Skip EETQ tests while package is broken with latest transformers (
Browse files Browse the repository at this point in the history
…#34854)

* CI Skip EETQ tests while package is broken

EETQ tries to import the shard_checkpoint function from transformers but
the function has been removed. Therefore, trying to use EETQ currently
results in an import error. This fix results in EETQ tests being skipped
if there is an import error.

The issue has been reported to EETQ:

NetEase-FuXi/EETQ#34

* Raise helpful error when trying to use eetq

* Forget to raise the error in else clause
  • Loading branch information
BenjaminBossan authored Nov 22, 2024
1 parent 861758e commit 286ffaa
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 1 deletion.
14 changes: 14 additions & 0 deletions src/transformers/quantizers/quantizer_eetq.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,20 @@ def validate_environment(self, *args, **kwargs):
"Please install the latest version of eetq from : https://github.com/NetEase-FuXi/EETQ"
)

try:
import eetq # noqa: F401
except ImportError as exc:
if "shard_checkpoint" in str(exc):
# EETQ 1.0.0 is currently broken with the latest transformers because it tries to import the removed
# shard_checkpoint function, see https://github.com/NetEase-FuXi/EETQ/issues/34.
# TODO: Update message once eetq releases a fix
raise ImportError(
"You are using a version of EETQ that is incompatible with the current transformers version. "
"Either downgrade transformers to <= v4.46.3 or, if available, upgrade EETQ to > v1.0.0."
) from exc
else:
raise

if not is_accelerate_available():
raise ImportError("Loading an EETQ quantized model requires accelerate (`pip install accelerate`)")

Expand Down
12 changes: 11 additions & 1 deletion src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1143,7 +1143,17 @@ def require_eetq(test_case):
"""
Decorator marking a test that requires eetq
"""
return unittest.skipUnless(is_eetq_available(), "test requires eetq")(test_case)
eetq_available = is_eetq_available()
if eetq_available:
try:
import eetq # noqa: F401
except ImportError as exc:
if "shard_checkpoint" in str(exc):
# EETQ 1.0.0 is currently broken with the latest transformers because it tries to import the removed
# shard_checkpoint function, see https://github.com/NetEase-FuXi/EETQ/issues/34.
# TODO: Remove once eetq releases a fix and this release is used in CI
eetq_available = False
return unittest.skipUnless(eetq_available, "test requires eetq")(test_case)


def require_av(test_case):
Expand Down

0 comments on commit 286ffaa

Please sign in to comment.