From 849a240ecab8ddae615af68f91fcdd50ea84e282 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Tue, 17 Sep 2024 15:13:41 +0000 Subject: [PATCH] Add: weight clipping to AWQ --- src/llmcompressor/modifiers/awq/base.py | 122 ++++++++++++++++++++- src/llmcompressor/pytorch/utils/helpers.py | 9 ++ 2 files changed, 130 insertions(+), 1 deletion(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index a1a9ef2ab..71767b0b4 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -10,11 +10,13 @@ from llmcompressor.modifiers import Modifier from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward from llmcompressor.pytorch.utils import ( + clear_memory, pseudo_quantize_tensor, tensor_forward_with_input_args, ) from llmcompressor.utils.fsdp.helpers import get_fsdp_parent from llmcompressor.utils.pytorch.module import ( + get_layer, get_layers, get_matching_layer, get_parent_by_name, @@ -48,6 +50,7 @@ class AWQMapping: :param smooth_layer: PyTorch module storing the activation layer :param balance_layers: list of PyTorch modules that smooth_layer feeds into, must be balanced to offset the smoothing of smooth_layer + :param balance_names: optional list of names of the balance_layers :param parent: parent module of the balance_layers :param parent_name: name of the parent module """ @@ -55,6 +58,7 @@ class AWQMapping: smooth_name: str smooth_layer: Module balance_layers: List[Module] + balance_names: Optional[List[str]] = None parent: Optional[Module] = None parent_name: Optional[str] = None @@ -107,6 +111,7 @@ class AWQModifier(Modifier): :param symmetric: whether to use symmetric quantization :param duo_scaling: whether to use duo scaling, which uses both input activations and weights to determine the scaling factor + :param apply_clip: whether to apply clipping to the weights after scaling """ mappings: List[Tuple] = DEFAULT_AWQ_MAPPINGS @@ -118,6 +123,7 @@ class AWQModifier(Modifier): bits: int = 4 symmetric: bool = True duo_scaling: bool = True + apply_clip: bool = True hooks_: Optional[List] = None resolved_mappings_: Optional[List] = None @@ -202,7 +208,7 @@ def _resolve_mappings(self, model: Module) -> List: to_smooth_layers = get_layers(to_smooth, model) for layer_name, smooth_layer in to_smooth_layers.items(): if layer_name not in self.ignore: - balance_layers = [] + balance_layers, balance_names = [], [] for balance_suffix in to_balance: # find the submodule that matches the activation layer balance_name, balance_layer = get_matching_layer( @@ -210,6 +216,8 @@ def _resolve_mappings(self, model: Module) -> List: ) if balance_layer: balance_layers.append(balance_layer) + balance_names.append(balance_name) + # each mapping can contain multiple layers to balance, but only # one layer to smooth @@ -226,6 +234,7 @@ def _resolve_mappings(self, model: Module) -> List: layer_name, smooth_layer, balance_layers, + balance_names=balance_names, parent=parent, parent_name=parent_name, ) @@ -314,6 +323,7 @@ def _apply_smoothing(self, model: Module): for mapping in tqdm(self.resolved_mappings_): smooth_layer = mapping.smooth_layer balance_layers = mapping.balance_layers + balance_names = mapping.balance_names activations = self.scales_[mapping.smooth_name].inps @@ -392,6 +402,15 @@ def smooth(module): smooth(layer) smooth(smooth_layer) + if self.apply_clip: + clip_list = self._search_best_clip( + balance_layers=balance_layers, + balance_names=balance_names, + input_feat=inp, + ) + + _apply_clip(model, clip_list) + # clear out allocated smoothing scales torch.cuda.empty_cache() @@ -598,3 +617,104 @@ def _forward_input_with_kwargs( inputs=inputs, input_kwargs=kwargs, )[0] + + @torch.no_grad() + def _search_best_clip(self, balance_layers, balance_names, input_feat): + clip_list = [] + avoid_clipping = ["q_", "k_", "query", "key", "Wqkv"] + + for name, layer in zip(balance_names, balance_layers): + # due to qk bmm, it is hard to clip precisely + if any([_ in name for _ in avoid_clipping]): + continue + + max_val = self._compute_best_clip(layer.weight, input_feat) + clip_list.append((name, max_val)) + + return clip_list + + @torch.no_grad() + def _compute_best_clip( + self, + w: torch.Tensor, + input_feat: torch.Tensor, + n_grid=20, + max_shrink=0.5, + n_sample_token=512, + ): + assert w.dim() == 2 + org_w_shape = w.shape + # w [co, ci] -> [co, 1, n_group, group size] + # input_feat [n_token, ci] -> [1, n_token, n_group, group size] + group_size = self.group_size if self.group_size > 0 else org_w_shape[1] + input_feat = input_feat.view(-1, input_feat.shape[-1]) + input_feat = input_feat.reshape(1, input_feat.shape[0], -1, group_size) + + # Compute input feature step size (minimum 1) + step_size = max(1, input_feat.shape[1] // n_sample_token) + input_feat = input_feat[:, ::step_size] + + w = w.reshape(org_w_shape[0], 1, -1, group_size) + + oc_batch_size = 256 if org_w_shape[0] % 256 == 0 else 64 # prevent OOM + assert org_w_shape[0] % oc_batch_size == 0 + w_all = w + best_max_val_all = [] + + for i_b in range(org_w_shape[0] // oc_batch_size): + w = w_all[i_b * oc_batch_size : (i_b + 1) * oc_batch_size] + + org_max_val = w.abs().amax(dim=-1, keepdim=True) # co, 1, n_group, 1 + + best_max_val = org_max_val.clone() + min_errs = torch.ones_like(org_max_val) * 1e9 + input_feat = input_feat.to(w.device) + org_out = (input_feat * w).sum(dim=-1) # co, n_token, n_group + + for i_s in range(int(max_shrink * n_grid)): + max_val = org_max_val * (1 - i_s / n_grid) + min_val = -max_val + cur_w = torch.clamp(w, min_val, max_val) + q_w = pseudo_quantize_tensor( + w=cur_w, + symmetric=self.symmetric, + group_size=group_size, + bit_width=self.bits, + )[0] + cur_out = (input_feat * q_w).sum(dim=-1) + + # co, 1, n_group, 1 + err = (cur_out - org_out).pow(2).mean(dim=1).view(min_errs.shape) + del cur_w + del cur_out + cur_best_idx = err < min_errs + min_errs[cur_best_idx] = err[cur_best_idx] + best_max_val[cur_best_idx] = max_val[cur_best_idx] + best_max_val_all.append(best_max_val) + + best_max_val = torch.cat(best_max_val_all, dim=0) + + clear_memory(input_feat) + clear_memory(org_out) + + return best_max_val.squeeze(1) + + +@torch.no_grad() +def _apply_clip(module, clip_list: Tuple[str, torch.Tensor]): + """ + Apply clipping to the weights of the given module + + :post-condition: the weights of the module are clipped to the given maximum values + :param module: module to apply clipping to + :param clip_list: list of tuples containing the name of the layer and the maximum + value to clip the weights to + """ + for name, max_val in clip_list: + _, layer = get_layer(target=name, module=module) + assert isinstance(layer, torch.nn.Linear) + max_val = max_val.to(layer.weight.device) + org_shape = layer.weight.shape + layer.weight.data = layer.weight.data.reshape(*max_val.shape[:2], -1) + layer.weight.data = torch.clamp(layer.weight.data, -max_val, max_val) + layer.weight.data = layer.weight.data.reshape(org_shape) diff --git a/src/llmcompressor/pytorch/utils/helpers.py b/src/llmcompressor/pytorch/utils/helpers.py index 252dcc2cc..dc06343f5 100644 --- a/src/llmcompressor/pytorch/utils/helpers.py +++ b/src/llmcompressor/pytorch/utils/helpers.py @@ -3,6 +3,7 @@ """ import functools +import gc import inspect import os import random @@ -91,6 +92,7 @@ "pseudo_dequantize_linear", "tensor_forward_with_input_args", "sanitize_kwargs_for_module", + "clear_memory", ] @@ -1298,3 +1300,10 @@ def pseudo_dequantize_linear( w = w.weight.data * scales return w + + +def clear_memory(value: Optional[Any] = None): + if value is not None: + del value + gc.collect() + torch.cuda.empty_cache()