forked from pytorch/ao
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a register_replacement to fix float8 delayed scaling kernel fusio…
…n issues in torchao/float8 (pytorch#1469) Summary: The original pr in inductor, pytorch/pytorch#143464 As suggested in the comments, we moved the patterns to torchao repo. Differential Revision: D67758184
- Loading branch information
1 parent
fe5f11b
commit db11bad
Showing
2 changed files
with
133 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
import functools | ||
import inspect | ||
from collections import deque | ||
|
||
import torch | ||
|
||
|
||
E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max | ||
E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max | ||
|
||
|
||
def amax_with_scaling_pattern(tensor_x_inp, scale_x, IS_E5M2): | ||
tensor_x = tensor_x_inp.to(torch.float32) * scale_x | ||
if IS_E5M2: | ||
tensor_x = tensor_x.clamp(min=-1 * E5M2_MAX_POS, max=E5M2_MAX_POS) | ||
tensor_x = tensor_x.to(torch.float8_e5m2) | ||
else: | ||
tensor_x = tensor_x.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS) | ||
tensor_x = tensor_x.to(torch.float8_e4m3fn) | ||
amax = torch.max(torch.abs(tensor_x_inp)) | ||
return (tensor_x, amax) | ||
|
||
|
||
def amax_with_scaling_tiled_replacement(tensor_x_inp, scale_x, IS_E5M2): | ||
tensor_x = tensor_x_inp.to(torch.float32) * scale_x | ||
if IS_E5M2: | ||
tensor_x = tensor_x.clamp(min=-1 * E5M2_MAX_POS, max=E5M2_MAX_POS) | ||
tensor_x = tensor_x.to(torch.float8_e5m2) | ||
else: | ||
tensor_x = tensor_x.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS) | ||
tensor_x = tensor_x.to(torch.float8_e4m3fn) | ||
amax_1 = torch.max(torch.abs(tensor_x_inp), dim=-1).values | ||
amax = torch.max(amax_1) | ||
return (tensor_x, amax) | ||
|
||
|
||
# The amax_with_scaling_pattern will also match dynamic scaling cases, we want to avoid that. | ||
# `scale_x` of delayed scaling comes from the previous iteration, instead of from `tensor_x_inp`. | ||
# We check that `scale_x` is not a dependency of `tensor_x_inp` | ||
def fp8_delayed_scaling_extra_check(match): | ||
scale_x_inputs = deque([match.kwargs["scale_x"]]) | ||
max_num_node_to_check = 50 # Don't traverse too many nodes | ||
current_num_node = 0 | ||
while len(scale_x_inputs) > 0 and current_num_node < max_num_node_to_check: | ||
current_node = scale_x_inputs.popleft() | ||
for n in current_node.all_input_nodes: | ||
if n == match.kwargs["tensor_x_inp"]: | ||
return False | ||
scale_x_inputs.append(n) | ||
current_num_node += 1 | ||
return True | ||
|
||
|
||
def partialize_and_update_signature(func, **kwargs): | ||
""" | ||
Equivalent to functools.partial but also updates the signature on returned function | ||
""" | ||
original_sig = inspect.signature(func) | ||
parameters = original_sig.parameters | ||
|
||
new_parameters = { | ||
key: value for key, value in parameters.items() if key not in kwargs | ||
} | ||
new_sig = inspect.Signature(parameters=list(new_parameters.values())) | ||
|
||
partial_func = functools.partial(func, **kwargs) | ||
|
||
def wrapper(*args, **kwargs): | ||
return partial_func(*args, **kwargs) | ||
|
||
wrapper.__signature__ = new_sig # type: ignore[attr-defined] | ||
wrapper.__name__ = func.__name__ | ||
|
||
return wrapper | ||
|
||
|
||
def register_fp8_delayed_scaling_patterns_inner(): | ||
from torch._inductor.pattern_matcher import fwd_only, register_replacement | ||
from torch._inductor.fx_passes.post_grad import ( | ||
pass_patterns as post_grad_patterns_all, | ||
) | ||
|
||
post_grad_patterns = post_grad_patterns_all[1] # medium priority | ||
|
||
if torch.cuda.is_available(): | ||
for IS_E5M2 in [True, False]: | ||
# torch.float16 has the same pattern as torch.bfloat16, because they both needs `tensor_x_inp.to(torch.float32)` | ||
# It will cause errors in `assert pattern_repr not in _seen_patterns` | ||
for dtype in [torch.float32, torch.bfloat16]: | ||
device = "cuda" | ||
register_replacement( | ||
partialize_and_update_signature( | ||
amax_with_scaling_pattern, IS_E5M2=IS_E5M2 | ||
), | ||
partialize_and_update_signature( | ||
amax_with_scaling_tiled_replacement, IS_E5M2=IS_E5M2 | ||
), | ||
[ | ||
torch.tensor((16, 16), device=device, dtype=dtype), | ||
torch.tensor(2.0, device=device, dtype=torch.float32), | ||
], | ||
fwd_only, | ||
post_grad_patterns, | ||
extra_check=fp8_delayed_scaling_extra_check, | ||
) | ||
|
||
|
||
def run_once(f): | ||
def wrapper(*args, **kwargs): | ||
if not wrapper.has_run: | ||
wrapper.has_run = True | ||
wrapper.result = f(*args, **kwargs) | ||
return wrapper.result | ||
else: | ||
return wrapper.result | ||
|
||
wrapper.has_run = False | ||
wrapper.result = None | ||
return wrapper | ||
|
||
|
||
@run_once | ||
def register_fp8_delayed_scaling_patterns() -> bool: | ||
# To make the fp8 delayed scaling pattern work, we need a fix pr from inductor, https://github.com/pytorch/pytorch/pull/139321 | ||
# Added the try-catch block to ignore the failed pattern if the current torch verions doesn't include the fix pr. | ||
try: | ||
register_fp8_delayed_scaling_patterns_inner() | ||
except: | ||
return False | ||
return True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters