From 19e4f971a81ecb8ad40dc2d7eb13b1e3ca61daee Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 12 Dec 2024 23:39:49 -0500 Subject: [PATCH] revert dampening frac --- examples/multimodal_vision/mllama.py | 3 ++- examples/multimodal_vision/pixtral.py | 4 ++-- src/llmcompressor/transformers/tracing/__init__.py | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/multimodal_vision/mllama.py b/examples/multimodal_vision/mllama.py index 2fac52e6..3f61532e 100644 --- a/examples/multimodal_vision/mllama.py +++ b/examples/multimodal_vision/mllama.py @@ -4,7 +4,8 @@ from transformers import AutoProcessor from llmcompressor.modifiers.quantization import GPTQModifier -from llmcompressor.pytorch.data_collator import DataCollator + +# from llmcompressor.pytorch.data_collator import DataCollator from llmcompressor.transformers import oneshot from llmcompressor.transformers.tracing import TracableMllamaForConditionalGeneration diff --git a/examples/multimodal_vision/pixtral.py b/examples/multimodal_vision/pixtral.py index a554f66b..f311582a 100644 --- a/examples/multimodal_vision/pixtral.py +++ b/examples/multimodal_vision/pixtral.py @@ -4,7 +4,8 @@ from transformers import AutoProcessor from llmcompressor.modifiers.quantization import GPTQModifier -from llmcompressor.pytorch.data_collator import DataCollator + +# from llmcompressor.pytorch.data_collator import DataCollator from llmcompressor.transformers import oneshot from llmcompressor.transformers.tracing import TracableLlavaForConditionalGeneration @@ -40,7 +41,6 @@ def data_collator(batch): scheme="W8A8", ignore=["re:.*lm_head", "re:vision_tower.*", "re:multi_modal_projector.*"], sequential_targets=["MistralDecoderLayer"], - dampening_frac=100.0, ), ] diff --git a/src/llmcompressor/transformers/tracing/__init__.py b/src/llmcompressor/transformers/tracing/__init__.py index bc2356a4..88be0fe8 100644 --- a/src/llmcompressor/transformers/tracing/__init__.py +++ b/src/llmcompressor/transformers/tracing/__init__.py @@ -1,6 +1,6 @@ from .llava import TracableLlavaForConditionalGeneration -from .mllama import TracableMllamaForConditionalGeneration from .mistral import TracableMistralForCausalLM +from .mllama import TracableMllamaForConditionalGeneration __all__ = [ "TracableLlavaForConditionalGeneration",