Skip to content

Commit

Permalink
test fixes
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
HDCharles committed Sep 4, 2024
1 parent 03d01ad commit 8376847
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@
)
from torchao.quantization.autoquant import (
AQInt8DynamicallyQuantizedLinearWeight,
AQWeightOnlyQuantizedLinearWeight,
AQWeightOnlyQuantizedLinearWeight2,
AQWeightOnlyQuantizedLinearWeight3,
AQInt8WeightOnlyQuantizedLinearWeight,
AQInt8WeightOnlyQuantizedLinearWeight2,
AQInt8WeightOnlyQuantizedLinearWeight3,
AutoQuantizableLinearWeight,

)
Expand Down Expand Up @@ -727,21 +727,21 @@ def test_aq_int8_dynamic_quant_subclass(self, device, dtype):
)
def test_aq_int8_weight_only_quant_subclass(self, device, dtype):
self._test_lin_weight_subclass_impl(
AQWeightOnlyQuantizedLinearWeight.from_float, device, 35, test_dtype=dtype
AQInt8WeightOnlyQuantizedLinearWeight.from_float, device, 35, test_dtype=dtype
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch")
def test_aq_int8_weight_only_quant_2_subclass(self, device, dtype):
self._test_lin_weight_subclass_impl(
AQWeightOnlyQuantizedLinearWeight2.from_float, device, 35, test_dtype=dtype
AQInt8WeightOnlyQuantizedLinearWeight2.from_float, device, 35, test_dtype=dtype
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch")
def test_aq_int8_weight_only_quant_3_subclass(self, device, dtype):
self._test_lin_weight_subclass_impl(
AQWeightOnlyQuantizedLinearWeight3.from_float, device, 35, test_dtype=dtype
AQInt8WeightOnlyQuantizedLinearWeight3.from_float, device, 35, test_dtype=dtype
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
Expand Down Expand Up @@ -1498,10 +1498,10 @@ def test_get_model_size_autoquant(self, device, dtype):
size = torchao.utils.get_model_size_in_bytes(model)

from torchao.quantization.autoquant import (
AQWeightOnlyQuantizedLinearWeight2,
AQInt8WeightOnlyQuantizedLinearWeight2,
)
qtensor_class_list = (
AQWeightOnlyQuantizedLinearWeight2,
AQInt8WeightOnlyQuantizedLinearWeight2,
)
mod = torchao.autoquant(torch.compile(model), qtensor_class_list = qtensor_class_list, set_inductor_config=False)
mod(example_input)
Expand Down

0 comments on commit 8376847

Please sign in to comment.