From 286ffaaf0ab981f3530d0ac34d1b172efa5c03db Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Fri, 22 Nov 2024 17:13:30 +0100 Subject: [PATCH] [CI] Skip EETQ tests while package is broken with latest transformers (#34854) * CI Skip EETQ tests while package is broken EETQ tries to import the shard_checkpoint function from transformers but the function has been removed. Therefore, trying to use EETQ currently results in an import error. This fix results in EETQ tests being skipped if there is an import error. The issue has been reported to EETQ: https://github.com/NetEase-FuXi/EETQ/issues/34 * Raise helpful error when trying to use eetq * Forget to raise the error in else clause --- src/transformers/quantizers/quantizer_eetq.py | 14 ++++++++++++++ src/transformers/testing_utils.py | 12 +++++++++++- 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/src/transformers/quantizers/quantizer_eetq.py b/src/transformers/quantizers/quantizer_eetq.py index 602df62c012a18..7dfce75c373ad7 100644 --- a/src/transformers/quantizers/quantizer_eetq.py +++ b/src/transformers/quantizers/quantizer_eetq.py @@ -53,6 +53,20 @@ def validate_environment(self, *args, **kwargs): "Please install the latest version of eetq from : https://github.com/NetEase-FuXi/EETQ" ) + try: + import eetq # noqa: F401 + except ImportError as exc: + if "shard_checkpoint" in str(exc): + # EETQ 1.0.0 is currently broken with the latest transformers because it tries to import the removed + # shard_checkpoint function, see https://github.com/NetEase-FuXi/EETQ/issues/34. + # TODO: Update message once eetq releases a fix + raise ImportError( + "You are using a version of EETQ that is incompatible with the current transformers version. " + "Either downgrade transformers to <= v4.46.3 or, if available, upgrade EETQ to > v1.0.0." + ) from exc + else: + raise + if not is_accelerate_available(): raise ImportError("Loading an EETQ quantized model requires accelerate (`pip install accelerate`)") diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 49c2aefa09260e..25d837ccec0fbe 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -1143,7 +1143,17 @@ def require_eetq(test_case): """ Decorator marking a test that requires eetq """ - return unittest.skipUnless(is_eetq_available(), "test requires eetq")(test_case) + eetq_available = is_eetq_available() + if eetq_available: + try: + import eetq # noqa: F401 + except ImportError as exc: + if "shard_checkpoint" in str(exc): + # EETQ 1.0.0 is currently broken with the latest transformers because it tries to import the removed + # shard_checkpoint function, see https://github.com/NetEase-FuXi/EETQ/issues/34. + # TODO: Remove once eetq releases a fix and this release is used in CI + eetq_available = False + return unittest.skipUnless(eetq_available, "test requires eetq")(test_case) def require_av(test_case):