From 101d731e2f86e15dfa141530de894b7853adb107 Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Mon, 7 Oct 2024 15:02:40 -0700 Subject: [PATCH] float8 training axiswise scaling support with per-gemm-argument configuration (#940) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: This PR finalizes the UX for axiswise scaling for float8 training, and introduces per-gemm-argument configurability to Float8Linear to enable exploration of future recipes. Not all combinations are supported yet. Specifically, the additional combination we now support and test is a recipe from @lw , where we do the following: ``` output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1 grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise grad_weight_hp = input_t_hp @ grad_output_hp ``` Key characteristics of this recipe: 1. increased accuracy for `grad_weight`, which is important for real workloads 2. `output` and `weight` now only need to be scaled axiswise across a single dim compared to vanilla all-axiswise, which is more amenable to fast kernels Here is how a user can configure this: ```python # # short form # config = torchao.float8.config.recipe_name_to_linear_config(Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP) # # or, long form # # output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1 cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) # grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise cc_go = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) cc_w_go = CastConfig(scaling_granularity=ScalingGranularity.TENSORWISE) # grad_weight_hp = input_t_hp @ grad_output_hp cc_i_gw = CastConfig(scaling_type=ScalingType.DISABLED) cc_go_gw = CastConfig(scaling_type=ScalingType.DISABLED) # ensure fast_accum is on to get fast kernels gc_o = Float8GemmConfig(use_fast_accum=True) gc_gi = Float8GemmConfig(use_fast_accum=True) gc_gw = Float8GemmConfig(use_fast_accum=True) config = Float8Config( cast_config_input = cc_i, cast_config_weight = cc_w, cast_config_grad_output = cc_go, cast_config_input_for_grad_weight = cc_i_gw, cast_config_weight_for_grad_output = cc_w_go, cast_config_grad_output_for_grad_weight = cc_go_gw, gemm_config_output=gc_o, gemm_config_grad_input=gc_gi, gemm_config_grad_weight=gc_gw, ) ``` # performance Below we provide basic performance characteristics of axiswise scaling in general, and the all-axiswise and lw recipes. ## gemm performance of torch._scaled_mm baseline: tensorwise scaling ``` > python benchmarks/float8/bench_matmul.py --shape_gen_name square --use_gpu_kernel_time True fast_accum name M K N ref_time_s fp8_time_s fp8_speedup 0 True 0 256 256 256 0.000004 0.000006 0.573115 1 True 1 512 512 512 0.000005 0.000007 0.659333 2 True 2 1024 1024 1024 0.000011 0.000010 1.080664 3 True 3 2048 2048 2048 0.000028 0.000017 1.596239 4 True 4 4096 4096 4096 0.000210 0.000082 2.551705 5 True 5 8192 8192 8192 0.001671 0.000680 2.457972 6 True 6 16384 16384 16384 0.015030 0.006498 2.313032 7 True 7 32768 32768 32768 0.103236 0.048097 2.146411 8 False 0 256 256 256 0.000004 0.000006 0.630061 9 False 1 512 512 512 0.000005 0.000007 0.767236 10 False 2 1024 1024 1024 0.000012 0.000008 1.391347 11 False 3 2048 2048 2048 0.000029 0.000020 1.457922 12 False 4 4096 4096 4096 0.000211 0.000101 2.100081 13 False 5 8192 8192 8192 0.001676 0.000788 2.128628 14 False 6 16384 16384 16384 0.014933 0.006351 2.351209 15 False 7 32768 32768 32768 0.103457 0.049498 2.090134 ``` experiment: axiswise-scaling ``` > python benchmarks/float8/bench_matmul.py --shape_gen_name square --use_gpu_kernel_time True --scaling_granularity axiswise fast_accum name M K N ref_time_s fp8_time_s fp8_speedup 0 True 0 256 256 256 0.000004 0.000004 0.966772 1 True 1 512 512 512 0.000005 0.000004 1.095791 2 True 2 1024 1024 1024 0.000011 0.000006 1.988363 3 True 3 2048 2048 2048 0.000027 0.000015 1.890065 4 True 4 4096 4096 4096 0.000210 0.000082 2.552356 5 True 5 8192 8192 8192 0.001674 0.001092 1.533132 6 True 6 16384 16384 16384 0.015114 0.008785 1.720480 7 True 7 32768 32768 32768 0.103286 0.071456 1.445439 8 False 0 256 256 256 0.000004 0.000004 0.899054 9 False 1 512 512 512 0.000005 0.000005 1.005340 10 False 2 1024 1024 1024 0.000011 0.000006 1.692868 11 False 3 2048 2048 2048 0.000028 0.000049 0.567655 12 False 4 4096 4096 4096 0.000210 0.000341 0.616193 13 False 5 8192 8192 8192 0.001678 0.002640 0.635541 14 False 6 16384 16384 16384 0.015051 0.021557 0.698212 15 False 7 32768 32768 32768 0.103497 0.169797 0.609533 ``` ## performance on microbenchmark of ln -> linear -> sigmoid Note: for large square shapes, performance tends to be fp8_delayed_tensorwise > fp8_dynamic_tensorwise > fp8_dynamic_axiswise > custom_recipe. For performance of fp8_dynamic_axiswise, it seems that the gap from tensorwise is mostly due to the gemm performance being behind tensorwise. ``` > python benchmarks/float8/float8_roofline.py ~/local/tmp/20241004_roofline.csv fwd_M fwd_K fwd_N bf16_gemm_s fp8_gemm_s fp8_axs_gemm_time_s fp8_oh_dyn_limit ... fp8_del_s fp8_dyn_axs_s fp8_lw_s fp8_dyn_sp fp8_del_sp fp8_dyn_axs_sp fp8_lw_sp 0 256 256 256 0.000011 0.000018 0.000012 6.50457971014493e-6 ... 0.000043 0.000049 0.000030 0.465634 0.457907 0.398357 0.643088 1 512 512 512 0.000014 0.000020 0.000013 8.01831884057971e-6 ... 0.000047 0.000054 0.000034 0.489556 0.493467 0.432643 0.685842 2 1024 1024 1024 0.000033 0.000026 0.000017 1.40732753623188e-5 ... 0.000060 0.000063 0.000050 0.734123 0.741467 0.705941 0.891199 3 2048 2048 2048 0.000081 0.000055 0.000044 3.82931014492754e-5 ... 0.000147 0.000159 0.000142 0.815678 0.800811 0.739865 0.827441 4 4096 4096 4096 0.000632 0.000274 0.000247 0.000135172405797101 ... 0.000602 0.000622 0.000662 1.236320 1.261848 1.221755 1.147678 5 8192 8192 8192 0.005027 0.002216 0.003292 0.000522689623188406 ... 0.003665 0.004776 0.005720 1.432213 1.513035 1.161130 0.969448 6 16384 16384 16384 0.045113 0.018975 0.025706 0.00207275849275362 ... 0.024664 0.032254 0.038051 1.803456 1.883291 1.440118 1.220738 7 32768 32768 32768 0.312459 0.147255 0.214492 0.00827303397101449 ... 0.182645 0.240962 0.270973 1.696376 1.766307 1.338827 1.190552 ``` ## performance on torchtitan LLaMa 3 8B on 8 H100 GPUs, float8 compute only: * baseline (bf16 + compile): 6,294 wps * f8 all-tensorwise: 7,359 wps (1.17x vs baseline) * f8 all-axiswise: 7,135 wps (1.13x vs baseline - surprising that this is close to all-tensorwise) * LW_AXISWISE_WITH_GW_HP: 6,506 wps (1.03x vs baseline) so, looks like we have performance work to do with `LW_AXISWISE_WITH_GW_HP` in future PRs # accuracy I did a very quick check that loss curves on torchtitan LLaMa 3 8B pretraining with 8 H100 GPUs look good for bf16/f8_tensorwise/f8_axiswise/f8_lw on 0.5k iterations. I will leave longer accuracy verifications for future work. Screenshot 2024-10-04 at 10 05 24 PM Test Plan: Reviewers: Subscribers: Tasks: Tags: --- benchmarks/float8/float8_roofline.py | 53 +++++- benchmarks/float8/profile_linear_float8.py | 53 ++---- test/float8/test_base.py | 128 ++++++------- test/float8/test_compile.py | 174 +++++------------- test/float8/test_numerics_integration.py | 127 +++++-------- torchao/float8/config.py | 177 ++++++++++++++++-- torchao/float8/float8_linear.py | 204 ++++++++++++--------- torchao/float8/float8_ops.py | 14 ++ torchao/float8/float8_scaling_utils.py | 15 ++ torchao/testing/float8/test_utils.py | 50 +++++ 10 files changed, 566 insertions(+), 429 deletions(-) create mode 100644 torchao/testing/float8/test_utils.py diff --git a/benchmarks/float8/float8_roofline.py b/benchmarks/float8/float8_roofline.py index 2f04b8ee8..19c6cc21b 100644 --- a/benchmarks/float8/float8_roofline.py +++ b/benchmarks/float8/float8_roofline.py @@ -70,6 +70,7 @@ ScalingType, CastConfig, ) +from torchao.float8.config import recipe_name_to_linear_config, Float8LinearRecipeName class LNLinearSigmoid(torch.nn.Module): @@ -129,6 +130,8 @@ def get_gemm_times(M, K, N, fast_accum, cache_filename=None): else: # cache does not exist yet, create it cache = dict() + else: + cache = dict() key = f"{M},{K},{N},{fast_accum}" if key in cache: return cache[key] @@ -153,13 +156,18 @@ def do_matmul(A, B): ) f8_time_s = get_gpu_kernel_gemm_time_s(do_matmul, A, B) + scale_a = torch.ones(M, 1, device=device) + scale_b = torch.ones(1, N, device=device) + fast_accum = True # for axiswise + f8_axs_time_s = get_gpu_kernel_gemm_time_s(do_matmul, A, B) + # save to cache if needed if cache_filename is not None: - cache[key] = [bf16_time_s, f8_time_s] + cache[key] = [bf16_time_s, f8_time_s, f8_axs_time_s] with open(cache_filename, 'w') as f: json.dump(cache, f) - return bf16_time_s, f8_time_s + return bf16_time_s, f8_time_s, f8_axs_time_s def run( outfile: str, @@ -231,13 +239,15 @@ def run( headers = [ 'fwd_M', 'fwd_K', 'fwd_N', # gemm microbenchmarks - 'bf16_gemm_s', 'fp8_gemm_s', + 'bf16_gemm_s', 'fp8_gemm_s', 'fp8_axs_gemm_time_s', # roofline memory overhead estimates 'fp8_oh_dyn_limit', 'fp8_oh_dyn_nolimit', 'fp8_oh_del_limit', 'fp8_oh_del_nolimit', # actual e2e measurements - 'bf16_e2e_s', 'fp8_dyn_e2e_s', 'fp8_del_e2e_s', - 'fp8_dyn_speedup', 'fp8_del_speedup', + 'bf16_s', 'fp8_dyn_s', 'fp8_del_s', 'fp8_dyn_axs_s', + # 'fp8_lw_s', + 'fp8_dyn_sp', 'fp8_del_sp', 'fp8_dyn_axs_sp', + # 'fp8_lw_sp', ] results = [] @@ -248,15 +258,18 @@ def run( break if gemm_time_strategy == "benchmarks": - bf16_g1, f8_g1 = get_gemm_times(M_val, K_val, N_val, True, gemm_cache_filename) - bf16_g2, f8_g2 = get_gemm_times(M_val, N_val, K_val, False, gemm_cache_filename) - bf16_g3, f8_g3 = get_gemm_times(K_val, M_val, N_val, False, gemm_cache_filename) + bf16_g1, f8_g1, f8_g1_axs = get_gemm_times(M_val, K_val, N_val, True, gemm_cache_filename) + bf16_g2, f8_g2, f8_g2_axs = get_gemm_times(M_val, N_val, K_val, False, gemm_cache_filename) + bf16_g3, f8_g3, f8_g3_axs = get_gemm_times(K_val, M_val, N_val, False, gemm_cache_filename) bf16_time_val = bf16_g1 + bf16_g2 + bf16_g3 fp8_gemm_time_s = f8_g1 + f8_g2 + f8_g3 + fp8_axs_gemm_time_s = f8_g1_axs + f8_g2_axs + f8_g3_axs else: assert gemm_time_strategy == "roofline", "unsupported" bf16_time_val = bf16_gemm_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val) fp8_gemm_time_s = fp8_gemm_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val) + # for now, assume axiswise gemm is similar to tensorwise + fp8_axs_gemm_time_s = fp8_gemm_time_s fp8_mem_time_dyn_limit_s = \ fp8_mem_time_sympy_dyn_limit.subs(M, M_val).subs(K, K_val).subs(N, N_val) @@ -291,14 +304,30 @@ def run( cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED), ) - m_fp8_del = convert_to_float8_training(m_orig) + m_fp8_del = convert_to_float8_training(copy.deepcopy(m_orig), config=config) m_fp8_del = torch.compile(m_fp8_del) fp8_del_time_actual_s = get_gpu_kernel_time(m_fp8_del, x) + # get the float8 dynamic axiswise scaling gpu kernel time + torch._dynamo.reset() + config = recipe_name_to_linear_config(Float8LinearRecipeName.ALL_AXISWISE) + m_fp8_dyn_axs = convert_to_float8_training(copy.deepcopy(m_orig), config=config) + m_fp8_dyn_axs = torch.compile(m_fp8_dyn_axs) + fp8_dyn_axs_time_actual_s = get_gpu_kernel_time(m_fp8_dyn_axs, x) + + # get the lw recipe scaling gpu kernel time + # TODO(future PR): enable below once basic performance issues + # are fixed + # torch._dynamo.reset() + # config = recipe_name_to_linear_config(Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP) + # m_fp8_lw = convert_to_float8_training(m_orig, config=config) + # m_fp8_lw = torch.compile(m_fp8_lw) + # fp8_lw_time_actual_s = get_gpu_kernel_time(m_fp8_lw, x) + results.append([ M_val, K_val, N_val, # gemm microbenchmarks - bf16_time_val, fp8_gemm_time_s, + bf16_time_val, fp8_gemm_time_s, fp8_axs_gemm_time_s, # roofline overhead estimates fp8_mem_time_dyn_limit_s, fp8_mem_time_dyn_nolimit_s, @@ -306,8 +335,12 @@ def run( fp8_mem_time_del_nolimit_s, # e2e numbers bf16_time_actual_s, fp8_dyn_time_actual_s, fp8_del_time_actual_s, + fp8_dyn_axs_time_actual_s, + # fp8_lw_time_actual_s, bf16_time_actual_s / fp8_dyn_time_actual_s, bf16_time_actual_s / fp8_del_time_actual_s, + bf16_time_actual_s / fp8_dyn_axs_time_actual_s, + # bf16_time_actual_s / fp8_lw_time_actual_s, ]) df = pd.DataFrame(results, columns=headers) diff --git a/benchmarks/float8/profile_linear_float8.py b/benchmarks/float8/profile_linear_float8.py index 6afefa009..f4f2813a3 100644 --- a/benchmarks/float8/profile_linear_float8.py +++ b/benchmarks/float8/profile_linear_float8.py @@ -27,12 +27,15 @@ Float8LinearConfig, ScalingType, ScalingGranularity, + Float8LinearRecipeName, + recipe_name_to_linear_config, ) from torchao.float8.float8_linear_utils import ( convert_to_float8_training, linear_requires_sync, sync_float8_amax_and_scale_history, ) +from torchao.testing.float8.test_utils import get_test_float8_linear_config from torch.profiler import profile, ProfilerActivity, record_function from utils import ( kernel_name_to_category, @@ -257,7 +260,7 @@ def main( scaling_type_input: str = "dynamic", scaling_type_weight: str = "dynamic", scaling_type_grad_output: str = "dynamic", - scaling_granularity: str = "tensorwise", + recipe_name: Optional[str] = None, model_type: str = "linear", dtype_filter: str = "both", add_inductor_metadata_to_trace: bool = True, @@ -269,47 +272,17 @@ def main( scaling_type_input = ScalingType(scaling_type_input) scaling_type_weight = ScalingType(scaling_type_weight) scaling_type_grad_output = ScalingType(scaling_type_grad_output) - scaling_granularity = ScalingGranularity(scaling_granularity) - if scaling_type_input is ScalingType.STATIC: - cast_config_input=CastConfig( - scaling_type=scaling_type_input, - static_scale=torch.tensor([1.0], device="cuda"), - scaling_granularity=scaling_granularity, + if recipe_name is None: + config = get_test_float8_linear_config( + scaling_type_input, + scaling_type_weight, + scaling_type_grad_output, + emulate=False, ) - else: - cast_config_input=CastConfig( - scaling_type=scaling_type_input, - scaling_granularity=scaling_granularity, - ) - if scaling_type_weight is ScalingType.STATIC: - cast_config_weight=CastConfig( - scaling_type=scaling_type_weight, - static_scale=torch.tensor([1.0], device="cuda"), - scaling_granularity=scaling_granularity, - ) - else: - cast_config_weight=CastConfig( - scaling_type=scaling_type_weight, - scaling_granularity=scaling_granularity, - ) - if scaling_type_grad_output is ScalingType.STATIC: - cast_config_grad_output=CastConfig( - scaling_type=scaling_type_grad_output, - static_scale=torch.tensor([1.0], device="cuda"), - scaling_granularity=scaling_granularity, - ) - else: - cast_config_grad_output=CastConfig( - scaling_type=scaling_type_grad_output, - scaling_granularity=scaling_granularity, - ) - - config = Float8LinearConfig( - cast_config_input=cast_config_input, - cast_config_weight=cast_config_weight, - cast_config_grad_output=cast_config_grad_output, - ) + elif recipe_name is not None: + recipe_name = Float8LinearRecipeName(recipe_name) + config = recipe_name_to_linear_config(recipe_name) scaling_repr = "_".join( [ diff --git a/test/float8/test_base.py b/test/float8/test_base.py index 0aab91f55..478f89bf4 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -8,6 +8,7 @@ import itertools import random import re +from typing import List, Tuple import unittest import warnings @@ -27,6 +28,8 @@ Float8LinearConfig, ScalingGranularity, ScalingType, + Float8LinearRecipeName, + recipe_name_to_linear_config, ) from torchao.float8.float8_linear import Float8Linear from torchao.float8.float8_linear_utils import ( @@ -35,7 +38,10 @@ sync_float8_amax_and_scale_history, ) from torchao.float8.float8_python_api import addmm_float8_unwrapped -from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic +from torchao.float8.float8_scaling_utils import ( + hp_tensor_to_float8_dynamic, + get_maybe_axiswise_dim, +) from torchao.float8.float8_tensor import ( Float8Tensor, GemmInputRole, @@ -51,6 +57,7 @@ FP8_TYPES, tensor_to_scale, ) +from torchao.testing.float8.test_utils import get_test_float8_linear_config random.seed(0) torch.manual_seed(0) @@ -59,6 +66,8 @@ is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) is_cuda_9_0 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) + + def bitwise_identical(a: Float8Tensor, b: Float8Tensor) -> bool: assert torch.all(a._scale == b._scale).item(), "scales are not identical" assert torch.all(a._data == b._data).item(), "data is not identical" @@ -205,9 +214,17 @@ def test_axiswise_reshape(self): a_fp8_d2_r2 = a_fp8_d2.reshape(3, -1) @pytest.mark.parametrize("a_shape", [(16, 32), (2, 16, 32), (1, 2, 16, 32)]) + @pytest.mark.parametrize( + "a_granularity,b_granularity", + [ + (ScalingGranularity.AXISWISE, ScalingGranularity.AXISWISE), + (ScalingGranularity.AXISWISE, ScalingGranularity.TENSORWISE), + (ScalingGranularity.TENSORWISE, ScalingGranularity.AXISWISE), + ] + ) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") @unittest.skipIf(not is_cuda_9_0, "Requires CUDA capability >= 9.0") - def test_axiswise_gemm(self, a_shape): + def test_axiswise_gemm(self, a_shape, a_granularity, b_granularity): a = torch.randn(*a_shape, dtype=torch.bfloat16, device="cuda") b = torch.randn(64, 32, dtype=torch.bfloat16, device="cuda") @@ -218,18 +235,20 @@ def test_axiswise_gemm(self, a_shape): e4m3_dtype, linear_mm_config, gemm_input_role=GemmInputRole.INPUT, - scaling_granularity=ScalingGranularity.AXISWISE, - axiswise_dim=-1, + scaling_granularity=a_granularity, + axiswise_dim=get_maybe_axiswise_dim(-1, a_granularity), ) a_fp8 = a_fp8.reshape(-1, a_shape[-1]) + b_fp8 = hp_tensor_to_float8_dynamic( b, e4m3_dtype, linear_mm_config, gemm_input_role=GemmInputRole.WEIGHT, - scaling_granularity=ScalingGranularity.AXISWISE, - axiswise_dim=-1, # will be transposed + scaling_granularity=b_granularity, + axiswise_dim=get_maybe_axiswise_dim(-1, b_granularity), ) + c_fp8_compute = torch.mm(a_fp8, b_fp8.t()) a = a.reshape(-1, a_shape[-1]) c_ref = torch.mm(a, b.t()) @@ -322,79 +341,64 @@ def _test_linear_impl( ) @pytest.mark.parametrize( "scaling_type_grad_output", - [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], - ) - @pytest.mark.parametrize( - "scaling_granularity", - [ScalingGranularity.TENSORWISE, ScalingGranularity.AXISWISE], + [ScalingType.DELAYED, ScalingType.DYNAMIC], ) @pytest.mark.parametrize("linear_dtype", [torch.bfloat16, torch.float32]) @pytest.mark.parametrize("linear_bias", [False, True]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") - def test_linear( + def test_linear_from_config_params( self, x_shape, emulate: bool, scaling_type_input: ScalingType, scaling_type_weight: ScalingType, scaling_type_grad_output: ScalingType, - scaling_granularity: ScalingGranularity, linear_dtype: torch.dtype, linear_bias: bool, ): - if scaling_granularity is ScalingGranularity.AXISWISE: - if ( - scaling_type_input != ScalingType.DYNAMIC or - scaling_type_weight != ScalingType.DYNAMIC or - scaling_type_grad_output != ScalingType.DYNAMIC or - linear_dtype != torch.bfloat16 or - (not is_cuda_9_0) - ): - pytest.skip() - x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype) m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype) - if scaling_type_input is ScalingType.STATIC: - cast_config_input = CastConfig( - scaling_type=scaling_type_input, - scaling_granularity=scaling_granularity, - static_scale=torch.tensor([1.0], device="cuda"), - ) - else: - cast_config_input = CastConfig( - scaling_type=scaling_type_input, - scaling_granularity=scaling_granularity, - ) - if scaling_type_weight is ScalingType.STATIC: - cast_config_weight = CastConfig( - scaling_type=scaling_type_weight, - scaling_granularity=scaling_granularity, - static_scale=torch.tensor([1.0], device="cuda"), - ) - else: - cast_config_weight = CastConfig( - scaling_type=scaling_type_weight, - scaling_granularity=scaling_granularity, - ) - if scaling_type_grad_output is ScalingType.STATIC: - cast_config_grad_output = CastConfig( - scaling_type=scaling_type_grad_output, - scaling_granularity=scaling_granularity, - static_scale=torch.tensor([1.0], device="cuda"), - ) - else: - cast_config_grad_output = CastConfig( - scaling_type=scaling_type_grad_output, - scaling_granularity=scaling_granularity, - ) + config = get_test_float8_linear_config( + scaling_type_input, + scaling_type_weight, + scaling_type_grad_output, + emulate, + ) - config = Float8LinearConfig( - cast_config_input=cast_config_input, - cast_config_weight=cast_config_weight, - cast_config_grad_output=cast_config_grad_output, - emulate=emulate, + self._test_linear_impl( + x, + m_ref, + config, ) + + # Note: there are now too many config combinations to test all of + # them, so this function factors out some of the recipes which are annoying + # to combine with the main testing function. + # TODO(future PR): make this cleaner. + @pytest.mark.parametrize( + "recipe_name", + [Float8LinearRecipeName.ALL_AXISWISE, Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP], + ) + @pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]) + @pytest.mark.parametrize("linear_bias", [True, False]) + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + def test_linear_from_recipe( + self, + recipe_name, + x_shape, + linear_bias: bool, + ): + if torch.cuda.get_device_capability() < (9, 0): + warnings.warn( + f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)" + ) + pytest.skip() + + linear_dtype = torch.bfloat16 + x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype) + m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype) + config = recipe_name_to_linear_config(recipe_name) self._test_linear_impl( x, m_ref, @@ -504,7 +508,7 @@ def test_repr(self): config=config, ) s = m.__repr__() - assert "i:dyn,w:del,go:dyn" in s + assert "i:dyn_ten,w:del_ten,go:dyn_ten" in s @unittest.skipIf(not is_cuda_8_9, "CUDA 8.9 not available") def test_inference_mode(self): diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index 317743288..7c445f880 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import copy import random +from typing import List, Tuple import sys import unittest from io import StringIO @@ -22,7 +23,8 @@ CastConfig, Float8LinearConfig, ScalingType, - ScalingGranularity, + Float8LinearRecipeName, + recipe_name_to_linear_config, ) from torchao.float8.float8_linear import Float8Linear from torchao.float8.float8_linear_utils import ( @@ -40,6 +42,7 @@ ScaledMMConfig, ) from torchao.float8.float8_utils import e4m3_dtype +from torchao.testing.float8.test_utils import get_test_float8_linear_config from torch._dynamo.test_case import TestCase as DynamoTestCase from torch._dynamo.testing import CompileCounterWithBackend @@ -59,7 +62,8 @@ def _test_compile_base( x_shape = (16, 16) linear_dtype = torch.bfloat16 - x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype) + x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype).requires_grad_() + x_ref = copy.deepcopy(x) m_ref = nn.Linear(16, 32, bias=True, device="cuda", dtype=linear_dtype) m_fp8 = Float8Linear.from_float( @@ -71,7 +75,7 @@ def _test_compile_base( m_ref = torch.compile(m_ref, backend=backend, fullgraph=fullgraph) y_fp8 = m_fp8(x) y_fp8.sum().backward() - y_ref = m_ref(x) + y_ref = m_ref(x_ref) y_ref.sum().backward() # TODO(future PR): can also test fp8 eager vs compile here with a tigher # tolerance @@ -80,74 +84,7 @@ def _test_compile_base( m_fp8.weight.grad, m_ref.weight.grad, atol=2e-1, rtol=2e-1 ) torch.testing.assert_close(m_fp8.bias.grad, m_ref.bias.grad, atol=8e-2, rtol=8e-2) - -def _get_config( - scaling_type_input, - scaling_type_weight, - scaling_type_grad_output, - scaling_granularity, - emulate, -): - if scaling_type_input is ScalingType.STATIC: - cast_config_input = CastConfig( - scaling_type=scaling_type_input, - scaling_granularity=scaling_granularity, - static_scale=torch.tensor([1.0], device="cuda"), - ) - else: - cast_config_input = CastConfig( - scaling_type=scaling_type_input, - scaling_granularity=scaling_granularity, - ) - if scaling_type_weight is ScalingType.STATIC: - cast_config_weight = CastConfig( - scaling_type=scaling_type_weight, - scaling_granularity=scaling_granularity, - static_scale=torch.tensor([1.0], device="cuda"), - ) - else: - cast_config_weight = CastConfig( - scaling_type=scaling_type_weight, - scaling_granularity=scaling_granularity, - ) - if scaling_type_grad_output is ScalingType.STATIC: - cast_config_grad_output = CastConfig( - scaling_type=scaling_type_grad_output, - scaling_granularity=scaling_granularity, - static_scale=torch.tensor([1.0], device="cuda"), - ) - else: - cast_config_grad_output = CastConfig( - scaling_type=scaling_type_grad_output, - scaling_granularity=scaling_granularity, - ) - - config = Float8LinearConfig( - cast_config_input=cast_config_input, - cast_config_weight=cast_config_weight, - cast_config_grad_output=cast_config_grad_output, - emulate=emulate, - ) - return config - - -def is_supported( - scaling_granularity, - scaling_type_input, - scaling_type_weight, - scaling_type_grad_output, - dtype, -) -> bool: - if scaling_granularity is ScalingGranularity.AXISWISE: - if ( - scaling_type_input != ScalingType.DYNAMIC or - scaling_type_weight != ScalingType.DYNAMIC or - scaling_type_grad_output != ScalingType.DYNAMIC or - dtype != torch.bfloat16 or - (not is_H100) - ): - return False - return True + torch.testing.assert_close(x.grad, x_ref.grad, atol=8e-2, rtol=8e-2) @pytest.mark.parametrize("fullgraph", [True]) @@ -160,11 +97,8 @@ def is_supported( @pytest.mark.parametrize( "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] ) -@pytest.mark.parametrize( - "scaling_granularity", [ScalingGranularity.TENSORWISE, ScalingGranularity.AXISWISE] -) @pytest.mark.parametrize("emulate", [False, True] if is_cuda_8_9 else [True]) -@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_eager_only( fullgraph, @@ -172,24 +106,13 @@ def test_eager_only( scaling_type_input: ScalingType, scaling_type_weight: ScalingType, scaling_type_grad_output: ScalingType, - scaling_granularity: ScalingGranularity, dtype: torch.dtype, ): - if not is_supported( - scaling_granularity, - scaling_type_input, - scaling_type_weight, - scaling_type_grad_output, - dtype, - ): - pytest.skip() - torch._dynamo.reset() - config = _get_config( - scaling_type_input, - scaling_type_weight, - scaling_type_grad_output, - scaling_granularity, + config = get_test_float8_linear_config( + scaling_type_input, + scaling_type_weight, + scaling_type_grad_output, emulate, ) _test_compile_base( @@ -211,10 +134,7 @@ def test_eager_only( @pytest.mark.parametrize( "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] ) -@pytest.mark.parametrize( - "scaling_granularity", [ScalingGranularity.TENSORWISE, ScalingGranularity.AXISWISE] -) -@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_aot_eager( fullgraph, @@ -222,24 +142,13 @@ def test_aot_eager( scaling_type_input: ScalingType, scaling_type_weight: ScalingType, scaling_type_grad_output: ScalingType, - scaling_granularity: ScalingGranularity, dtype: torch.dtype, ): - if not is_supported( - scaling_granularity, - scaling_type_input, - scaling_type_weight, - scaling_type_grad_output, - dtype, - ): - pytest.skip() - torch._dynamo.reset() - config = _get_config( - scaling_type_input, - scaling_type_weight, - scaling_type_grad_output, - scaling_granularity, + config = get_test_float8_linear_config( + scaling_type_input, + scaling_type_weight, + scaling_type_grad_output, emulate, ) _test_compile_base( @@ -261,35 +170,21 @@ def test_aot_eager( @pytest.mark.parametrize( "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] ) -@pytest.mark.parametrize( - "scaling_granularity", [ScalingGranularity.TENSORWISE, ScalingGranularity.AXISWISE] -) @unittest.skipIf(not torch.cuda.is_available() or not is_cuda_8_9, "CUDA with float8 support not available") -@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) -def test_inductor( +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) +def test_inductor_from_config_params( fullgraph, emulate: bool, scaling_type_input: ScalingType, scaling_type_weight: ScalingType, scaling_type_grad_output: ScalingType, - scaling_granularity: ScalingGranularity, dtype: torch.dtype, ): - if not is_supported( - scaling_granularity, - scaling_type_input, - scaling_type_weight, - scaling_type_grad_output, - dtype, - ): - pytest.skip() - torch._dynamo.reset() - config = _get_config( - scaling_type_input, - scaling_type_weight, - scaling_type_grad_output, - scaling_granularity, + config = get_test_float8_linear_config( + scaling_type_input, + scaling_type_weight, + scaling_type_grad_output, emulate, ) _test_compile_base( @@ -299,6 +194,27 @@ def test_inductor( dtype, ) +# Note: there are now too many config combinations to test all of +# them, so this function factors out some of the recipes which are annoying +# to combine with the main testing function. +# TODO(future PR): make this cleaner. +@pytest.mark.parametrize( + "recipe_name", + [Float8LinearRecipeName.ALL_AXISWISE, Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP], +) +@unittest.skipIf(not is_H100, "CUDA with capability 9.0 or greater not available") +def test_inductor_from_recipe(recipe_name): + torch._dynamo.reset() + config = recipe_name_to_linear_config(recipe_name) + fullgraph = True + dtype = torch.bfloat16 + _test_compile_base( + "inductor", + fullgraph, + config, + dtype, + ) + class TestGraphBreaks(DynamoTestCase): class MockLinear(torch.nn.Module): diff --git a/test/float8/test_numerics_integration.py b/test/float8/test_numerics_integration.py index 07fcddaad..a91b784c8 100644 --- a/test/float8/test_numerics_integration.py +++ b/test/float8/test_numerics_integration.py @@ -24,6 +24,8 @@ Float8LinearConfig, ScalingType, ScalingGranularity, + Float8LinearRecipeName, + recipe_name_to_linear_config, ) from torchao.float8.float8_linear_utils import ( convert_to_float8_training, @@ -31,6 +33,7 @@ sync_float8_amax_and_scale_history, ) from torchao.float8.float8_utils import compute_error, IS_ROCM +from torchao.testing.float8.test_utils import get_test_float8_linear_config is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) is_cuda_9_0 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) @@ -84,44 +87,9 @@ def init_weights(self, init_std: float): class TestFloat8NumericsIntegrationTest: - @pytest.mark.parametrize( - "scaling_type_input", - [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], - ) - @pytest.mark.parametrize( - "scaling_type_weight", - [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], - ) - @pytest.mark.parametrize( - "scaling_type_grad_output", - [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], - ) - @pytest.mark.parametrize( - "scaling_granularity", - [ScalingGranularity.TENSORWISE, ScalingGranularity.AXISWISE], - ) - @pytest.mark.skipif(not is_cuda_8_9, reason="requires SM89 compatible machine") - @pytest.mark.skipif(IS_ROCM, reason="test doesn't currently work on the ROCm stack") - def test_encoder_fw_bw( - self, - scaling_type_input: ScalingType, - scaling_type_weight: ScalingType, - scaling_type_grad_output: ScalingType, - scaling_granularity: ScalingGranularity, - ): - # TODO(later): maybe add float16 back if it becomes important - data_dtype = torch.bfloat16 - - if scaling_granularity is ScalingGranularity.AXISWISE: - if ( - scaling_type_input != ScalingType.DYNAMIC or - scaling_type_weight != ScalingType.DYNAMIC or - scaling_type_grad_output != ScalingType.DYNAMIC or - data_dtype != torch.bfloat16 or - (not is_cuda_9_0) - ): - pytest.skip() + def _test_impl(self, config: Float8LinearConfig) -> None: + data_dtype = torch.bfloat16 # LLaMa 3 70B shapes model_ref = ( FeedForward( @@ -137,44 +105,6 @@ def test_encoder_fw_bw( # for now just test the encoder to simplify things model_fp8 = copy.deepcopy(model_ref) - if scaling_type_input is ScalingType.STATIC: - cast_config_input = CastConfig( - scaling_type=scaling_type_input, - scaling_granularity=scaling_granularity, - static_scale=torch.tensor([1.0], device="cuda"), - ) - else: - cast_config_input = CastConfig( - scaling_type=scaling_type_input, - scaling_granularity=scaling_granularity, - ) - if scaling_type_weight is ScalingType.STATIC: - cast_config_weight = CastConfig( - scaling_type=scaling_type_weight, - static_scale=torch.tensor([1.0], device="cuda"), - ) - else: - cast_config_weight = CastConfig( - scaling_type=scaling_type_weight, - scaling_granularity=scaling_granularity, - ) - if scaling_type_grad_output is ScalingType.STATIC: - cast_config_grad_output = CastConfig( - scaling_type=scaling_type_grad_output, - static_scale=torch.tensor([1.0], device="cuda"), - ) - else: - cast_config_grad_output = CastConfig( - scaling_type=scaling_type_grad_output, - scaling_granularity=scaling_granularity, - ) - - config = Float8LinearConfig( - cast_config_input=cast_config_input, - cast_config_weight=cast_config_weight, - cast_config_grad_output=cast_config_grad_output, - ) - convert_to_float8_training( model_fp8, config=config, @@ -212,9 +142,9 @@ def test_encoder_fw_bw( out_sqnr = compute_error(model_ref_out, model_fp8_out) any_static_scaling = ( - scaling_type_input is ScalingType.STATIC - or scaling_type_weight is ScalingType.STATIC - or scaling_type_grad_output is ScalingType.STATIC + config.cast_config_input.scaling_type is ScalingType.STATIC + or config.cast_config_weight.scaling_type is ScalingType.STATIC + or config.cast_config_grad_output.scaling_type is ScalingType.STATIC ) if any_static_scaling: assert out_sqnr > 10.0 @@ -236,6 +166,47 @@ def test_encoder_fw_bw( sqnr = compute_error(ref_grad, cur_grad) assert sqnr > grad_sqnr_threshold + @pytest.mark.parametrize( + "scaling_type_input", + [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], + ) + @pytest.mark.parametrize( + "scaling_type_weight", + [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], + ) + @pytest.mark.parametrize( + "scaling_type_grad_output", + [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], + ) + @pytest.mark.skipif(not is_cuda_8_9, reason="requires SM89 compatible machine") + @pytest.mark.skipif(IS_ROCM, reason="test doesn't currently work on the ROCm stack") + def test_encoder_fw_bw_from_config_params( + self, + scaling_type_input: ScalingType, + scaling_type_weight: ScalingType, + scaling_type_grad_output: ScalingType, + ): + config = get_test_float8_linear_config( + scaling_type_input, + scaling_type_weight, + scaling_type_grad_output, + emulate=False, + ) + self._test_impl(config) + + @pytest.mark.parametrize( + "recipe_name", + [Float8LinearRecipeName.ALL_AXISWISE, Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP], + ) + @pytest.mark.skipif(not is_cuda_9_0, reason="requires SM90 compatible machine") + @pytest.mark.skipif(IS_ROCM, reason="test doesn't currently work on the ROCm stack") + def test_encoder_fw_bw_from_recipe( + self, + recipe_name: str, + ): + config = recipe_name_to_linear_config(recipe_name) + self._test_impl(config) + if __name__ == "__main__": pytest.main([__file__]) diff --git a/torchao/float8/config.py b/torchao/float8/config.py index 0ed6d2622..556d2a4d4 100644 --- a/torchao/float8/config.py +++ b/torchao/float8/config.py @@ -15,15 +15,20 @@ class ScalingType(enum.Enum): DELAYED = "delayed" DYNAMIC = "dynamic" STATIC = "static" + # ScalingType.DISABLED means "skip scaling for this tensor, leave it in + # its original precision. + DISABLED = "disabled" def short_str(self): if self is ScalingType.DELAYED: return "del" elif self is ScalingType.DYNAMIC: return "dyn" - else: - assert self is ScalingType.STATIC + elif self is ScalingType.STATIC: return "sta" + else: + assert self is ScalingType.DISABLED + return "dis" class ScalingGranularity(enum.Enum): @@ -48,13 +53,16 @@ def short_str(self): @dataclass(frozen=True) class CastConfig: """ - Configuration for casting a single tensor to float8 + Configuration for maybe casting a single tensor to float8 """ scaling_type: ScalingType = ScalingType.DYNAMIC scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE static_scale: Optional[torch.Tensor] = None + def short_str(self): + return f"{self.scaling_type.short_str()}_{self.scaling_granularity.short_str()}" + def __post_init__(self): if self.scaling_type is ScalingType.STATIC: assert ( @@ -107,15 +115,34 @@ class Float8LinearConfig: """ # - # Per-tensor configuration for `input`, `weight`, `grad_output` + # Per-tensor configuration for casting of `input`, `weight`, `grad_output` + # for the operands of gemms calculating `output`, `grad_weight`, and `grad_input`. + # + # Note: + # 1. if `cast_config_input_for_grad_weight` is None, then + # `cast_config_input` is used for scaling `input` for both gemms that + # use `input. + # 2. if `cast_config_input_for_grad_weight` is specified, then + # a. `cast_config_input` is used for scaling `input` for the gemm that calculates + # `output` + # b. `cast_config_input_for_grad_weight` is used for scaling `input` for + # the gemm that calculates `grad_weight` + # 3. the same behavior holds for `cast_config_weight` and `cast_config_grad_output`. # + # `input` cast_config_input: CastConfig = CastConfig() + cast_config_input_for_grad_weight: Optional[CastConfig] = None + # `weight` cast_config_weight: CastConfig = CastConfig() + cast_config_weight_for_grad_input: Optional[CastConfig] = None + # `grad_output` cast_config_grad_output: CastConfig = CastConfig() + cast_config_grad_output_for_grad_weight: Optional[CastConfig] = None # # Per-gemm configuration for gemms calculating `output`, `grad_input` and # `grad_weight` + # TODO(this PR): throw warning if fast_accum False is used with axiswise scaling # gemm_config_output: Float8GemmConfig = Float8GemmConfig(use_fast_accum=True) gemm_config_grad_input: Float8GemmConfig = Float8GemmConfig() @@ -174,28 +201,140 @@ class Float8LinearConfig: force_recompute_fp8_weight_in_bwd: bool = False def __post_init__(self): + # Populate the additional cast overrides, if the user did not specify them + # Note: this hacks around the frozen-ness of this dataclass + # by using `object.__setattr__`. This is fine, as what we really need + # is for this object to be frozen after `__post_init__` for torch.compile + # to work. + # Source of hack: https://stackoverflow.com/a/65959419/ + if self.cast_config_input_for_grad_weight is None: + object.__setattr__(self, "cast_config_input_for_grad_weight", self.cast_config_input) + if self.cast_config_weight_for_grad_input is None: + object.__setattr__(self, "cast_config_weight_for_grad_input", self.cast_config_weight) + if self.cast_config_grad_output_for_grad_weight is None: + object.__setattr__(self, "cast_config_grad_output_for_grad_weight", self.cast_config_grad_output) + # float8 all-gather only supports tensorwise, in the future may support blockwise if self.cast_config_weight.scaling_granularity != ScalingGranularity.TENSORWISE: assert not self.enable_fsdp_float8_all_gather, \ f"enable_fsdp_float8_all_gather only supports tensorwise scaling granularity, got {self.cast_config_weight.scaling_granularity}" - # for now, axiswise granularity is all-or-nothing - # TODO(future PR): enable more granular setting per-gemm-input - has_any_axiswise_scaling = ( - self.cast_config_input.scaling_granularity is ScalingGranularity.AXISWISE or - self.cast_config_weight.scaling_granularity is ScalingGranularity.AXISWISE or - self.cast_config_grad_output.scaling_granularity is ScalingGranularity.AXISWISE - ) - has_all_axiswise_scaling = ( - self.cast_config_input.scaling_granularity is ScalingGranularity.AXISWISE and - self.cast_config_weight.scaling_granularity is ScalingGranularity.AXISWISE and - self.cast_config_grad_output.scaling_granularity is ScalingGranularity.AXISWISE - ) - if has_any_axiswise_scaling: - assert has_all_axiswise_scaling, \ - "for now, axiswise scaling must be enabled for either all casts or none of the casts" + # save some characters in the compatibility checks below + cc_i = self.cast_config_input + cc_w = self.cast_config_weight + cc_go = self.cast_config_grad_output + cc_i_gw = self.cast_config_input_for_grad_weight + cc_w_gi = self.cast_config_weight_for_grad_input + cc_go_gw = self.cast_config_grad_output_for_grad_weight + # for now, we only have gemm kernels where both operands are either both + # in high precision, or both in float8. In the future, this may be relaxed. + # TODO(future): make the float8 check more precise with the specific dtypes. + for cc1, cc2, gemm_name in ( + (cc_i, cc_w, "output"), + (cc_go, cc_w_gi, "grad_input"), + (cc_i_gw, cc_go_gw, "grad_weight"), + ): + is_disabled_1 = cc1.scaling_type is ScalingType.DISABLED + is_disabled_2 = cc1.scaling_type is ScalingType.DISABLED + assert is_disabled_1 == is_disabled_2, \ + f"incompatible operand precision for {gemm_name}" + # If True, use 'fnuz' float8 types for calculations. # Currently, ROCm only supports fnuz variants. # TODO(future PR): move this to Float8LinearConfig use_fnuz_dtype = False + + +# Pre-made recipes for common configurations +# TODO(future PR): go through a round of design on this, and eventually expose +# as a top level public API. +class Float8LinearRecipeName(enum.Enum): + ALL_TENSORWISE = "all_tensorwise" + ALL_AXISWISE = "all_axiswise" + LW_AXISWISE_WITH_GW_HP = "lw_axiswise_with_gw_hp" + + +def recipe_name_to_linear_config( + recipe_name: Float8LinearRecipeName, +) -> Float8LinearConfig: + """ + Input: `Float8LinearRecipeName` value + Output: a `Float8LinearConfig` configured to implement the recipe + """ + + if recipe_name is Float8LinearRecipeName.ALL_TENSORWISE: + # Default, dynamic per-tensor scaling with the cuBLAS tensorwise kernel + return Float8LinearConfig() + + elif recipe_name is Float8LinearRecipeName.ALL_AXISWISE: + # dynamic axiswise scaling with the CUTLASS rowwise kernel + cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) + cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) + cc_go = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) + + # The current rowwise CUTLASS kernels in `torch._scaled_mm` are only + # fast with `use_fast_accum=True`. Note that rowwise scaling is more + # accurate than tensorwise scaling, so the overall impact on accuracy + # of tensorwise vs rowwise taking this flag into account will vary. + gc_o = Float8GemmConfig(use_fast_accum=True) + gc_gi = Float8GemmConfig(use_fast_accum=True) + gc_gw = Float8GemmConfig(use_fast_accum=True) + + return Float8LinearConfig( + cast_config_input=cc_i, + cast_config_weight=cc_w, + cast_config_grad_output=cc_go, + gemm_config_output=gc_o, + gemm_config_grad_input=gc_gi, + gemm_config_grad_weight=gc_gw, + ) + + elif recipe_name is Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP: + + # lw's recipe for a modification on all-axiswise: + # + # output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1 + # grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise + # grad_weight_hp = input_t_hp @ grad_output_hp + # + # key characteristics: + # * increased accuracy for grad_weight + # * `input`, `weight` and `grad_output` now only need to be scaled + # axiswise across a single dim compared to vanilla all-axiswise, + # which is more amenable to fast kernels + + # output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1 + cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) + cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) + + # grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise + cc_go = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) + cc_w_gi = CastConfig(scaling_granularity=ScalingGranularity.TENSORWISE) + + # grad_weight_hp = input_t_hp @ grad_output_hp + cc_i_gw = CastConfig(scaling_type=ScalingType.DISABLED) + cc_go_gw = CastConfig(scaling_type=ScalingType.DISABLED) + + # The current rowwise CUTLASS kernels in `torch._scaled_mm` are only + # fast with `use_fast_accum=True`. Note that rowwise scaling is more + # accurate than tensorwise scaling, so the overall impact on accuracy + # of tensorwise vs rowwise taking this flag into account will vary. + gc_o = Float8GemmConfig(use_fast_accum=True) + gc_gi = Float8GemmConfig(use_fast_accum=True) + gc_gw = Float8GemmConfig(use_fast_accum=True) + + return Float8LinearConfig( + cast_config_input=cc_i, + cast_config_weight=cc_w, + cast_config_grad_output=cc_go, + cast_config_input_for_grad_weight=cc_i_gw, + cast_config_weight_for_grad_input=cc_w_gi, + cast_config_grad_output_for_grad_weight=cc_go_gw, + gemm_config_output=gc_o, + gemm_config_grad_input=gc_gi, + gemm_config_grad_weight=gc_gw, + ) + + else: + raise AssertionError(f"unknown recipe_name {recipe_name}") diff --git a/torchao/float8/float8_linear.py b/torchao/float8/float8_linear.py index 9aaffa99c..22ceff316 100644 --- a/torchao/float8/float8_linear.py +++ b/torchao/float8/float8_linear.py @@ -23,6 +23,7 @@ hp_tensor_to_float8_delayed, hp_tensor_to_float8_dynamic, hp_tensor_to_float8_static, + get_maybe_axiswise_dim, NoopFwToFloat8E5M2BwDelayed, NoopFwToFloat8E5M2BwDynamic, NoopFwToFloat8E5M2BwStatic, @@ -37,8 +38,8 @@ ) from torchao.float8.float8_utils import ( - e4m3_dtype, - e5m2_dtype, + e4m3_dtype, + e5m2_dtype, tensor_to_amax, tensor_to_scale, ) @@ -122,54 +123,56 @@ class manual_float8_matmul_with_args_in_hp(torch.autograd.Function): and other granularities in a separate PR. """ - # TODO(this PR): types of inputs @staticmethod def forward( ctx, input_hp: torch.Tensor, weight_hp_t: torch.Tensor, linear_mm_config: LinearMMConfig, - input_scaling_granularity: ScalingGranularity, - weight_scaling_granularity: ScalingGranularity, - grad_output_scaling_granularity: ScalingGranularity, + config: Float8LinearConfig, ): ctx.save_for_backward(input_hp, weight_hp_t) ctx.linear_mm_config = linear_mm_config - ctx.input_scaling_granularity = input_scaling_granularity - ctx.weight_scaling_granularity = weight_scaling_granularity - ctx.grad_output_scaling_granularity = grad_output_scaling_granularity - - input_fp8 = hp_tensor_to_float8_dynamic( - input_hp, - e4m3_dtype, - linear_mm_config, - gemm_input_role=GemmInputRole.INPUT, - scaling_granularity=input_scaling_granularity, - axiswise_dim=-1, - ) + ctx.config = config - weight_fp8_t = hp_tensor_to_float8_dynamic( - weight_hp_t, - e4m3_dtype, - linear_mm_config, - gemm_input_role=GemmInputRole.WEIGHT, - scaling_granularity=weight_scaling_granularity, - axiswise_dim=0, - ) + c = config + + if c.cast_config_input.scaling_type is ScalingType.DISABLED: + input_maybe_fp8 = input_hp + else: + input_maybe_fp8 = hp_tensor_to_float8_dynamic( + input_hp, + e4m3_dtype, + linear_mm_config, + gemm_input_role=GemmInputRole.INPUT, + scaling_granularity=c.cast_config_input.scaling_granularity, + axiswise_dim=get_maybe_axiswise_dim(-1, c.cast_config_input.scaling_granularity), + ) + + if c.cast_config_weight.scaling_type is ScalingType.DISABLED: + weight_maybe_fp8_t = weight_hp_t + else: + weight_maybe_fp8_t = hp_tensor_to_float8_dynamic( + weight_hp_t, + e4m3_dtype, + linear_mm_config, + gemm_input_role=GemmInputRole.WEIGHT, + scaling_granularity=c.cast_config_weight.scaling_granularity, + axiswise_dim=get_maybe_axiswise_dim(0, c.cast_config_weight.scaling_granularity), + ) # the reshapes are needed in order to make the shapes compatible with # torch.mm - orig_shape = input_fp8.shape - input_fp8_reshaped = input_fp8.reshape(-1, orig_shape[-1]) - res_bits = torch.mm(input_fp8_reshaped, weight_fp8_t) + orig_shape = input_maybe_fp8.shape + input_maybe_fp8_reshaped = input_maybe_fp8.reshape(-1, orig_shape[-1]) + res_bits = torch.mm(input_maybe_fp8_reshaped, weight_maybe_fp8_t) res_bits = res_bits.reshape(*orig_shape[:-1], res_bits.shape[-1]) return res_bits @staticmethod def backward(ctx, grad_output): input_hp, weight_hp_t = ctx.saved_tensors - - # TODO scaling + c = ctx.config # the reshapes are needed in order to make the shapes compatible with # torch.mm @@ -182,26 +185,37 @@ def backward(ctx, grad_output): # calculate grad_input # - grad_output_reshaped_fp8_dim0 = hp_tensor_to_float8_dynamic( - grad_output_reshaped, - e5m2_dtype, - ctx.linear_mm_config, - gemm_input_role=GemmInputRole.GRAD_OUTPUT, - scaling_granularity=ctx.grad_output_scaling_granularity, - axiswise_dim=-1, - ) - weight_t_fp8_dim0 = hp_tensor_to_float8_dynamic( - weight_hp_t, - e4m3_dtype, - ctx.linear_mm_config, - gemm_input_role=GemmInputRole.WEIGHT, - scaling_granularity=ctx.weight_scaling_granularity, - axiswise_dim=-1, # will be transposed - ) + if c.cast_config_grad_output.scaling_type is ScalingType.DISABLED: + grad_output_reshaped_maybe_fp8_dim0 = grad_output_reshaped + else: + grad_output_reshaped_maybe_fp8_dim0 = hp_tensor_to_float8_dynamic( + grad_output_reshaped, + e5m2_dtype, + ctx.linear_mm_config, + gemm_input_role=GemmInputRole.GRAD_OUTPUT, + scaling_granularity=c.cast_config_grad_output.scaling_granularity, + axiswise_dim=get_maybe_axiswise_dim(-1, c.cast_config_grad_output.scaling_granularity), + ) + + if c.cast_config_weight_for_grad_input.scaling_type is ScalingType.DISABLED: + weight_t_maybe_fp8_dim0 = weight_hp_t + else: + # Note: we need https://github.com/pytorch/pytorch/issues/136267 + # to be solved to have a chance to reuse max(abs(weight, dim=...)) + # from the forward to get max(abs(weight)) here without reading + # the entire tensor. + weight_t_maybe_fp8_dim0 = hp_tensor_to_float8_dynamic( + weight_hp_t, + e4m3_dtype, + ctx.linear_mm_config, + gemm_input_role=GemmInputRole.WEIGHT, + scaling_granularity=c.cast_config_weight_for_grad_input.scaling_granularity, + axiswise_dim=get_maybe_axiswise_dim(-1, c.cast_config_weight_for_grad_input.scaling_granularity), + ) grad_input = torch.mm( - grad_output_reshaped_fp8_dim0, - weight_t_fp8_dim0.t(), + grad_output_reshaped_maybe_fp8_dim0, + weight_t_maybe_fp8_dim0.t(), ) grad_input = grad_input.reshape( *grad_output_orig_shape[:-1], grad_input.shape[-1] @@ -214,29 +228,38 @@ def backward(ctx, grad_output): # calculate grad_weight # - grad_output_reshaped_fp8_dim1 = hp_tensor_to_float8_dynamic( - grad_output_reshaped, - e5m2_dtype, - ctx.linear_mm_config, - gemm_input_role=GemmInputRole.GRAD_OUTPUT, - scaling_granularity=ctx.grad_output_scaling_granularity, - axiswise_dim=0, # will be transposed - ) - input_reshaped_fp8_dim1 = hp_tensor_to_float8_dynamic( - input_hp_reshaped, - e4m3_dtype, - ctx.linear_mm_config, - gemm_input_role=GemmInputRole.INPUT, - scaling_granularity=ctx.input_scaling_granularity, - axiswise_dim=0, - ) + if c.cast_config_grad_output_for_grad_weight.scaling_type is ScalingType.DISABLED: + grad_output_reshaped_maybe_fp8_dim1 = grad_output_reshaped + else: + grad_output_reshaped_maybe_fp8_dim1 = hp_tensor_to_float8_dynamic( + grad_output_reshaped, + e5m2_dtype, + ctx.linear_mm_config, + gemm_input_role=GemmInputRole.GRAD_OUTPUT, + scaling_granularity=c.cast_config_grad_output_for_grad_weight.scaling_granularity, + axiswise_dim=get_maybe_axiswise_dim(0, c.cast_config_grad_output_for_grad_weight.scaling_granularity), + ) + + if c.cast_config_input_for_grad_weight.scaling_type is ScalingType.DISABLED: + input_reshaped_maybe_fp8_dim1 = input_hp_reshaped + else: + input_reshaped_maybe_fp8_dim1 = hp_tensor_to_float8_dynamic( + input_hp_reshaped, + e4m3_dtype, + ctx.linear_mm_config, + gemm_input_role=GemmInputRole.INPUT, + scaling_granularity=c.cast_config_input_for_grad_weight.scaling_granularity, + axiswise_dim=get_maybe_axiswise_dim(0, c.cast_config_input_for_grad_weight.scaling_granularity), + ) grad_weight = torch.mm( - grad_output_reshaped_fp8_dim1.t(), - input_reshaped_fp8_dim1, + grad_output_reshaped_maybe_fp8_dim1.t(), + input_reshaped_maybe_fp8_dim1, ) - return grad_input, grad_weight.t(), None, None, None, None + empty_grads = None, None + + return grad_input, grad_weight.t(), *empty_grads class Float8Linear(torch.nn.Linear): @@ -531,14 +554,16 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: if self.has_any_delayed_scaling: self.float8_pre_forward(input) - # TODO(this PR): reuse with config, make a property - has_all_axiswise_scaling = ( - self.config.cast_config_input.scaling_granularity is ScalingGranularity.AXISWISE and - self.config.cast_config_weight.scaling_granularity is ScalingGranularity.AXISWISE and - self.config.cast_config_grad_output.scaling_granularity is ScalingGranularity.AXISWISE + has_any_axiswise_scaling = ( + self.config.cast_config_input.scaling_granularity is ScalingGranularity.AXISWISE or + self.config.cast_config_weight.scaling_granularity is ScalingGranularity.AXISWISE or + self.config.cast_config_grad_output.scaling_granularity is ScalingGranularity.AXISWISE or + self.config.cast_config_input_for_grad_weight.scaling_granularity is ScalingGranularity.AXISWISE or + self.config.cast_config_weight_for_grad_input.scaling_granularity is ScalingGranularity.AXISWISE or + self.config.cast_config_grad_output_for_grad_weight.scaling_granularity is ScalingGranularity.AXISWISE ) - if not has_all_axiswise_scaling: + if not has_any_axiswise_scaling: input_fp8 = self.cast_input_to_float8(input, self.is_amax_initialized) # If force_recompute_fp8_weight_in_bwd, we only recompute the fp8 weight, # weight_scale should be saved. @@ -568,9 +593,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: input, self.weight.t(), self.linear_mm_config, - self.config.cast_config_input.scaling_granularity, - self.config.cast_config_weight.scaling_granularity, - self.config.cast_config_grad_output.scaling_granularity, + self.config, ) if self.bias is not None: @@ -580,21 +603,20 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: self.float8_post_forward() return output - def scaling_type_repr(self): - # add scaling type settings without using too many characters - # example: "i:del,w:del,go:dyn" - return f"i:{self.scaling_type_input.short_str()},w:{self.scaling_type_weight.short_str()},go:{self.scaling_type_grad_output.short_str()}" - - def scaling_granularity_repr(self): - # add scaling granularity settings without using too many characters - # example: "i:ten,w:ten,g:ten" or "i:axs,w:axs,g:axs" - gi = self.config.cast_config_input.scaling_granularity.short_str() - gw = self.config.cast_config_weight.scaling_granularity.short_str() - ggo = self.config.cast_config_grad_output.scaling_granularity.short_str() - return f"i:{gi},w:{gw},go:{ggo}" - def extra_repr(self): - s = f'{super().extra_repr()}, scaling_type="{self.scaling_type_repr()}", scaling_granularity="{self.scaling_granularity_repr()}"' + c = self.config + ci = f"i:{c.cast_config_input.short_str()}" + cw = f"w:{c.cast_config_weight.short_str()}" + cgo = f"go:{c.cast_config_grad_output.short_str()}" + parts = [ci, cw, cgo] + if c.cast_config_input_for_grad_weight != c.cast_config_input: + parts.append(f"i_gw:{c.cast_config_input_for_grad_weight.short_str()}") + if c.cast_config_weight_for_grad_input != c.cast_config_weight: + parts.append(f"w_gi:{c.cast_config_weight_for_grad_input.short_str()}") + if c.cast_config_grad_output_for_grad_weight != c.cast_config_grad_output: + parts.append(f"go_gw:{c.cast_config_grad_output_for_grad_weight.short_str()}") + cast_config_str = ",".join(parts) + s = f'{super().extra_repr()}, cast_configs={cast_config_str}"' return s @classmethod diff --git a/torchao/float8/float8_ops.py b/torchao/float8/float8_ops.py index b97d03211..8f5bc768e 100644 --- a/torchao/float8/float8_ops.py +++ b/torchao/float8/float8_ops.py @@ -251,6 +251,20 @@ def preprocess_addmm(a: Float8Tensor, b: Float8Tensor): if is_row_major(b_data.stride()): b_data = b_data.t().contiguous().t() b_scale = b._scale + + # Today, torch._scaled_mm only supports both operands using the + # same granularity. The code below checks for cases where one + # operand is scaled axiswise and one tensorwise. If this case is found, + # we reshape the tensorwise scale to be repeat along the needed axis, + # so that torch._scaled_mm can call the axiswise-axiswise kernel. + # Note: using shape/size info does not work with compile here, which is + # why we are using inferring scaling type from the presence of + # axiswise_dim. + if a._axiswise_dim is None and b._axiswise_dim is not None: + a_scale = a_scale.repeat(a_data.shape[0]).reshape(-1, 1) + elif a._axiswise_dim is not None and b._axiswise_dim is None: + b_scale = b_scale.repeat(b_data.shape[1]).reshape(1, -1) + return a_data, a_scale, b_data, b_scale diff --git a/torchao/float8/float8_scaling_utils.py b/torchao/float8/float8_scaling_utils.py index 3207c0c9f..fc22a4e35 100644 --- a/torchao/float8/float8_scaling_utils.py +++ b/torchao/float8/float8_scaling_utils.py @@ -143,6 +143,21 @@ def hp_tensor_to_float8_static( ) +def get_maybe_axiswise_dim( + axiswise_dim: int, + scaling_granularity: ScalingGranularity, +) -> Optional[int]: + """ + Convenience function which takes in an axiswise dim which is only relevant + for axiswise scaing, and a scaling type. The output is pass-through + if scaling type is axiswise, and None otherwise. This is done to keep the + logic from choosing the axiswise dim out of the scaling function. + """ + if scaling_granularity is ScalingGranularity.AXISWISE: + return axiswise_dim + return None + + def _maybe_initialize_amaxes_scales_for_float8_cast( x, cur_amax, diff --git a/torchao/testing/float8/test_utils.py b/torchao/testing/float8/test_utils.py new file mode 100644 index 000000000..7f37c3f30 --- /dev/null +++ b/torchao/testing/float8/test_utils.py @@ -0,0 +1,50 @@ +import torch +from torchao.float8.config import ( + ScalingGranularity, + ScalingType, + CastConfig, + Float8LinearConfig, +) + + +def get_test_float8_linear_config( + scaling_type_input, + scaling_type_weight, + scaling_type_grad_output, + emulate: bool, +): + static_scale_one = torch.tensor([1.0], device="cuda") + + if scaling_type_input is ScalingType.STATIC: + static_scale_input = static_scale_one + else: + static_scale_input = None + if scaling_type_weight is ScalingType.STATIC: + static_scale_weight = static_scale_one + else: + static_scale_weight = None + if scaling_type_grad_output is ScalingType.STATIC: + static_scale_grad_output = static_scale_one + else: + static_scale_grad_output = None + + cast_config_input = CastConfig( + scaling_type=scaling_type_input, + static_scale=static_scale_input, + ) + cast_config_weight = CastConfig( + scaling_type=scaling_type_weight, + static_scale=static_scale_weight, + ) + cast_config_grad_output = CastConfig( + scaling_type=scaling_type_grad_output, + static_scale=static_scale_grad_output, + ) + + config = Float8LinearConfig( + cast_config_input=cast_config_input, + cast_config_weight=cast_config_weight, + cast_config_grad_output=cast_config_grad_output, + emulate=emulate, + ) + return config