From 4b2dfafd05e1bfa4e8023b894b35f02aed6e5342 Mon Sep 17 00:00:00 2001 From: Shuqi Yang Date: Thu, 2 Jan 2025 01:39:56 -0800 Subject: [PATCH] Add a register_replacement to fix float8 delayed scaling kernel fusion issues in torchao/float8 (#1469) Summary: The original pr in inductor, https://github.com/pytorch/pytorch/pull/143464 As suggested in the comments, we moved the patterns to torchao repo. Differential Revision: D67758184 --- torchao/float8/inductor_utils.py | 126 +++++++++++++++++++++++ torchao/float8/stateful_float8_linear.py | 3 + 2 files changed, 129 insertions(+) create mode 100644 torchao/float8/inductor_utils.py diff --git a/torchao/float8/inductor_utils.py b/torchao/float8/inductor_utils.py new file mode 100644 index 000000000..2ade4ca47 --- /dev/null +++ b/torchao/float8/inductor_utils.py @@ -0,0 +1,126 @@ +import functools +import inspect +from collections import deque + +import torch + + +E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max +E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max + +def amax_with_scaling_pattern(tensor_x_inp, scale_x, IS_E5M2): + tensor_x = tensor_x_inp.to(torch.float32) * scale_x + if IS_E5M2: + tensor_x = tensor_x.clamp(min=-1 * E5M2_MAX_POS, max=E5M2_MAX_POS) + tensor_x = tensor_x.to(torch.float8_e5m2) + else: + tensor_x = tensor_x.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS) + tensor_x = tensor_x.to(torch.float8_e4m3fn) + amax = torch.max(torch.abs(tensor_x_inp)) + return (tensor_x, amax) + + +def amax_with_scaling_tiled_replacement(tensor_x_inp, scale_x, IS_E5M2): + tensor_x = tensor_x_inp.to(torch.float32) * scale_x + if IS_E5M2: + tensor_x = tensor_x.clamp(min=-1 * E5M2_MAX_POS, max=E5M2_MAX_POS) + tensor_x = tensor_x.to(torch.float8_e5m2) + else: + tensor_x = tensor_x.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS) + tensor_x = tensor_x.to(torch.float8_e4m3fn) + amax_1 = torch.max(torch.abs(tensor_x_inp), dim=-1).values + amax = torch.max(amax_1) + return (tensor_x, amax) + + +# The amax_with_scaling_pattern will also match dynamic scaling cases, we want to avoid that. +# `scale_x` of delayed scaling comes from the previous iteration, instead of from `tensor_x_inp`. +# We check that `scale_x` is not a dependency of `tensor_x_inp` +def fp8_delayed_scaling_extra_check(match): + scale_x_inputs = deque([match.kwargs["scale_x"]]) + max_num_node_to_check = 50 # Don't traverse too many nodes + current_num_node = 0 + while len(scale_x_inputs) > 0 and current_num_node < max_num_node_to_check: + current_node = scale_x_inputs.popleft() + for n in current_node.all_input_nodes: + if n == match.kwargs["tensor_x_inp"]: + return False + scale_x_inputs.append(n) + current_num_node += 1 + return True + + +def partialize_and_update_signature(func, **kwargs): + """ + Equivalent to functools.partial but also updates the signature on returned function + """ + original_sig = inspect.signature(func) + parameters = original_sig.parameters + + new_parameters = { + key: value for key, value in parameters.items() if key not in kwargs + } + new_sig = inspect.Signature(parameters=list(new_parameters.values())) + + partial_func = functools.partial(func, **kwargs) + + def wrapper(*args, **kwargs): + return partial_func(*args, **kwargs) + + wrapper.__signature__ = new_sig # type: ignore[attr-defined] + wrapper.__name__ = func.__name__ + + return wrapper + + +def register_fp8_delayed_scaling_patterns_inner(): + from torch._inductor.pattern_matcher import fwd_only, register_replacement + from torch._inductor.fx_passes.post_grad import pass_patterns as post_grad_patterns_all + + post_grad_patterns = post_grad_patterns_all[1] # medium priority + + if torch.cuda.is_available(): + for IS_E5M2 in [True, False]: + # torch.float16 has the same pattern as torch.bfloat16, because they both needs `tensor_x_inp.to(torch.float32)` + # It will cause errors in `assert pattern_repr not in _seen_patterns` + for dtype in [torch.float32, torch.bfloat16]: + device = "cuda" + register_replacement( + partialize_and_update_signature( + amax_with_scaling_pattern, IS_E5M2=IS_E5M2 + ), + partialize_and_update_signature( + amax_with_scaling_tiled_replacement, IS_E5M2=IS_E5M2 + ), + [ + torch.tensor((16, 16), device=device, dtype=dtype), + torch.tensor(2.0, device=device, dtype=torch.float32), + ], + fwd_only, + post_grad_patterns, + extra_check=fp8_delayed_scaling_extra_check, + ) + + +def run_once(f): + def wrapper(*args, **kwargs): + if not wrapper.has_run: + wrapper.has_run = True + wrapper.result = f(*args, **kwargs) + return wrapper.result + else: + return wrapper.result + wrapper.has_run = False + wrapper.result = None + return wrapper + + +@run_once +def register_fp8_delayed_scaling_patterns() -> bool: + # To make the fp8 delayed scaling pattern work, we need a fix pr from inductor, https://github.com/pytorch/pytorch/pull/139321 + # Added the try-catch block to ignore the failed pattern if the current torch verions doesn't include the fix pr. + try: + register_fp8_delayed_scaling_patterns_inner() + except: + return False + return True diff --git a/torchao/float8/stateful_float8_linear.py b/torchao/float8/stateful_float8_linear.py index 94851511b..62f13bdfc 100644 --- a/torchao/float8/stateful_float8_linear.py +++ b/torchao/float8/stateful_float8_linear.py @@ -36,6 +36,9 @@ WeightWithDynamicFloat8CastTensor, WeightWithStaticFloat8CastTensor, ) +from torchao.float8.inductor_utils import register_fp8_delayed_scaling_patterns + +register_fp8_delayed_scaling_patterns() class StatefulFloat8Linear(Float8Linear):