Skip to content

Commit

Permalink
Fix CI by tweaking torchao tests (huggingface#34832)
Browse files Browse the repository at this point in the history
  • Loading branch information
SunMarc authored Nov 20, 2024
1 parent bf42c3b commit 3cb8676
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 7 deletions.
9 changes: 7 additions & 2 deletions src/transformers/utils/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1264,8 +1264,13 @@ def post_init(self):
r"""
Safety checker that arguments are correct - also replaces some NoneType arguments with their default values.
"""
if not version.parse(importlib.metadata.version("torchao")) >= version.parse("0.4.0"):
raise ValueError("Requires torchao 0.4.0 version and above")
if is_torchao_available():
if not version.parse(importlib.metadata.version("torchao")) >= version.parse("0.4.0"):
raise ValueError("Requires torchao 0.4.0 version and above")
else:
raise ValueError(
"TorchAoConfig requires torchao to be installed, please install with `pip install torchao`"
)

_STR_TO_METHOD = self._get_torchao_quant_type_to_method()
if self.quant_type not in _STR_TO_METHOD.keys():
Expand Down
11 changes: 6 additions & 5 deletions tests/quantization/torchao_integration/test_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,12 +246,13 @@ class TorchAoSerializationTest(unittest.TestCase):
# TODO: investigate why we don't have the same output as the original model for this test
SERIALIZED_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
quant_config = TorchAoConfig("int4_weight_only", group_size=32)
quant_scheme, quant_scheme_kwargs = "int4_weight_only", {"group_size": 32}
device = "cuda:0"

# called only once for all test in this class
@classmethod
def setUpClass(cls):
cls.quant_config = TorchAoConfig(cls.quant_scheme, **cls.quant_scheme_kwargs)
cls.quantized_model = AutoModelForCausalLM.from_pretrained(
cls.model_name,
torch_dtype=torch.bfloat16,
Expand Down Expand Up @@ -290,21 +291,21 @@ def test_serialization_expected_output(self):


class TorchAoSerializationW8A8Test(TorchAoSerializationTest):
quant_config = TorchAoConfig("int8_dynamic_activation_int8_weight")
quant_scheme, quant_scheme_kwargs = "int8_dynamic_activation_int8_weight", {}
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
device = "cuda:0"


class TorchAoSerializationW8Test(TorchAoSerializationTest):
quant_config = TorchAoConfig("int8_weight_only")
quant_scheme, quant_scheme_kwargs = "int8_weight_only", {}
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
device = "cuda:0"


class TorchAoSerializationW8A8CPUTest(TorchAoSerializationTest):
quant_config = TorchAoConfig("int8_dynamic_activation_int8_weight")
quant_scheme, quant_scheme_kwargs = "int8_dynamic_activation_int8_weight", {}
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
device = "cpu"
Expand All @@ -318,7 +319,7 @@ def test_serialization_expected_output_cuda(self):


class TorchAoSerializationW8CPUTest(TorchAoSerializationTest):
quant_config = TorchAoConfig("int8_weight_only")
quant_scheme, quant_scheme_kwargs = "int8_weight_only", {}
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
device = "cpu"
Expand Down

0 comments on commit 3cb8676

Please sign in to comment.