Skip to content

Commit

Permalink
float8 training axiswise scaling support with per-gemm-argument confi…
Browse files Browse the repository at this point in the history
…guration (#940)

Summary:

This PR finalizes the UX for axiswise scaling for float8 training, and introduces per-gemm-argument configurability to Float8Linear to enable exploration of future recipes. Not all combinations are supported yet.  Specifically, the additional combination we now support and test is a recipe from @lw , where we do the following:

```
output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1
grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise
grad_weight_hp = input_t_hp @ grad_output_hp
```

Key characteristics of this recipe:
1. increased accuracy for `grad_weight`, which is important for real workloads
2. `output` and `weight` now only need to be scaled axiswise across a single dim compared to vanilla all-axiswise, which is more amenable to fast kernels

Here is how a user can configure this:

```python
#
# short form
#

config = torchao.float8.config.recipe_name_to_linear_config(Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP)

#
# or, long form
#

# output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1
cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)

# grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise
cc_go = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
cc_w_go = CastConfig(scaling_granularity=ScalingGranularity.TENSORWISE)

# grad_weight_hp = input_t_hp @ grad_output_hp
cc_i_gw = CastConfig(scaling_type=ScalingType.DISABLED)
cc_go_gw = CastConfig(scaling_type=ScalingType.DISABLED)

# ensure fast_accum is on to get fast kernels
gc_o = Float8GemmConfig(use_fast_accum=True)
gc_gi = Float8GemmConfig(use_fast_accum=True)
gc_gw = Float8GemmConfig(use_fast_accum=True)

config = Float8Config(
    cast_config_input = cc_i,
    cast_config_weight = cc_w,
    cast_config_grad_output = cc_go,
    cast_config_input_for_grad_weight = cc_i_gw,
    cast_config_weight_for_grad_output = cc_w_go,
    cast_config_grad_output_for_grad_weight = cc_go_gw,
    gemm_config_output=gc_o,
    gemm_config_grad_input=gc_gi,
    gemm_config_grad_weight=gc_gw,
)
```

# performance

Below we provide basic performance characteristics of axiswise scaling in general, and the all-axiswise and lw recipes.

## gemm performance of torch._scaled_mm

baseline: tensorwise scaling

```
> python benchmarks/float8/bench_matmul.py --shape_gen_name square --use_gpu_kernel_time True
    fast_accum  name      M      K      N  ref_time_s  fp8_time_s  fp8_speedup
0         True     0    256    256    256    0.000004    0.000006     0.573115
1         True     1    512    512    512    0.000005    0.000007     0.659333
2         True     2   1024   1024   1024    0.000011    0.000010     1.080664
3         True     3   2048   2048   2048    0.000028    0.000017     1.596239
4         True     4   4096   4096   4096    0.000210    0.000082     2.551705
5         True     5   8192   8192   8192    0.001671    0.000680     2.457972
6         True     6  16384  16384  16384    0.015030    0.006498     2.313032
7         True     7  32768  32768  32768    0.103236    0.048097     2.146411
8        False     0    256    256    256    0.000004    0.000006     0.630061
9        False     1    512    512    512    0.000005    0.000007     0.767236
10       False     2   1024   1024   1024    0.000012    0.000008     1.391347
11       False     3   2048   2048   2048    0.000029    0.000020     1.457922
12       False     4   4096   4096   4096    0.000211    0.000101     2.100081
13       False     5   8192   8192   8192    0.001676    0.000788     2.128628
14       False     6  16384  16384  16384    0.014933    0.006351     2.351209
15       False     7  32768  32768  32768    0.103457    0.049498     2.090134                
```

experiment: axiswise-scaling

```
> python benchmarks/float8/bench_matmul.py --shape_gen_name square --use_gpu_kernel_time True --scaling_granularity axiswise

    fast_accum  name      M      K      N  ref_time_s  fp8_time_s  fp8_speedup
0         True     0    256    256    256    0.000004    0.000004     0.966772
1         True     1    512    512    512    0.000005    0.000004     1.095791
2         True     2   1024   1024   1024    0.000011    0.000006     1.988363
3         True     3   2048   2048   2048    0.000027    0.000015     1.890065
4         True     4   4096   4096   4096    0.000210    0.000082     2.552356
5         True     5   8192   8192   8192    0.001674    0.001092     1.533132
6         True     6  16384  16384  16384    0.015114    0.008785     1.720480
7         True     7  32768  32768  32768    0.103286    0.071456     1.445439
8        False     0    256    256    256    0.000004    0.000004     0.899054
9        False     1    512    512    512    0.000005    0.000005     1.005340
10       False     2   1024   1024   1024    0.000011    0.000006     1.692868
11       False     3   2048   2048   2048    0.000028    0.000049     0.567655
12       False     4   4096   4096   4096    0.000210    0.000341     0.616193
13       False     5   8192   8192   8192    0.001678    0.002640     0.635541
14       False     6  16384  16384  16384    0.015051    0.021557     0.698212
15       False     7  32768  32768  32768    0.103497    0.169797     0.609533

```

## performance on microbenchmark of ln -> linear -> sigmoid

Note: for large square shapes, performance tends to be fp8_delayed_tensorwise > fp8_dynamic_tensorwise > fp8_dynamic_axiswise > custom_recipe.  For performance of fp8_dynamic_axiswise, it seems that the gap from tensorwise is mostly due to the gemm performance being behind tensorwise.

```
> python benchmarks/float8/float8_roofline.py ~/local/tmp/20241004_roofline.csv

   fwd_M  fwd_K  fwd_N  bf16_gemm_s  fp8_gemm_s  fp8_axs_gemm_time_s      fp8_oh_dyn_limit  ... fp8_del_s fp8_dyn_axs_s  fp8_lw_s  fp8_dyn_sp  fp8_del_sp  fp8_dyn_axs_sp  fp8_lw_sp
0    256    256    256     0.000011    0.000018             0.000012   6.50457971014493e-6  ...  0.000043      0.000049  0.000030    0.465634    0.457907        0.398357   0.643088
1    512    512    512     0.000014    0.000020             0.000013   8.01831884057971e-6  ...  0.000047      0.000054  0.000034    0.489556    0.493467        0.432643   0.685842
2   1024   1024   1024     0.000033    0.000026             0.000017   1.40732753623188e-5  ...  0.000060      0.000063  0.000050    0.734123    0.741467        0.705941   0.891199
3   2048   2048   2048     0.000081    0.000055             0.000044   3.82931014492754e-5  ...  0.000147      0.000159  0.000142    0.815678    0.800811        0.739865   0.827441
4   4096   4096   4096     0.000632    0.000274             0.000247  0.000135172405797101  ...  0.000602      0.000622  0.000662    1.236320    1.261848        1.221755   1.147678
5   8192   8192   8192     0.005027    0.002216             0.003292  0.000522689623188406  ...  0.003665      0.004776  0.005720    1.432213    1.513035        1.161130   0.969448
6  16384  16384  16384     0.045113    0.018975             0.025706   0.00207275849275362  ...  0.024664      0.032254  0.038051    1.803456    1.883291        1.440118   1.220738
7  32768  32768  32768     0.312459    0.147255             0.214492   0.00827303397101449  ...  0.182645      0.240962  0.270973    1.696376    1.766307        1.338827   1.190552

```

## performance on torchtitan LLaMa 3 8B on 8 H100 GPUs, float8 compute only:

* baseline (bf16 + compile): 6,294 wps
* f8 all-tensorwise: 7,359 wps (1.17x vs baseline)
* f8 all-axiswise: 7,135 wps (1.13x vs baseline - surprising that this is close to all-tensorwise)
* LW_AXISWISE_WITH_GW_HP: 6,506 wps (1.03x vs baseline)

so, looks like we have performance work to do with `LW_AXISWISE_WITH_GW_HP` in future PRs

# accuracy

I did a very quick check that loss curves on torchtitan LLaMa 3 8B pretraining with 8 H100 GPUs look good for bf16/f8_tensorwise/f8_axiswise/f8_lw on 0.5k iterations.  I will leave longer accuracy verifications for future work.

<img width="973" alt="Screenshot 2024-10-04 at 10 05 24 PM" src="https://github.com/user-attachments/assets/0d682183-41ef-4f04-992f-cd0d0fc8a65c">


Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
vkuzo authored and jainapurva committed Oct 15, 2024
1 parent dd3bc3b commit 101d731
Show file tree
Hide file tree
Showing 10 changed files with 566 additions and 429 deletions.
53 changes: 43 additions & 10 deletions benchmarks/float8/float8_roofline.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
ScalingType,
CastConfig,
)
from torchao.float8.config import recipe_name_to_linear_config, Float8LinearRecipeName


