From 3f713509d2250947cbc57f83369b0ecb0b987997 Mon Sep 17 00:00:00 2001 From: "Yanan Cao (PyTorch)" Date: Wed, 18 Dec 2024 17:02:55 -0800 Subject: [PATCH] pytorch/ao/test/integration Reviewed By: avikchaudhuri Differential Revision: D67388002 --- test/integration/test_integration.py | 639 +++++++++++++++++---------- 1 file changed, 411 insertions(+), 228 deletions(-) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 273f60655..dffd52642 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -6,49 +6,52 @@ # mypy: ignore-errors import copy -import unittest import itertools +import logging +import os +import unittest import torch import torch.nn as nn -from torch._inductor.utils import run_and_get_code -from torch._dynamo import config import torchao +from parameterized import parameterized +from torch._dynamo import config +from torch._inductor.utils import run_and_get_code from torch.ao.quantization import MinMaxObserver, QConfigMapping - -from torchao.quantization.dynamic_quant import ( - DynamicallyPerAxisQuantizedLinear, -) -from torchao.dtypes import TensorCoreTiledLayout, Int4CPULayout -from torchao.quantization.quant_api import ( - int4_weight_only, - int8_weight_only, - int8_dynamic_activation_int8_weight, - int8_dynamic_activation_int4_weight, - quantize_, - _replace_with_custom_fn_if_matches_filter, +from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx +from torchao.dtypes import Int4CPULayout, TensorCoreTiledLayout +from torchao.dtypes.utils import is_device +from torchao.quantization import safe_int_mm +from torchao.quantization.autoquant import ( + AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight, + AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight, + AQFloat8WeightOnlyQuantizedLinearWeight, + AQInt8DynamicallyQuantizedLinearWeight, + AQInt8WeightOnlyQuantizedLinearWeight, + AQInt8WeightOnlyQuantizedLinearWeight2, + AQInt8WeightOnlyQuantizedLinearWeight3, + AutoQuantizableLinearWeight, ) + +from torchao.quantization.dynamic_quant import DynamicallyPerAxisQuantizedLinear + # APIs to be deprecated (used for torch 2.2.2 and 2.3) from torchao.quantization.quant_api import ( + _replace_with_custom_fn_if_matches_filter, + change_linear_weights_to_int4_woqtensors, change_linear_weights_to_int8_dqtensors, change_linear_weights_to_int8_woqtensors, - change_linear_weights_to_int4_woqtensors, -) -from torchao.quantization import ( - safe_int_mm, + int4_weight_only, + int8_dynamic_activation_int4_weight, + int8_dynamic_activation_int8_weight, + int8_weight_only, + quantize_, ) from torchao.quantization.quant_primitives import ( choose_qparams_affine, - quantize_affine, dequantize_affine, MappingType, -) -from torchao.quantization.utils import ( - dequantize_per_channel, - dequantize_per_tensor, - dynamically_quantize_per_channel, - quant_int8_dynamic_per_token_linear, - quantize_activation_per_token_absmax, + quantize_affine, ) from torchao.quantization.smoothquant import ( @@ -58,43 +61,32 @@ swap_linear_with_smooth_fq_linear, ) from torchao.quantization.subclass import ( + Int4WeightOnlyQuantizedLinearWeight, Int8DynamicallyQuantizedLinearWeight, Int8WeightOnlyQuantizedLinearWeight, - Int4WeightOnlyQuantizedLinearWeight ) from torchao.quantization.utils import ( _apply_logging_hook, + _fqn_to_op_to_shape_to_count, compute_error, compute_error as SQNR, - _fqn_to_op_to_shape_to_count, + dequantize_per_channel, + dequantize_per_tensor, + dynamically_quantize_per_channel, LoggingTensorMode, + quant_int8_dynamic_per_token_linear, + quantize_activation_per_token_absmax, ) -from torchao.quantization.autoquant import ( - AQInt8DynamicallyQuantizedLinearWeight, - AQInt8WeightOnlyQuantizedLinearWeight, - AQInt8WeightOnlyQuantizedLinearWeight2, - AQInt8WeightOnlyQuantizedLinearWeight3, - AutoQuantizableLinearWeight, - AQFloat8WeightOnlyQuantizedLinearWeight, - AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight, - AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight, -) -from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx -import os -from parameterized import parameterized -import itertools -import logging from torchao.utils import ( + benchmark_model, + is_fbcode, + is_sm_at_least_90, TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, unwrap_tensor_subclass, - is_fbcode, - benchmark_model, - is_sm_at_least_90, ) -from torchao.dtypes.utils import is_device try: import gemlite @@ -113,24 +105,23 @@ COMMON_DEVICE_DTYPE = list(itertools.product(COMMON_DEVICES, COMMON_DTYPES)).copy() + def _int8wo_api(mod): if TORCH_VERSION_AT_LEAST_2_4: quantize_(mod, int8_weight_only(), set_inductor_config=False) - if ( - not TORCH_VERSION_AT_LEAST_2_5 - or ( - not TORCH_VERSION_AT_LEAST_2_6 - and torch._inductor.config.freezing - ) + if not TORCH_VERSION_AT_LEAST_2_5 or ( + not TORCH_VERSION_AT_LEAST_2_6 and torch._inductor.config.freezing ): unwrap_tensor_subclass(mod) else: change_linear_weights_to_int8_woqtensors(mod) + def _int8wo_groupwise_api(mod): group_size = 32 quantize_(mod, int8_weight_only(group_size=group_size), set_inductor_config=False) + def _int8da_int8w_api(mod): if TORCH_VERSION_AT_LEAST_2_4: quantize_(mod, int8_dynamic_activation_int8_weight(), set_inductor_config=False) @@ -139,9 +130,15 @@ def _int8da_int8w_api(mod): else: change_linear_weights_to_int8_dqtensors(mod) + def _int4wo_api(mod): - if is_device(next(mod.parameters()).device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6: - quantize_(mod, int4_weight_only(layout=Int4CPULayout()), set_inductor_config=False) + if ( + is_device(next(mod.parameters()).device.type, "cpu") + and TORCH_VERSION_AT_LEAST_2_6 + ): + quantize_( + mod, int4_weight_only(layout=Int4CPULayout()), set_inductor_config=False + ) unwrap_tensor_subclass(mod) elif TORCH_VERSION_AT_LEAST_2_4: quantize_(mod, int4_weight_only(), set_inductor_config=False) @@ -150,6 +147,7 @@ def _int4wo_api(mod): else: change_linear_weights_to_int4_woqtensors(mod) + def _int8da_int4w_api(mod): quantize_(mod, int8_dynamic_activation_int4_weight(), set_inductor_config=False) if not TORCH_VERSION_AT_LEAST_2_5: @@ -163,6 +161,7 @@ def _int8da_int4w_api(mod): _int4wo_api, ] + def undo_recommended_configs(): torch._inductor.config.coordinate_descent_tuning = False torch._inductor.config.coordinate_descent_check_all_directions = False @@ -171,28 +170,40 @@ def undo_recommended_configs(): torch._inductor.config.triton.unique_kernel_names = False torch.set_float32_matmul_precision("highest") + def combine_parameters(a, b): new_tuples = [] - for (tuple1, tuple2) in itertools.product(a, b): + for tuple1, tuple2 in itertools.product(a, b): new_tuples.append(tuple1 + tuple2) return new_tuples + def run_supported_device_dtype(test_method): """Assumes that the 3rd arg (args[2]) of the decorated method is device and there is a `test_dtype` kwarg or the 4th arg (args[3]) that indicates the dtype for testing """ + def wrapper(*args, **kwargs): if len(args) < 3: - raise unittest.SkipTest(f"Not enough args. Expected more than or equal to 3, but got {len(args)}") + raise unittest.SkipTest( + f"Not enough args. Expected more than or equal to 3, but got {len(args)}" + ) device = args[2] dtype = kwargs["test_dtype"] if "test_dtype" in kwargs else args[3] if device == "cuda" and not torch.cuda.is_available(): raise unittest.SkipTest(f"Need CUDA available.") - if device == "cuda" and torch.cuda.is_available() and dtype == torch.bfloat16 and torch.cuda.get_device_capability() < (8, 0): + if ( + device == "cuda" + and torch.cuda.is_available() + and dtype == torch.bfloat16 + and torch.cuda.get_device_capability() < (8, 0) + ): raise unittest.SkipTest("Need CUDA and SM80+ available.") return test_method(*args, **kwargs) + return wrapper + class SmoothquantUnitTest(unittest.TestCase): # first, let's reproduce the graphic from the paper, Figure 4, to ensure # we are calculating the scales correctly @@ -498,7 +509,9 @@ def _test_dynamic_quant_per_channel_numerics_impl( assert torch.max(torch.abs(y_vals - y_ref.int_repr())) <= 1 # dequantize - x_dq = dequantize_per_channel(y_vals, y_scale, y_zero_point, out_dtype=float_dtype) + x_dq = dequantize_per_channel( + y_vals, y_scale, y_zero_point, out_dtype=float_dtype + ) x_ref_dq = y_ref.dequantize().to(float_dtype) # off-by-one for scale is okay torch.testing.assert_close( @@ -524,7 +537,9 @@ def _test_quantize_per_token_impl(self, device, dtype): x = torch.randn(3, 3, 3, device=device, dtype=dtype) xq, scales = quantize_activation_per_token_absmax(x) block_size = (1, 1, 3) - x_dq = dequantize_affine(xq, block_size, scales, None, torch.int8, output_dtype=x.dtype) + x_dq = dequantize_affine( + xq, block_size, scales, None, torch.int8, output_dtype=x.dtype + ) sqnr = compute_error(x, x_dq) self.assertTrue(sqnr >= 45.0) @@ -628,6 +643,7 @@ def wrap_torch_int_mm(x, w): torch.testing.assert_close(z_ref, z_eager, atol=0, rtol=0) torch.testing.assert_close(z_ref, z_torch_compile, atol=0, rtol=0) + class TestSubclass(unittest.TestCase): @run_supported_device_dtype def _test_dequantize_impl( @@ -647,19 +663,21 @@ def _test_dequantize_impl( self.assertGreater( SQNR(w, lin.weight.dequantize()), min_sqnr, - f"{lin.weight.__class__.__name__} failed dtype={test_dtype}" - ) + f"{lin.weight.__class__.__name__} failed dtype={test_dtype}", + ) self.assertGreater( - SQNR(w.t(), - lin.weight.t().dequantize()), + SQNR(w.t(), lin.weight.t().dequantize()), min_sqnr, - f"{lin.weight.__class__.__name__} failed transpose on dtype={test_dtype}" + f"{lin.weight.__class__.__name__} failed transpose on dtype={test_dtype}", ) @parameterized.expand(COMMON_DEVICE_DTYPE) def test_dequantize_int8_dynamic_quant_subclass(self, device, dtype): self._test_dequantize_impl( - Int8DynamicallyQuantizedLinearWeight.from_float, device, 35, test_dtype=dtype, + Int8DynamicallyQuantizedLinearWeight.from_float, + device, + 35, + test_dtype=dtype, ) @parameterized.expand(COMMON_DEVICE_DTYPE) @@ -676,9 +694,15 @@ def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype): self.skipTest(f"Temporarily skipping for {device}") if dtype != torch.bfloat16: self.skipTest("Currently only supports bfloat16.") - for test_shape in ([(16, 1024, 16)] + ([(1, 1024, 8)] if device=='cuda' else [])): + for test_shape in [(16, 1024, 16)] + ( + [(1, 1024, 8)] if device == "cuda" else [] + ): self._test_dequantize_impl( - Int4WeightOnlyQuantizedLinearWeight.from_float, device, 15, test_shape=test_shape, test_dtype=dtype + Int4WeightOnlyQuantizedLinearWeight.from_float, + device, + 15, + test_shape=test_shape, + test_dtype=dtype, ) @parameterized.expand(COMMON_DEVICE_DTYPE) @@ -689,14 +713,16 @@ def test_dequantize_int4_weight_only_quant_subclass_grouped(self, device, dtype) self.skipTest(f"Temporarily skipping for {device}") if dtype != torch.bfloat16: self.skipTest("Currently only supports bfloat16.") - m_shapes = [16, 256] + ([1] if device=="cuda" else []) - n_shapes = [16] + ([8, 13] if device=="cuda" else []) + m_shapes = [16, 256] + ([1] if device == "cuda" else []) + n_shapes = [16] + ([8, 13] if device == "cuda" else []) for groupsize in [256, 128]: for inner_k_tiles in [8, 4, 2]: for m in m_shapes: for n in n_shapes: self._test_dequantize_impl( - lambda w: Int4WeightOnlyQuantizedLinearWeight.from_float(w, groupsize, inner_k_tiles), + lambda w: Int4WeightOnlyQuantizedLinearWeight.from_float( + w, groupsize, inner_k_tiles + ), device, 15, test_shape=[m, 256, n], @@ -727,21 +753,26 @@ def _test_lin_weight_subclass_impl( self.assertGreater( SQNR(ref_f, test), min_sqnr, - f"{lin.weight.__class__.__name__} failed, no compile, dtype={test_dtype}, (m, k, n)={test_shape}" + f"{lin.weight.__class__.__name__} failed, no compile, dtype={test_dtype}, (m, k, n)={test_shape}", ) - lin_comp = torch.compile(lin, mode='max-autotune') + lin_comp = torch.compile(lin, mode="max-autotune") test_comp = lin_comp(x) self.assertGreater( SQNR(ref_f, test_comp), min_sqnr, - f"{lin.weight.__class__.__name__} failed at compile with dtype={test_dtype}, (m, k, n)={test_shape}" + f"{lin.weight.__class__.__name__} failed at compile with dtype={test_dtype}, (m, k, n)={test_shape}", ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_4, "skip because there is some bug in inductor codegen") + @unittest.skipIf( + TORCH_VERSION_AT_LEAST_2_4, "skip because there is some bug in inductor codegen" + ) def test_int8_dynamic_quant_subclass(self, device, dtype): self._test_lin_weight_subclass_impl( - Int8DynamicallyQuantizedLinearWeight.from_float, device, 35, test_dtype=dtype + Int8DynamicallyQuantizedLinearWeight.from_float, + device, + 35, + test_dtype=dtype, ) @parameterized.expand(COMMON_DEVICE_DTYPE) @@ -752,47 +783,73 @@ def test_int8_weight_only_quant_subclass(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch" + ) def test_aq_int8_dynamic_quant_subclass(self, device, dtype): self._test_lin_weight_subclass_impl( - AQInt8DynamicallyQuantizedLinearWeight.from_float, device, 35, test_dtype=dtype + AQInt8DynamicallyQuantizedLinearWeight.from_float, + device, + 35, + test_dtype=dtype, ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch" + ) @unittest.skip( "This segfaults in CI cuda only, disable to unblock PR, we can investigate " "later if needed" ) def test_aq_int8_weight_only_quant_subclass(self, device, dtype): self._test_lin_weight_subclass_impl( - AQInt8WeightOnlyQuantizedLinearWeight.from_float, device, 35, test_dtype=dtype + AQInt8WeightOnlyQuantizedLinearWeight.from_float, + device, + 35, + test_dtype=dtype, ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch" + ) def test_aq_int8_weight_only_quant_2_subclass(self, device, dtype): self._test_lin_weight_subclass_impl( - AQInt8WeightOnlyQuantizedLinearWeight2.from_float, device, 35, test_dtype=dtype + AQInt8WeightOnlyQuantizedLinearWeight2.from_float, + device, + 35, + test_dtype=dtype, ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch" + ) def test_aq_int8_weight_only_quant_3_subclass(self, device, dtype): self._test_lin_weight_subclass_impl( - AQInt8WeightOnlyQuantizedLinearWeight3.from_float, device, 35, test_dtype=dtype + AQInt8WeightOnlyQuantizedLinearWeight3.from_float, + device, + 35, + test_dtype=dtype, ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch" + ) @unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run") def test_aq_float8_weight_only_quant_subclass(self, device, dtype): self._test_lin_weight_subclass_impl( - AQFloat8WeightOnlyQuantizedLinearWeight.from_float, device, 30, test_dtype=dtype + AQFloat8WeightOnlyQuantizedLinearWeight.from_float, + device, + 30, + test_dtype=dtype, ) def test_autoquantizable_flatten_unflatten(self): from torchao.quantization import DEFAULT_AUTOQUANT_CLASS_LIST + weight = torch.randn(16, 32) qtensor_class_list = DEFAULT_AUTOQUANT_CLASS_LIST aqw = AutoQuantizableLinearWeight.from_float(weight, qtensor_class_list) @@ -800,29 +857,45 @@ def test_autoquantizable_flatten_unflatten(self): tensor_data_dict = {name: getattr(aqw, name) for name in tensor_data_name_dict} outer_size = aqw.size() outer_stride = aqw.stride() - reconstructed = type(aqw).__tensor_unflatten__(tensor_data_dict, tensor_attributes, outer_size, outer_stride) - + reconstructed = type(aqw).__tensor_unflatten__( + tensor_data_dict, tensor_attributes, outer_size, outer_stride + ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch" + ) @unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run") def test_aq_float8_dynamic_quant_rowwise_scaling_subclass(self, device, dtype): if dtype != torch.bfloat16: - with self.assertRaisesRegex(AssertionError, "PerRow quantization only works for bfloat16 precision"): + with self.assertRaisesRegex( + AssertionError, "PerRow quantization only works for bfloat16 precision" + ): self._test_lin_weight_subclass_impl( - AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight.from_float, device, 25, test_dtype=dtype + AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight.from_float, + device, + 25, + test_dtype=dtype, ) else: self._test_lin_weight_subclass_impl( - AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight.from_float, device, 25, test_dtype=dtype + AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight.from_float, + device, + 25, + test_dtype=dtype, ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch" + ) @unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run") def test_aq_float8_dynamic_quant_tensorwise_scaling_subclass(self, device, dtype): self._test_lin_weight_subclass_impl( - AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight.from_float, device, 25, test_dtype=dtype + AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight.from_float, + device, + 25, + test_dtype=dtype, ) @parameterized.expand(COMMON_DEVICE_DTYPE) @@ -833,9 +906,15 @@ def test_int4_weight_only_quant_subclass(self, device, dtype): self.skipTest(f"Temporarily skipping for {device}") if dtype != torch.bfloat16: self.skipTest(f"Fails for {dtype}") - for test_shape in ([(16, 1024, 16)] + ([(1, 1024, 8)] if device=='cuda' else [])): + for test_shape in [(16, 1024, 16)] + ( + [(1, 1024, 8)] if device == "cuda" else [] + ): self._test_lin_weight_subclass_impl( - Int4WeightOnlyQuantizedLinearWeight.from_float, device, 10, test_shape=test_shape, test_dtype=dtype + Int4WeightOnlyQuantizedLinearWeight.from_float, + device, + 10, + test_shape=test_shape, + test_dtype=dtype, ) @parameterized.expand(COMMON_DEVICE_DTYPE) @@ -844,14 +923,16 @@ def test_int4_weight_only_quant_subclass(self, device, dtype): def test_int4_weight_only_quant_subclass_grouped(self, device, dtype): if dtype != torch.bfloat16: self.skipTest(f"Fails for {dtype}") - m_shapes = [16, 256] + ([1] if device=="cuda" else []) - n_shapes = [16] + ([8, 13] if device=="cuda" else []) + m_shapes = [16, 256] + ([1] if device == "cuda" else []) + n_shapes = [16] + ([8, 13] if device == "cuda" else []) for groupsize in [128, 64]: for inner_k_tiles in [8, 4, 2]: for m in m_shapes: for n in n_shapes: self._test_lin_weight_subclass_impl( - lambda w: Int4WeightOnlyQuantizedLinearWeight.from_float(w, groupsize, inner_k_tiles), + lambda w: Int4WeightOnlyQuantizedLinearWeight.from_float( + w, groupsize, inner_k_tiles + ), device, 10, test_shape=[m, 256, n], @@ -866,12 +947,14 @@ def _test_lin_weight_subclass_api_impl( test_device, min_sqnr=35, test_dtype=torch.bfloat16, - test_shape=(32, 64, 32) + test_shape=(32, 64, 32), ): m, k, n = test_shape x = torch.randn(m, k, device=test_device, dtype=test_dtype) mod = nn.Sequential( - nn.Linear(k, n, device=test_device), nn.ReLU(), nn.Linear(n, n, device=test_device) + nn.Linear(k, n, device=test_device), + nn.ReLU(), + nn.Linear(n, n, device=test_device), ).to(test_dtype) ref_f = mod(x) api(mod) @@ -883,17 +966,18 @@ def _test_lin_weight_subclass_api_impl( test = mod(x) self.assertGreater( SQNR(ref_f, test), - min_sqnr, f"{api.__name__} failed, no compile dtype={test_dtype}, (m, k, n)={test_shape}" + min_sqnr, + f"{api.__name__} failed, no compile dtype={test_dtype}, (m, k, n)={test_shape}", ) mod_qc = torch.compile(mod, mode="max-autotune") test_comp = mod_qc(x) self.assertGreater( - SQNR(ref_f, test_comp), min_sqnr, - f"{api.__name__} failed when compiled with dtype={test_dtype}, (m, k, n)={test_shape}" + SQNR(ref_f, test_comp), + min_sqnr, + f"{api.__name__} failed when compiled with dtype={test_dtype}, (m, k, n)={test_shape}", ) - @parameterized.expand(COMMON_DEVICE_DTYPE) def test_int8_dynamic_quant_subclass_api(self, device, dtype): self._test_lin_weight_subclass_api_impl( @@ -916,7 +1000,9 @@ def test_int8_weight_only_quant_subclass_api(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) @torch._inductor.config.patch({"freezing": True}) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "freeze requires torch 2.4 and after.") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "freeze requires torch 2.4 and after." + ) def test_int8_weight_only_quant_with_freeze(self, device, dtype): torch._dynamo.reset() self._test_lin_weight_subclass_api_impl( @@ -931,13 +1017,11 @@ def test_int4_weight_only_quant_subclass_api(self, device, dtype): self.skipTest(f"Temporarily skipping for {device}") if dtype != torch.bfloat16: self.skipTest(f"Fails for {dtype}") - for test_shape in ([(16, 1024, 16)] + ([(1, 1024, 256)] if device=='cuda' else [])): + for test_shape in [(16, 1024, 16)] + ( + [(1, 1024, 256)] if device == "cuda" else [] + ): self._test_lin_weight_subclass_api_impl( - _int4wo_api, - device, - 15, - test_shape=test_shape, - test_dtype=dtype + _int4wo_api, device, 15, test_shape=test_shape, test_dtype=dtype ) @parameterized.expand(COMMON_DEVICE_DTYPE) @@ -982,12 +1066,14 @@ def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype): if dtype != torch.bfloat16: self.skipTest(f"Fails for {dtype}") layout_list = [] - if device == 'cpu' and TORCH_VERSION_AT_LEAST_2_6: + if device == "cpu" and TORCH_VERSION_AT_LEAST_2_6: layout_list.append(Int4CPULayout()) else: for inner_k_tiles in [4, 2]: layout_list.append(TensorCoreTiledLayout(inner_k_tiles=inner_k_tiles)) - for test_shape in ([(256, 256, 16)] + ([(256, 256, 8)] if device=='cuda' else [])): + for test_shape in [(256, 256, 16)] + ( + [(256, 256, 8)] if device == "cuda" else [] + ): for groupsize in [64, 32]: for layout in layout_list: kwargs = {"groupsize": groupsize, "layout": layout} @@ -1046,7 +1132,9 @@ def test_weight_only_groupwise_quant(self): m = nn.Sequential(nn.Linear(512, 32)) y_ref = m(x) _int8wo_groupwise_api(m) - self.assertEqual(m[0].weight.tensor_impl.int_data.shape, torch.Size([32, 512])) + self.assertEqual( + m[0].weight.tensor_impl.int_data.shape, torch.Size([32, 512]) + ) self.assertEqual(m[0].weight.tensor_impl.scale.shape, torch.Size([32, 16])) y_wo = m(x) sqnr = compute_error(y_ref, y_wo) @@ -1057,7 +1145,11 @@ def test_weight_only_groupwise_embedding_quant(self): m = nn.Embedding(4096, 128) input = torch.randint(0, 4096, (1, 6)) - quantize_(m, int8_weight_only(group_size=group_size), filter_fn=lambda x, *args: isinstance(x, nn.Embedding)) + quantize_( + m, + int8_weight_only(group_size=group_size), + filter_fn=lambda x, *args: isinstance(x, nn.Embedding), + ) y_q = m(input) y_ref = m.weight.dequantize()[input] @@ -1065,23 +1157,31 @@ def test_weight_only_groupwise_embedding_quant(self): self.assertGreater(sqnr, 45.0) - @parameterized.expand(COMMON_DEVICE_DTYPE) @torch.no_grad() @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_weight_only_quant_force_mixed_mm(self, device, dtype): undo_recommended_configs() if device != "cuda": - self.skipTest(f"weight_only_quant_force_mixed_mm can't be constructed on {device}") + self.skipTest( + f"weight_only_quant_force_mixed_mm can't be constructed on {device}" + ) if dtype == torch.bfloat16 and torch.cuda.get_device_capability() < (8, 0): self.skipTest("test requires SM capability of at least (8, 0).") from torch._inductor import config - mixed_mm_key, mixed_mm_val = ("mixed_mm_choice", "triton") if TORCH_VERSION_AT_LEAST_2_5 else ("force_mixed_mm", True) - with config.patch({ - "epilogue_fusion": True, - mixed_mm_key: mixed_mm_val, - }): + mixed_mm_key, mixed_mm_val = ( + ("mixed_mm_choice", "triton") + if TORCH_VERSION_AT_LEAST_2_5 + else ("force_mixed_mm", True) + ) + + with config.patch( + { + "epilogue_fusion": True, + mixed_mm_key: mixed_mm_val, + } + ): for x_shape in [[2, 4], [5, 5, 5, 4], [1, 4, 4]]: torch._dynamo.reset() x = torch.randn(*x_shape).to(device).to(dtype) @@ -1101,17 +1201,26 @@ def test_weight_only_quant_force_mixed_mm(self, device, dtype): def test_weight_only_quant_use_mixed_mm(self, device, dtype): undo_recommended_configs() if device != "cuda": - self.skipTest(f"weight_only_quant_force_mixed_mm can't be constructed on {device}") + self.skipTest( + f"weight_only_quant_force_mixed_mm can't be constructed on {device}" + ) if dtype == torch.bfloat16 and torch.cuda.get_device_capability() < (8, 0): self.skipTest("test requires SM capability of at least (8, 0).") torch.manual_seed(0) from torch._inductor import config - mixed_mm_key, mixed_mm_val = ("mixed_mm_choice", "triton") if TORCH_VERSION_AT_LEAST_2_5 else ("force_mixed_mm", True) - with config.patch({ - "epilogue_fusion": False, - mixed_mm_key: mixed_mm_val, - }): + mixed_mm_key, mixed_mm_val = ( + ("mixed_mm_choice", "triton") + if TORCH_VERSION_AT_LEAST_2_5 + else ("force_mixed_mm", True) + ) + + with config.patch( + { + "epilogue_fusion": False, + mixed_mm_key: mixed_mm_val, + } + ): for x_shape in [[2, 4], [5, 5, 5, 4], [1, 4, 4]]: torch._dynamo.reset() x = torch.randn(*x_shape).to(device).to(dtype) @@ -1128,11 +1237,7 @@ class TestSaveLoadMeta(unittest.TestCase): @torch.no_grad() @run_supported_device_dtype def _test_handle_save_load_meta_impl( - self, - api, - test_device, - min_sqnr=35, - test_dtype=torch.bfloat16 + self, api, test_device, min_sqnr=35, test_dtype=torch.bfloat16 ): logger.info(f"TestSaveLoad: {api}, {test_device}, {test_dtype}") m, k, n = 32, 64, 32 @@ -1167,10 +1272,12 @@ def forward(self, x): model_qc = torch.compile(model, mode="max-autotune") ref_q = model_qc(x).detach() - assert SQNR(ref_f, ref_q) > min_sqnr, f"got sqnr: {SQNR(ref_f, ref_q)}, expected: {min_sqnr}" + assert ( + SQNR(ref_f, ref_q) > min_sqnr + ), f"got sqnr: {SQNR(ref_f, ref_q)}, expected: {min_sqnr}" # load model structure - with torch.device('meta'): + with torch.device("meta"): model = test_model().to(dtype=test_dtype) api(model) @@ -1187,16 +1294,22 @@ def forward(self, x): model_qc = torch.compile(model, mode="max-autotune") test = model_qc(x).detach() - assert SQNR(ref_f, test) > min_sqnr, f"got sqnr: {SQNR(ref_f, ref_q)}, expected: {min_sqnr}" + assert ( + SQNR(ref_f, test) > min_sqnr + ), f"got sqnr: {SQNR(ref_f, ref_q)}, expected: {min_sqnr}" self.assertTrue(torch.equal(ref_q, test)) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(is_fbcode(), "'PlainAQTTensorImpl' object has no attribute 'int_data'") + @unittest.skipIf( + is_fbcode(), "'PlainAQTTensorImpl' object has no attribute 'int_data'" + ) @torch.no_grad() def test_save_load_dqtensors(self, device, dtype): if device == "cpu": self.skipTest(f"indcutor failed for cpu right now") - self._test_handle_save_load_meta_impl(_int8da_int8w_api, device, test_dtype=dtype) + self._test_handle_save_load_meta_impl( + _int8da_int8w_api, device, test_dtype=dtype + ) @parameterized.expand(COMMON_DEVICE_DTYPE) @torch.no_grad() @@ -1223,7 +1336,9 @@ def test_save_load_int4woqtensors(self, device, dtype): class TorchCompileUnitTest(unittest.TestCase): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "fullgraph requires torch nightly.") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_3, "fullgraph requires torch nightly." + ) def test_fullgraph(self): lin_fp16 = nn.Linear(32, 16, device="cuda", dtype=torch.float16) lin_smooth = SmoothFakeDynamicallyQuantizedLinear.from_float( @@ -1276,11 +1391,15 @@ class SmoothquantIntegrationTest(unittest.TestCase): def test_non_dynamically_quantizable_linear(self): if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0): self.skipTest("test requires SM capability of at least (8, 0).") - model = torch.nn.Sequential( - torch.nn.modules.linear.NonDynamicallyQuantizableLinear(32,32), - torch.nn.ReLU() - ).to("cuda").to(torch.bfloat16) - example_input = torch.randn(32,32, device="cuda", dtype=torch.bfloat16) + model = ( + torch.nn.Sequential( + torch.nn.modules.linear.NonDynamicallyQuantizableLinear(32, 32), + torch.nn.ReLU(), + ) + .to("cuda") + .to(torch.bfloat16) + ) + example_input = torch.randn(32, 32, device="cuda", dtype=torch.bfloat16) ref = model(example_input) swap_linear_with_smooth_fq_linear(model) model(ref) @@ -1346,18 +1465,23 @@ def test_on_dummy_distilbert(self): print("sqnr_pt_quant", sqnr_pt_quant) self.assertTrue(sqnr_sq >= 8.0) + class TestAutoQuant(unittest.TestCase): - @parameterized.expand(combine_parameters(COMMON_DEVICE_DTYPE, - [ - (16, 128, 128), - (64, 128, 128), - # (2**15, 128, 128), TODO: Runs out of shared memory on T4 - (16, 128, 256), - # (64, 128, 256), # TODO: Runs out of shared memory on T4 - (16, 256, 128), - (64, 256, 128), - # (256, 256, 128), TODO: Runs out of shared memory on T4 - ])) + @parameterized.expand( + combine_parameters( + COMMON_DEVICE_DTYPE, + [ + (16, 128, 128), + (64, 128, 128), + # (2**15, 128, 128), TODO: Runs out of shared memory on T4 + (16, 128, 256), + # (64, 128, 256), # TODO: Runs out of shared memory on T4 + (16, 256, 128), + (64, 256, 128), + # (256, 256, 128), TODO: Runs out of shared memory on T4 + ], + ) + ) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "autoquant requires 2.3+.") def test_autoquant_one_input(self, device, dtype, m, k, n): undo_recommended_configs() @@ -1375,23 +1499,31 @@ def test_autoquant_one_input(self, device, dtype, m, k, n): torch._dynamo.config.automatic_dynamic_shapes = False example_input = torch.randn(m, k, device=device, dtype=dtype) - model = torch.nn.Sequential( - torch.nn.ReLU(), - torch.nn.Linear(k,n), - torch.nn.ReLU(), - ).to(device).to(dtype) + model = ( + torch.nn.Sequential( + torch.nn.ReLU(), + torch.nn.Linear(k, n), + torch.nn.ReLU(), + ) + .to(device) + .to(dtype) + ) out = model(example_input) torchao.autoquant(model, set_inductor_config=False) out2 = model(example_input) sqnr = SQNR(out, out2) self.assertTrue(sqnr >= 30) - @parameterized.expand(combine_parameters(COMMON_DEVICE_DTYPE, - [ - (1, 1, 128, 128), - (1, 32, 128, 128), - (32, 32, 128, 128), - ])) + @parameterized.expand( + combine_parameters( + COMMON_DEVICE_DTYPE, + [ + (1, 1, 128, 128), + (1, 32, 128, 128), + (32, 32, 128, 128), + ], + ) + ) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant requires 2.5+.") def test_autoquant_compile(self, device, dtype, m1, m2, k, n): undo_recommended_configs() @@ -1405,16 +1537,22 @@ def test_autoquant_compile(self, device, dtype, m1, m2, k, n): # This test fails on v0.4.0 and torch 2.4, so skipping for now. if m1 == 1 or m2 == 1 and not TORCH_VERSION_AT_LEAST_2_5: self.skipTest(f"Shape {(m1, m2, k, n)} requires torch version > 2.4") - model = torch.nn.Sequential( - torch.nn.ReLU(), - torch.nn.Linear(k,n), - torch.nn.ReLU(), - ).to(device).to(dtype) + model = ( + torch.nn.Sequential( + torch.nn.ReLU(), + torch.nn.Linear(k, n), + torch.nn.ReLU(), + ) + .to(device) + .to(dtype) + ) example_input = torch.randn(m1, k, device=device, dtype=dtype) example_input2 = torch.randn(m2, k, device=device, dtype=dtype) out = model(example_input) - mod = torchao.autoquant(torch.compile(model), manual=True, set_inductor_config=False) + mod = torchao.autoquant( + torch.compile(model), manual=True, set_inductor_config=False + ) mod(example_input) mod(example_input2) mod.finalize_autoquant() @@ -1428,6 +1566,7 @@ def test_autoquant_compile(self, device, dtype, m1, m2, k, n): def test_autoquant_mha(self, device, dtype): if device != "cuda" or not torch.cuda.is_available(): self.skipTest(f"autoquant currently does not support {device}") + class MHAModel(torch.nn.Module): def __init__(self): super().__init__() @@ -1439,17 +1578,16 @@ def forward(self, x): return self.lin(y) mod = MHAModel().to(device).to(dtype) - input = torch.randn(1,1,4096).to(device).to(dtype) - out=mod(*input) + input = torch.randn(1, 1, 4096).to(device).to(dtype) + out = mod(*input) torchao.autoquant(mod, set_inductor_config=False) assert not isinstance(mod.mha.out_proj.weight, AutoQuantizableLinearWeight) assert isinstance(mod.lin.weight, AutoQuantizableLinearWeight) mod(*input) from torchao.quantization.autoquant import AUTOQUANT_CACHE - assert len(AUTOQUANT_CACHE)>0 - + assert len(AUTOQUANT_CACHE) > 0 @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant requires 2.5+.") @@ -1461,16 +1599,22 @@ def test_autoquant_manual(self, device, dtype): if dtype == torch.bfloat16: self.skipTest(f"bfloat16 requires sm80+") m1, m2, k, n = 32, 32, 128, 128 - model = torch.nn.Sequential( - torch.nn.ReLU(), - torch.nn.Linear(k,n), - torch.nn.ReLU(), - ).to(device).to(dtype) + model = ( + torch.nn.Sequential( + torch.nn.ReLU(), + torch.nn.Linear(k, n), + torch.nn.ReLU(), + ) + .to(device) + .to(dtype) + ) example_input = torch.randn(m1, k, device=device, dtype=dtype) example_input2 = torch.randn(m2, k, device=device, dtype=dtype) out = model(example_input) - mod = torchao.autoquant(torch.compile(model), manual=True, set_inductor_config=False) + mod = torchao.autoquant( + torch.compile(model), manual=True, set_inductor_config=False + ) mod(example_input) mod(example_input2) mod.finalize_autoquant() @@ -1486,13 +1630,16 @@ def test_autoquant_manual(self, device, dtype): sqnr2 = SQNR(out, out3) self.assertTrue(sqnr2 >= 30) - - @parameterized.expand(combine_parameters(COMMON_DEVICE_DTYPE, - [ - (1, 1, 128, 128), - (1, 32, 128, 128), - (32, 32, 128, 128), - ])) + @parameterized.expand( + combine_parameters( + COMMON_DEVICE_DTYPE, + [ + (1, 1, 128, 128), + (1, 32, 128, 128), + (32, 32, 128, 128), + ], + ) + ) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant requires 2.5+.") def test_autoquant_kwargs(self, device, dtype, m1, m2, k, n): undo_recommended_configs() @@ -1511,7 +1658,7 @@ class NeedsKwargs(torch.nn.Module): def __init__(self): super().__init__() self.rel = torch.nn.ReLU() - self.lin = torch.nn.Linear(k,n) + self.lin = torch.nn.Linear(k, n) def forward(self, x, y): x = self.rel(x) @@ -1532,10 +1679,14 @@ def forward(self, x, y): sqnr = SQNR(out, out2) self.assertTrue(sqnr >= 30) - @parameterized.expand(combine_parameters(COMMON_DEVICE_DTYPE, - [ - (16, 128, 128), - ])) + @parameterized.expand( + combine_parameters( + COMMON_DEVICE_DTYPE, + [ + (16, 128, 128), + ], + ) + ) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "autoquant requires 2.3+.") def test_autoquant_double_access(self, device, dtype, m, k, n): undo_recommended_configs() @@ -1571,11 +1722,15 @@ def forward(self, x): def test_autoquant_min_sqnr(self, device, dtype): m, k, n = 128, 128, 128 example_input = torch.randn(m, k, device=device, dtype=dtype) - model = torch.nn.Sequential( - torch.nn.ReLU(), - torch.nn.Linear(k,n), - torch.nn.ReLU(), - ).to(device).to(dtype) + model = ( + torch.nn.Sequential( + torch.nn.ReLU(), + torch.nn.Linear(k, n), + torch.nn.ReLU(), + ) + .to(device) + .to(dtype) + ) out = model(example_input) torchao.autoquant(model, min_sqnr=60) out2 = model(example_input) @@ -1585,7 +1740,6 @@ def test_autoquant_min_sqnr(self, device, dtype): self.assertTrue(sqnr >= 50, f"sqnr: {sqnr}") - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "requires 2.5+.") @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") @unittest.skip("AOTI tests are failing right now") @@ -1595,12 +1749,18 @@ class TestAOTI(unittest.TestCase): ) def test_aoti(self, api, test_device, test_dtype): if api is change_linear_weights_to_int8_dqtensors and test_device == "cuda": - self.skipTest(f"{api} in {test_device} is not support for aoti compilation yet") + self.skipTest( + f"{api} in {test_device} is not support for aoti compilation yet" + ) - if test_device == "cuda" and torch.cuda.is_available() and test_dtype == torch.bfloat16 and torch.cuda.get_device_capability() < (8, 0): + if ( + test_device == "cuda" + and torch.cuda.is_available() + and test_dtype == torch.bfloat16 + and torch.cuda.get_device_capability() < (8, 0) + ): self.skipTest("Need CUDA and SM80+ available.") - logger.info(f"TestAOTI: {api}, {test_device}, {test_dtype}") m, k, n = 32, 64, 32 @@ -1634,17 +1794,30 @@ def forward(self, x): torch._inductor.config.mixed_mm_choice = "triton" example_inputs = (x,) - torch._inductor.aoti_compile_and_package(torch.export.export(model, example_inputs), example_inputs) + torch._inductor.aoti_compile_and_package( + torch.export.export(model, example_inputs, strict=True), example_inputs + ) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "requires 2.5+.") @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") class TestExport(unittest.TestCase): @parameterized.expand( - list(itertools.product(TENSOR_SUBCLASS_APIS + [_int8da_int4w_api], COMMON_DEVICES, COMMON_DTYPES)), + list( + itertools.product( + TENSOR_SUBCLASS_APIS + [_int8da_int4w_api], + COMMON_DEVICES, + COMMON_DTYPES, + ) + ), ) def test_export(self, api, test_device, test_dtype): - if test_device == "cuda" and torch.cuda.is_available() and test_dtype == torch.bfloat16 and torch.cuda.get_device_capability() < (8, 0): + if ( + test_device == "cuda" + and torch.cuda.is_available() + and test_dtype == torch.bfloat16 + and torch.cuda.get_device_capability() < (8, 0) + ): self.skipTest("Need CUDA and SM80+ available.") logger.info(f"TestExport: {api}, {test_device}, {test_dtype}") @@ -1694,8 +1867,6 @@ def forward(self, x): self.assertFalse(torch.ops.aten.narrow.default in targets) - - class TestUtils(unittest.TestCase): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant requires 2.5+.") @@ -1708,21 +1879,28 @@ def test_get_model_size_autoquant(self, device, dtype): if dtype == torch.bfloat16: self.skipTest(f"bfloat16 requires sm80+") m, k, n = 16, 128, 128 - model = torch.nn.Sequential( - torch.nn.ReLU(), - torch.nn.Linear(k,n), - torch.nn.ReLU(), - ).to(device).to(dtype) + model = ( + torch.nn.Sequential( + torch.nn.ReLU(), + torch.nn.Linear(k, n), + torch.nn.ReLU(), + ) + .to(device) + .to(dtype) + ) example_input = torch.randn(m, k, device=device, dtype=dtype) size = torchao.utils.get_model_size_in_bytes(model) from torchao.quantization.autoquant import ( AQInt8WeightOnlyQuantizedLinearWeight2, ) - qtensor_class_list = ( - AQInt8WeightOnlyQuantizedLinearWeight2, + + qtensor_class_list = (AQInt8WeightOnlyQuantizedLinearWeight2,) + mod = torchao.autoquant( + torch.compile(model), + qtensor_class_list=qtensor_class_list, + set_inductor_config=False, ) - mod = torchao.autoquant(torch.compile(model), qtensor_class_list = qtensor_class_list, set_inductor_config=False) mod(example_input) size2 = torchao.utils.get_model_size_in_bytes(mod) self.assertTrue(size2 < size) @@ -1737,21 +1915,22 @@ def test_get_model_size_aqt(self, api, test_device, test_dtype): if test_device != "cuda" or not torch.cuda.is_available(): self.skipTest(f"{api} currently does not support {test_device}") k, n = 1024, 1024 - model = torch.nn.Sequential( - torch.nn.ReLU(), - torch.nn.Linear(k,n), - torch.nn.ReLU(), - ).to(test_device).to(test_dtype) + model = ( + torch.nn.Sequential( + torch.nn.ReLU(), + torch.nn.Linear(k, n), + torch.nn.ReLU(), + ) + .to(test_device) + .to(test_dtype) + ) size = torchao.utils.get_model_size_in_bytes(model) api(model) size2 = torchao.utils.get_model_size_in_bytes(model) self.assertTrue(size2 < size) - - class TestBenchmarkModel(unittest.TestCase): - class ToyLinearModel(torch.nn.Module): def __init__(self, m=64, n=32, k=64): super().__init__() @@ -1759,7 +1938,11 @@ def __init__(self, m=64, n=32, k=64): self.linear2 = torch.nn.Linear(n, k, bias=False) def example_inputs(self, batch_size=1, dtype=torch.float32, device="cpu"): - return (torch.randn(batch_size, self.linear1.in_features, dtype=dtype, device=device),) + return ( + torch.randn( + batch_size, self.linear1.in_features, dtype=dtype, device=device + ), + ) def forward(self, x): x = self.linear1(x) @@ -1772,7 +1955,7 @@ def run_benchmark_model(self, device): m = self.ToyLinearModel(1024, 1024, 1024).eval().to(dtype).to(device) m_bf16 = copy.deepcopy(m) example_inputs = m.example_inputs(dtype=dtype, device=device) - m_bf16 = torch.compile(m_bf16, mode='max-autotune') + m_bf16 = torch.compile(m_bf16, mode="max-autotune") num_runs = 1 return benchmark_model(m_bf16, num_runs, example_inputs)