Skip to content

Commit

Permalink
wrap ignore but do not treat as sequential target
Browse files Browse the repository at this point in the history
  • Loading branch information
kylesayrs committed Dec 13, 2024
1 parent 1bf683e commit b75fe15
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 44 deletions.
1 change: 0 additions & 1 deletion examples/multimodal_vision/mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ def data_collator(batch):
targets="Linear",
scheme="W8A8",
ignore=["re:.*lm_head", "re:multi_modal_projector.*", "re:vision_model.*"],
dampening_frac=100.0,
),
]

Expand Down
6 changes: 3 additions & 3 deletions examples/multimodal_vision/pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# Load model.
model_id = "mgoin/pixtral-12b"
model = TracableLlavaForConditionalGeneration.from_pretrained(
model_id, device_map="balanced", torch_dtype="auto"
model_id, device_map="auto", torch_dtype="auto"
)
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)

Expand Down Expand Up @@ -57,8 +57,8 @@ def data_collator(batch):
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
trust_remote_code_model=True,
output_dir=save_path,
# data_collator=data_collator,
data_collator=DataCollator(),
data_collator=data_collator,
# data_collator=DataCollator(),
)

model.save_pretrained(save_path)
Expand Down
10 changes: 3 additions & 7 deletions src/llmcompressor/modifiers/quantization/gptq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,7 @@
from llmcompressor.pipelines.sequential import run_pipeline as run_sequential
from llmcompressor.transformers import tracing
from llmcompressor.utils.metric_logging import CompressionLogger
from llmcompressor.utils.pytorch.module import (
get_layers,
get_no_split_params,
qat_active,
)
from llmcompressor.utils.pytorch.module import get_no_split_params, qat_active

__all__ = ["GPTQModifier"]

Expand Down Expand Up @@ -213,8 +209,8 @@ def on_initialize(self, state: "State", **kwargs) -> bool:
# infer sequential targets
if self.sequential_targets is None:
self.sequential_targets = get_no_split_params(state.model)
elif isinstance(self.sequential_targets, str):
self.sequential_targets = get_layers(self.sequential_targets, self.model)
if isinstance(self.sequential_targets, str):
self.sequential_targets = [self.sequential_targets]

# infer update size
if self._update_size is None:
Expand Down
52 changes: 25 additions & 27 deletions src/llmcompressor/pipelines/sequential/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@
from compressed_tensors.quantization import find_name_or_class_matches
from torch.fx import Graph, GraphModule, Node
from torch.nn import Module
from transformers import PreTrainedModel
from transformers.utils.fx import HFTracer

from llmcompressor.modifiers.utils.hooks import HooksMixin
from llmcompressor.utils.helpers import calibration_forward_context

__all__ = ["trace_subgraphs", "Subgraph"]


@dataclass
class Subgraph:
Expand All @@ -25,34 +28,18 @@ def compile_forward(self):
return code.globals.get("forward")


__all__ = ["infer_sequential_targets", "trace_subgraphs"]


def infer_sequential_targets(
model: Module, sequential_targets: List[str], ignore: List[str]
) -> Set[Module]:
"""
Future: infer from recipe
List of modules which are guaranteed to be split into different partitions and
whose inner operations will not be traced
"""
targets_names = sequential_targets + ignore

sequential_targets = set(
module
for name, module in model.named_modules()
if find_name_or_class_matches(name, module, targets_names)
)

return sequential_targets


def trace_subgraphs(
model: Module, sample_input: Dict[str, Any], targets: Set[Module]
model: PreTrainedModel,
sample_input: Dict[str, Any],
sequential_targets: List[str],
ignore: List[str],
) -> List[Subgraph]:
# find modules
sequential_targets = match_modules(model, sequential_targets)
ignore = match_modules(model, ignore)

# initialize arguments
tracer = get_tracer(model, targets)
tracer = get_tracer(model, sequential_targets, ignore)
concrete_args = populate_concrete_args(model, sample_input)

# trace
Expand All @@ -78,14 +65,16 @@ def trace_subgraphs(
graph.device = model.device

# perform subgraph partition
partitions = topological_partition(graph, targets)
partitions = topological_partition(graph, sequential_targets)
subgraphs = partition_graph(model, partitions)
trace_consumed_names(subgraphs)

return subgraphs


def get_tracer(model: Module, sequential_targets: List[Module]) -> HFTracer:
def get_tracer(
model: Module, sequential_targets: Set[Module], ignore: Set[Module]
) -> HFTracer:
offloaded_modules = set(
module for module in model.modules() if has_offloaded_params(module)
)
Expand All @@ -96,6 +85,7 @@ def is_leaf_module(self, module: Module, module_qualified_name: str) -> bool:
return (
module in sequential_targets
or module in offloaded_modules
or module in ignore
or super().is_leaf_module(module, module_qualified_name)
)

Expand Down Expand Up @@ -261,3 +251,11 @@ def check_assumption(graph: Graph) -> bool:
return False

return True


def match_modules(model: Module, target_names: List[str]) -> Set[Module]:
return set(
module
for name, module in model.named_modules()
if find_name_or_class_matches(name, module, target_names)
)
8 changes: 2 additions & 6 deletions src/llmcompressor/pipelines/sequential/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,7 @@

from llmcompressor.modifiers.utils.hooks import HooksMixin
from llmcompressor.pipelines.sequential.cache import IntermediatesCache
from llmcompressor.pipelines.sequential.helpers import (
infer_sequential_targets,
trace_subgraphs,
)
from llmcompressor.pipelines.sequential.helpers import trace_subgraphs
from llmcompressor.utils.helpers import calibration_forward_context

__all__ = ["run_pipeline"]
Expand Down Expand Up @@ -42,8 +39,7 @@ def run_pipeline(
"""
# trace subgraphs
sample_input = next(iter(dataloader))
targets = infer_sequential_targets(model, sequential_targets, ignore)
subgraphs = trace_subgraphs(model, sample_input, targets)
subgraphs = trace_subgraphs(model, sample_input, sequential_targets, ignore)

# FUTURE: apply recipe to model
# initialize(recipe, model)
Expand Down

0 comments on commit b75fe15

Please sign in to comment.