From b75fe15f18089d455709c081bb278ea1070ff4ee Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 12 Dec 2024 21:51:47 -0500 Subject: [PATCH] wrap ignore but do not treat as sequential target --- examples/multimodal_vision/mllama.py | 1 - examples/multimodal_vision/pixtral.py | 6 +-- .../modifiers/quantization/gptq/base.py | 10 ++-- .../pipelines/sequential/helpers.py | 52 +++++++++---------- .../pipelines/sequential/pipeline.py | 8 +-- 5 files changed, 33 insertions(+), 44 deletions(-) diff --git a/examples/multimodal_vision/mllama.py b/examples/multimodal_vision/mllama.py index 87dc21d5..44df76e7 100644 --- a/examples/multimodal_vision/mllama.py +++ b/examples/multimodal_vision/mllama.py @@ -42,7 +42,6 @@ def data_collator(batch): targets="Linear", scheme="W8A8", ignore=["re:.*lm_head", "re:multi_modal_projector.*", "re:vision_model.*"], - dampening_frac=100.0, ), ] diff --git a/examples/multimodal_vision/pixtral.py b/examples/multimodal_vision/pixtral.py index 600a146c..cf3ae5cb 100644 --- a/examples/multimodal_vision/pixtral.py +++ b/examples/multimodal_vision/pixtral.py @@ -11,7 +11,7 @@ # Load model. model_id = "mgoin/pixtral-12b" model = TracableLlavaForConditionalGeneration.from_pretrained( - model_id, device_map="balanced", torch_dtype="auto" + model_id, device_map="auto", torch_dtype="auto" ) processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) @@ -57,8 +57,8 @@ def data_collator(batch): num_calibration_samples=NUM_CALIBRATION_SAMPLES, trust_remote_code_model=True, output_dir=save_path, - # data_collator=data_collator, - data_collator=DataCollator(), + data_collator=data_collator, + # data_collator=DataCollator(), ) model.save_pretrained(save_path) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 1cfef2fe..63d657a5 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -29,11 +29,7 @@ from llmcompressor.pipelines.sequential import run_pipeline as run_sequential from llmcompressor.transformers import tracing from llmcompressor.utils.metric_logging import CompressionLogger -from llmcompressor.utils.pytorch.module import ( - get_layers, - get_no_split_params, - qat_active, -) +from llmcompressor.utils.pytorch.module import get_no_split_params, qat_active __all__ = ["GPTQModifier"] @@ -213,8 +209,8 @@ def on_initialize(self, state: "State", **kwargs) -> bool: # infer sequential targets if self.sequential_targets is None: self.sequential_targets = get_no_split_params(state.model) - elif isinstance(self.sequential_targets, str): - self.sequential_targets = get_layers(self.sequential_targets, self.model) + if isinstance(self.sequential_targets, str): + self.sequential_targets = [self.sequential_targets] # infer update size if self._update_size is None: diff --git a/src/llmcompressor/pipelines/sequential/helpers.py b/src/llmcompressor/pipelines/sequential/helpers.py index b0136a53..29e45480 100644 --- a/src/llmcompressor/pipelines/sequential/helpers.py +++ b/src/llmcompressor/pipelines/sequential/helpers.py @@ -7,11 +7,14 @@ from compressed_tensors.quantization import find_name_or_class_matches from torch.fx import Graph, GraphModule, Node from torch.nn import Module +from transformers import PreTrainedModel from transformers.utils.fx import HFTracer from llmcompressor.modifiers.utils.hooks import HooksMixin from llmcompressor.utils.helpers import calibration_forward_context +__all__ = ["trace_subgraphs", "Subgraph"] + @dataclass class Subgraph: @@ -25,34 +28,18 @@ def compile_forward(self): return code.globals.get("forward") -__all__ = ["infer_sequential_targets", "trace_subgraphs"] - - -def infer_sequential_targets( - model: Module, sequential_targets: List[str], ignore: List[str] -) -> Set[Module]: - """ - Future: infer from recipe - - List of modules which are guaranteed to be split into different partitions and - whose inner operations will not be traced - """ - targets_names = sequential_targets + ignore - - sequential_targets = set( - module - for name, module in model.named_modules() - if find_name_or_class_matches(name, module, targets_names) - ) - - return sequential_targets - - def trace_subgraphs( - model: Module, sample_input: Dict[str, Any], targets: Set[Module] + model: PreTrainedModel, + sample_input: Dict[str, Any], + sequential_targets: List[str], + ignore: List[str], ) -> List[Subgraph]: + # find modules + sequential_targets = match_modules(model, sequential_targets) + ignore = match_modules(model, ignore) + # initialize arguments - tracer = get_tracer(model, targets) + tracer = get_tracer(model, sequential_targets, ignore) concrete_args = populate_concrete_args(model, sample_input) # trace @@ -78,14 +65,16 @@ def trace_subgraphs( graph.device = model.device # perform subgraph partition - partitions = topological_partition(graph, targets) + partitions = topological_partition(graph, sequential_targets) subgraphs = partition_graph(model, partitions) trace_consumed_names(subgraphs) return subgraphs -def get_tracer(model: Module, sequential_targets: List[Module]) -> HFTracer: +def get_tracer( + model: Module, sequential_targets: Set[Module], ignore: Set[Module] +) -> HFTracer: offloaded_modules = set( module for module in model.modules() if has_offloaded_params(module) ) @@ -96,6 +85,7 @@ def is_leaf_module(self, module: Module, module_qualified_name: str) -> bool: return ( module in sequential_targets or module in offloaded_modules + or module in ignore or super().is_leaf_module(module, module_qualified_name) ) @@ -261,3 +251,11 @@ def check_assumption(graph: Graph) -> bool: return False return True + + +def match_modules(model: Module, target_names: List[str]) -> Set[Module]: + return set( + module + for name, module in model.named_modules() + if find_name_or_class_matches(name, module, target_names) + ) diff --git a/src/llmcompressor/pipelines/sequential/pipeline.py b/src/llmcompressor/pipelines/sequential/pipeline.py index 10a40cfc..45d10e4f 100644 --- a/src/llmcompressor/pipelines/sequential/pipeline.py +++ b/src/llmcompressor/pipelines/sequential/pipeline.py @@ -8,10 +8,7 @@ from llmcompressor.modifiers.utils.hooks import HooksMixin from llmcompressor.pipelines.sequential.cache import IntermediatesCache -from llmcompressor.pipelines.sequential.helpers import ( - infer_sequential_targets, - trace_subgraphs, -) +from llmcompressor.pipelines.sequential.helpers import trace_subgraphs from llmcompressor.utils.helpers import calibration_forward_context __all__ = ["run_pipeline"] @@ -42,8 +39,7 @@ def run_pipeline( """ # trace subgraphs sample_input = next(iter(dataloader)) - targets = infer_sequential_targets(model, sequential_targets, ignore) - subgraphs = trace_subgraphs(model, sample_input, targets) + subgraphs = trace_subgraphs(model, sample_input, sequential_targets, ignore) # FUTURE: apply recipe to model # initialize(recipe, model)