From faf855f8e8c354700459c97a08a4456cd4b17c09 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Fri, 20 Dec 2024 14:30:39 -0800 Subject: [PATCH] refactor to make kernel algo configurable; refactor unit tests to test both algos --- .../kernels/fp8_dynamic_tensorwise.py | 179 +++++++++++++++--- .../kernels/fp8_dynamic_tensorwise_test.py | 15 +- 2 files changed, 168 insertions(+), 26 deletions(-) diff --git a/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise.py b/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise.py index 86b2c866f..c0605be1c 100644 --- a/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise.py +++ b/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise.py @@ -7,6 +7,7 @@ """ Triton kernels for scaling high precision tensors to float8. """ +from enum import Enum import torch import triton @@ -29,6 +30,17 @@ torch.float64: tl.float64, } + +class KernelAlgorithm(Enum): + """Enum for FP8 conversion strategy.""" + + # use atomic max to compute global amax between blocks + ATOMIC_MAX = "atomic_max" + + # reduce shared buffer containing local block amaxes to find global amax + REDUCTION = "reduction" + + kernel_configs = [ triton.Config({"BLOCK_SIZE": 128}, num_warps=1), triton.Config({"BLOCK_SIZE": 256}, num_warps=2), @@ -36,9 +48,10 @@ ] +# --- atomic max version of kernel --- @triton.autotune(configs=kernel_configs, key=["input_size"]) @triton.jit -def _block_amax( +def _block_amax_atomic( input_ptr, amax_ptr, num_elements, @@ -58,7 +71,7 @@ def _block_amax( @triton.autotune(configs=kernel_configs, key=["input_size"]) @triton.jit -def _to_fp8( +def _to_fp8_atomic( input_ptr, scale_out_ptr, amax_ptr, @@ -76,6 +89,7 @@ def _to_fp8( scale = (fp8_dtype_max / tl.clamp(global_amax, min=EPS, max=float("inf"))).to( tl.float32 ) + # only one program needs to store the scale block_id = tl.program_id(axis=0) if block_id == 0: @@ -94,11 +108,85 @@ def _to_fp8( tl.store(out_ptr + block_offs, fp8_vals, mask=mask) +# --- reduction version of kernel --- +@triton.jit +def _block_amax_reduction( + input_ptr, + block_amaxes_ptr, + num_elements, + input_dtype: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + EPS: tl.constexpr, +): + # compute local amax for each block + block_id = tl.program_id(axis=0) + block_start = block_id * BLOCK_SIZE + block_offs = block_start + tl.arange(0, BLOCK_SIZE) + block_mask = block_offs < num_elements + vals = tl.load(input_ptr + block_offs, mask=block_mask).to(input_dtype) + block_amax = tl.max(tl.abs(vals), axis=0) + tl.store(block_amaxes_ptr + block_id, block_amax) + + +@triton.jit +def _fp8_scale_reduction( + block_amaxes_ptr, + scale_out_ptr, + num_elements, + fp8_dtype_max, + BLOCK_SIZE: tl.constexpr, + EPS: tl.constexpr, +): + # calculate global amax across all blocks + global_amax = tl.zeros([1], dtype=tl.float64) + num_blocks = tl.cdiv(num_elements, BLOCK_SIZE) + for i in range(num_blocks): + block_max = tl.load(block_amaxes_ptr + i) + global_amax = tl.maximum(global_amax, block_max) + + # compute scale, must be fp32 + scale = (fp8_dtype_max / tl.clamp(global_amax, min=EPS, max=float("inf"))).to( + tl.float32 + ) + scale_off = tl.arange(0, 1) + tl.store(scale_out_ptr + scale_off, scale) + + +@triton.jit +def _to_fp8_reduction( + input_ptr, + scale_ptr, + out_ptr, + num_elements, + fp8_dtype_min, + fp8_dtype_max, + input_dtype: tl.constexpr, + output_dtype: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + EPS: tl.constexpr, +): + # load previously computed scale + scale = tl.load(scale_ptr) + + # load block of input tensor + block_id = tl.program_id(axis=0) + block_start = block_id * BLOCK_SIZE + block_offs = block_start + tl.arange(0, BLOCK_SIZE) + mask = block_offs < num_elements + vals = tl.load(input_ptr + block_offs, mask=mask).to(input_dtype) + + # perform conversion + vals = vals * scale + fp8_vals = tl.clamp(vals, min=fp8_dtype_min, max=fp8_dtype_max).to(output_dtype) + tl.store(out_ptr + block_offs, fp8_vals, mask=mask) + + def triton_hp_tensor_to_float8_dynamic( hp_tensor: torch.Tensor, fp8_dtype: torch.dtype, linear_mm_config: LinearMMConfig, gemm_input_role: GemmInputRole = GemmInputRole.INPUT, + algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_MAX, ) -> Float8Tensor: assert hp_tensor.is_contiguous(), "tensor must be contiguous" @@ -114,35 +202,78 @@ def triton_hp_tensor_to_float8_dynamic( # allocate memory for computed scale, local block maxes, and output fp8 tensor scale_out = torch.empty((1,), dtype=torch.float32, device=hp_tensor.device) - global_amax = torch.zeros((1,), dtype=torch.float32, device=hp_tensor.device) + fp8_output = torch.empty_like( flattened_input, dtype=fp8_dtype, device=hp_tensor.device ) grid = lambda meta: (triton.cdiv(num_elements, meta["BLOCK_SIZE"]),) - # compute global amax to be used for scaling - _block_amax[grid]( - flattened_input, - global_amax, - num_elements, - input_dtype=tl_input_dtype, - EPS=EPS, - ) + if algo == KernelAlgorithm.ATOMIC_MAX: + global_amax = torch.zeros((1,), dtype=torch.float32, device=hp_tensor.device) + # compute global amax to be used for scaling + _block_amax_atomic[grid]( + flattened_input, + global_amax, + num_elements, + input_dtype=tl_input_dtype, + EPS=EPS, + ) - # perform conversion - _to_fp8[grid]( - flattened_input, - scale_out, - global_amax, - fp8_output, - num_elements, - fp8_dtype_min, - fp8_dtype_max, - input_dtype=tl_input_dtype, - output_dtype=tl_output_dtype, - EPS=EPS, - ) + # perform conversion and store scale for use in Float8Tensor + _to_fp8_atomic[grid]( + flattened_input, + scale_out, + global_amax, + fp8_output, + num_elements, + fp8_dtype_min, + fp8_dtype_max, + input_dtype=tl_input_dtype, + output_dtype=tl_output_dtype, + EPS=EPS, + ) + elif algo == KernelAlgorithm.REDUCTION: + max_block_size = 512 + BLOCK_SIZE = min(max_block_size, num_elements) + block_amaxes = torch.zeros( + (num_elements // BLOCK_SIZE,), dtype=torch.float32, device=hp_tensor.device + ) + # compute local amax for each block + _block_amax_reduction[grid]( + flattened_input, + block_amaxes, + num_elements, + input_dtype=tl_input_dtype, + BLOCK_SIZE=BLOCK_SIZE, + EPS=EPS, + ) + + # calculate global amax across all blocks and use it to compute scale + _fp8_scale_reduction[(1, 1, 1)]( + block_amaxes, + scale_out, + num_elements, + fp8_dtype_max, + BLOCK_SIZE=BLOCK_SIZE, + EPS=EPS, + ) + + # perform conversion + _to_fp8_reduction[grid]( + flattened_input, + scale_out, + fp8_output, + num_elements, + fp8_dtype_min, + fp8_dtype_max, + input_dtype=tl_input_dtype, + output_dtype=tl_output_dtype, + BLOCK_SIZE=BLOCK_SIZE, + EPS=EPS, + ) + else: + raise ValueError(f"Unsupported kernel algorithm: {algo}") return Float8Tensor( fp8_output.reshape(orig_shape), diff --git a/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise_test.py b/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise_test.py index bc10f269c..a58748a3e 100644 --- a/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise_test.py +++ b/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise_test.py @@ -3,14 +3,24 @@ from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic from torchao.float8.float8_tensor import LinearMMConfig from torchao.prototype.float8nocompile.kernels.fp8_dynamic_tensorwise import ( + KernelAlgorithm, triton_hp_tensor_to_float8_dynamic, ) -def test_fp8_triton_hp_tensor_to_float8_dynamic(): +@pytest.mark.parametrize( + "algo", [KernelAlgorithm.ATOMIC_MAX, KernelAlgorithm.REDUCTION] +) +@pytest.mark.parametrize( + "input_shape", + [(32, 32), (512, 512), (4096, 4096)], +) +def test_fp8_triton_hp_tensor_to_float8_dynamic( + algo: KernelAlgorithm, input_shape: tuple[int, int] +): assert torch.cuda.is_available() device = "cuda" - input_bf16 = torch.randn((4, 4), dtype=torch.bfloat16, device=device) + input_bf16 = torch.randn(input_shape, dtype=torch.bfloat16, device=device) x_bf16 = input_bf16.clone().detach().to(device) y_bf16 = input_bf16.clone().detach().to(device) @@ -26,6 +36,7 @@ def test_fp8_triton_hp_tensor_to_float8_dynamic(): y_bf16, torch.float8_e4m3fn, LinearMMConfig(), + algo=algo, ) def allclose_fp8(tensor1, tensor2, atol=1e-3, rtol=1e-3):