Skip to content

Commit

Permalink
pytorch/ao/test/quantization
Browse files Browse the repository at this point in the history
Differential Revision: D67388025

Pull Request resolved: #1443
  • Loading branch information
gmagogsfm authored Dec 19, 2024
1 parent aea2356 commit 9f0dcdc
Showing 1 changed file with 11 additions and 20 deletions.
31 changes: 11 additions & 20 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,38 +13,29 @@
from pathlib import Path

import torch
from torch.ao.quantization.quantize_pt2e import (
convert_pt2e,
prepare_pt2e,
)
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
XNNPACKQuantizer,
get_symmetric_quantization_config,
XNNPACKQuantizer,
)
from torch.testing._internal import common_utils
from torch.testing._internal.common_utils import TestCase

from torchao import quantize_
from torchao._models.llama.model import Transformer, prepare_inputs_for_model
from torchao._models.llama.model import prepare_inputs_for_model, Transformer
from torchao._models.llama.tokenizer import get_tokenizer
from torchao.dtypes import (
AffineQuantizedTensor,
)
from torchao.quantization import (
LinearActivationQuantizedTensor,
)
from torchao.dtypes import AffineQuantizedTensor
from torchao.quantization import LinearActivationQuantizedTensor
from torchao.quantization.quant_api import (
Quantizer,
TwoStepQuantizer,
_replace_with_custom_fn_if_matches_filter,
int4_weight_only,
int8_dynamic_activation_int4_weight,
int8_dynamic_activation_int8_weight,
int8_weight_only,
Quantizer,
TwoStepQuantizer,
)
from torchao.quantization.quant_primitives import (
MappingType,
)
from torchao.quantization.quant_primitives import MappingType
from torchao.quantization.subclass import (
Int4WeightOnlyQuantizedLinearWeight,
Int8WeightOnlyQuantizedLinearWeight,
Expand All @@ -59,7 +50,7 @@


def dynamic_quant(model, example_inputs):
m = torch.export.export(model, example_inputs).module()
m = torch.export.export(model, example_inputs, strict=True).module()
quantizer = XNNPACKQuantizer().set_global(
get_symmetric_quantization_config(is_dynamic=True)
)
Expand All @@ -69,7 +60,7 @@ def dynamic_quant(model, example_inputs):


def capture_and_prepare(model, example_inputs):
m = torch.export.export(model, example_inputs)
m = torch.export.export(model, example_inputs, strict=True)
quantizer = XNNPACKQuantizer().set_global(
get_symmetric_quantization_config(is_dynamic=True)
)
Expand Down Expand Up @@ -666,7 +657,7 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self):

m_unwrapped = unwrap_tensor_subclass(m)

m = torch.export.export(m_unwrapped, example_inputs).module()
m = torch.export.export(m_unwrapped, example_inputs, strict=True).module()
exported_model_res = m(*example_inputs)

self.assertTrue(torch.equal(exported_model_res, ref))
Expand Down

0 comments on commit 9f0dcdc

Please sign in to comment.