Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wip] float8 training: invert the meaning of scale #1351

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,8 +565,8 @@ def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum):
use_fast_accum=use_fast_accum,
)
out_emulated = torch.mm(
a_fp8._data.float() / a_fp8._scale,
b_fp8._data.float() / b_fp8._scale,
a_fp8._data.float() * a_fp8._scale,
b_fp8._data.float() * b_fp8._scale,
).to(output_dtype)

if output_dtype != base_dtype:
Expand Down Expand Up @@ -841,13 +841,13 @@ def test_fp8_tensor_statistics(self):
tensor_len = x1_hp.numel()

# Overflow caused by a too large scaling factor
s_overflow = torch.tensor(1e9)
s_overflow = 1 / torch.tensor(1e9)
fp8_overflow = hp_tensor_and_scale_to_float8(x1_hp, s_overflow, lp_dtype)
(zero_cnt, max_cnt) = fp8_tensor_statistics(fp8_overflow, lp_dtype)
self.assertEqual((zero_cnt, max_cnt), (0, tensor_len))

# Underflow caused by a too small scaling factor
s_underflow = torch.tensor(1e-9)
s_underflow = 1 / torch.tensor(1e-9)
fp8_underflow = hp_tensor_and_scale_to_float8(x1_hp, s_underflow, lp_dtype)
(zero_cnt, max_cnt) = fp8_tensor_statistics(fp8_underflow, lp_dtype)
self.assertEqual((zero_cnt, max_cnt), (tensor_len, 0))
Expand Down
2 changes: 1 addition & 1 deletion torchao/float8/float8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def float8_mm(aten_op, args, kwargs=None):
b._linear_mm_config,
)
if scaled_mm_config.emulate:
return torch.mm(a._data.float() / a._scale, b._data.float() / b._scale).to(
return torch.mm(a._data.float() * a._scale, b._data.float() * b._scale).to(
output_dtype
)
tensor_out = addmm_float8_unwrapped(
Expand Down
6 changes: 4 additions & 2 deletions torchao/float8/float8_python_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@ def addmm_float8_unwrapped(
as inputs. This is used to standardize the logic between subclassed and non subclassed
versions of the linear module.
"""
a_inverse_scale = a_scale.reciprocal()
b_inverse_scale = b_scale.reciprocal()
# a_inverse_scale = a_scale.reciprocal()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not having to do this is what reduces the # of kernels for delayed scaling

# b_inverse_scale = b_scale.reciprocal()
a_inverse_scale = a_scale
b_inverse_scale = b_scale

if output_dtype == torch.float32 and bias is not None:
# Bias is not supported by _scaled_mm when output is fp32
Expand Down
6 changes: 4 additions & 2 deletions torchao/float8/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,8 @@ def forward(
# Note: when the line below is compiled with `torch.compile`, `tensor` is automatically
# upcasted to `float32` to multiply with the scale
# In order to match numerics between eager and compile, we upcast manually here.
tensor_scaled = tensor.to(torch.float32) * scale
# tensor_scaled = tensor.to(torch.float32) * scale
tensor_scaled = tensor.to(torch.float32) / scale
bits_fp8 = to_fp8_saturated(tensor_scaled, float8_dtype)

if isinstance(bits_fp8, DTensor):
Expand Down Expand Up @@ -203,7 +204,8 @@ class _FromFloat8ConstrFunc(torch.autograd.Function):

@staticmethod
def forward(ctx, tensor):
return tensor._data.to(tensor._orig_dtype) / tensor._scale
# return tensor._data.to(tensor._orig_dtype) / tensor._scale
return tensor._data.to(tensor._orig_dtype) * tensor._scale

@staticmethod
def backward(ctx, g):
Expand Down
3 changes: 2 additions & 1 deletion torchao/float8/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ def amax_to_scale(amax: torch.Tensor, float8_dtype: torch.dtype):
# upcast to float64 to ensure same numeric between compile and eager
amax = amax.to(torch.float64)
if float8_dtype in FP8_TYPES:
res = torch.finfo(float8_dtype).max / torch.clamp(amax, min=EPS)
# res = torch.finfo(float8_dtype).max / torch.clamp(amax, min=EPS)
res = torch.clamp(amax, min=EPS) / torch.finfo(float8_dtype).max
else:
raise ValueError(f"Unsupported float8_dtype: {float8_dtype}")

Expand Down
8 changes: 3 additions & 5 deletions torchao/float8/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,12 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None:
# clamp is dispatched through DTensor
# it will issue a single all-reduce
amax_tensor = torch.clamp(amax_tensor, EPS) # Replicate
# keep consistent with float8_utils.amax_to_scale
# torch.compile and eager show different numerics for 1.0 / float32,
# upcast to float64 to ensure same numeric between compile and eager
origin_dtype = amax_tensor.dtype
amax_tensor = amax_tensor.to(torch.float64)
scale_tensor = torch.finfo(torch.float8_e4m3fn).max / amax_tensor # Replicate
if origin_dtype is torch.float16:
scale_tensor = torch.clamp(scale_tensor, max=torch.finfo(torch.float16).max)
# scale_tensor = torch.finfo(torch.float8_e4m3fn).max / amax_tensor # Replicate
# TODO(future PR): make the e4m3 dtype customizeable here
scale_tensor = amax_tensor / torch.finfo(torch.float8_e4m3fn).max # Replicate
local_scale_tensor = scale_tensor.to_local().to(torch.float32)
for i, float8_linear in enumerate(float8_linears):
float8_linear.weight._local_tensor._precomputed_scale = local_scale_tensor[i]
Expand Down
Loading