Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pytorch/ao/torchao/experimental/tests #1441

Merged
merged 1 commit into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions torchao/experimental/tests/test_embedding_xbit_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import copy

import glob
import subprocess
import os
import subprocess

import sys
import tempfile
Expand All @@ -18,10 +18,11 @@

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")))
from torchao.experimental.quant_api import (
IntxWeightEmbeddingQuantizer,
_IntxWeightQuantizedEmbeddingFallback,
IntxWeightEmbeddingQuantizer,
)


def cmake_build_torchao_ops(temp_build_dir):
from distutils.sysconfig import get_python_lib

Expand Down Expand Up @@ -62,7 +63,9 @@ def test_accuracy(self):
group_size = 128
embedding_dim = 4096
num_embeddings = 131
model = torch.nn.Sequential(*[torch.nn.Embedding(num_embeddings, embedding_dim)])
model = torch.nn.Sequential(
*[torch.nn.Embedding(num_embeddings, embedding_dim)]
)
indices = torch.randint(0, num_embeddings, (7,), dtype=torch.int32)

for nbit in [1, 2, 3, 4, 5, 6, 7, 8]:
Expand All @@ -88,10 +91,11 @@ def test_export_compile_aoti(self):
group_size = 128
embedding_dim = 4096
num_embeddings = 131
model = torch.nn.Sequential(*[torch.nn.Embedding(num_embeddings, embedding_dim)])
model = torch.nn.Sequential(
*[torch.nn.Embedding(num_embeddings, embedding_dim)]
)
indices = torch.randint(0, num_embeddings, (42,), dtype=torch.int32)


print("Quantizing model")
quantizer = IntxWeightEmbeddingQuantizer(
device="cpu",
Expand All @@ -102,7 +106,7 @@ def test_export_compile_aoti(self):
quantized_model = quantizer.quantize(model)

print("Exporting quantized model")
exported = torch.export.export(quantized_model, (indices,))
exported = torch.export.export(quantized_model, (indices,), strict=True)

print("Compiling quantized model")
quantized_model_compiled = torch.compile(quantized_model)
Expand All @@ -121,5 +125,6 @@ def test_export_compile_aoti(self):
fn = torch._export.aot_load(f"{tmpdirname}/model.so", "cpu")
fn(indices)


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def test_export_compile_aoti(self):
quantized_model = quantizer.quantize(model)

print("Exporting quantized model")
exported = torch.export.export(quantized_model, (activations,))
exported = torch.export.export(quantized_model, (activations,), strict=True)

print("Compiling quantized model")
quantized_model_compiled = torch.compile(quantized_model)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")))

from torchao.experimental.quant_api import int8_dynamic_activation_intx_weight
from torchao.quantization.quant_api import quantize_

from torchao.utils import unwrap_tensor_subclass
from torchao.experimental.quant_api import (
_Int8DynActIntxWeightQuantizedLinearFallback,
int8_dynamic_activation_intx_weight,
)
from torchao.quantization.quant_api import quantize_

from torchao.utils import unwrap_tensor_subclass


def cmake_build_torchao_ops(temp_build_dir):
from distutils.sysconfig import get_python_lib
Expand Down Expand Up @@ -98,7 +99,7 @@ def test_accuracy(self):
result = quantized_model(activations)
expected_result = quantized_model_reference(activations)

#TODO: remove expected_result2 checks when we deprecate non-subclass API
# TODO: remove expected_result2 checks when we deprecate non-subclass API
reference_impl = _Int8DynActIntxWeightQuantizedLinearFallback()
reference_impl.quantize_and_pack_weights(
model[0].weight, nbit, group_size, has_weight_zeros
Expand All @@ -115,8 +116,12 @@ def test_accuracy(self):
self.assertTrue(torch.allclose(actual_val, expected_val, atol=1e-6))
if not torch.allclose(actual_val, expected_val):
num_mismatch_at_low_tol += 1

self.assertTrue(torch.allclose(expected_val, expected_val2, atol=1e-2, rtol=1e-1))

self.assertTrue(
torch.allclose(
expected_val, expected_val2, atol=1e-2, rtol=1e-1
)
)
if not torch.allclose(expected_val, expected_val2):
num_mismatch_at_low_tol2 += 1

Expand Down Expand Up @@ -156,8 +161,8 @@ def test_export_compile_aoti(self):
unwrap_tensor_subclass(model)

print("Exporting quantized model")
exported = torch.export.export(model, (activations,))
exported = torch.export.export(model, (activations,), strict=True)

print("Compiling quantized model")
compiled = torch.compile(unwrapped_model)
with torch.no_grad():
Expand Down
Loading