diff --git a/torchao/experimental/ops/mps/test/test_quantizer.py b/torchao/experimental/ops/mps/test/test_quantizer.py index 72a6b76fa..eb6663c1e 100644 --- a/torchao/experimental/ops/mps/test/test_quantizer.py +++ b/torchao/experimental/ops/mps/test/test_quantizer.py @@ -4,18 +4,17 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -from typing import Optional import copy import itertools import os import sys +import unittest +from typing import Optional import torch -import unittest from parameterized import parameterized -from torchao.experimental.quant_api import UIntxWeightOnlyLinearQuantizer -from torchao.experimental.quant_api import _quantize +from torchao.experimental.quant_api import _quantize, UIntxWeightOnlyLinearQuantizer libname = "libtorchao_ops_mps_aten.dylib" libpath = os.path.abspath( @@ -80,7 +79,7 @@ def test_export(self, nbit): activations = torch.randn(m, k0, dtype=torch.float32, device="mps") quantized_model = self._quantize_model(model, torch.float32, nbit, group_size) - exported = torch.export.export(quantized_model, (activations,)) + exported = torch.export.export(quantized_model, (activations,), strict=True) for node in exported.graph.nodes: if node.op == "call_function":