Skip to content

Commit

Permalink
bump memory requirements
Browse files Browse the repository at this point in the history
  • Loading branch information
kylesayrs committed Dec 11, 2024
1 parent ccb007f commit e1055b0
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 83 deletions.
23 changes: 10 additions & 13 deletions src/llmcompressor/modifiers/quantization/gptq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,25 +306,22 @@ def compress_module(
CompressionLogger(module) as comp_logger,
):
loss, quantized_weight, scale, zero_point, g_idx = quantize_weight(
weight=module.weight.data,
module=module,
quant_args=quant_args,
hessian=self._hessians[module],
hessians_dict=self._hessians,
blocksize=self.block_size,
percdamp=self.dampening_frac,
module_class=type(module),
)
comp_logger.set_loss(loss)

module.weight += quantized_weight - module.weight # Future: FSDP
update_offload_parameter(module, "weight", module.weight.data)
update_offload_parameter(module, "weight_scale", scale)
update_offload_parameter(module, "weight_zero_point", zero_point)
if g_idx is not None:
update_offload_parameter(module, "weight_g_idx", g_idx)

del self._hessians[module]
del self._num_samples[module]
update_offload_parameter(module, "weight", quantized_weight)
update_offload_parameter(module, "weight_scale", scale)
update_offload_parameter(module, "weight_zero_point", zero_point)
if g_idx is not None:
update_offload_parameter(module, "weight_g_idx", g_idx)

comp_logger.set_loss(loss)
# self._hessians[module] already deleted by quantize_weight
del self._num_samples[module]

@contextlib.contextmanager
def _maybe_onload_hessian(self, module: torch.nn.Module):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math
from copy import copy
from typing import Optional, Tuple, Type, Union
from typing import Dict, Optional, Tuple, Type, Union

