Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Support model offloading SparseGPTQ #918

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
22 changes: 20 additions & 2 deletions src/llmcompressor/modifiers/obcq/utils/sgpt_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
import time

from compressed_tensors.quantization.lifecycle.forward import forward_quantize
from compressed_tensors.utils import (
get_offloaded_device,
is_module_offloaded,
update_prefix_dict,
)

from llmcompressor.modifiers.utils.compression_wrapper import ModuleCompressionWrapper
from llmcompressor.pytorch.utils.helpers import tensor_sparsity
from llmcompressor.utils import getattr_chain

try:
Expand Down Expand Up @@ -87,9 +93,14 @@ def compress(
diagonal norm
:param preserve_sparsity_mask: Extend or ignore the base sparsity mask
"""
if is_module_offloaded(self.layer):
self.layer._hf_hook.pre_forward(self.layer)

final_shape = self.layer.weight.shape
final_dtype = self.layer.weight.dtype
W = self.layer.weight.data.clone()

# ensure weight has been properly quantized (if applicable) before sparsifying
args_loc = "quantization_scheme.weights"
weight_quant_args = getattr_chain(self.layer, args_loc, None)
if weight_quant_args is not None:
Expand Down Expand Up @@ -204,8 +215,8 @@ def compress(
else:
W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])

logger.info("time %.2f" % (time.time() - tick))
logger.info("error %.2f" % torch.sum(Losses).item())
logger.info(f"time {time.time() - tick:.2f}")
logger.info(f"error {torch.sum(Losses).item():.2f}")

if isinstance(self.layer, transformers.Conv1D):
W = W.t()
Expand All @@ -218,6 +229,13 @@ def compress(
self.layer.weight -= self.layer.weight
self.layer.weight += W

logger.info(f"sparsity {tensor_sparsity(W):.2f}")

if is_module_offloaded(self.layer):
device = get_offloaded_device(self.layer)
update_prefix_dict(self.layer, "weight", self.layer.weight.to(device))
self.layer._hf_hook.post_forward(self.layer, None)

def free(self):
"""
Free the Hessian memory after the layer is complete
Expand Down
17 changes: 13 additions & 4 deletions src/llmcompressor/modifiers/utils/pytorch_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Callable, Dict, List, Optional, Tuple

import torch
from compressed_tensors import is_module_offloaded
from torch.nn import Module
from torch.utils.data import DataLoader
from tqdm import tqdm
Expand Down Expand Up @@ -70,11 +71,19 @@ def run_calibration_forward(
calibration_function if calibration_function else tensors_module_forward
)

# move model to optional specified device if it is not already there
model_device = next(model.parameters()).device
if device is not None and model_device != device:
# move to specified device if specified
if device is not None:
model.to(device)
model_device = next(model.parameters()).device
model_device = device

# start on the cpu if the model is offloaded
elif any((m for m in model.modules() if is_module_offloaded(m))):
model_device = torch.device("cpu")

# copy model device if not offloaded
else:
model_device = model.device

_dataloader = (
calibration_dataloader
if num_calibration_steps is None
Expand Down
Loading