Skip to content

Commit

Permalink
add fake_sequential
Browse files Browse the repository at this point in the history
  • Loading branch information
kylesayrs committed Dec 14, 2024
1 parent a71352a commit 2d249a2
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/llmcompressor/pipelines/fake_sequential/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# flake8: noqa
from .pipeline import run_pipeline
99 changes: 99 additions & 0 deletions src/llmcompressor/pipelines/fake_sequential/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import contextlib
import inspect
from dataclasses import dataclass
from typing import Any, Dict, List, Set, Tuple

import torch
import tqdm
from compressed_tensors.quantization import find_name_or_class_matches
from compressed_tensors.utils import get_execution_device
from torch.nn import Module
from torch.utils.data.dataloader import DataLoader

from llmcompressor.modifiers.utils.pytorch_helpers import apply_pad_mask_to_batch
from llmcompressor.pipelines.cache import IntermediatesCache
from llmcompressor.pytorch.utils.helpers import tensors_to_device
from llmcompressor.utils.helpers import calibration_forward_context

__all__ = ["match_modules", "compute_first_layer_intermediates"]


def match_modules(model: Module, target_names: List[str]) -> List[Module]:
names_layers = [
(name, module)
for name, module in model.named_modules()
if find_name_or_class_matches(name, module, target_names)
]

names_layers = sorted(names_layers, key=lambda name_layer: name_layer[0])
return [layer for _name, layer in names_layers]


def compute_first_layer_intermediates(
model: Module,
layers: List[Module],
dataloader: DataLoader,
mask_padding: bool = True,
) -> IntermediatesCache:
model_device = get_execution_device(model)
intermediates = IntermediatesCache.empty(len(dataloader), torch.device("cpu"))
first_layer = layers[0]
signature = inspect.signature(first_layer.forward)

with calibration_forward_context(model), early_stop_hook(first_layer):
desc = "Preparing intermediates cache"
for batch_index, batch in enumerate(tqdm.tqdm(dataloader, desc=desc)):
batch = apply_pad_mask_to_batch(batch) if mask_padding else batch
batch = tensors_to_device(batch, model_device)

try:
model(**batch)
except EarlyStopException as exception:
layer_args = args_to_kwargs(exception._args, signature)
assert not set(layer_args.keys()) & set(exception._kwargs.keys())
layer_args.update(exception._kwargs)

intermediates.update(batch_index, layer_args)
else:
raise ValueError(
"Attempted to capture first layer intermediates, but "
"EarlyStopException was not raised"
)

return intermediates


def to_next_layer_kwargs(args: Tuple[Any, ...], next_layer: Module) -> Dict[str, Any]:
signature = inspect.signature(next_layer.forward)
return args_to_kwargs(args, signature)


def args_to_kwargs(
args: Tuple[Any, ...], signature: inspect.Signature
) -> Dict[str, Any]:
return {name: arg for name, arg in zip(signature.parameters.keys(), args)}


@contextlib.contextmanager
def early_stop_hook(module: Module):
def trigger_early_stop_fn(module, args, kwargs):
raise EarlyStopException(_args=args, _kwargs=kwargs)

handle = module.register_forward_pre_hook(trigger_early_stop_fn, with_kwargs=True)

yield

handle.remove()


@dataclass
class EarlyStopException(Exception):
"""
Note: this is exception different from the exception defined in
llmcompressor.modifiers.utils.pytorch_helpers, and will eventually replace
Attribute names `args` and `kwargs` are reserved for `dataclass`
"""

_args: Tuple[Any, ...]
_kwargs: Dict[str, Any]
61 changes: 61 additions & 0 deletions src/llmcompressor/pipelines/fake_sequential/pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from contextlib import nullcontext
from typing import List

import torch
import torch.utils.data.dataloader
import tqdm
from compressed_tensors.utils import get_execution_device

from llmcompressor.modifiers.utils.hooks import HooksMixin
from llmcompressor.pipelines.cache import IntermediatesCache
from llmcompressor.pipelines.fake_sequential.helpers import (
compute_first_layer_intermediates,
match_modules,
to_next_layer_kwargs,
)
from llmcompressor.utils.helpers import calibration_forward_context

__all__ = ["run_pipeline"]


def run_pipeline(
model: torch.nn.Module,
sequential_targets: List[str], # FUTURE: replace with recipe inference
dataloader: torch.utils.data.DataLoader,
propagate_error: bool,
):
""" """
# find layers
layers = match_modules(model, sequential_targets)

# FUTURE: apply recipe to model
# initialize(recipe, model)

with calibration_forward_context(model):
intermediates = compute_first_layer_intermediates(model, layers, dataloader)

num_layers = len(layers)
for layer_index, layer in enumerate(layers):
# prepare tqdm description texts
calib_desc = f"({layer_index + 1}/{num_layers}): Calibrating"
prop_desc = f"({layer_index + 1}/{num_layers}): Propagate"

if propagate_error:
# do an preliminary pass to trigger modifier hooks
for batch_index in tqdm.tqdm(range(len(dataloader)), desc=calib_desc):
inputs = intermediates.fetch(batch_index)
layer(**inputs)

# if using propagate_error, then this pass does not trigger modifier hooks
# and is only used for capturing intermediates
# otherwise, this pass triggers modifier hooks and captures intermediates
with HooksMixin.disable_hooks() if propagate_error else nullcontext():
desc = prop_desc if propagate_error else calib_desc
for batch_index in tqdm.tqdm(range(len(dataloader)), desc=desc):
inputs = intermediates.fetch(batch_index)
output = layer(**inputs)
output = to_next_layer_kwargs(output, layers[layer_index + 1])

if layer_index < num_layers - 1:
intermediates.delete(batch_index)
intermediates.update(batch_index, output)

0 comments on commit 2d249a2

Please sign in to comment.