From 92fbddc1ab5b848085809dd794541a867cf07736 Mon Sep 17 00:00:00 2001 From: George Ohashi Date: Mon, 9 Dec 2024 12:46:46 -0500 Subject: [PATCH 1/3] add quantization then finetune -- run_compressed=False --- .../quantization_w8a8_fp8/llama3_example.py | 26 ++++++++++++++----- .../pytorch/utils/sparsification.py | 19 ++++++++------ 2 files changed, 30 insertions(+), 15 deletions(-) diff --git a/examples/quantization_w8a8_fp8/llama3_example.py b/examples/quantization_w8a8_fp8/llama3_example.py index 6dc870b32..a91c35974 100644 --- a/examples/quantization_w8a8_fp8/llama3_example.py +++ b/examples/quantization_w8a8_fp8/llama3_example.py @@ -3,7 +3,8 @@ from llmcompressor.modifiers.quantization import QuantizationModifier from llmcompressor.transformers import oneshot -MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" +# MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" +MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # Load model. model = AutoModelForCausalLM.from_pretrained( @@ -22,14 +23,25 @@ # Apply quantization. oneshot(model=model, recipe=recipe) -# Confirm generations of the quantized model look sane. -print("========== SAMPLE GENERATION ==============") -input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") -output = model.generate(input_ids, max_new_tokens=20) -print(tokenizer.decode(output[0])) -print("==========================================") +# # Confirm generations of the quantized model look sane. +# print("========== SAMPLE GENERATION ==============") +# input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") +# output = model.generate(input_ids, max_new_tokens=20) +# print(tokenizer.decode(output[0])) +# print("==========================================") # Save to disk in compressed-tensors format. SAVE_DIR = MODEL_ID.split("/")[1] + "-FP8-Dynamic" model.save_pretrained(SAVE_DIR) tokenizer.save_pretrained(SAVE_DIR) + +from transformers import AutoModelForCausalLM +from transformers.utils.quantization_config import CompressedTensorsConfig + +quantization_config = CompressedTensorsConfig(run_compressed=False) + +breakpoint() + +model = AutoModelForCausalLM.from_pretrained( + SAVE_DIR, quantization_config=quantization_config, device_map="auto" +) diff --git a/src/llmcompressor/pytorch/utils/sparsification.py b/src/llmcompressor/pytorch/utils/sparsification.py index fa9cebfd1..8abc7fb5a 100644 --- a/src/llmcompressor/pytorch/utils/sparsification.py +++ b/src/llmcompressor/pytorch/utils/sparsification.py @@ -105,15 +105,18 @@ def params_quantized(self) -> int: """ :return: number of parameters across quantized layers """ - return sum( - torch.numel(self.trainable_params[f"{name}.weight"]) - + ( - torch.numel(self.trainable_params[f"{name}.bias"]) - if hasattr(layer, "bias") and layer.bias is not None - else 0 + num_params = 0 + for name, layer in get_quantized_layers(self.module): + num_param = torch.numel( + self.trainable_params.get(f"{name}.weight", torch.tensor([])) ) - for (name, layer) in get_quantized_layers(self.module) - ) + if num_param is None: + logger.warning(f"{name} is not recognized in trainable_params") + continue + if hasattr(layer, "bias") and layer.bias is not None: + num_params += layer.bias + + return num_params @property def params_quantized_percent(self) -> float: From 299eed36a316ff8fd65174531de18bae632290f0 Mon Sep 17 00:00:00 2001 From: George Ohashi Date: Mon, 9 Dec 2024 12:48:52 -0500 Subject: [PATCH 2/3] add test --- .../quantization_w8a8_fp8/llama3_example.py | 26 ++----- .../finetune/test_oneshot_then_finetune.py | 72 +++++++++++++++++++ 2 files changed, 79 insertions(+), 19 deletions(-) diff --git a/examples/quantization_w8a8_fp8/llama3_example.py b/examples/quantization_w8a8_fp8/llama3_example.py index a91c35974..6dc870b32 100644 --- a/examples/quantization_w8a8_fp8/llama3_example.py +++ b/examples/quantization_w8a8_fp8/llama3_example.py @@ -3,8 +3,7 @@ from llmcompressor.modifiers.quantization import QuantizationModifier from llmcompressor.transformers import oneshot -# MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" -MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" +MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" # Load model. model = AutoModelForCausalLM.from_pretrained( @@ -23,25 +22,14 @@ # Apply quantization. oneshot(model=model, recipe=recipe) -# # Confirm generations of the quantized model look sane. -# print("========== SAMPLE GENERATION ==============") -# input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") -# output = model.generate(input_ids, max_new_tokens=20) -# print(tokenizer.decode(output[0])) -# print("==========================================") +# Confirm generations of the quantized model look sane. +print("========== SAMPLE GENERATION ==============") +input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") +output = model.generate(input_ids, max_new_tokens=20) +print(tokenizer.decode(output[0])) +print("==========================================") # Save to disk in compressed-tensors format. SAVE_DIR = MODEL_ID.split("/")[1] + "-FP8-Dynamic" model.save_pretrained(SAVE_DIR) tokenizer.save_pretrained(SAVE_DIR) - -from transformers import AutoModelForCausalLM -from transformers.utils.quantization_config import CompressedTensorsConfig - -quantization_config = CompressedTensorsConfig(run_compressed=False) - -breakpoint() - -model = AutoModelForCausalLM.from_pretrained( - SAVE_DIR, quantization_config=quantization_config, device_map="auto" -) diff --git a/tests/llmcompressor/transformers/finetune/test_oneshot_then_finetune.py b/tests/llmcompressor/transformers/finetune/test_oneshot_then_finetune.py index db5950188..82ca4609d 100644 --- a/tests/llmcompressor/transformers/finetune/test_oneshot_then_finetune.py +++ b/tests/llmcompressor/transformers/finetune/test_oneshot_then_finetune.py @@ -5,6 +5,7 @@ import pytest +from llmcompressor.modifiers.quantization import QuantizationModifier from tests.testing_utils import requires_torch @@ -91,5 +92,76 @@ def test_oneshot_then_finetune(self): resume_from_checkpoint=True, # use last checkpoint ) + def test_quantization_then_finetune(self): + from transformers import AutoModelForCausalLM + + from llmcompressor.core import create_session + from llmcompressor.transformers import oneshot, train + + recipe = QuantizationModifier( + targets="Linear", scheme="FP8_DYNAMIC", ignore=["lm_head"] + ) + + model = AutoModelForCausalLM.from_pretrained( + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + device_map="auto", + ) + dataset = "open_platypus" + concatenate_data = False + num_calibration_samples = 64 + output_dir = self.output / "oneshot_out" + splits = {"calibration": "train[:10%]"} + + with create_session(): + oneshot( + model=model, + dataset=dataset, + output_dir=output_dir, + num_calibration_samples=num_calibration_samples, + recipe=recipe, + concatenate_data=concatenate_data, + splits=splits, + ) + + from transformers.utils.quantization_config import CompressedTensorsConfig + + quantization_config = CompressedTensorsConfig(run_compressed=False) + model = AutoModelForCausalLM.from_pretrained( + output_dir, + device_map="auto", + quantization_config=quantization_config, + ) + dataset = "open_platypus" + concatenate_data = False + output_dir = self.output / "finetune_out" + splits = {"calibration": "train[:10%]", "train": "train[:10%]"} + + with create_session(): + train( + model=model, + dataset=dataset, + output_dir=output_dir, + num_calibration_samples=num_calibration_samples, + recipe=recipe, + concatenate_data=concatenate_data, + splits=splits, + ) + + # test reloading checkpoint and final model + model = AutoModelForCausalLM.from_pretrained( + output_dir, device_map="auto", quantization_config=quantization_config + ) + with create_session(): + train( + model=model, + dataset=dataset, + output_dir=output_dir, + num_calibration_samples=num_calibration_samples, + recipe=recipe, + concatenate_data=concatenate_data, + splits=splits, + resume_from_checkpoint=True, # use last checkpoint + ) + def tearDown(self): shutil.rmtree(self.output) From 9ea94ed33b76d1cf565b5652a59c6c579b68736f Mon Sep 17 00:00:00 2001 From: George Ohashi Date: Mon, 9 Dec 2024 13:10:01 -0500 Subject: [PATCH 3/3] clean up --- .../finetune/test_oneshot_then_finetune.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/tests/llmcompressor/transformers/finetune/test_oneshot_then_finetune.py b/tests/llmcompressor/transformers/finetune/test_oneshot_then_finetune.py index 1c1371bf3..73bf7f66b 100644 --- a/tests/llmcompressor/transformers/finetune/test_oneshot_then_finetune.py +++ b/tests/llmcompressor/transformers/finetune/test_oneshot_then_finetune.py @@ -4,7 +4,12 @@ from pathlib import Path import pytest +from transformers import AutoModelForCausalLM + +from llmcompressor.core import create_session from llmcompressor.modifiers.quantization import QuantizationModifier +from llmcompressor.transformers import oneshot, train + @pytest.mark.unit @pytest.mark.skipif( @@ -18,11 +23,6 @@ def setUp(self): self.output = Path("./finetune_output") def test_oneshot_then_finetune(self): - from transformers import AutoModelForCausalLM - - from llmcompressor.core import create_session - from llmcompressor.transformers import oneshot, train - recipe_str = "tests/llmcompressor/transformers/obcq/recipes/test_tiny2.yaml" model = AutoModelForCausalLM.from_pretrained( "Xenova/llama2.c-stories15M", device_map="auto" @@ -89,11 +89,6 @@ def test_oneshot_then_finetune(self): ) def test_quantization_then_finetune(self): - from transformers import AutoModelForCausalLM - - from llmcompressor.core import create_session - from llmcompressor.transformers import oneshot, train - recipe = QuantizationModifier( targets="Linear", scheme="FP8_DYNAMIC", ignore=["lm_head"] )