-
Notifications
You must be signed in to change notification settings - Fork 66
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
162 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
# flake8: noqa | ||
from .pipeline import run_pipeline |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
import contextlib | ||
import inspect | ||
from dataclasses import dataclass | ||
from typing import Any, Dict, List, Set, Tuple | ||
|
||
import torch | ||
import tqdm | ||
from compressed_tensors.quantization import find_name_or_class_matches | ||
from compressed_tensors.utils import get_execution_device | ||
from torch.nn import Module | ||
from torch.utils.data.dataloader import DataLoader | ||
|
||
from llmcompressor.modifiers.utils.pytorch_helpers import apply_pad_mask_to_batch | ||
from llmcompressor.pipelines.cache import IntermediatesCache | ||
from llmcompressor.pytorch.utils.helpers import tensors_to_device | ||
from llmcompressor.utils.helpers import calibration_forward_context | ||
|
||
__all__ = ["match_modules", "compute_first_layer_intermediates"] | ||
|
||
|
||
def match_modules(model: Module, target_names: List[str]) -> List[Module]: | ||
names_layers = [ | ||
(name, module) | ||
for name, module in model.named_modules() | ||
if find_name_or_class_matches(name, module, target_names) | ||
] | ||
|
||
names_layers = sorted(names_layers, key=lambda name_layer: name_layer[0]) | ||
return [layer for _name, layer in names_layers] | ||
|
||
|
||
def compute_first_layer_intermediates( | ||
model: Module, | ||
layers: List[Module], | ||
dataloader: DataLoader, | ||
mask_padding: bool = True, | ||
) -> IntermediatesCache: | ||
model_device = get_execution_device(model) | ||
intermediates = IntermediatesCache.empty(len(dataloader), torch.device("cpu")) | ||
first_layer = layers[0] | ||
signature = inspect.signature(first_layer.forward) | ||
|
||
with calibration_forward_context(model), early_stop_hook(first_layer): | ||
desc = "Preparing intermediates cache" | ||
for batch_index, batch in enumerate(tqdm.tqdm(dataloader, desc=desc)): | ||
batch = apply_pad_mask_to_batch(batch) if mask_padding else batch | ||
batch = tensors_to_device(batch, model_device) | ||
|
||
try: | ||
model(**batch) | ||
except EarlyStopException as exception: | ||
layer_args = args_to_kwargs(exception._args, signature) | ||
assert not set(layer_args.keys()) & set(exception._kwargs.keys()) | ||
layer_args.update(exception._kwargs) | ||
|
||
intermediates.update(batch_index, layer_args) | ||
else: | ||
raise ValueError( | ||
"Attempted to capture first layer intermediates, but " | ||
"EarlyStopException was not raised" | ||
) | ||
|
||
return intermediates | ||
|
||
|
||
def to_next_layer_kwargs(args: Tuple[Any, ...], next_layer: Module) -> Dict[str, Any]: | ||
signature = inspect.signature(next_layer.forward) | ||
return args_to_kwargs(args, signature) | ||
|
||
|
||
def args_to_kwargs( | ||
args: Tuple[Any, ...], signature: inspect.Signature | ||
) -> Dict[str, Any]: | ||
return {name: arg for name, arg in zip(signature.parameters.keys(), args)} | ||
|
||
|
||
@contextlib.contextmanager | ||
def early_stop_hook(module: Module): | ||
def trigger_early_stop_fn(module, args, kwargs): | ||
raise EarlyStopException(_args=args, _kwargs=kwargs) | ||
|
||
handle = module.register_forward_pre_hook(trigger_early_stop_fn, with_kwargs=True) | ||
|
||
yield | ||
|
||
handle.remove() | ||
|
||
|
||
@dataclass | ||
class EarlyStopException(Exception): | ||
""" | ||
Note: this is exception different from the exception defined in | ||
llmcompressor.modifiers.utils.pytorch_helpers, and will eventually replace | ||
Attribute names `args` and `kwargs` are reserved for `dataclass` | ||
""" | ||
|
||
_args: Tuple[Any, ...] | ||
_kwargs: Dict[str, Any] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
from contextlib import nullcontext | ||
from typing import List | ||
|
||
import torch | ||
import torch.utils.data.dataloader | ||
import tqdm | ||
from compressed_tensors.utils import get_execution_device | ||
|
||
from llmcompressor.modifiers.utils.hooks import HooksMixin | ||
from llmcompressor.pipelines.cache import IntermediatesCache | ||
from llmcompressor.pipelines.fake_sequential.helpers import ( | ||
compute_first_layer_intermediates, | ||
match_modules, | ||
to_next_layer_kwargs, | ||
) | ||
from llmcompressor.utils.helpers import calibration_forward_context | ||
|
||
__all__ = ["run_pipeline"] | ||
|
||
|
||
def run_pipeline( | ||
model: torch.nn.Module, | ||
sequential_targets: List[str], # FUTURE: replace with recipe inference | ||
dataloader: torch.utils.data.DataLoader, | ||
propagate_error: bool, | ||
): | ||
""" """ | ||
# find layers | ||
layers = match_modules(model, sequential_targets) | ||
|
||
# FUTURE: apply recipe to model | ||
# initialize(recipe, model) | ||
|
||
with calibration_forward_context(model): | ||
intermediates = compute_first_layer_intermediates(model, layers, dataloader) | ||
|
||
num_layers = len(layers) | ||
for layer_index, layer in enumerate(layers): | ||
# prepare tqdm description texts | ||
calib_desc = f"({layer_index + 1}/{num_layers}): Calibrating" | ||
prop_desc = f"({layer_index + 1}/{num_layers}): Propagate" | ||
|
||
if propagate_error: | ||
# do an preliminary pass to trigger modifier hooks | ||
for batch_index in tqdm.tqdm(range(len(dataloader)), desc=calib_desc): | ||
inputs = intermediates.fetch(batch_index) | ||
layer(**inputs) | ||
|
||
# if using propagate_error, then this pass does not trigger modifier hooks | ||
# and is only used for capturing intermediates | ||
# otherwise, this pass triggers modifier hooks and captures intermediates | ||
with HooksMixin.disable_hooks() if propagate_error else nullcontext(): | ||
desc = prop_desc if propagate_error else calib_desc | ||
for batch_index in tqdm.tqdm(range(len(dataloader)), desc=desc): | ||
inputs = intermediates.fetch(batch_index) | ||
output = layer(**inputs) | ||
output = to_next_layer_kwargs(output, layers[layer_index + 1]) | ||
|
||
if layer_index < num_layers - 1: | ||
intermediates.delete(batch_index) | ||
intermediates.update(batch_index, output) |