From 5002a80070db84d087f44539b4bbb3452e84a974 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Thu, 19 Dec 2024 09:48:08 -0800 Subject: [PATCH] float8nocompile: two pass fp8 conversion kernel --- .../kernels/fp8_dynamic_tensorwise.py | 144 +++++++++++++----- .../kernels/fp8_dynamic_tensorwise_test.py | 14 +- 2 files changed, 116 insertions(+), 42 deletions(-) diff --git a/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise.py b/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise.py index fc3beb9f5..f2a3ee97d 100644 --- a/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise.py +++ b/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise.py @@ -15,6 +15,7 @@ from torchao.float8.float8_tensor import Float8Tensor, GemmInputRole, LinearMMConfig +EPS = 1e-12 FP8_DTYPE_MAP = { torch.int8: tl.int8, @@ -31,46 +32,75 @@ @triton.jit -def _triton_to_fp8( +def _block_amax( input_ptr, - scale_out_ptr, - tensor_out_ptr, - fp8_dtype_min, - fp8_dtype_max, - n_elements, + block_amaxes_ptr, + num_elements, input_dtype: tl.constexpr, - tensor_out_dtype: tl.constexpr, BLOCK_SIZE: tl.constexpr, EPS: tl.constexpr, ): - offs = tl.arange(0, BLOCK_SIZE) + # compute local amax for each block + block_id = tl.program_id(axis=0) + block_start = block_id * BLOCK_SIZE + block_offs = block_start + tl.arange(0, BLOCK_SIZE) + block_mask = block_offs < num_elements + vals = tl.load(input_ptr + block_offs, mask=block_mask).to(input_dtype) + block_amax = tl.max(tl.abs(vals), axis=0) + tl.store(block_amaxes_ptr + block_id, block_amax) - # get amax - amax = tl.zeros([1], dtype=tl.float64) - for i in range(0, n_elements, BLOCK_SIZE): - block_offs = (i * BLOCK_SIZE) + offs - block_mask = block_offs < n_elements - vals = tl.load(input_ptr + block_offs, mask=block_mask).to(input_dtype) - amax = tl.maximum(amax, tl.max(tl.abs(vals))) - import pdb - pdb.set_trace() +@triton.jit +def _fp8_scale( + block_amaxes_ptr, + scale_out_ptr, + num_elements, + fp8_dtype_max, + BLOCK_SIZE: tl.constexpr, + EPS: tl.constexpr, +): + # calculate global amax across all blocks + global_amax = tl.zeros([1], dtype=tl.float64) + num_blocks = (num_elements + BLOCK_SIZE - 1) // BLOCK_SIZE + for i in range(num_blocks): + block_max = tl.load(block_amaxes_ptr + i) + global_amax = tl.maximum(global_amax, block_max) # compute scale, must be fp32 - scale = (fp8_dtype_max / tl.clamp(amax, min=EPS, max=float("inf"))).to(tl.float32) - scale_offs = tl.arange(0, 1) - tl.store(scale_out_ptr + scale_offs, scale) + scale = (fp8_dtype_max / tl.clamp(global_amax, min=EPS, max=float("inf"))).to( + tl.float32 + ) + scale_off = tl.arange(0, 1) + tl.store(scale_out_ptr + scale_off, scale) + + +@triton.jit +def _to_fp8( + input_ptr, + scale_ptr, + out_ptr, + num_elements, + fp8_dtype_min, + fp8_dtype_max, + input_dtype: tl.constexpr, + output_dtype: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + EPS: tl.constexpr, +): + # load previously computed scale + scale = tl.load(scale_ptr) + + # load block of input tensor + block_id = tl.program_id(axis=0) + block_start = block_id * BLOCK_SIZE + block_offs = block_start + tl.arange(0, BLOCK_SIZE) + mask = block_offs < num_elements + vals = tl.load(input_ptr + block_offs, mask=mask).to(input_dtype) # perform conversion - for i in range(0, n_elements, BLOCK_SIZE): - block_offs = (i * BLOCK_SIZE) + offs - block_mask = block_offs < n_elements - vals = tl.load(input_ptr + block_offs, mask=block_mask) - vals = vals * scale - fp8_vals = tl.clamp(vals, min=fp8_dtype_min, max=fp8_dtype_max).to( - tensor_out_dtype - ) - tl.store(tensor_out_ptr + block_offs, fp8_vals, mask=block_mask) + vals = vals * scale + fp8_vals = tl.clamp(vals, min=fp8_dtype_min, max=fp8_dtype_max).to(output_dtype) + tl.store(out_ptr + block_offs, fp8_vals, mask=mask) def triton_hp_tensor_to_float8_dynamic( @@ -80,32 +110,64 @@ def triton_hp_tensor_to_float8_dynamic( gemm_input_role: GemmInputRole = GemmInputRole.INPUT, ) -> Float8Tensor: + BLOCK_SIZE = 8 + num_elements = hp_tensor.numel() + orig_shape = hp_tensor.shape + flattened_input = hp_tensor.flatten() + tl_input_dtype = FP8_DTYPE_MAP[hp_tensor.dtype] tl_output_dtype = FP8_DTYPE_MAP[fp8_dtype] - grid = lambda meta: (triton.cdiv(hp_tensor.numel(), meta["BLOCK_SIZE"]),) - - tensor_out = torch.empty_like(hp_tensor, dtype=fp8_dtype, device=hp_tensor.device) - scale_out = torch.empty((1,), dtype=torch.float32, device=hp_tensor.device) fp8_dtype_min = torch.finfo(fp8_dtype).min fp8_dtype_max = torch.finfo(fp8_dtype).max - _triton_to_fp8[grid]( - hp_tensor.flatten(), + # allocate memory for computed scale, local block maxes, and output fp8 tensor + scale_out = torch.empty((1,), dtype=torch.float32, device=hp_tensor.device) + block_amaxes = torch.zeros( + (num_elements // BLOCK_SIZE,), dtype=torch.float32, device=hp_tensor.device + ) + fp8_output = torch.empty_like( + flattened_input, dtype=fp8_dtype, device=hp_tensor.device + ) + + # compute local amax for each block + grid = lambda meta: (triton.cdiv(num_elements, meta["BLOCK_SIZE"]),) + _block_amax[grid]( + flattened_input, + block_amaxes, + num_elements, + input_dtype=tl_input_dtype, + BLOCK_SIZE=BLOCK_SIZE, + EPS=EPS, + ) + + # calculate global amax across all blocks and use it to compute scale + _fp8_scale[(1, 1, 1)]( + block_amaxes, scale_out, - tensor_out, - fp8_dtype_min, + num_elements, fp8_dtype_max, + BLOCK_SIZE=BLOCK_SIZE, + EPS=EPS, + ) + + # perform conversion + _to_fp8[grid]( + flattened_input, + scale_out, + fp8_output, num_elements, + fp8_dtype_min, + fp8_dtype_max, input_dtype=tl_input_dtype, - tensor_out_dtype=tl_output_dtype, - BLOCK_SIZE=8, # TODO: tune - EPS=1e-12, + output_dtype=tl_output_dtype, + BLOCK_SIZE=BLOCK_SIZE, + EPS=EPS, ) return Float8Tensor( - tensor_out, + fp8_output.reshape(orig_shape), scale_out, orig_dtype=hp_tensor.dtype, linear_mm_config=linear_mm_config, diff --git a/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise_test.py b/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise_test.py index 1e795e746..f6ebb3e90 100644 --- a/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise_test.py +++ b/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise_test.py @@ -28,5 +28,17 @@ def test_fp8_triton_hp_tensor_to_float8_dynamic(): LinearMMConfig(), ) + def allclose_fp8(tensor1, tensor2, atol=1e-3, rtol=1e-3): + if tensor1.shape != tensor2.shape: + raise ValueError("Tensors must have the same shape for comparison.") + if tensor1.dtype != tensor2.dtype: + raise ValueError("Tensors must have the same dtype for comparison.") + + # convert fp8 tensors to a higher precision (e.g., float32) for comparison + # since fp8 ops necessary for allclose are not supported + tensor1_fp32 = tensor1.to(torch.float32) + tensor2_fp32 = tensor2.to(torch.float32) + return torch.allclose(tensor1_fp32, tensor2_fp32, atol=atol, rtol=rtol) + assert torch.allclose(x_fp8._scale, y_fp8._scale, atol=1e-3, rtol=1e-3) - assert torch.allclose(x_fp8._data, y_fp8._data, atol=1e-3, rtol=1e-3) + assert allclose_fp8(x_fp8._data, y_fp8._data, atol=1e-3, rtol=1e-3)