diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 0c3c96d34..bfe1f9715 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -306,25 +306,22 @@ def compress_module( CompressionLogger(module) as comp_logger, ): loss, quantized_weight, scale, zero_point, g_idx = quantize_weight( - weight=module.weight.data, + module=module, quant_args=quant_args, - hessian=self._hessians[module], + hessians_dict=self._hessians, blocksize=self.block_size, percdamp=self.dampening_frac, - module_class=type(module), ) + comp_logger.set_loss(loss) - module.weight += quantized_weight - module.weight # Future: FSDP - update_offload_parameter(module, "weight", module.weight.data) - update_offload_parameter(module, "weight_scale", scale) - update_offload_parameter(module, "weight_zero_point", zero_point) - if g_idx is not None: - update_offload_parameter(module, "weight_g_idx", g_idx) - - del self._hessians[module] - del self._num_samples[module] + update_offload_parameter(module, "weight", quantized_weight) + update_offload_parameter(module, "weight_scale", scale) + update_offload_parameter(module, "weight_zero_point", zero_point) + if g_idx is not None: + update_offload_parameter(module, "weight_g_idx", g_idx) - comp_logger.set_loss(loss) + # self._hessians[module] already deleted by quantize_weight + del self._num_samples[module] @contextlib.contextmanager def _maybe_onload_hessian(self, module: torch.nn.Module): diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py index fc4b56edc..d35fb9748 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py @@ -1,6 +1,6 @@ import math from copy import copy -from typing import Optional, Tuple, Type, Union +from typing import Dict, Optional, Tuple, Type, Union import torch import transformers @@ -58,31 +58,30 @@ def accumulate_hessian( def quantize_weight( - weight: torch.Tensor, + module: torch.nn.Module, quant_args: QuantizationArgs, - hessian: Optional[torch.Tensor] = None, - inp: Optional[torch.Tensor] = None, + hessians_dict: Dict[torch.nn.Module, torch.Tensor], blocksize: int = 128, percdamp: float = 0.01, - module_class: Type[torch.nn.Module] = torch.nn.Linear, ) -> Tuple[float, torch.Tensor, torch.Tensor, Union[torch.Tensor, None], torch.Tensor]: """ Quantize a module weight according to the GPTQ algorithm - :param weight: weight being quantized + :param module: module with weight being quantized :param quant_args: quantization arguments used to find quantization parameters - :param hessian: preaccumulated hessian for quantization - :param inp: module inputs used to calculate hessian. Incompatible with `hessian` arg + :param hessian_dict: dictionary containing preaccumulated hessian for quantization :param blocksize: chunk size of quantization updates :param percdamp: dampening factor on hessian diagonal - :param module_class: class of module, likely torch.nn.Linear :return: loss, quantized_weight, scale, zero_point, g_idx """ strategy = quant_args.strategy actorder = quant_args.actorder - final_shape = weight.shape - final_dtype = weight.dtype - W = weight.data.clone() + final_shape = module.weight.shape + final_dtype = module.weight.dtype + module_class = type(module) + W = module.weight.clone() + H = hessians_dict[module] # unfortunately python does not have a `move` keyword + del hessians_dict[module] # so we have to delete the original reference manually # create observer for calculating quantization parameters observer = Observer.load_from_registry( @@ -100,16 +99,6 @@ def quantize_weight( num_rows = W.shape[0] num_columns = W.shape[1] - # compute hessian - if inp is not None: - if hessian is not None: - raise ValueError("Must pass either inp or hessian, but not both") - H = _compute_hessian(inp, module_class, device=weight.device) - elif hessian is not None: - H = hessian - else: - raise ValueError("Must pass either inp or hessian") - if strategy == QuantizationStrategy.GROUP: # mapping from column index to group index g_idx = ( @@ -146,7 +135,7 @@ def quantize_weight( else None ) - losses = torch.zeros(num_rows, device=weight.device) + losses = torch.zeros(num_rows, device=module.weight.device) # mask dead hessian values dead = torch.diag(H) == 0 @@ -154,7 +143,20 @@ def quantize_weight( W[:, dead] = 0 # compute inverse hessian in place to save memory - Hinv = _invert_hessian(H, percdamp) + try: + damp = percdamp * torch.mean(torch.diag(H)) + diag = torch.arange(H.shape[0], device=H.device) + H[diag, diag] += damp + H = torch.linalg.cholesky(H) + H = torch.cholesky_inverse(H) + H = torch.linalg.cholesky(H, upper=True) + Hinv = H + except torch._C._LinAlgError: + raise ValueError( + "Failed to invert hessian due to numerical instability. Consider " + "increasing GPTQModifier.dampening_frac, increasing the number " + "of calibration samples, or shuffling the calibration dataset" + ) # See section 3.4 of https://arxiv.org/abs/2203.07259 for i1 in range(0, num_columns, blocksize): @@ -265,50 +267,6 @@ def quantize_weight( ) -def _invert_hessian(H: torch.Tensor, percdamp: float) -> torch.Tensor: - """ - Performs in-place inversion of the hessian in order to save memory - - :param H: hessian being inverted - :param percdamp: dampening factor on hessian diagonal - :return: inverted hessian - """ - damp = percdamp * torch.mean(torch.diag(H)) - diag = torch.arange(H.shape[0], device=H.device) - H[diag, diag] += damp - H = torch.linalg.cholesky(H) - H = torch.cholesky_inverse(H) - H = torch.linalg.cholesky(H, upper=True) - return H - - -def _compute_hessian( - inp: torch.Tensor, module_class: Type[torch.nn.Module], device -) -> torch.Tensor: - """ - Calculate the hessian with respect to the module inputs - - :param inp: module inputs - :param module_class: class of module, likely torch.nn.Linear - :return: hessian w.r.t. module inputs - """ - inp = inp.to(device=device) - if len(inp.shape) == 2: - inp = inp.unsqueeze(0) - - nsamples = inp.shape[0] # note this is the number of dataset samples, not - # multiplied by the sequence length - - if module_class in (torch.nn.Linear, transformers.Conv1D): - if len(inp.shape) == 3: - inp = inp.reshape((-1, inp.shape[-1])) - inp = inp.t() - - inp = inp.to(dtype=GPTQ_PRECISION) - inp = math.sqrt(2 / nsamples) * inp - return inp.matmul(inp.t()) - - def _apply_activation_ordering( W: torch.Tensor, H: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: diff --git a/src/llmcompressor/transformers/compression/helpers.py b/src/llmcompressor/transformers/compression/helpers.py index 7c839c5a7..ee764b9f8 100644 --- a/src/llmcompressor/transformers/compression/helpers.py +++ b/src/llmcompressor/transformers/compression/helpers.py @@ -137,7 +137,8 @@ def hessian_memory_requirements(model: torch.nn.Module) -> int: max_total_hessian_elems = max(total_hessian_elems.values()) overall_max_column_size = max(max_column_size.values()) bytes_per_weight = 32 // 8 # hessians are float32 - inverse_reserved = overall_max_column_size * overall_max_column_size + # allocate enough space for out of place operations + inverse_reserved = overall_max_column_size * overall_max_column_size * 2 return (max_total_hessian_elems + inverse_reserved) * bytes_per_weight @@ -236,7 +237,7 @@ def calculate_offload_device_map( reserved_memory = 0 if reserve_for_hessians: - reserved_memory = hessian_memory_requirements(dummy_model) + reserved_memory = hessian_memory_requirements(dummy_model) * 2 reserved_memory += quantization_memory_requirement(dummy_model) memory_limits = {