diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index eb5f1337d1..205ba91290 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -13,38 +13,29 @@ from pathlib import Path import torch -from torch.ao.quantization.quantize_pt2e import ( - convert_pt2e, - prepare_pt2e, -) +from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e from torch.ao.quantization.quantizer.xnnpack_quantizer import ( - XNNPACKQuantizer, get_symmetric_quantization_config, + XNNPACKQuantizer, ) from torch.testing._internal import common_utils from torch.testing._internal.common_utils import TestCase from torchao import quantize_ -from torchao._models.llama.model import Transformer, prepare_inputs_for_model +from torchao._models.llama.model import prepare_inputs_for_model, Transformer from torchao._models.llama.tokenizer import get_tokenizer -from torchao.dtypes import ( - AffineQuantizedTensor, -) -from torchao.quantization import ( - LinearActivationQuantizedTensor, -) +from torchao.dtypes import AffineQuantizedTensor +from torchao.quantization import LinearActivationQuantizedTensor from torchao.quantization.quant_api import ( - Quantizer, - TwoStepQuantizer, _replace_with_custom_fn_if_matches_filter, int4_weight_only, int8_dynamic_activation_int4_weight, int8_dynamic_activation_int8_weight, int8_weight_only, + Quantizer, + TwoStepQuantizer, ) -from torchao.quantization.quant_primitives import ( - MappingType, -) +from torchao.quantization.quant_primitives import MappingType from torchao.quantization.subclass import ( Int4WeightOnlyQuantizedLinearWeight, Int8WeightOnlyQuantizedLinearWeight, @@ -59,7 +50,7 @@ def dynamic_quant(model, example_inputs): - m = torch.export.export(model, example_inputs).module() + m = torch.export.export(model, example_inputs, strict=True).module() quantizer = XNNPACKQuantizer().set_global( get_symmetric_quantization_config(is_dynamic=True) ) @@ -69,7 +60,7 @@ def dynamic_quant(model, example_inputs): def capture_and_prepare(model, example_inputs): - m = torch.export.export(model, example_inputs) + m = torch.export.export(model, example_inputs, strict=True) quantizer = XNNPACKQuantizer().set_global( get_symmetric_quantization_config(is_dynamic=True) ) @@ -666,7 +657,7 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self): m_unwrapped = unwrap_tensor_subclass(m) - m = torch.export.export(m_unwrapped, example_inputs).module() + m = torch.export.export(m_unwrapped, example_inputs, strict=True).module() exported_model_res = m(*example_inputs) self.assertTrue(torch.equal(exported_model_res, ref))