diff --git a/test/float8/test_base.py b/test/float8/test_base.py index d00b96d3bb..751c66f8ad 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -565,8 +565,8 @@ def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum): use_fast_accum=use_fast_accum, ) out_emulated = torch.mm( - a_fp8._data.float() / a_fp8._scale, - b_fp8._data.float() / b_fp8._scale, + a_fp8._data.float() * a_fp8._scale, + b_fp8._data.float() * b_fp8._scale, ).to(output_dtype) if output_dtype != base_dtype: @@ -841,13 +841,13 @@ def test_fp8_tensor_statistics(self): tensor_len = x1_hp.numel() # Overflow caused by a too large scaling factor - s_overflow = torch.tensor(1e9) + s_overflow = 1 / torch.tensor(1e9) fp8_overflow = hp_tensor_and_scale_to_float8(x1_hp, s_overflow, lp_dtype) (zero_cnt, max_cnt) = fp8_tensor_statistics(fp8_overflow, lp_dtype) self.assertEqual((zero_cnt, max_cnt), (0, tensor_len)) # Underflow caused by a too small scaling factor - s_underflow = torch.tensor(1e-9) + s_underflow = 1 / torch.tensor(1e-9) fp8_underflow = hp_tensor_and_scale_to_float8(x1_hp, s_underflow, lp_dtype) (zero_cnt, max_cnt) = fp8_tensor_statistics(fp8_underflow, lp_dtype) self.assertEqual((zero_cnt, max_cnt), (tensor_len, 0)) diff --git a/torchao/float8/float8_ops.py b/torchao/float8/float8_ops.py index 921d50e093..76059921d4 100644 --- a/torchao/float8/float8_ops.py +++ b/torchao/float8/float8_ops.py @@ -285,7 +285,7 @@ def float8_mm(aten_op, args, kwargs=None): b._linear_mm_config, ) if scaled_mm_config.emulate: - return torch.mm(a._data.float() / a._scale, b._data.float() / b._scale).to( + return torch.mm(a._data.float() * a._scale, b._data.float() * b._scale).to( output_dtype ) tensor_out = addmm_float8_unwrapped( diff --git a/torchao/float8/float8_python_api.py b/torchao/float8/float8_python_api.py index 6608dba958..2a5f6e8c1e 100644 --- a/torchao/float8/float8_python_api.py +++ b/torchao/float8/float8_python_api.py @@ -34,8 +34,10 @@ def addmm_float8_unwrapped( as inputs. This is used to standardize the logic between subclassed and non subclassed versions of the linear module. """ - a_inverse_scale = a_scale.reciprocal() - b_inverse_scale = b_scale.reciprocal() + # a_inverse_scale = a_scale.reciprocal() + # b_inverse_scale = b_scale.reciprocal() + a_inverse_scale = a_scale + b_inverse_scale = b_scale if output_dtype == torch.float32 and bias is not None: # Bias is not supported by _scaled_mm when output is fp32 diff --git a/torchao/float8/float8_tensor.py b/torchao/float8/float8_tensor.py index 20f40330a8..656387a57b 100644 --- a/torchao/float8/float8_tensor.py +++ b/torchao/float8/float8_tensor.py @@ -151,7 +151,8 @@ def forward( # Note: when the line below is compiled with `torch.compile`, `tensor` is automatically # upcasted to `float32` to multiply with the scale # In order to match numerics between eager and compile, we upcast manually here. - tensor_scaled = tensor.to(torch.float32) * scale + # tensor_scaled = tensor.to(torch.float32) * scale + tensor_scaled = tensor.to(torch.float32) / scale bits_fp8 = to_fp8_saturated(tensor_scaled, float8_dtype) if isinstance(bits_fp8, DTensor): @@ -203,7 +204,8 @@ class _FromFloat8ConstrFunc(torch.autograd.Function): @staticmethod def forward(ctx, tensor): - return tensor._data.to(tensor._orig_dtype) / tensor._scale + # return tensor._data.to(tensor._orig_dtype) / tensor._scale + return tensor._data.to(tensor._orig_dtype) * tensor._scale @staticmethod def backward(ctx, g): diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index 29319f3814..0a2f13f2c6 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -45,7 +45,8 @@ def amax_to_scale(amax: torch.Tensor, float8_dtype: torch.dtype): # upcast to float64 to ensure same numeric between compile and eager amax = amax.to(torch.float64) if float8_dtype in FP8_TYPES: - res = torch.finfo(float8_dtype).max / torch.clamp(amax, min=EPS) + # res = torch.finfo(float8_dtype).max / torch.clamp(amax, min=EPS) + res = torch.clamp(amax, min=EPS) / torch.finfo(float8_dtype).max else: raise ValueError(f"Unsupported float8_dtype: {float8_dtype}") diff --git a/torchao/float8/fsdp_utils.py b/torchao/float8/fsdp_utils.py index 8c60995a86..648e16441e 100644 --- a/torchao/float8/fsdp_utils.py +++ b/torchao/float8/fsdp_utils.py @@ -64,14 +64,12 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None: # clamp is dispatched through DTensor # it will issue a single all-reduce amax_tensor = torch.clamp(amax_tensor, EPS) # Replicate - # keep consistent with float8_utils.amax_to_scale # torch.compile and eager show different numerics for 1.0 / float32, # upcast to float64 to ensure same numeric between compile and eager - origin_dtype = amax_tensor.dtype amax_tensor = amax_tensor.to(torch.float64) - scale_tensor = torch.finfo(torch.float8_e4m3fn).max / amax_tensor # Replicate - if origin_dtype is torch.float16: - scale_tensor = torch.clamp(scale_tensor, max=torch.finfo(torch.float16).max) + # scale_tensor = torch.finfo(torch.float8_e4m3fn).max / amax_tensor # Replicate + # TODO(future PR): make the e4m3 dtype customizeable here + scale_tensor = amax_tensor / torch.finfo(torch.float8_e4m3fn).max # Replicate local_scale_tensor = scale_tensor.to_local().to(torch.float32) for i, float8_linear in enumerate(float8_linears): float8_linear.weight._local_tensor._precomputed_scale = local_scale_tensor[i]