Skip to content

Commit

Permalink
Add: weight clipping to AWQ
Browse files Browse the repository at this point in the history
  • Loading branch information
rahul-tuli committed Sep 18, 2024
1 parent b92a433 commit 849a240
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 1 deletion.
122 changes: 121 additions & 1 deletion src/llmcompressor/modifiers/awq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
from llmcompressor.modifiers import Modifier
from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward
from llmcompressor.pytorch.utils import (
clear_memory,
pseudo_quantize_tensor,
tensor_forward_with_input_args,
)
from llmcompressor.utils.fsdp.helpers import get_fsdp_parent
from llmcompressor.utils.pytorch.module import (
get_layer,
get_layers,
get_matching_layer,
get_parent_by_name,
Expand Down Expand Up @@ -48,13 +50,15 @@ class AWQMapping:
:param smooth_layer: PyTorch module storing the activation layer
:param balance_layers: list of PyTorch modules that smooth_layer feeds into, must be
balanced to offset the smoothing of smooth_layer
:param balance_names: optional list of names of the balance_layers
:param parent: parent module of the balance_layers
:param parent_name: name of the parent module
"""

smooth_name: str
smooth_layer: Module
balance_layers: List[Module]
balance_names: Optional[List[str]] = None
parent: Optional[Module] = None
parent_name: Optional[str] = None

Expand Down Expand Up @@ -107,6 +111,7 @@ class AWQModifier(Modifier):
:param symmetric: whether to use symmetric quantization
:param duo_scaling: whether to use duo scaling, which uses both input activations
and weights to determine the scaling factor
:param apply_clip: whether to apply clipping to the weights after scaling
"""

mappings: List[Tuple] = DEFAULT_AWQ_MAPPINGS
Expand All @@ -118,6 +123,7 @@ class AWQModifier(Modifier):
bits: int = 4
symmetric: bool = True
duo_scaling: bool = True
apply_clip: bool = True

hooks_: Optional[List] = None
resolved_mappings_: Optional[List] = None
Expand Down Expand Up @@ -202,14 +208,16 @@ def _resolve_mappings(self, model: Module) -> List:
to_smooth_layers = get_layers(to_smooth, model)
for layer_name, smooth_layer in to_smooth_layers.items():
if layer_name not in self.ignore:
balance_layers = []
balance_layers, balance_names = [], []
for balance_suffix in to_balance:
# find the submodule that matches the activation layer
balance_name, balance_layer = get_matching_layer(
balance_suffix, layer_name, model
)
if balance_layer:
balance_layers.append(balance_layer)
balance_names.append(balance_name)

# each mapping can contain multiple layers to balance, but only
# one layer to smooth

Expand All @@ -226,6 +234,7 @@ def _resolve_mappings(self, model: Module) -> List:
layer_name,
smooth_layer,
balance_layers,
balance_names=balance_names,
parent=parent,
parent_name=parent_name,
)
Expand Down Expand Up @@ -314,6 +323,7 @@ def _apply_smoothing(self, model: Module):
for mapping in tqdm(self.resolved_mappings_):
smooth_layer = mapping.smooth_layer
balance_layers = mapping.balance_layers
balance_names = mapping.balance_names

activations = self.scales_[mapping.smooth_name].inps

Expand Down Expand Up @@ -392,6 +402,15 @@ def smooth(module):
smooth(layer)
smooth(smooth_layer)

if self.apply_clip:
clip_list = self._search_best_clip(
balance_layers=balance_layers,
balance_names=balance_names,
input_feat=inp,
)

_apply_clip(model, clip_list)

# clear out allocated smoothing scales
torch.cuda.empty_cache()

Expand Down Expand Up @@ -598,3 +617,104 @@ def _forward_input_with_kwargs(
inputs=inputs,
input_kwargs=kwargs,
)[0]

@torch.no_grad()
def _search_best_clip(self, balance_layers, balance_names, input_feat):
clip_list = []
avoid_clipping = ["q_", "k_", "query", "key", "Wqkv"]

for name, layer in zip(balance_names, balance_layers):
# due to qk bmm, it is hard to clip precisely
if any([_ in name for _ in avoid_clipping]):
continue

