From f3d1c3bca22668f80465671a23f48aece8fab544 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Thu, 3 Oct 2024 10:56:20 -0700 Subject: [PATCH] fix --- test/quantization/test_quant_primitives.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index 6012a00c4..4e0663eb8 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -30,6 +30,7 @@ TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_5, + TORCH_VERSION_AT_LEAST_2_6, is_fbcode, ) @@ -207,7 +208,10 @@ def test_choose_qparams_token_asym(self): mapping_type = MappingType.ASYMMETRIC dtype = torch.int8 block_size = (1, 10) - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float64, zero_point_dtype=torch.int64) + if TORCH_VERSION_AT_LEAST_2_6: + scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float64, zero_point_dtype=torch.int64) + else: + scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps) scale_ref, zp_ref = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric(input, dtype) scale_ref = scale_ref.squeeze()