Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check transformers version in BLOOM for inference v1 #6766

Merged
merged 8 commits into from
Jan 7, 2025
12 changes: 12 additions & 0 deletions deepspeed/module_inject/containers/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,18 @@
class DS_BloomContainer(MetaTensorContainer, HybridEngineContainer, BaseTransformerContainer):

def __init__(self, **kwargs):
# Check transformers version, error if > 4.43.4 (breaks at 4.44.0)
from importlib.metadata import version
v_transformers = version('transformers')
vers = v_transformers.split('.')
major = int(vers[0])
minor = int(vers[1])
if major > 4 or (major == 4 and minor > 43):
import sys
sys.exit(
f"Transformers version {v_transformers} exceeds version 4.43.4! After transformers version 4.43.4, BLOOM inference with DeepSpeed is no longer supported."
)

super().__init__(**kwargs)

# All model specific things should be defined here instead of the base class.
Expand Down