Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
float8 training axiswise scaling support with per-gemm-argument confi…
…guration (#940) Summary: This PR finalizes the UX for axiswise scaling for float8 training, and introduces per-gemm-argument configurability to Float8Linear to enable exploration of future recipes. Not all combinations are supported yet. Specifically, the additional combination we now support and test is a recipe from @lw , where we do the following: ``` output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1 grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise grad_weight_hp = input_t_hp @ grad_output_hp ``` Key characteristics of this recipe: 1. increased accuracy for `grad_weight`, which is important for real workloads 2. `output` and `weight` now only need to be scaled axiswise across a single dim compared to vanilla all-axiswise, which is more amenable to fast kernels Here is how a user can configure this: ```python # # short form # config = torchao.float8.config.recipe_name_to_linear_config(Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP) # # or, long form # # output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1 cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) # grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise cc_go = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) cc_w_go = CastConfig(scaling_granularity=ScalingGranularity.TENSORWISE) # grad_weight_hp = input_t_hp @ grad_output_hp cc_i_gw = CastConfig(scaling_type=ScalingType.DISABLED) cc_go_gw = CastConfig(scaling_type=ScalingType.DISABLED) # ensure fast_accum is on to get fast kernels gc_o = Float8GemmConfig(use_fast_accum=True) gc_gi = Float8GemmConfig(use_fast_accum=True) gc_gw = Float8GemmConfig(use_fast_accum=True) config = Float8Config( cast_config_input = cc_i, cast_config_weight = cc_w, cast_config_grad_output = cc_go, cast_config_input_for_grad_weight = cc_i_gw, cast_config_weight_for_grad_output = cc_w_go, cast_config_grad_output_for_grad_weight = cc_go_gw, gemm_config_output=gc_o, gemm_config_grad_input=gc_gi, gemm_config_grad_weight=gc_gw, ) ``` # performance Below we provide basic performance characteristics of axiswise scaling in general, and the all-axiswise and lw recipes. ## gemm performance of torch._scaled_mm baseline: tensorwise scaling ``` > python benchmarks/float8/bench_matmul.py --shape_gen_name square --use_gpu_kernel_time True fast_accum name M K N ref_time_s fp8_time_s fp8_speedup 0 True 0 256 256 256 0.000004 0.000006 0.573115 1 True 1 512 512 512 0.000005 0.000007 0.659333 2 True 2 1024 1024 1024 0.000011 0.000010 1.080664 3 True 3 2048 2048 2048 0.000028 0.000017 1.596239 4 True 4 4096 4096 4096 0.000210 0.000082 2.551705 5 True 5 8192 8192 8192 0.001671 0.000680 2.457972 6 True 6 16384 16384 16384 0.015030 0.006498 2.313032 7 True 7 32768 32768 32768 0.103236 0.048097 2.146411 8 False 0 256 256 256 0.000004 0.000006 0.630061 9 False 1 512 512 512 0.000005 0.000007 0.767236 10 False 2 1024 1024 1024 0.000012 0.000008 1.391347 11 False 3 2048 2048 2048 0.000029 0.000020 1.457922 12 False 4 4096 4096 4096 0.000211 0.000101 2.100081 13 False 5 8192 8192 8192 0.001676 0.000788 2.128628 14 False 6 16384 16384 16384 0.014933 0.006351 2.351209 15 False 7 32768 32768 32768 0.103457 0.049498 2.090134 ``` experiment: axiswise-scaling ``` > python benchmarks/float8/bench_matmul.py --shape_gen_name square --use_gpu_kernel_time True --scaling_granularity axiswise fast_accum name M K N ref_time_s fp8_time_s fp8_speedup 0 True 0 256 256 256 0.000004 0.000004 0.966772 1 True 1 512 512 512 0.000005 0.000004 1.095791 2 True 2 1024 1024 1024 0.000011 0.000006 1.988363 3 True 3 2048 2048 2048 0.000027 0.000015 1.890065 4 True 4 4096 4096 4096 0.000210 0.000082 2.552356 5 True 5 8192 8192 8192 0.001674 0.001092 1.533132 6 True 6 16384 16384 16384 0.015114 0.008785 1.720480 7 True 7 32768 32768 32768 0.103286 0.071456 1.445439 8 False 0 256 256 256 0.000004 0.000004 0.899054 9 False 1 512 512 512 0.000005 0.000005 1.005340 10 False 2 1024 1024 1024 0.000011 0.000006 1.692868 11 False 3 2048 2048 2048 0.000028 0.000049 0.567655 12 False 4 4096 4096 4096 0.000210 0.000341 0.616193 13 False 5 8192 8192 8192 0.001678 0.002640 0.635541 14 False 6 16384 16384 16384 0.015051 0.021557 0.698212 15 False 7 32768 32768 32768 0.103497 0.169797 0.609533 ``` ## performance on microbenchmark of ln -> linear -> sigmoid Note: for large square shapes, performance tends to be fp8_delayed_tensorwise > fp8_dynamic_tensorwise > fp8_dynamic_axiswise > custom_recipe. For performance of fp8_dynamic_axiswise, it seems that the gap from tensorwise is mostly due to the gemm performance being behind tensorwise. ``` > python benchmarks/float8/float8_roofline.py ~/local/tmp/20241004_roofline.csv fwd_M fwd_K fwd_N bf16_gemm_s fp8_gemm_s fp8_axs_gemm_time_s fp8_oh_dyn_limit ... fp8_del_s fp8_dyn_axs_s fp8_lw_s fp8_dyn_sp fp8_del_sp fp8_dyn_axs_sp fp8_lw_sp 0 256 256 256 0.000011 0.000018 0.000012 6.50457971014493e-6 ... 0.000043 0.000049 0.000030 0.465634 0.457907 0.398357 0.643088 1 512 512 512 0.000014 0.000020 0.000013 8.01831884057971e-6 ... 0.000047 0.000054 0.000034 0.489556 0.493467 0.432643 0.685842 2 1024 1024 1024 0.000033 0.000026 0.000017 1.40732753623188e-5 ... 0.000060 0.000063 0.000050 0.734123 0.741467 0.705941 0.891199 3 2048 2048 2048 0.000081 0.000055 0.000044 3.82931014492754e-5 ... 0.000147 0.000159 0.000142 0.815678 0.800811 0.739865 0.827441 4 4096 4096 4096 0.000632 0.000274 0.000247 0.000135172405797101 ... 0.000602 0.000622 0.000662 1.236320 1.261848 1.221755 1.147678 5 8192 8192 8192 0.005027 0.002216 0.003292 0.000522689623188406 ... 0.003665 0.004776 0.005720 1.432213 1.513035 1.161130 0.969448 6 16384 16384 16384 0.045113 0.018975 0.025706 0.00207275849275362 ... 0.024664 0.032254 0.038051 1.803456 1.883291 1.440118 1.220738 7 32768 32768 32768 0.312459 0.147255 0.214492 0.00827303397101449 ... 0.182645 0.240962 0.270973 1.696376 1.766307 1.338827 1.190552 ``` ## performance on torchtitan LLaMa 3 8B on 8 H100 GPUs, float8 compute only: * baseline (bf16 + compile): 6,294 wps * f8 all-tensorwise: 7,359 wps (1.17x vs baseline) * f8 all-axiswise: 7,135 wps (1.13x vs baseline - surprising that this is close to all-tensorwise) * LW_AXISWISE_WITH_GW_HP: 6,506 wps (1.03x vs baseline) so, looks like we have performance work to do with `LW_AXISWISE_WITH_GW_HP` in future PRs # accuracy I did a very quick check that loss curves on torchtitan LLaMa 3 8B pretraining with 8 H100 GPUs look good for bf16/f8_tensorwise/f8_axiswise/f8_lw on 0.5k iterations. I will leave longer accuracy verifications for future work. <img width="973" alt="Screenshot 2024-10-04 at 10 05 24 PM" src="https://github.com/user-attachments/assets/0d682183-41ef-4f04-992f-cd0d0fc8a65c"> Test Plan: Reviewers: Subscribers: Tasks: Tags:
- Loading branch information