diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index 2e67f5a4a..bd61320be 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -127,10 +127,7 @@ def to_mx( # For now, calculate the scale in floating point. # TODO(future) audit if there is a need to bit shift exponents instead. - scale_fp = torch.pow( - torch.full(max_abs.size(), 2.0, device=scale_e8m0_biased.device), - scale_e8m0_unbiased, - ) + scale_fp = torch.exp2(scale_e8m0_unbiased).to(torch.float32) # Today, 2**-127 returns 0 in compile+inductor+triton because it is in the # float32 denormal range. For now, manually adjust the fp scale. This is @@ -177,13 +174,10 @@ def to_mx( def get_fp_scale(scale_e8m0): s_offset = scale_e8m0.to(torch.int16) - E8M0_EXPONENT_BIAS - # TODO(later): it would be nice if there was a way to do the 2^x operation - # in PyTorch without creating a tensor of twos - two = torch.full(s_offset.size(), 2.0, device=scale_e8m0.device) - # pow(two, s_offset) can be out of range of floating point formats. + # TODO(later): handle this for float16 if we decide to support float16 # scales. - s_fp = torch.pow(two, s_offset) + s_fp = torch.exp2(s_offset) # If a block exponent was 255, set values of that block to NaN s_fp = torch.where(scale_e8m0 != E8M0_EXPONENT_NAN_VAL, s_fp, float("nan"))