Skip to content

Commit

Permalink
fix memory and offloading issues
Browse files Browse the repository at this point in the history
  • Loading branch information
kylesayrs committed Dec 12, 2024
1 parent e1055b0 commit 70421ed
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 67 deletions.
39 changes: 11 additions & 28 deletions src/llmcompressor/pipelines/sequential/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,26 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Set

import torch
from compressed_tensors import has_offloaded_params
from compressed_tensors.quantization import find_name_or_class_matches
from torch.fx import Graph, GraphModule, Node
from torch.nn import Module
from transformers.utils.fx import HFTracer

from llmcompressor.modifiers.utils.hooks import HooksMixin
from llmcompressor.utils.helpers import calibration_forward_context, getattr_chain
from llmcompressor.utils.helpers import calibration_forward_context


@dataclass
class Subgraph:
graph: Graph
input_names: List[str]
consumed_names: List[str]
input_device: torch.device
input_names: Set[str]
consumed_names: Set[str]

def compile_forward(self):
code = self.graph.python_code("self")
exec(code.src, code.globals)
return code.globals.get("forward")


__all__ = ["infer_sequential_targets", "trace_subgraphs"]
Expand Down Expand Up @@ -213,34 +216,14 @@ def partition_graph(model: Module, partitions: List[List[Node]]) -> List[Subgrap
}
graph.output(output_dict)

# find input device for subgraph
# note: find_nodes is topologically sorted
modules = [
getattr_chain(model, node.target)
for node in graph.find_nodes(op="call_module")
]
if len(modules) > 0:
first_offloaded = next(
(m for m in modules if has_offloaded_params(m)), None
)
input_device = (
torch.device(first_offloaded._hf_hook.execution_device)
if first_offloaded is not None
else next(modules[0].parameters()).device
)

else:
input_device = model.device

# save the subgraph for this partition
graph.lint()
input_names = [node.name for node in graph.nodes if node.op == "placeholder"]
input_names = set(node.name for node in graph.nodes if node.op == "placeholder")
subgraphs.append(
Subgraph(
graph=graph,
input_names=input_names,
consumed_names=[], # populated later
input_device=input_device,
consumed_names=set(), # populated later
)
)

Expand All @@ -256,7 +239,7 @@ def trace_consumed_names(subgraphs: List[Dict[str, Any]]):
for input_name in all_input_names:
for subgraph in reversed(subgraphs):
if input_name in subgraph.input_names:
subgraph.consumed_names.append(input_name)
subgraph.consumed_names.add(input_name)
break
else:
assert False
Expand Down
57 changes: 20 additions & 37 deletions src/llmcompressor/pipelines/sequential/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
import torch
import torch.utils.data.dataloader
import tqdm
from compressed_tensors.utils import get_execution_device

from llmcompressor.modifiers.utils.hooks import HooksMixin
from llmcompressor.modifiers.utils.pytorch_helpers import apply_pad_mask_to_batch
from llmcompressor.pipelines.sequential.cache import IntermediatesCache
from llmcompressor.pipelines.sequential.helpers import (
infer_sequential_targets,
trace_subgraphs,
)
from llmcompressor.pytorch.utils.helpers import tensors_to_device
from llmcompressor.utils.helpers import calibration_forward_context

__all__ = ["run_pipeline"]
Expand All @@ -33,7 +33,7 @@ def run_pipeline(
In order to reduce memory requirements
1. Data is passed through each subgraph with batch size 1
2. The intermediate activations between each subgraph are offloaded onto the CPU
2. Intermediate activations between each subgraph are offloaded onto the CPU
This pipeline requires that the model be tracable with respect to data from the
data loader. This may be an issue for vision language models with vision datasets,
Expand All @@ -50,55 +50,38 @@ def run_pipeline(

with calibration_forward_context(model):
# prepare intermediates cache
desc = "Preparing intermediates cache"
batch_intermediates = [
apply_pad_mask_to_batch(batch) if "attention_mask" in batch else batch
for batch in tqdm.tqdm(dataloader, desc=desc)
]
batch_outputs = [None for _ in range(len(dataloader))]
model_device = get_execution_device(model)
intermediates = IntermediatesCache.from_dataloader(dataloader, model_device)
model_outputs = [dict() for _ in range(len(dataloader))]

num_subgraphs = len(subgraphs)
for index, subgraph in enumerate(subgraphs):
for subgraph_index, subgraph in enumerate(subgraphs):
# prepare tqdm description texts
uncomp_desc = f"({index + 1}/{num_subgraphs}): Calibrating"
comp_desc = f"({index + 1}/{num_subgraphs}): Propagate"
calib_desc = f"({subgraph_index + 1}/{num_subgraphs}): Calibrating"
prop_desc = f"({subgraph_index + 1}/{num_subgraphs}): Propagate"

# compile subgraph forward function
code = subgraph.graph.python_code("self")
exec(code.src, code.globals)
forward_function = code.globals.get("forward")
forward_function = subgraph.compile_forward()

if propagate_error:
# do an preliminary pass to trigger modifier hooks
for batch_index in tqdm.tqdm(range(len(dataloader)), desc=uncomp_desc):
intermediates = batch_intermediates[batch_index]
inputs = {
input_name: intermediates[input_name]
for input_name in subgraph.input_names
}
inputs = tensors_to_device(inputs, subgraph.input_device)
for batch_index in tqdm.tqdm(range(len(dataloader)), desc=calib_desc):
inputs = intermediates.fetch(batch_index, subgraph.input_names)
forward_function(model, **inputs)
del inputs

# if using propagate_error, then this pass does not trigger modifier hooks
# and is only used for capturing intermediates
# otherwise, this pass triggers modifier hooks and captures intermediates
with HooksMixin.disable_hooks() if propagate_error else nullcontext():
desc = comp_desc if propagate_error else uncomp_desc
desc = prop_desc if propagate_error else calib_desc
for batch_index in tqdm.tqdm(range(len(dataloader)), desc=desc):
intermediates = batch_intermediates[batch_index]

inputs = {
input_name: intermediates[input_name]
for input_name in subgraph.input_names
}
inputs = tensors_to_device(inputs, subgraph.input_device)
inputs = intermediates.fetch(batch_index, subgraph.input_names)
subgraph_output = forward_function(model, **inputs)
subgraph_output = tensors_to_device(subgraph_output, "cpu")

for consumed_name in subgraph.consumed_names:
del intermediates[consumed_name]
del inputs

if index < len(subgraphs) - 1:
intermediates.update(subgraph_output)
if subgraph_index < len(subgraphs) - 1:
intermediates.update(batch_index, subgraph_output)
intermediates.delete(batch_index, subgraph.consumed_names)
else:
batch_outputs[batch_index] = subgraph_output
model_outputs[batch_index] = subgraph_output
3 changes: 1 addition & 2 deletions src/llmcompressor/transformers/compression/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,7 @@ def hessian_memory_requirements(model: torch.nn.Module) -> int:
max_total_hessian_elems = max(total_hessian_elems.values())
overall_max_column_size = max(max_column_size.values())
bytes_per_weight = 32 // 8 # hessians are float32
# allocate enough space for out of place operations
inverse_reserved = overall_max_column_size * overall_max_column_size * 2
inverse_reserved = overall_max_column_size * overall_max_column_size
return (max_total_hessian_elems + inverse_reserved) * bytes_per_weight


Expand Down

0 comments on commit 70421ed

Please sign in to comment.