diff --git a/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise.py b/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise.py index 0c3929bfa..86b2c866f 100644 --- a/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise.py +++ b/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise.py @@ -9,7 +9,6 @@ """ import torch - import triton import triton.language as tl @@ -30,11 +29,18 @@ torch.float64: tl.float64, } +kernel_configs = [ + triton.Config({"BLOCK_SIZE": 128}, num_warps=1), + triton.Config({"BLOCK_SIZE": 256}, num_warps=2), + triton.Config({"BLOCK_SIZE": 512}, num_warps=4), +] + +@triton.autotune(configs=kernel_configs, key=["input_size"]) @triton.jit def _block_amax( input_ptr, - block_amaxes_ptr, + amax_ptr, num_elements, input_dtype: tl.constexpr, BLOCK_SIZE: tl.constexpr, @@ -47,37 +53,15 @@ def _block_amax( block_mask = block_offs < num_elements vals = tl.load(input_ptr + block_offs, mask=block_mask).to(input_dtype) block_amax = tl.max(tl.abs(vals), axis=0) - tl.store(block_amaxes_ptr + block_id, block_amax) - - -@triton.jit -def _fp8_scale( - block_amaxes_ptr, - scale_out_ptr, - num_elements, - fp8_dtype_max, - BLOCK_SIZE: tl.constexpr, - EPS: tl.constexpr, -): - # calculate global amax across all blocks - global_amax = tl.zeros([1], dtype=tl.float64) - num_blocks = tl.cdiv(num_elements, BLOCK_SIZE) - for i in range(num_blocks): - block_max = tl.load(block_amaxes_ptr + i) - global_amax = tl.maximum(global_amax, block_max) - - # compute scale, must be fp32 - scale = (fp8_dtype_max / tl.clamp(global_amax, min=EPS, max=float("inf"))).to( - tl.float32 - ) - scale_off = tl.arange(0, 1) - tl.store(scale_out_ptr + scale_off, scale) + tl.atomic_max(amax_ptr, block_amax) +@triton.autotune(configs=kernel_configs, key=["input_size"]) @triton.jit def _to_fp8( input_ptr, - scale_ptr, + scale_out_ptr, + amax_ptr, out_ptr, num_elements, fp8_dtype_min, @@ -87,11 +71,18 @@ def _to_fp8( BLOCK_SIZE: tl.constexpr, EPS: tl.constexpr, ): - # load previously computed scale - scale = tl.load(scale_ptr) + # compute scale, must be fp32 + global_amax = tl.load(amax_ptr) + scale = (fp8_dtype_max / tl.clamp(global_amax, min=EPS, max=float("inf"))).to( + tl.float32 + ) + # only one program needs to store the scale + block_id = tl.program_id(axis=0) + if block_id == 0: + scale_offs = tl.arange(0, 1) + tl.store(scale_out_ptr + scale_offs, scale) # load block of input tensor - block_id = tl.program_id(axis=0) block_start = block_id * BLOCK_SIZE block_offs = block_start + tl.arange(0, BLOCK_SIZE) mask = block_offs < num_elements @@ -109,11 +100,8 @@ def triton_hp_tensor_to_float8_dynamic( linear_mm_config: LinearMMConfig, gemm_input_role: GemmInputRole = GemmInputRole.INPUT, ) -> Float8Tensor: - assert hp_tensor.is_contiguous(), "tensor must be contiguous" - BLOCK_SIZE = 8 # TODO(danielvegamyhre): tune this for perf - num_elements = hp_tensor.numel() orig_shape = hp_tensor.shape flattened_input = hp_tensor.flatten() @@ -126,31 +114,19 @@ def triton_hp_tensor_to_float8_dynamic( # allocate memory for computed scale, local block maxes, and output fp8 tensor scale_out = torch.empty((1,), dtype=torch.float32, device=hp_tensor.device) - block_amaxes = torch.zeros( - (num_elements // BLOCK_SIZE,), dtype=torch.float32, device=hp_tensor.device - ) + global_amax = torch.zeros((1,), dtype=torch.float32, device=hp_tensor.device) fp8_output = torch.empty_like( flattened_input, dtype=fp8_dtype, device=hp_tensor.device ) - # compute local amax for each block grid = lambda meta: (triton.cdiv(num_elements, meta["BLOCK_SIZE"]),) + + # compute global amax to be used for scaling _block_amax[grid]( flattened_input, - block_amaxes, + global_amax, num_elements, input_dtype=tl_input_dtype, - BLOCK_SIZE=BLOCK_SIZE, - EPS=EPS, - ) - - # calculate global amax across all blocks and use it to compute scale - _fp8_scale[(1, 1, 1)]( - block_amaxes, - scale_out, - num_elements, - fp8_dtype_max, - BLOCK_SIZE=BLOCK_SIZE, EPS=EPS, ) @@ -158,13 +134,13 @@ def triton_hp_tensor_to_float8_dynamic( _to_fp8[grid]( flattened_input, scale_out, + global_amax, fp8_output, num_elements, fp8_dtype_min, fp8_dtype_max, input_dtype=tl_input_dtype, output_dtype=tl_output_dtype, - BLOCK_SIZE=BLOCK_SIZE, EPS=EPS, )