Skip to content

Commit

Permalink
Exllama kernels support for AWQ models (huggingface#28634)
Browse files Browse the repository at this point in the history
* added exllama kernels support for awq models

* doc

* style

* Update src/transformers/modeling_utils.py

Co-authored-by: Marc Sun <[email protected]>

* refactor

* moved exllama post init to after device dispatching

* bump autoawq version

* added exllama test

* style

* configurable exllama kernels

* copy exllama_config from gptq

* moved exllama version check to post init

* moved to quantization dockerfile

---------

Co-authored-by: Marc Sun <[email protected]>
  • Loading branch information
IlyasMoutawwakil and SunMarc authored Mar 5, 2024
1 parent 81c8191 commit 4fc708f
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 13 deletions.
2 changes: 1 addition & 1 deletion docker/transformers-quantization-latest-gpu/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/opt
RUN python3 -m pip install --no-cache-dir aqlm[gpu]==1.0.2

# Add autoawq for quantization testing
RUN python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ/releases/download/v0.1.8/autoawq-0.1.8+cu118-cp38-cp38-linux_x86_64.whl
RUN python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ/releases/download/v0.2.0/autoawq-0.2.0+cu118-cp38-cp38-linux_x86_64.whl

# When installing in editable mode, `transformers` is not recognized as a package.
# this line must be added in order for python to be aware of transformers.
Expand Down
12 changes: 10 additions & 2 deletions src/transformers/integrations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@

_import_structure = {
"aqlm": ["replace_with_aqlm_linear"],
"awq": ["fuse_awq_modules", "replace_with_awq_linear"],
"awq": [
"fuse_awq_modules",
"post_init_awq_exllama_modules",
"replace_with_awq_linear",
],
"bitsandbytes": [
"get_keys_to_not_convert",
"replace_8bit_linear",
Expand Down Expand Up @@ -82,7 +86,11 @@

if TYPE_CHECKING:
from .aqlm import replace_with_aqlm_linear
from .awq import fuse_awq_modules, replace_with_awq_linear
from .awq import (
fuse_awq_modules,
post_init_awq_exllama_modules,
replace_with_awq_linear,
)
from .bitsandbytes import (
get_keys_to_not_convert,
replace_8bit_linear,
Expand Down
59 changes: 53 additions & 6 deletions src/transformers/integrations/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@
from ..activations import ACT2FN
from ..modeling_utils import PreTrainedModel
from ..utils import is_auto_awq_available, is_torch_available
from ..utils.quantization_config import AwqBackendPackingMethod, AwqConfig, AWQLinearVersion
from ..utils.quantization_config import (
AwqBackendPackingMethod,
AwqConfig,
AWQLinearVersion,
ExllamaVersion,
)


if is_torch_available():
Expand Down Expand Up @@ -91,13 +96,30 @@ def replace_with_awq_linear(
)

if backend == AwqBackendPackingMethod.AUTOAWQ:
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
elif backend == AwqBackendPackingMethod.LLMAWQ:
from awq.quantize.qmodule import WQLinear
if quantization_config.version == AWQLinearVersion.GEMM:
from awq.modules.linear.gemm import WQLinear_GEMM

if backend == AwqBackendPackingMethod.AUTOAWQ:
target_cls = WQLinear_GEMM if quantization_config.version == AWQLinearVersion.GEMM else WQLinear_GEMV
target_cls = WQLinear_GEMM
elif quantization_config.version == AWQLinearVersion.GEMV:
from awq.modules.linear.gemv import WQLinear_GEMV

target_cls = WQLinear_GEMV
elif quantization_config.version == AWQLinearVersion.EXLLAMA:
if quantization_config.exllama_config["version"] == ExllamaVersion.ONE:
from awq.modules.linear.exllama import WQLinear_Exllama

target_cls = WQLinear_Exllama
elif quantization_config.exllama_config["version"] == ExllamaVersion.TWO:
from awq.modules.linear.exllamav2 import WQLinear_ExllamaV2

target_cls = WQLinear_ExllamaV2
else:
raise ValueError(f"Unrecognized Exllama version: {quantization_config.exllama_config['version']}")
else:
raise ValueError(f"Unrecognized AWQ version: {quantization_config.version}")
else:
from awq.quantize.qmodule import WQLinear

target_cls = WQLinear

for name, module in model.named_children():
Expand Down Expand Up @@ -372,3 +394,28 @@ def _fuse_awq_attention_layers(model, module, modules_to_fuse, current_module_na
setattr(parent, child_name, fused_attention_layer.to(previous_device))

del q_proj, k_proj, v_proj, o_proj


def post_init_awq_exllama_modules(model, exllama_config):
"""
Runs post init for Exllama layers which performs:
- Weights unpacking, reordering and repacking
- Devices scratch space allocation
"""

if exllama_config["version"] == ExllamaVersion.ONE:
from awq.modules.linear.exllama import exllama_post_init

model = exllama_post_init(model)
elif exllama_config["version"] == ExllamaVersion.TWO:
from awq.modules.linear.exllamav2 import exllamav2_post_init

model = exllamav2_post_init(
model,
max_input_len=exllama_config["max_input_len"],
max_batch_size=exllama_config["max_batch_size"],
)
else:
raise ValueError(f"Unrecognized Exllama version: {exllama_config['version']}")

return model
11 changes: 11 additions & 0 deletions src/transformers/quantizers/quantizer_awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from ..modeling_utils import PreTrainedModel

from ..utils import is_accelerate_available, is_auto_awq_available, is_torch_available, logging
from ..utils.quantization_config import AWQLinearVersion


if is_torch_available():
Expand Down Expand Up @@ -98,12 +99,22 @@ def _process_model_after_weight_loading(self, model):
model = fuse_awq_modules(model, self.quantization_config)
model._awq_is_fused = True # TODO: consider storing this flag in model.config instead

if self.quantization_config.version == AWQLinearVersion.EXLLAMA:
from ..integrations import post_init_awq_exllama_modules

model = post_init_awq_exllama_modules(model, self.quantization_config.exllama_config)

@property
def is_serializable(self):
# AWQ through auto-awq has been always serializable, except if the model is fused.
if self.quantization_config.do_fuse:
logger.warning("You cannot save an AWQ model that uses fused modules!")
return False

if self.quantization_config.version == AWQLinearVersion.EXLLAMA:
logger.warning("You cannot save an AWQ model that uses Exllama backend!")
return False

return True

@property
Expand Down
42 changes: 38 additions & 4 deletions src/transformers/utils/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class QuantizationMethod(str, Enum):
class AWQLinearVersion(str, Enum):
GEMM = "gemm"
GEMV = "gemv"
EXLLAMA = "exllama"

@staticmethod
def from_str(version: str):
Expand All @@ -52,6 +53,8 @@ def from_str(version: str):
return AWQLinearVersion.GEMM
elif version == "gemv":
return AWQLinearVersion.GEMV
elif version == "exllama":
return AWQLinearVersion.EXLLAMA
else:
raise ValueError(f"Unknown AWQLinearVersion {version}")

Expand Down Expand Up @@ -606,7 +609,7 @@ class AwqConfig(QuantizationConfigMixin):
Whether to use zero point quantization.
version (`AWQLinearVersion`, *optional*, defaults to `AWQLinearVersion.GEMM`):
The version of the quantization algorithm to use. GEMM is better for big batch_size (e.g. >= 8) otherwise,
GEMV is better (e.g. < 8 )
GEMV is better (e.g. < 8 ). GEMM models are compatible with Exllama kernels.
backend (`AwqBackendPackingMethod`, *optional*, defaults to `AwqBackendPackingMethod.AUTOAWQ`):
The quantization backend. Some models might be quantized using `llm-awq` backend. This is useful for users
that quantize their own models using `llm-awq` library.
Expand All @@ -620,6 +623,10 @@ class AwqConfig(QuantizationConfigMixin):
The list of modules to not quantize, useful for quantizing models that explicitly require to have
some modules left in their original precision (e.g. Whisper encoder, Llava encoder, Mixtral gate layers).
Note you cannot quantize directly with transformers, please refer to `AutoAWQ` documentation for quantizing HF models.
exllama_config (`Dict[str, Any]`, *optional*):
You can specify the version of the exllama kernel through the `version` key, the maximum sequence
length through the `max_input_len` key, and the maximum batch size through the `max_batch_size` key.
Defaults to `{"version": 2, "max_input_len": 2048, "max_batch_size": 8}` if unset.
"""

def __init__(
Expand All @@ -633,6 +640,7 @@ def __init__(
fuse_max_seq_len: Optional[int] = None,
modules_to_fuse: Optional[dict] = None,
modules_to_not_convert: Optional[List] = None,
exllama_config: Optional[Dict[str, int]] = None,
**kwargs,
):
self.quant_method = QuantizationMethod.AWQ
Expand All @@ -644,6 +652,7 @@ def __init__(
self.backend = backend
self.fuse_max_seq_len = fuse_max_seq_len
self.modules_to_not_convert = modules_to_not_convert
self.exllama_config = exllama_config

self.modules_to_fuse = modules_to_fuse
if do_fuse is None:
Expand All @@ -667,9 +676,9 @@ def post_init(self):
)

self.version = AWQLinearVersion.from_str(self.version)
if self.version not in [AWQLinearVersion.GEMM, AWQLinearVersion.GEMV]:
if self.version not in [AWQLinearVersion.GEMM, AWQLinearVersion.GEMV, AWQLinearVersion.EXLLAMA]:
raise ValueError(
f"Only supported versions are in [AWQLinearVersion.GEMM, AWQLinearVersion.GEMV] - not recognized version {self.version}"
f"Only supported versions are in [AWQLinearVersion.GEMM, AWQLinearVersion.GEMV, AWQLinearVersion.EXLLAMA] - not recognized version {self.version}"
)

if self.backend == AwqBackendPackingMethod.LLMAWQ:
Expand Down Expand Up @@ -724,9 +733,34 @@ def post_init(self):
f"Required fields are missing in the fusing mapping, required fields are {required_keys}"
)

if self.version == AWQLinearVersion.EXLLAMA:
awq_version_supports_exllama = False
MIN_AWQ_VERSION = "0.2.0"
if is_auto_awq_available():
awq_version_supports_exllama = version.parse(importlib.metadata.version("autoawq")) >= version.parse(
MIN_AWQ_VERSION
)

if not awq_version_supports_exllama:
raise ValueError(
f"You current version of `autoawq` does not support exllama backend, "
f"please upgrade `autoawq` package to at least {MIN_AWQ_VERSION}."
)

if self.exllama_config is None:
self.exllama_config = {"version": ExllamaVersion.TWO, "max_input_len": 2048, "max_batch_size": 8}
else:
if "version" not in self.exllama_config:
raise ValueError("`exllama_config` needs to have a `version` key.")
elif self.exllama_config["version"] not in [ExllamaVersion.ONE, ExllamaVersion.TWO]:
exllama_version = self.exllama_config["version"]
raise ValueError(
f"Only supported versions are in [ExllamaVersion.ONE, ExllamaVersion.TWO] - not recognized version {exllama_version}"
)

def get_loading_attributes(self):
attibutes_dict = copy.deepcopy(self.__dict__)
loading_attibutes = ["do_fuse", "modules_to_fuse", "fuse_max_seq_len"]
loading_attibutes = ["version", "do_fuse", "modules_to_fuse", "fuse_max_seq_len"]
loading_attibutes_dict = {i: j for i, j in attibutes_dict.items() if i in loading_attibutes}
return loading_attibutes_dict

Expand Down
14 changes: 14 additions & 0 deletions tests/quantization/autoawq/test_awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,20 @@ def test_quantized_model_bf16(self):
output = quantized_model.generate(**input_ids, max_new_tokens=40)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT_BF16)

def test_quantized_model_exllama(self):
"""
Simple test that checks if the quantized model is working properly with exllama backend
"""
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)

quantization_config = AwqConfig(version="exllama")
quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name, quantization_config=quantization_config
).to(torch_device)

output = quantized_model.generate(**input_ids, max_new_tokens=40)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)

def test_quantized_model_no_device_map(self):
"""
Simple test that checks if the quantized model is working properly
Expand Down

0 comments on commit 4fc708f

Please sign in to comment.