Skip to content

Commit

Permalink
Fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
jerryzh168 committed May 15, 2024
1 parent 94a058c commit e9fc492
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 15 deletions.
16 changes: 8 additions & 8 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ def test_eval_wrapper(self):
def test_quantized_tensor_subclass_8da4w(self):
from torchao.quantization.subclass import (
AffineQuantizedTensor,
LinearActAffineQuantizedTensor,
LinearActQuantizedTensor,
)
from torchao.quantization.quant_primitives import MappingType
import copy
Expand Down Expand Up @@ -428,15 +428,15 @@ def get_per_token_block_size(x):
def dynamic_quant(linear):
# note: order is important
linear.weight = torch.nn.Parameter(AffineQuantizedTensor.from_float(linear.weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps), requires_grad=False)
linear.weight = torch.nn.Parameter(LinearActAffineQuantizedTensor.from_float(linear.weight, input_quant_func), requires_grad=False)
linear.weight = torch.nn.Parameter(LinearActQuantizedTensor.from_float(linear.weight, input_quant_func), requires_grad=False)

m = ToyLinearModel().eval()
m_copy = copy.deepcopy(m)
example_inputs = m.example_inputs()
dynamic_quant(m.linear1)
dynamic_quant(m.linear2)
assert isinstance(m.linear1.weight, LinearActAffineQuantizedTensor)
assert isinstance(m.linear2.weight, LinearActAffineQuantizedTensor)
assert isinstance(m.linear1.weight, LinearActQuantizedTensor)
assert isinstance(m.linear2.weight, LinearActQuantizedTensor)

# reference
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
Expand Down Expand Up @@ -538,7 +538,7 @@ def to_quantized(weight):
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_quantized_tensor_subclass_int8_dyn_quant(self):
from torchao.quantization.subclass import AffineQuantizedTensor
from torchao.quantization.subclass import LinearActAffineQuantizedTensor
from torchao.quantization.subclass import LinearActQuantizedTensor
from torchao.quantization.quant_primitives import MappingType
from torchao.quantization.quant_primitives import ZeroPointDomain
import copy
Expand Down Expand Up @@ -573,12 +573,12 @@ def get_per_token_block_size(x):
def dynamic_quant(linear):
# note: order is important
linear.weight = torch.nn.Parameter(AffineQuantizedTensor.from_float(linear.weight, mapping_type, get_weight_block_size(linear.weight), target_dtype, eps=eps, zero_point_dtype=zero_point_dtype), requires_grad=False)
linear.weight = torch.nn.Parameter(LinearActAffineQuantizedTensor.from_float(linear.weight, input_quant_func), requires_grad=False)
linear.weight = torch.nn.Parameter(LinearActQuantizedTensor.from_float(linear.weight, input_quant_func), requires_grad=False)

dynamic_quant(m.linear1)
dynamic_quant(m.linear2)
assert isinstance(m.linear1.weight, LinearActAffineQuantizedTensor)
assert isinstance(m.linear2.weight, LinearActAffineQuantizedTensor)
assert isinstance(m.linear1.weight, LinearActQuantizedTensor)
assert isinstance(m.linear2.weight, LinearActQuantizedTensor)
assert isinstance(m.linear1.weight.original_weight_tensor, AffineQuantizedTensor)
assert isinstance(m.linear2.weight.original_weight_tensor, AffineQuantizedTensor)

Expand Down
15 changes: 8 additions & 7 deletions torchao/quantization/subclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,7 +830,6 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):
input_tensor = input_tensor.dequantize()

# weight only quantization

# TODO: enable cpu and mps path as well
# TODO: make sure weight dimension matches the expectation of the int4mm kernel
# TODO: move this to TinygemmAffineQuantizedTensor
Expand Down Expand Up @@ -862,6 +861,9 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):
# TODO: enable mps path as well
# per channel int8 weight only quantizated mm
return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), weight_qtensor.int_data, weight_qtensor.scale)
else:
weight_tensor = weight_qtensor.dequantize()
return torch.nn.functional.linear(input_tensor, weight_tensor, bias)
else:
if isinstance(input_tensor, AffineQuantizedTensor):
input_tensor = input_tensor.dequantize()
Expand Down Expand Up @@ -985,10 +987,9 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
)


class LinearActAffineQuantizedTensor(torch.Tensor):
class LinearActQuantizedTensor(torch.Tensor):
"""
Activation quantization with AffineQuantizedTensor
Applies activation affine quantization for linear operator
Applies activation quantization for linear operator
"""
def __new__(
cls,
Expand Down Expand Up @@ -1045,7 +1046,7 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):
args[1],
args[2] if len(args) > 2 else None,
)
if isinstance(weight_tensor, LinearActAffineQuantizedTensor):
if isinstance(weight_tensor, LinearActQuantizedTensor):
input_quant_func = weight_tensor.input_quant_func
original_weight_tensor = weight_tensor.original_weight_tensor
aqt = input_quant_func(input_tensor)
Expand All @@ -1054,7 +1055,7 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):
with torch._C.DisableTorchFunctionSubclass():
return func(*args, **kwargs)
except:
print(f"ERR: LinearActAffineQuantizedTensor subclass doesn't implement {func}")
print(f"ERR: LinearActQuantizedTensor subclass doesn't implement {func}")

def _apply_fn_to_data(self, fn):
return self.__class__(
Expand Down Expand Up @@ -1103,5 +1104,5 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
)

raise NotImplementedError(
f"LinearActAffineQuantizedTensor dispatch: attempting to run {func}, this is not supported"
f"LinearActQuantizedTensor dispatch: attempting to run {func}, this is not supported"
)

0 comments on commit e9fc492

Please sign in to comment.