class LNLinearSigmoid(torch.nn.Module):
Expand Down Expand Up @@ -129,6 +130,8 @@ def get_gemm_times(M, K, N, fast_accum, cache_filename=None):
else:
# cache does not exist yet, create it
cache = dict()
else:
cache = dict()
key = f"{M},{K},{N},{fast_accum}"
if key in cache:
return cache[key]
Expand All @@ -153,13 +156,18 @@ def do_matmul(A, B):
)
f8_time_s = get_gpu_kernel_gemm_time_s(do_matmul, A, B)

scale_a = torch.ones(M, 1, device=device)
scale_b = torch.ones(1, N, device=device)
fast_accum = True # for axiswise
f8_axs_time_s = get_gpu_kernel_gemm_time_s(do_matmul, A, B)

# save to cache if needed
if cache_filename is not None:
cache[key] = [bf16_time_s, f8_time_s]
cache[key] = [bf16_time_s, f8_time_s, f8_axs_time_s]
with open(cache_filename, 'w') as f:
json.dump(cache, f)

return bf16_time_s, f8_time_s
return bf16_time_s, f8_time_s, f8_axs_time_s

def run(
outfile: str,
Expand Down Expand Up @@ -231,13 +239,15 @@ def run(
headers = [
'fwd_M', 'fwd_K', 'fwd_N',
# gemm microbenchmarks
'bf16_gemm_s', 'fp8_gemm_s',
'bf16_gemm_s', 'fp8_gemm_s', 'fp8_axs_gemm_time_s',
# roofline memory overhead estimates
'fp8_oh_dyn_limit', 'fp8_oh_dyn_nolimit',
'fp8_oh_del_limit', 'fp8_oh_del_nolimit',
# actual e2e measurements
'bf16_e2e_s', 'fp8_dyn_e2e_s', 'fp8_del_e2e_s',
'fp8_dyn_speedup', 'fp8_del_speedup',
'bf16_s', 'fp8_dyn_s', 'fp8_del_s', 'fp8_dyn_axs_s',
# 'fp8_lw_s',
'fp8_dyn_sp', 'fp8_del_sp', 'fp8_dyn_axs_sp',
# 'fp8_lw_sp',
]
results = []

Expand All @@ -248,15 +258,18 @@ def run(
break

if gemm_time_strategy == "benchmarks":
bf16_g1, f8_g1 = get_gemm_times(M_val, K_val, N_val, True, gemm_cache_filename)
bf16_g2, f8_g2 = get_gemm_times(M_val, N_val, K_val, False, gemm_cache_filename)
bf16_g3, f8_g3 = get_gemm_times(K_val, M_val, N_val, False, gemm_cache_filename)
bf16_g1, f8_g1, f8_g1_axs = get_gemm_times(M_val, K_val, N_val, True, gemm_cache_filename)
bf16_g2, f8_g2, f8_g2_axs = get_gemm_times(M_val, N_val, K_val, False, gemm_cache_filename)
bf16_g3, f8_g3, f8_g3_axs = get_gemm_times(K_val, M_val, N_val, False, gemm_cache_filename)
bf16_time_val = bf16_g1 + bf16_g2 + bf16_g3
fp8_gemm_time_s = f8_g1 + f8_g2 + f8_g3
fp8_axs_gemm_time_s = f8_g1_axs + f8_g2_axs + f8_g3_axs
else:
assert gemm_time_strategy == "roofline", "unsupported"
bf16_time_val = bf16_gemm_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val)
fp8_gemm_time_s = fp8_gemm_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val)
# for now, assume axiswise gemm is similar to tensorwise
fp8_axs_gemm_time_s = fp8_gemm_time_s

fp8_mem_time_dyn_limit_s = \
fp8_mem_time_sympy_dyn_limit.subs(M, M_val).subs(K, K_val).subs(N, N_val)
Expand Down Expand Up @@ -291,23 +304,43 @@ def run(
cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED),
cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED),
)
m_fp8_del = convert_to_float8_training(m_orig)
m_fp8_del = convert_to_float8_training(copy.deepcopy(m_orig), config=config)
m_fp8_del = torch.compile(m_fp8_del)
fp8_del_time_actual_s = get_gpu_kernel_time(m_fp8_del, x)

