diff --git a/src/llmcompressor/pipelines/sequential/helpers.py b/src/llmcompressor/pipelines/sequential/helpers.py index 4fd4db4f..b35f309e 100644 --- a/src/llmcompressor/pipelines/sequential/helpers.py +++ b/src/llmcompressor/pipelines/sequential/helpers.py @@ -3,7 +3,6 @@ from dataclasses import dataclass from typing import Any, Dict, List, Set -import torch from compressed_tensors import has_offloaded_params from compressed_tensors.quantization import find_name_or_class_matches from torch.fx import Graph, GraphModule, Node @@ -11,15 +10,19 @@ from transformers.utils.fx import HFTracer from llmcompressor.modifiers.utils.hooks import HooksMixin -from llmcompressor.utils.helpers import calibration_forward_context, getattr_chain +from llmcompressor.utils.helpers import calibration_forward_context @dataclass class Subgraph: graph: Graph - input_names: List[str] - consumed_names: List[str] - input_device: torch.device + input_names: Set[str] + consumed_names: Set[str] + + def compile_forward(self): + code = self.graph.python_code("self") + exec(code.src, code.globals) + return code.globals.get("forward") __all__ = ["infer_sequential_targets", "trace_subgraphs"] @@ -213,34 +216,14 @@ def partition_graph(model: Module, partitions: List[List[Node]]) -> List[Subgrap } graph.output(output_dict) - # find input device for subgraph - # note: find_nodes is topologically sorted - modules = [ - getattr_chain(model, node.target) - for node in graph.find_nodes(op="call_module") - ] - if len(modules) > 0: - first_offloaded = next( - (m for m in modules if has_offloaded_params(m)), None - ) - input_device = ( - torch.device(first_offloaded._hf_hook.execution_device) - if first_offloaded is not None - else next(modules[0].parameters()).device - ) - - else: - input_device = model.device - # save the subgraph for this partition graph.lint() - input_names = [node.name for node in graph.nodes if node.op == "placeholder"] + input_names = set(node.name for node in graph.nodes if node.op == "placeholder") subgraphs.append( Subgraph( graph=graph, input_names=input_names, - consumed_names=[], # populated later - input_device=input_device, + consumed_names=set(), # populated later ) ) @@ -256,7 +239,7 @@ def trace_consumed_names(subgraphs: List[Dict[str, Any]]): for input_name in all_input_names: for subgraph in reversed(subgraphs): if input_name in subgraph.input_names: - subgraph.consumed_names.append(input_name) + subgraph.consumed_names.add(input_name) break else: assert False diff --git a/src/llmcompressor/pipelines/sequential/pipeline.py b/src/llmcompressor/pipelines/sequential/pipeline.py index 1325712d..ec34b227 100644 --- a/src/llmcompressor/pipelines/sequential/pipeline.py +++ b/src/llmcompressor/pipelines/sequential/pipeline.py @@ -4,14 +4,14 @@ import torch import torch.utils.data.dataloader import tqdm +from compressed_tensors.utils import get_execution_device from llmcompressor.modifiers.utils.hooks import HooksMixin -from llmcompressor.modifiers.utils.pytorch_helpers import apply_pad_mask_to_batch +from llmcompressor.pipelines.sequential.cache import IntermediatesCache from llmcompressor.pipelines.sequential.helpers import ( infer_sequential_targets, trace_subgraphs, ) -from llmcompressor.pytorch.utils.helpers import tensors_to_device from llmcompressor.utils.helpers import calibration_forward_context __all__ = ["run_pipeline"] @@ -33,7 +33,7 @@ def run_pipeline( In order to reduce memory requirements 1. Data is passed through each subgraph with batch size 1 - 2. The intermediate activations between each subgraph are offloaded onto the CPU + 2. Intermediate activations between each subgraph are offloaded onto the CPU This pipeline requires that the model be tracable with respect to data from the data loader. This may be an issue for vision language models with vision datasets, @@ -50,55 +50,38 @@ def run_pipeline( with calibration_forward_context(model): # prepare intermediates cache - desc = "Preparing intermediates cache" - batch_intermediates = [ - apply_pad_mask_to_batch(batch) if "attention_mask" in batch else batch - for batch in tqdm.tqdm(dataloader, desc=desc) - ] - batch_outputs = [None for _ in range(len(dataloader))] + model_device = get_execution_device(model) + intermediates = IntermediatesCache.from_dataloader(dataloader, model_device) + model_outputs = [dict() for _ in range(len(dataloader))] num_subgraphs = len(subgraphs) - for index, subgraph in enumerate(subgraphs): + for subgraph_index, subgraph in enumerate(subgraphs): # prepare tqdm description texts - uncomp_desc = f"({index + 1}/{num_subgraphs}): Calibrating" - comp_desc = f"({index + 1}/{num_subgraphs}): Propagate" + calib_desc = f"({subgraph_index + 1}/{num_subgraphs}): Calibrating" + prop_desc = f"({subgraph_index + 1}/{num_subgraphs}): Propagate" # compile subgraph forward function - code = subgraph.graph.python_code("self") - exec(code.src, code.globals) - forward_function = code.globals.get("forward") + forward_function = subgraph.compile_forward() if propagate_error: # do an preliminary pass to trigger modifier hooks - for batch_index in tqdm.tqdm(range(len(dataloader)), desc=uncomp_desc): - intermediates = batch_intermediates[batch_index] - inputs = { - input_name: intermediates[input_name] - for input_name in subgraph.input_names - } - inputs = tensors_to_device(inputs, subgraph.input_device) + for batch_index in tqdm.tqdm(range(len(dataloader)), desc=calib_desc): + inputs = intermediates.fetch(batch_index, subgraph.input_names) forward_function(model, **inputs) + del inputs # if using propagate_error, then this pass does not trigger modifier hooks # and is only used for capturing intermediates # otherwise, this pass triggers modifier hooks and captures intermediates with HooksMixin.disable_hooks() if propagate_error else nullcontext(): - desc = comp_desc if propagate_error else uncomp_desc + desc = prop_desc if propagate_error else calib_desc for batch_index in tqdm.tqdm(range(len(dataloader)), desc=desc): - intermediates = batch_intermediates[batch_index] - - inputs = { - input_name: intermediates[input_name] - for input_name in subgraph.input_names - } - inputs = tensors_to_device(inputs, subgraph.input_device) + inputs = intermediates.fetch(batch_index, subgraph.input_names) subgraph_output = forward_function(model, **inputs) - subgraph_output = tensors_to_device(subgraph_output, "cpu") - - for consumed_name in subgraph.consumed_names: - del intermediates[consumed_name] + del inputs - if index < len(subgraphs) - 1: - intermediates.update(subgraph_output) + if subgraph_index < len(subgraphs) - 1: + intermediates.update(batch_index, subgraph_output) + intermediates.delete(batch_index, subgraph.consumed_names) else: - batch_outputs[batch_index] = subgraph_output + model_outputs[batch_index] = subgraph_output diff --git a/src/llmcompressor/transformers/compression/helpers.py b/src/llmcompressor/transformers/compression/helpers.py index ee764b9f..8dd6a0cb 100644 --- a/src/llmcompressor/transformers/compression/helpers.py +++ b/src/llmcompressor/transformers/compression/helpers.py @@ -137,8 +137,7 @@ def hessian_memory_requirements(model: torch.nn.Module) -> int: max_total_hessian_elems = max(total_hessian_elems.values()) overall_max_column_size = max(max_column_size.values()) bytes_per_weight = 32 // 8 # hessians are float32 - # allocate enough space for out of place operations - inverse_reserved = overall_max_column_size * overall_max_column_size * 2 + inverse_reserved = overall_max_column_size * overall_max_column_size return (max_total_hessian_elems + inverse_reserved) * bytes_per_weight