import torch
import transformers
Expand Down Expand Up @@ -58,31 +58,30 @@ def accumulate_hessian(


def quantize_weight(
weight: torch.Tensor,
module: torch.nn.Module,
quant_args: QuantizationArgs,
hessian: Optional[torch.Tensor] = None,
inp: Optional[torch.Tensor] = None,
hessians_dict: Dict[torch.nn.Module, torch.Tensor],
blocksize: int = 128,
percdamp: float = 0.01,
module_class: Type[torch.nn.Module] = torch.nn.Linear,
) -> Tuple[float, torch.Tensor, torch.Tensor, Union[torch.Tensor, None], torch.Tensor]:
"""
Quantize a module weight according to the GPTQ algorithm
:param weight: weight being quantized
:param module: module with weight being quantized
:param quant_args: quantization arguments used to find quantization parameters
:param hessian: preaccumulated hessian for quantization
:param inp: module inputs used to calculate hessian. Incompatible with `hessian` arg
:param hessian_dict: dictionary containing preaccumulated hessian for quantization
:param blocksize: chunk size of quantization updates
:param percdamp: dampening factor on hessian diagonal
:param module_class: class of module, likely torch.nn.Linear
:return: loss, quantized_weight, scale, zero_point, g_idx
"""
strategy = quant_args.strategy
actorder = quant_args.actorder
final_shape = weight.shape
final_dtype = weight.dtype
W = weight.data.clone()
final_shape = module.weight.shape
final_dtype = module.weight.dtype
module_class = type(module)
W = module.weight.clone()
H = hessians_dict[module] # unfortunately python does not have a `move` keyword
del hessians_dict[module] # so we have to delete the original reference manually

# create observer for calculating quantization parameters
observer = Observer.load_from_registry(
Expand All @@ -100,16 +99,6 @@ def quantize_weight(
num_rows = W.shape[0]
num_columns = W.shape[1]

# compute hessian
if inp is not None:
if hessian is not None:
raise ValueError("Must pass either inp or hessian, but not both")
H = _compute_hessian(inp, module_class, device=weight.device)
elif hessian is not None:
H = hessian
else:
raise ValueError("Must pass either inp or hessian")

if strategy == QuantizationStrategy.GROUP:
# mapping from column index to group index
g_idx = (
Expand Down Expand Up @@ -146,15 +135,28 @@ def quantize_weight(
else None
)

losses = torch.zeros(num_rows, device=weight.device)
losses = torch.zeros(num_rows, device=module.weight.device)

# mask dead hessian values
dead = torch.diag(H) == 0
H[dead, dead] = 1
W[:, dead] = 0

# compute inverse hessian in place to save memory
Hinv = _invert_hessian(H, percdamp)
try:
damp = percdamp * torch.mean(torch.diag(H))
diag = torch.arange(H.shape[0], device=H.device)
H[diag, diag] += damp
H = torch.linalg.cholesky(H)
H = torch.cholesky_inverse(H)
H = torch.linalg.cholesky(H, upper=True)
Hinv = H
except torch._C._LinAlgError:
raise ValueError(
"Failed to invert hessian due to numerical instability. Consider "
"increasing GPTQModifier.dampening_frac, increasing the number "
"of calibration samples, or shuffling the calibration dataset"
)

# See section 3.4 of https://arxiv.org/abs/2203.07259
for i1 in range(0, num_columns, blocksize):
Expand Down Expand Up @@ -265,50 +267,6 @@ def quantize_weight(
)


def _invert_hessian(H: torch.Tensor, percdamp: float) -> torch.Tensor:
"""
Performs in-place inversion of the hessian in order to save memory
:param H: hessian being inverted
:param percdamp: dampening factor on hessian diagonal
:return: inverted hessian
"""
damp = percdamp * torch.mean(torch.diag(H))
diag = torch.arange(H.shape[0], device=H.device)
H[diag, diag] += damp
H = torch.linalg.cholesky(H)
H = torch.cholesky_inverse(H)
H = torch.linalg.cholesky(H, upper=True)
return H


def _compute_hessian(
inp: torch.Tensor, module_class: Type[torch.nn.Module], device
) -> torch.Tensor:
"""
Calculate the hessian with respect to the module inputs
:param inp: module inputs
:param module_class: class of module, likely torch.nn.Linear
:return: hessian w.r.t. module inputs
"""
inp = inp.to(device=device)
if len(inp.shape) == 2:
inp = inp.unsqueeze(0)

nsamples = inp.shape[0] # note this is the number of dataset samples, not
# multiplied by the sequence length

if module_class in (torch.nn.Linear, transformers.Conv1D):
if len(inp.shape) == 3:
inp = inp.reshape((-1, inp.shape[-1]))
inp = inp.t()

inp = inp.to(dtype=GPTQ_PRECISION)
inp = math.sqrt(2 / nsamples) * inp
return inp.matmul(inp.t())


def _apply_activation_ordering(
W: torch.Tensor, H: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
Expand Down
5 changes: 3 additions & 2 deletions src/llmcompressor/transformers/compression/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,8 @@ def hessian_memory_requirements(model: torch.nn.Module) -> int:
max_total_hessian_elems = max(total_hessian_elems.values())
overall_max_column_size = max(max_column_size.values())
bytes_per_weight = 32 // 8 # hessians are float32
inverse_reserved = overall_max_column_size * overall_max_column_size
# allocate enough space for out of place operations
inverse_reserved = overall_max_column_size * overall_max_column_size * 2
return (max_total_hessian_elems + inverse_reserved) * bytes_per_weight


Expand Down Expand Up @@ -236,7 +237,7 @@ def calculate_offload_device_map(

reserved_memory = 0
if reserve_for_hessians:
reserved_memory = hessian_memory_requirements(dummy_model)
reserved_memory = hessian_memory_requirements(dummy_model) * 2
reserved_memory += quantization_memory_requirement(dummy_model)

memory_limits = {
Expand Down

0 comments on commit e1055b0

Please sign in to comment.