Skip to content

Commit

Permalink
Add a register_replacement to fix float8 delayed scaling kernel fusio…
Browse files Browse the repository at this point in the history
…n issues in torchao/float8 (#1469)

Summary:

The original pr in inductor, pytorch/pytorch#143464

As suggested in the comments, we moved the patterns to torchao repo.

Differential Revision: D67758184
  • Loading branch information
y-sq authored and facebook-github-bot committed Jan 2, 2025
1 parent fe5f11b commit 4b00505
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 0 deletions.
126 changes: 126 additions & 0 deletions torchao/float8/inductor_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import functools
import inspect
from collections import deque

import torch


E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max
E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max

def amax_with_scaling_pattern(tensor_x_inp, scale_x, IS_E5M2):
tensor_x = tensor_x_inp.to(torch.float32) * scale_x
if IS_E5M2:
tensor_x = tensor_x.clamp(min=-1 * E5M2_MAX_POS, max=E5M2_MAX_POS)
tensor_x = tensor_x.to(torch.float8_e5m2)
else:
tensor_x = tensor_x.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS)
tensor_x = tensor_x.to(torch.float8_e4m3fn)
amax = torch.max(torch.abs(tensor_x_inp))
return (tensor_x, amax)


def amax_with_scaling_tiled_replacement(tensor_x_inp, scale_x, IS_E5M2):
tensor_x = tensor_x_inp.to(torch.float32) * scale_x
if IS_E5M2:
tensor_x = tensor_x.clamp(min=-1 * E5M2_MAX_POS, max=E5M2_MAX_POS)
tensor_x = tensor_x.to(torch.float8_e5m2)
else:
tensor_x = tensor_x.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS)
tensor_x = tensor_x.to(torch.float8_e4m3fn)
amax_1 = torch.max(torch.abs(tensor_x_inp), dim=-1).values
amax = torch.max(amax_1)
return (tensor_x, amax)


# The amax_with_scaling_pattern will also match dynamic scaling cases, we want to avoid that.
# `scale_x` of delayed scaling comes from the previous iteration, instead of from `tensor_x_inp`.
# We check that `scale_x` is not a dependency of `tensor_x_inp`
def fp8_delayed_scaling_extra_check(match):
scale_x_inputs = deque([match.kwargs["scale_x"]])
max_num_node_to_check = 50 # Don't traverse too many nodes
current_num_node = 0
while len(scale_x_inputs) > 0 and current_num_node < max_num_node_to_check:
current_node = scale_x_inputs.popleft()
for n in current_node.all_input_nodes:
if n == match.kwargs["tensor_x_inp"]:
return False
scale_x_inputs.append(n)
current_num_node += 1
return True


def partialize_and_update_signature(func, **kwargs):
"""
Equivalent to functools.partial but also updates the signature on returned function
"""
original_sig = inspect.signature(func)
parameters = original_sig.parameters

new_parameters = {
key: value for key, value in parameters.items() if key not in kwargs
}
new_sig = inspect.Signature(parameters=list(new_parameters.values()))

partial_func = functools.partial(func, **kwargs)

def wrapper(*args, **kwargs):
return partial_func(*args, **kwargs)

wrapper.__signature__ = new_sig # type: ignore[attr-defined]
wrapper.__name__ = func.__name__

return wrapper


def register_fp8_delayed_scaling_patterns_inner():
from torch._inductor.pattern_matcher import fwd_only, register_replacement
from torch._inductor.fx_passes.post_grad import pass_patterns as post_grad_patterns_all

post_grad_patterns = post_grad_patterns_all[1] # medium priority

if torch.cuda.is_available():
for IS_E5M2 in [True, False]:
# torch.float16 has the same pattern as torch.bfloat16, because they both needs `tensor_x_inp.to(torch.float32)`
# It will cause errors in `assert pattern_repr not in _seen_patterns`
for dtype in [torch.float32, torch.bfloat16]:
device = "cuda"
register_replacement(
partialize_and_update_signature(
amax_with_scaling_pattern, IS_E5M2=IS_E5M2
),
partialize_and_update_signature(
amax_with_scaling_tiled_replacement, IS_E5M2=IS_E5M2
),
[
torch.tensor((16, 16), device=device, dtype=dtype),
torch.tensor(2.0, device=device, dtype=torch.float32),
],
fwd_only,
post_grad_patterns,
extra_check=fp8_delayed_scaling_extra_check,
)


def run_once(f):
def wrapper(*args, **kwargs):
if not wrapper.has_run:
wrapper.has_run = True
wrapper.result = f(*args, **kwargs)
return wrapper.result
else:
return wrapper.result
wrapper.has_run = False
wrapper.result = None
return wrapper


@run_once
def register_fp8_delayed_scaling_patterns() -> bool:
# To make the fp8 delayed scaling pattern work, we need a fix pr from inductor, https://github.com/pytorch/pytorch/pull/139321
# Added the try-catch block to ignore the failed pattern if the current torch verions doesn't include the fix pr.
try:
register_fp8_delayed_scaling_patterns_inner()
except:
return False
return True
3 changes: 3 additions & 0 deletions torchao/float8/stateful_float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@
WeightWithDynamicFloat8CastTensor,
WeightWithStaticFloat8CastTensor,
)
from torchao.float8.inductor_utils import register_fp8_delayed_scaling_patterns

register_fp8_delayed_scaling_patterns()


class StatefulFloat8Linear(Float8Linear):
Expand Down

0 comments on commit 4b00505

Please sign in to comment.