Skip to content

Commit

Permalink
Update awq.py
Browse files Browse the repository at this point in the history
  • Loading branch information
younesbelkada authored Apr 18, 2024
1 parent 63c5e27 commit 784c377
Showing 1 changed file with 29 additions and 2 deletions.
31 changes: 29 additions & 2 deletions src/transformers/integrations/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,8 @@ def fuse_awq_modules(model, quantization_config):
else:
raise ValueError("Fusing is only supported for the AutoAWQ backend")

fused_attention_modules = []

for name, module in model.named_modules():
if modules_to_not_convert is not None:
if any(module_name_to_not_convert in name for module_name_to_not_convert in modules_to_not_convert):
Expand All @@ -241,7 +243,27 @@ def fuse_awq_modules(model, quantization_config):
_fuse_awq_mlp(model, name, modules_to_fuse["mlp"], module, QuantFusedMLP)

# Replace attention layers
_fuse_awq_attention_layers(model, module, modules_to_fuse, name, QuantAttentionFused)
attention_has_been_fused = _fuse_awq_attention_layers(
model, module, modules_to_fuse, name, QuantAttentionFused
)

if attention_has_been_fused:
fused_attention_modules.append(name)

# For AWQ fused + Llama we need to set `config._attn_implementation` = "custom" to avoid unexpected behavior and pass
# `None` attention mask to the fused attention modules as now the attention mask is dropped by our models and dealt
# by the `AttentionMaskConverter` module.
if len(fused_attention_modules) > 0:
fused_attention_parent_modules = set(
fused_attention_module.split(".")[0] for fused_attention_module in fused_attention_modules
)
for module_name, module in model.named_modules():
if any(
module_name in fused_attention_parent_module
for fused_attention_parent_module in fused_attention_parent_modules
):
if hasattr(module, "config") and hasattr(module.config, "_attn_implementation"):
module.config._attn_implementation = "custom"
return model


Expand Down Expand Up @@ -332,8 +354,10 @@ def _fuse_awq_attention_layers(model, module, modules_to_fuse, current_module_na
"""
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV

module_has_been_fused = False

if len(modules_to_fuse["attention"]) == 0:
return
return module_has_been_fused

if hasattr(module, modules_to_fuse["attention"][0]):
# First, we pack the QKV layers together
Expand Down Expand Up @@ -394,6 +418,9 @@ def _fuse_awq_attention_layers(model, module, modules_to_fuse, current_module_na
setattr(parent, child_name, fused_attention_layer.to(previous_device))

del q_proj, k_proj, v_proj, o_proj
module_has_been_fused = True

return module_has_been_fused


def post_init_awq_exllama_modules(model, exllama_config):
Expand Down

0 comments on commit 784c377

Please sign in to comment.