From f6465196a2dcf379d1c66cc57a4505b9082f5121 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Wed, 25 Sep 2024 13:52:07 -0700 Subject: [PATCH] float8 dynamic autoquant --- torchao/quantization/autoquant.py | 2 +- torchao/quantization/utils.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index 19780088ee..bde7028f23 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -553,7 +553,7 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]): INTERPOLATION_CONSTANT = mode[1] w_qtensor = cls.from_float(weight) x_vals_float8, x_scales = quantize_activation_per_token_absmax( - act_mat.reshape(-1, act_mat.shape[-1]) + act_mat.reshape(-1, act_mat.shape[-1]), dtype=torch.float8_e4m3fn ) quantized_matmul = ( lambda x_vals_float8, x_scales, w_vals_float8: diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index 0df6174d0f..a9632aa90f 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -139,13 +139,12 @@ def _get_per_token_block_size(x: torch.Tensor) -> List[int]: # taken from # https://github.com/mit-han-lab/smoothquant/blob/2f87951dacfb9238d8d657f52ae83a82a3c9ba0c/smoothquant/fake_quant.py#L26 # and slightly modified -def quantize_activation_per_token_absmax(t): +def quantize_activation_per_token_absmax(t, dtype=torch.int8): # if the shape of t is [B, N, K], the shape of scales will be [B, N, 1] mapping_type = MappingType.SYMMETRIC block_size = list(t.shape) for i in range(len(block_size) - 1): block_size[i] = 1 - dtype = torch.int8 eps = 1e-5 # Note: the original smoothquant does not clamp to qmin/qmax here, # but some of the tests with bfloat16 ended up with a flipped sign