Skip to content

Commit

Permalink
Torchao weights only + prequantized compability (#34355)
Browse files Browse the repository at this point in the history
* weights only compability

* better tests from code review

* ping torch version

* add weights_only check
  • Loading branch information
SunMarc authored Nov 20, 2024
1 parent f297af5 commit 67890de
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3602,7 +3602,11 @@ def from_pretrained(

if hf_quantizer is not None:
hf_quantizer.validate_environment(
torch_dtype=torch_dtype, from_tf=from_tf, from_flax=from_flax, device_map=device_map
torch_dtype=torch_dtype,
from_tf=from_tf,
from_flax=from_flax,
device_map=device_map,
weights_only=weights_only,
)
torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
device_map = hf_quantizer.update_device_map(device_map)
Expand Down
19 changes: 19 additions & 0 deletions src/transformers/quantizers/quantizer_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,15 @@ def validate_environment(self, *args, **kwargs):
)
else:
self.offload = True
if self.pre_quantized:
weights_only = kwargs.get("weights_only", None)
if weights_only:
torch_version = version.parse(importlib.metadata.version("torch"))
if torch_version < version.parse("2.5.0"):
raise RuntimeError(
f"In order to use torchao pre-quantized model, you need to have torch>=2.5.0. However, the current version is {torch_version}."
f" You can also set with `weights_only=False` in `from_pretrained` if you don't want to update torch"
)

def update_torch_dtype(self, torch_dtype):
if self.quantization_config.quant_type == "int4_weight_only":
Expand All @@ -103,6 +112,10 @@ def update_torch_dtype(self, torch_dtype):
"Setting torch_dtype to torch.bfloat16 for int4_weight_only quantization since only bfloat16 is supported right now. Please set torch_dtype=torch.bfloat16 to remove this warning."
)
torch_dtype = torch.bfloat16
if self.quantization_config.quant_type == "int8_dynamic_activation_int8_weight":
if torch_dtype is None:
# we need to set the torch_dtype, otherwise we have dtype mismatch when performing the quantized linear op
torch_dtype = torch.float32
return torch_dtype

def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
Expand Down Expand Up @@ -198,6 +211,12 @@ def is_serializable(self, safe_serialization=None):
)
if not _is_torchao_serializable:
logger.warning("torchao quantized model is only serializable after huggingface_hub >= 0.25.0 ")
if self.offload and self.quantization_config.modules_to_not_convert is None:
logger.warning(
"The model contains offloaded modules and these modules are not quantized. We don't recommend saving the model as we won't be able to reload them."
"If you want to specify modules to not quantize, please specify modules_to_not_convert in the quantization_config."
)
return False
return _is_torchao_serializable

@property
Expand Down
95 changes: 95 additions & 0 deletions tests/quantization/torchao_integration/test_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

import gc
import tempfile
import unittest

from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
Expand Down Expand Up @@ -236,5 +237,99 @@ def test_int8_dynamic_activation_int8_weight_quant(self):
self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), EXPECTED_OUTPUT)


@require_torch_gpu
@require_torchao
class TorchAoSerializationTest(unittest.TestCase):
input_text = "What are we having for dinner?"
max_new_tokens = 10
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n- 1. What is the temperature outside"
# TODO: investigate why we don't have the same output as the original model for this test
SERIALIZED_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
quant_config = TorchAoConfig("int4_weight_only", group_size=32)
device = "cuda:0"

# called only once for all test in this class
@classmethod
def setUpClass(cls):
cls.quantized_model = AutoModelForCausalLM.from_pretrained(
cls.model_name,
torch_dtype=torch.bfloat16,
device_map=cls.device,
quantization_config=cls.quant_config,
)
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name)

def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
gc.collect()

def test_original_model_expected_output(self):
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device)
output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)

self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.ORIGINAL_EXPECTED_OUTPUT)

def check_serialization_expected_output(self, device, expected_output):
"""
Test if we can serialize and load/infer the model again on the same device
"""
with tempfile.TemporaryDirectory() as tmpdirname:
self.quantized_model.save_pretrained(tmpdirname, safe_serialization=False)
loaded_quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name, torch_dtype=torch.bfloat16, device_map=self.device
)
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device)

output = loaded_quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), expected_output)

def test_serialization_expected_output(self):
self.check_serialization_expected_output(self.device, self.SERIALIZED_EXPECTED_OUTPUT)


class TorchAoSerializationW8A8Test(TorchAoSerializationTest):
quant_config = TorchAoConfig("int8_dynamic_activation_int8_weight")
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
device = "cuda:0"


class TorchAoSerializationW8Test(TorchAoSerializationTest):
quant_config = TorchAoConfig("int8_weight_only")
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
device = "cuda:0"


class TorchAoSerializationW8A8CPUTest(TorchAoSerializationTest):
quant_config = TorchAoConfig("int8_dynamic_activation_int8_weight")
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
device = "cpu"

def test_serialization_expected_output_cuda(self):
"""
Test if we can serialize on device (cpu) and load/infer the model on cuda
"""
new_device = "cuda:0"
self.check_serialization_expected_output(new_device, self.SERIALIZED_EXPECTED_OUTPUT)


class TorchAoSerializationW8CPUTest(TorchAoSerializationTest):
quant_config = TorchAoConfig("int8_weight_only")
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
device = "cpu"

def test_serialization_expected_output_cuda(self):
"""
Test if we can serialize on device (cpu) and load/infer the model on cuda
"""
new_device = "cuda:0"
self.check_serialization_expected_output(new_device, self.SERIALIZED_EXPECTED_OUTPUT)


if __name__ == "__main__":
unittest.main()

0 comments on commit 67890de

Please sign in to comment.