# get the float8 dynamic axiswise scaling gpu kernel time
torch._dynamo.reset()
config = recipe_name_to_linear_config(Float8LinearRecipeName.ALL_AXISWISE)
m_fp8_dyn_axs = convert_to_float8_training(copy.deepcopy(m_orig), config=config)
m_fp8_dyn_axs = torch.compile(m_fp8_dyn_axs)
fp8_dyn_axs_time_actual_s = get_gpu_kernel_time(m_fp8_dyn_axs, x)

# get the lw recipe scaling gpu kernel time
# TODO(future PR): enable below once basic performance issues
# are fixed
# torch._dynamo.reset()
# config = recipe_name_to_linear_config(Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP)
# m_fp8_lw = convert_to_float8_training(m_orig, config=config)
# m_fp8_lw = torch.compile(m_fp8_lw)
# fp8_lw_time_actual_s = get_gpu_kernel_time(m_fp8_lw, x)

results.append([
M_val, K_val, N_val,
# gemm microbenchmarks
bf16_time_val, fp8_gemm_time_s,
bf16_time_val, fp8_gemm_time_s, fp8_axs_gemm_time_s,
# roofline overhead estimates
fp8_mem_time_dyn_limit_s,
fp8_mem_time_dyn_nolimit_s,
fp8_mem_time_del_limit_s,
fp8_mem_time_del_nolimit_s,
# e2e numbers
bf16_time_actual_s, fp8_dyn_time_actual_s, fp8_del_time_actual_s,
fp8_dyn_axs_time_actual_s,
# fp8_lw_time_actual_s,
bf16_time_actual_s / fp8_dyn_time_actual_s,
bf16_time_actual_s / fp8_del_time_actual_s,
bf16_time_actual_s / fp8_dyn_axs_time_actual_s,
# bf16_time_actual_s / fp8_lw_time_actual_s,
])