max_val = self._compute_best_clip(layer.weight, input_feat)
clip_list.append((name, max_val))

return clip_list

@torch.no_grad()
def _compute_best_clip(
self,
w: torch.Tensor,
input_feat: torch.Tensor,
n_grid=20,
max_shrink=0.5,
n_sample_token=512,
):
assert w.dim() == 2
org_w_shape = w.shape
# w [co, ci] -> [co, 1, n_group, group size]
# input_feat [n_token, ci] -> [1, n_token, n_group, group size]
group_size = self.group_size if self.group_size > 0 else org_w_shape[1]
input_feat = input_feat.view(-1, input_feat.shape[-1])
input_feat = input_feat.reshape(1, input_feat.shape[0], -1, group_size)

# Compute input feature step size (minimum 1)
step_size = max(1, input_feat.shape[1] // n_sample_token)
input_feat = input_feat[:, ::step_size]

w = w.reshape(org_w_shape[0], 1, -1, group_size)

oc_batch_size = 256 if org_w_shape[0] % 256 == 0 else 64 # prevent OOM
assert org_w_shape[0] % oc_batch_size == 0
w_all = w
best_max_val_all = []

for i_b in range(org_w_shape[0] // oc_batch_size):
w = w_all[i_b * oc_batch_size : (i_b + 1) * oc_batch_size]

org_max_val = w.abs().amax(dim=-1, keepdim=True) # co, 1, n_group, 1

best_max_val = org_max_val.clone()
min_errs = torch.ones_like(org_max_val) * 1e9
input_feat = input_feat.to(w.device)
org_out = (input_feat * w).sum(dim=-1) # co, n_token, n_group

for i_s in range(int(max_shrink * n_grid)):
max_val = org_max_val * (1 - i_s / n_grid)
min_val = -max_val
cur_w = torch.clamp(w, min_val, max_val)
q_w = pseudo_quantize_tensor(
w=cur_w,
symmetric=self.symmetric,
group_size=group_size,
bit_width=self.bits,
)[0]
cur_out = (input_feat * q_w).sum(dim=-1)

# co, 1, n_group, 1
err = (cur_out - org_out).pow(2).mean(dim=1).view(min_errs.shape)
del cur_w
del cur_out
cur_best_idx = err < min_errs
min_errs[cur_best_idx] = err[cur_best_idx]
best_max_val[cur_best_idx] = max_val[cur_best_idx]
best_max_val_all.append(best_max_val)

best_max_val = torch.cat(best_max_val_all, dim=0)

clear_memory(input_feat)
clear_memory(org_out)

return best_max_val.squeeze(1)


@torch.no_grad()
def _apply_clip(module, clip_list: Tuple[str, torch.Tensor]):
"""
Apply clipping to the weights of the given module
:post-condition: the weights of the module are clipped to the given maximum values
:param module: module to apply clipping to
:param clip_list: list of tuples containing the name of the layer and the maximum
value to clip the weights to
"""
for name, max_val in clip_list:
_, layer = get_layer(target=name, module=module)
assert isinstance(layer, torch.nn.Linear)
max_val = max_val.to(layer.weight.device)
org_shape = layer.weight.shape
layer.weight.data = layer.weight.data.reshape(*max_val.shape[:2], -1)
layer.weight.data = torch.clamp(layer.weight.data, -max_val, max_val)
layer.weight.data = layer.weight.data.reshape(org_shape)
9 changes: 9 additions & 0 deletions src/llmcompressor/pytorch/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

import functools
import gc
import inspect
import os
import random
Expand Down Expand Up @@ -91,6 +92,7 @@
"pseudo_dequantize_linear",
"tensor_forward_with_input_args",
"sanitize_kwargs_for_module",
"clear_memory",
]


Expand Down Expand Up @@ -1298,3 +1300,10 @@ def pseudo_dequantize_linear(
w = w.weight.data * scales

return w


def clear_memory(value: Optional[Any] = None):
if value is not None:
del value
gc.collect()
torch.cuda.empty_cache()

0 comments on commit 849a240

Please sign in to comment.