Skip to content

Commit

Permalink
fix offload
Browse files Browse the repository at this point in the history
Signed-off-by: Dipika <[email protected]>
  • Loading branch information
dsikka committed Dec 15, 2024
1 parent 540d4b2 commit d1a07fc
Showing 1 changed file with 13 additions and 0 deletions.
13 changes: 13 additions & 0 deletions src/llmcompressor/modifiers/smoothquant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Callable, Dict, List, Optional, Tuple

import torch
from compressed_tensors.utils.offload import is_module_offloaded
from loguru import logger
from torch.nn import Module

Expand All @@ -12,6 +13,7 @@
handle_mapping_resolution_errors,
)
from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward
from llmcompressor.timer_utils import log_time
from llmcompressor.utils.fsdp.helpers import get_fsdp_parent
from llmcompressor.utils.pytorch.module import get_layers, get_matching_layer

Expand Down Expand Up @@ -102,6 +104,7 @@ class SmoothQuantModifier(Modifier):
resolved_mappings_: Optional[List] = None
scales_: Optional[Dict] = None

@log_time
def on_initialize(self, state: State, **kwargs) -> bool:
"""
Initialize and run SmoothQuant on the given state
Expand Down Expand Up @@ -318,8 +321,13 @@ def _calculate_smoothing_scales(
# get the channel-wise dynamic range for each layer to be balanced
weight_scales = []
for layer in balance_layers:
offloaded = is_module_offloaded(layer)
if offloaded:
layer._hf_hook.pre_forward(layer)

scale = layer.weight.abs().max(dim=0, keepdim=True)[0]
weight_scales.append(scale)

weight_scales = 2.0 * torch.cat(weight_scales, dim=0).max(dim=0)[0]

# calculate the amount of smoothing to apply
Expand All @@ -329,4 +337,9 @@ def _calculate_smoothing_scales(
1 - self.smoothing_strength
)
scales = torch.where(weight_scales > 0.0, scales, activation_scales)

for layer in balance_layers:
offloaded = is_module_offloaded(layer)
if offloaded:
layer._hf_hook.post_forward(layer, None)
return scales

0 comments on commit d1a07fc

Please sign in to comment.