Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Quantization then finetune tests #964

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions src/llmcompressor/pytorch/utils/sparsification.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,18 @@ def params_quantized(self) -> int:
"""
:return: number of parameters across quantized layers
"""
return sum(
torch.numel(self.trainable_params[f"{name}.weight"])
+ (
torch.numel(self.trainable_params[f"{name}.bias"])
if hasattr(layer, "bias") and layer.bias is not None
else 0
num_params = 0
for name, layer in get_quantized_layers(self.module):
num_param = torch.numel(
self.trainable_params.get(f"{name}.weight", torch.tensor([]))
)
for (name, layer) in get_quantized_layers(self.module)
)
if num_param is None:
logger.warning(f"{name} is not recognized in trainable_params")
continue
if hasattr(layer, "bias") and layer.bias is not None:
num_params += layer.bias

return num_params

@property
def params_quantized_percent(self) -> float:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
from pathlib import Path

import pytest
from transformers import AutoModelForCausalLM

from llmcompressor.core import create_session
from llmcompressor.modifiers.quantization import QuantizationModifier
from llmcompressor.transformers import oneshot, train


@pytest.mark.unit
Expand All @@ -18,11 +23,6 @@ def setUp(self):
self.output = Path("./finetune_output")

def test_oneshot_then_finetune(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_sparsification_then_finetune

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or move the recipe to an argument

parameterize or make functions, test_sparsification_then_finetune, test_quantization_then_finetune, test_oneshot_then_finetune, the latter is called by the first two

from transformers import AutoModelForCausalLM

from llmcompressor.core import create_session
from llmcompressor.transformers import oneshot, train

recipe_str = "tests/llmcompressor/transformers/obcq/recipes/test_tiny2.yaml"
model = AutoModelForCausalLM.from_pretrained(
"Xenova/llama2.c-stories15M", device_map="auto"
Expand Down Expand Up @@ -88,5 +88,71 @@ def test_oneshot_then_finetune(self):
resume_from_checkpoint=True, # use last checkpoint
)

def test_quantization_then_finetune(self):
recipe = QuantizationModifier(
targets="Linear", scheme="FP8_DYNAMIC", ignore=["lm_head"]
)

model = AutoModelForCausalLM.from_pretrained(
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
device_map="auto",
)
dataset = "open_platypus"
concatenate_data = False
num_calibration_samples = 64
output_dir = self.output / "oneshot_out"
splits = {"calibration": "train[:10%]"}

with create_session():
oneshot(
model=model,
dataset=dataset,
output_dir=output_dir,
num_calibration_samples=num_calibration_samples,
recipe=recipe,
concatenate_data=concatenate_data,
splits=splits,
)

from transformers.utils.quantization_config import CompressedTensorsConfig

quantization_config = CompressedTensorsConfig(run_compressed=False)
model = AutoModelForCausalLM.from_pretrained(
output_dir,
device_map="auto",
quantization_config=quantization_config,
)
dataset = "open_platypus"
concatenate_data = False
output_dir = self.output / "finetune_out"
splits = {"calibration": "train[:10%]", "train": "train[:10%]"}

with create_session():
train(
model=model,
dataset=dataset,
output_dir=output_dir,
num_calibration_samples=num_calibration_samples,
recipe=recipe,
concatenate_data=concatenate_data,
splits=splits,
)

# test reloading checkpoint and final model
model = AutoModelForCausalLM.from_pretrained(
output_dir, device_map="auto", quantization_config=quantization_config
)
with create_session():
train(
model=model,
dataset=dataset,
output_dir=output_dir,
num_calibration_samples=num_calibration_samples,
recipe=recipe,
concatenate_data=concatenate_data,
splits=splits,
resume_from_checkpoint=True, # use last checkpoint
)

def tearDown(self):
shutil.rmtree(self.output)
Loading