df = pd.DataFrame(results, columns=headers)
Expand Down
53 changes: 13 additions & 40 deletions benchmarks/float8/profile_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,15 @@
Float8LinearConfig,
ScalingType,
ScalingGranularity,
Float8LinearRecipeName,
recipe_name_to_linear_config,
)
from torchao.float8.float8_linear_utils import (
convert_to_float8_training,
linear_requires_sync,
sync_float8_amax_and_scale_history,
)
from torchao.testing.float8.test_utils import get_test_float8_linear_config
from torch.profiler import profile, ProfilerActivity, record_function
from utils import (
kernel_name_to_category,
Expand Down Expand Up @@ -257,7 +260,7 @@ def main(
scaling_type_input: str = "dynamic",
scaling_type_weight: str = "dynamic",
scaling_type_grad_output: str = "dynamic",
scaling_granularity: str = "tensorwise",
recipe_name: Optional[str] = None,
model_type: str = "linear",
dtype_filter: str = "both",
add_inductor_metadata_to_trace: bool = True,
Expand All @@ -269,47 +272,17 @@ def main(
scaling_type_input = ScalingType(scaling_type_input)
scaling_type_weight = ScalingType(scaling_type_weight)
scaling_type_grad_output = ScalingType(scaling_type_grad_output)
scaling_granularity = ScalingGranularity(scaling_granularity)

if scaling_type_input is ScalingType.STATIC:
cast_config_input=CastConfig(
scaling_type=scaling_type_input,
static_scale=torch.tensor([1.0], device="cuda"),
scaling_granularity=scaling_granularity,
if recipe_name is None:
config = get_test_float8_linear_config(
scaling_type_input,
scaling_type_weight,
scaling_type_grad_output,
emulate=False,
)
else:
cast_config_input=CastConfig(
scaling_type=scaling_type_input,
scaling_granularity=scaling_granularity,
)
if scaling_type_weight is ScalingType.STATIC:
cast_config_weight=CastConfig(
scaling_type=scaling_type_weight,
static_scale=torch.tensor([1.0], device="cuda"),
scaling_granularity=scaling_granularity,
)
else:
cast_config_weight=CastConfig(
scaling_type=scaling_type_weight,
scaling_granularity=scaling_granularity,
)
if scaling_type_grad_output is ScalingType.STATIC:
cast_config_grad_output=CastConfig(
scaling_type=scaling_type_grad_output,
static_scale=torch.tensor([1.0], device="cuda"),
scaling_granularity=scaling_granularity,
)
else:
cast_config_grad_output=CastConfig(
scaling_type=scaling_type_grad_output,
scaling_granularity=scaling_granularity,
)

config = Float8LinearConfig(
cast_config_input=cast_config_input,
cast_config_weight=cast_config_weight,
cast_config_grad_output=cast_config_grad_output,
)
elif recipe_name is not None:
recipe_name = Float8LinearRecipeName(recipe_name)
config = recipe_name_to_linear_config(recipe_name)

scaling_repr = "_".join(
[
Expand Down
Loading

0 comments on commit 101d731

Please sign in to comment.