Skip to content

Commit

Permalink
[Awq] Enable the possibility to skip quantization for some target m…
Browse files Browse the repository at this point in the history
…odules (huggingface#27950)

* v1

* add docstring

* add tests

* add awq 0.1.8

* oops

* fix test
  • Loading branch information
younesbelkada authored and Saibo-creator committed Jan 4, 2024
1 parent 30493fa commit b5d8e26
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 1 deletion.
2 changes: 1 addition & 1 deletion docker/transformers-all-latest-gpu/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ RUN python3 -m pip install --no-cache-dir auto-gptq --extra-index-url https://hu
RUN python3 -m pip install --no-cache-dir einops

# Add autoawq for quantization testing
RUN python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ/releases/download/v0.1.7/autoawq-0.1.7+cu118-cp38-cp38-linux_x86_64.whl
RUN python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ/releases/download/v0.1.8/autoawq-0.1.8+cu118-cp38-cp38-linux_x86_64.whl

# For bettertransformer + gptq
RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/optimum@main#egg=optimum
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3575,6 +3575,9 @@ def from_pretrained(
if quantization_config is None:
quantization_config = AwqConfig.from_dict(config.quantization_config)

if quantization_config.modules_to_not_convert is not None:
modules_to_not_convert.extend(quantization_config.modules_to_not_convert)

model, has_been_replaced = replace_with_awq_linear(
model, quantization_config=quantization_config, modules_to_not_convert=modules_to_not_convert
)
Expand Down
19 changes: 19 additions & 0 deletions src/transformers/utils/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,10 @@ class AwqConfig(QuantizationConfigMixin):
The Maximum sequence length to generate when using fusing.
modules_to_fuse (`dict`, *optional*, default to `None`):
Overwrite the natively supported fusing scheme with the one specified by the users.
modules_to_not_convert (`list`, *optional*, default to `None`):
The list of modules to not quantize, useful for quantizing models that explicitly require to have
some modules left in their original precision (e.g. Whisper encoder, Llava encoder, Mixtral gate layers).
Note you cannot quantize directly with transformers, please refer to `AutoAWQ` documentation for quantizing HF models.
"""

def __init__(
Expand All @@ -576,6 +580,7 @@ def __init__(
do_fuse: Optional[bool] = None,
fuse_max_seq_len: Optional[int] = None,
modules_to_fuse: Optional[dict] = None,
modules_to_not_convert: Optional[List] = None,
**kwargs,
):
self.quant_method = QuantizationMethod.AWQ
Expand All @@ -586,6 +591,7 @@ def __init__(
self.version = version
self.backend = backend
self.fuse_max_seq_len = fuse_max_seq_len
self.modules_to_not_convert = modules_to_not_convert

self.modules_to_fuse = modules_to_fuse
if do_fuse is None:
Expand Down Expand Up @@ -638,6 +644,19 @@ def post_init(self):
f"You current version of `autoawq` does not support module fusing, please upgrade `autoawq` package to at least {MIN_AWQ_VERSION}."
)

if self.modules_to_not_convert is not None:
awq_version_supports_non_conversion = False
MIN_AWQ_VERSION = "0.1.8"
if is_auto_awq_available():
awq_version_supports_non_conversion = version.parse(
importlib.metadata.version("autoawq")
) >= version.parse(MIN_AWQ_VERSION)

if not awq_version_supports_non_conversion:
raise ValueError(
f"You current version of `autoawq` does not support module quantization skipping, please upgrade `autoawq` package to at least {MIN_AWQ_VERSION}."
)

if self.do_fuse and self.modules_to_fuse is not None:
required_keys = [
"hidden_size",
Expand Down
19 changes: 19 additions & 0 deletions tests/quantization/autoawq/test_awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def test_from_dict(self):
class AwqTest(unittest.TestCase):
model_name = "TheBloke/Mistral-7B-v0.1-AWQ"
dummy_transformers_model_name = "bigscience/bloom-560m"
model_with_no_k_proj_quantized = "hf-internal-testing/opt-125m-awq-no-k-proj"

input_text = "Hello my name is"

Expand Down Expand Up @@ -223,6 +224,24 @@ def test_quantized_model_multi_gpu(self):

self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)

def test_quantized_model_no_k_proj_quantized(self):
"""
Simple test that checks if the quantized model is working properly with multiple GPUs
"""
dummy_input = torch.LongTensor([[0, 1, 0]]).to(torch_device)

quantized_model = AutoModelForCausalLM.from_pretrained(self.model_with_no_k_proj_quantized).to(torch_device)

self.assertTrue(isinstance(quantized_model.model.decoder.layers[0].self_attn.k_proj, torch.nn.Linear))
self.assertFalse(isinstance(quantized_model.model.decoder.layers[0].self_attn.v_proj, torch.nn.Linear))

EXPECTED_OUTPUT = torch.LongTensor([[0, 1, 0, 50118, 50118, 133, 248, 12, 134, 16, 10, 372, 2031]]).to(
torch_device
)

output = quantized_model.generate(dummy_input, max_new_tokens=10)
self.assertTrue((EXPECTED_OUTPUT == output).all())


@slow
@require_torch_gpu
Expand Down

0 comments on commit b5d8e26

Please sign in to comment.