Skip to content

Commit

Permalink
final fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
younesbelkada committed Nov 22, 2023
1 parent 0bd1b0c commit 61db430
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 3 deletions.
2 changes: 1 addition & 1 deletion docker/transformers-all-latest-gpu/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ RUN python3 -m pip install --no-cache-dir auto-gptq --extra-index-url https://hu
RUN python3 -m pip install --no-cache-dir einops

# Add autoawq for quantization testing
RUN python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ/releases/download/v0.1.6/autoawq-0.1.6+cu118-cp38-cp38-linux_x86_64.whl
RUN python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ/releases/download/v0.1.7/autoawq-0.1.7+cu118-cp38-cp38-linux_x86_64.whl

# For bettertransformer + gptq
RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/optimum@main#egg=optimum
Expand Down
21 changes: 19 additions & 2 deletions src/transformers/utils/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from packaging import version

from ..utils import is_torch_available, logging
from ..utils import is_auto_awq_available, is_torch_available, logging


if is_torch_available():
Expand Down Expand Up @@ -574,7 +574,7 @@ def __init__(

self.modules_to_fuse = modules_to_fuse
if do_fuse is None:
self.do_fuse = modules_to_fuse is not None and len(modules_to_fuse > 0)
self.do_fuse = modules_to_fuse is not None and len(modules_to_fuse) > 0
else:
self.do_fuse = do_fuse
self.fuse_max_seq_len = fuse_max_seq_len
Expand Down Expand Up @@ -610,6 +610,23 @@ def post_init(self):
"You cannot enable fused modules without specifying a `fuse_max_seq_len`, make sure to pass a valid `fuse_max_seq_len` for your usecase"
)

if self.do_fuse:
awq_version_supports_fusing = False
MIN_AWQ_VERSION = "0.1.7"
if is_auto_awq_available():
# For some reason `version.parse(importlib.metadata.version("awq"))` always returns
# `<Version('0.1.0')>` which makes that logic unusable. Therefore we need to import
# awq and get `awq.__version__`
import awq

awq_version = awq.__version__
awq_version_supports_fusing = version.parse(awq_version) >= version.parse(MIN_AWQ_VERSION)

if not awq_version_supports_fusing:
raise ValueError(
"You current version of `autoawq` does not support module fusing, please upgrade `autoawq` package."
)

if self.do_fuse and self.modules_to_fuse is not None:
required_keys = [
"hidden_size",
Expand Down

0 comments on commit 61db430

Please sign in to comment.