Skip to content

Commit

Permalink
float8 dynamic autoquant
Browse files Browse the repository at this point in the history
  • Loading branch information
jainapurva committed Sep 26, 2024
1 parent 58fe60d commit f646519
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 3 deletions.
2 changes: 1 addition & 1 deletion torchao/quantization/autoquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
INTERPOLATION_CONSTANT = mode[1]
w_qtensor = cls.from_float(weight)
x_vals_float8, x_scales = quantize_activation_per_token_absmax(
act_mat.reshape(-1, act_mat.shape[-1])
act_mat.reshape(-1, act_mat.shape[-1]), dtype=torch.float8_e4m3fn
)
quantized_matmul = (
lambda x_vals_float8, x_scales, w_vals_float8:
Expand Down
3 changes: 1 addition & 2 deletions torchao/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,12 @@ def _get_per_token_block_size(x: torch.Tensor) -> List[int]:
# taken from
# https://github.com/mit-han-lab/smoothquant/blob/2f87951dacfb9238d8d657f52ae83a82a3c9ba0c/smoothquant/fake_quant.py#L26
# and slightly modified
def quantize_activation_per_token_absmax(t):
def quantize_activation_per_token_absmax(t, dtype=torch.int8):
# if the shape of t is [B, N, K], the shape of scales will be [B, N, 1]
mapping_type = MappingType.SYMMETRIC
block_size = list(t.shape)
for i in range(len(block_size) - 1):
block_size[i] = 1
dtype = torch.int8
eps = 1e-5
# Note: the original smoothquant does not clamp to qmin/qmax here,
# but some of the tests with bfloat16 ended up with a flipped sign
Expand Down

0 comments on commit f646519

Please sign in to comment.