From 18de9e04399f12375cd9d5c185c4b2d3a9ebacb4 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Wed, 18 Dec 2024 15:58:07 -0800 Subject: [PATCH 1/3] float8nocompile: triton kernel for conversion to fp8e4m3 --- torchao/prototype/float8nocompile/.gitignore | 2 +- .../kernels/fp8_dynamic_tensorwise.py | 113 ++++++++++++++++++ .../kernels/fp8_dynamic_tensorwise_test.py | 32 +++++ 3 files changed, 146 insertions(+), 1 deletion(-) create mode 100644 torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise.py create mode 100644 torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise_test.py diff --git a/torchao/prototype/float8nocompile/.gitignore b/torchao/prototype/float8nocompile/.gitignore index 174f6de301..77d8c1ac39 100644 --- a/torchao/prototype/float8nocompile/.gitignore +++ b/torchao/prototype/float8nocompile/.gitignore @@ -1 +1 @@ -kernels/ +kernels/autogen/ diff --git a/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise.py b/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise.py new file mode 100644 index 0000000000..fc3beb9f53 --- /dev/null +++ b/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise.py @@ -0,0 +1,113 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +""" +Triton kernels for scaling high precision tensors to float8. +""" + +import torch + +import triton +import triton.language as tl + +from torchao.float8.float8_tensor import Float8Tensor, GemmInputRole, LinearMMConfig + + +FP8_DTYPE_MAP = { + torch.int8: tl.int8, + torch.int16: tl.int16, + torch.int32: tl.int32, + torch.int64: tl.int64, + torch.float8_e4m3fn: tl.float8e4nv, + torch.float8_e5m2: tl.float8e5, + torch.float16: tl.float16, + torch.bfloat16: tl.bfloat16, + torch.float32: tl.float32, + torch.float64: tl.float64, +} + + +@triton.jit +def _triton_to_fp8( + input_ptr, + scale_out_ptr, + tensor_out_ptr, + fp8_dtype_min, + fp8_dtype_max, + n_elements, + input_dtype: tl.constexpr, + tensor_out_dtype: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + EPS: tl.constexpr, +): + offs = tl.arange(0, BLOCK_SIZE) + + # get amax + amax = tl.zeros([1], dtype=tl.float64) + for i in range(0, n_elements, BLOCK_SIZE): + block_offs = (i * BLOCK_SIZE) + offs + block_mask = block_offs < n_elements + vals = tl.load(input_ptr + block_offs, mask=block_mask).to(input_dtype) + amax = tl.maximum(amax, tl.max(tl.abs(vals))) + import pdb + + pdb.set_trace() + + # compute scale, must be fp32 + scale = (fp8_dtype_max / tl.clamp(amax, min=EPS, max=float("inf"))).to(tl.float32) + scale_offs = tl.arange(0, 1) + tl.store(scale_out_ptr + scale_offs, scale) + + # perform conversion + for i in range(0, n_elements, BLOCK_SIZE): + block_offs = (i * BLOCK_SIZE) + offs + block_mask = block_offs < n_elements + vals = tl.load(input_ptr + block_offs, mask=block_mask) + vals = vals * scale + fp8_vals = tl.clamp(vals, min=fp8_dtype_min, max=fp8_dtype_max).to( + tensor_out_dtype + ) + tl.store(tensor_out_ptr + block_offs, fp8_vals, mask=block_mask) + + +def triton_hp_tensor_to_float8_dynamic( + hp_tensor: torch.Tensor, + fp8_dtype: torch.dtype, + linear_mm_config: LinearMMConfig, + gemm_input_role: GemmInputRole = GemmInputRole.INPUT, +) -> Float8Tensor: + + num_elements = hp_tensor.numel() + tl_input_dtype = FP8_DTYPE_MAP[hp_tensor.dtype] + tl_output_dtype = FP8_DTYPE_MAP[fp8_dtype] + grid = lambda meta: (triton.cdiv(hp_tensor.numel(), meta["BLOCK_SIZE"]),) + + tensor_out = torch.empty_like(hp_tensor, dtype=fp8_dtype, device=hp_tensor.device) + scale_out = torch.empty((1,), dtype=torch.float32, device=hp_tensor.device) + + fp8_dtype_min = torch.finfo(fp8_dtype).min + fp8_dtype_max = torch.finfo(fp8_dtype).max + + _triton_to_fp8[grid]( + hp_tensor.flatten(), + scale_out, + tensor_out, + fp8_dtype_min, + fp8_dtype_max, + num_elements, + input_dtype=tl_input_dtype, + tensor_out_dtype=tl_output_dtype, + BLOCK_SIZE=8, # TODO: tune + EPS=1e-12, + ) + + return Float8Tensor( + tensor_out, + scale_out, + orig_dtype=hp_tensor.dtype, + linear_mm_config=linear_mm_config, + gemm_input_role=gemm_input_role, + ) diff --git a/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise_test.py b/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise_test.py new file mode 100644 index 0000000000..1e795e746e --- /dev/null +++ b/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise_test.py @@ -0,0 +1,32 @@ +import pytest +import torch +from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic +from torchao.float8.float8_tensor import LinearMMConfig +from torchao.prototype.float8nocompile.kernels.fp8_dynamic_tensorwise import ( + triton_hp_tensor_to_float8_dynamic, +) + + +def test_fp8_triton_hp_tensor_to_float8_dynamic(): + assert torch.cuda.is_available() + device = "cuda" + input_bf16 = torch.randn((4, 4), dtype=torch.bfloat16, device=device) + x_bf16 = input_bf16.clone().detach().to(device) + y_bf16 = input_bf16.clone().detach().to(device) + + # production implementation + x_fp8 = hp_tensor_to_float8_dynamic( + x_bf16, + torch.float8_e4m3fn, + LinearMMConfig(), + ) + + # float8nocompile triton implementation + y_fp8 = triton_hp_tensor_to_float8_dynamic( + y_bf16, + torch.float8_e4m3fn, + LinearMMConfig(), + ) + + assert torch.allclose(x_fp8._scale, y_fp8._scale, atol=1e-3, rtol=1e-3) + assert torch.allclose(x_fp8._data, y_fp8._data, atol=1e-3, rtol=1e-3) From 867538ecd8f2093ae4ed442c20a1a64738b0283e Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Thu, 19 Dec 2024 09:48:08 -0800 Subject: [PATCH 2/3] float8nocompile: two pass fp8 conversion kernel --- .../kernels/fp8_dynamic_tensorwise.py | 144 +++++++++++++----- .../kernels/fp8_dynamic_tensorwise_test.py | 14 +- 2 files changed, 116 insertions(+), 42 deletions(-) diff --git a/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise.py b/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise.py index fc3beb9f53..f2a3ee97da 100644 --- a/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise.py +++ b/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise.py @@ -15,6 +15,7 @@ from torchao.float8.float8_tensor import Float8Tensor, GemmInputRole, LinearMMConfig +EPS = 1e-12 FP8_DTYPE_MAP = { torch.int8: tl.int8, @@ -31,46 +32,75 @@ @triton.jit -def _triton_to_fp8( +def _block_amax( input_ptr, - scale_out_ptr, - tensor_out_ptr, - fp8_dtype_min, - fp8_dtype_max, - n_elements, + block_amaxes_ptr, + num_elements, input_dtype: tl.constexpr, - tensor_out_dtype: tl.constexpr, BLOCK_SIZE: tl.constexpr, EPS: tl.constexpr, ): - offs = tl.arange(0, BLOCK_SIZE) + # compute local amax for each block + block_id = tl.program_id(axis=0) + block_start = block_id * BLOCK_SIZE + block_offs = block_start + tl.arange(0, BLOCK_SIZE) + block_mask = block_offs < num_elements + vals = tl.load(input_ptr + block_offs, mask=block_mask).to(input_dtype) + block_amax = tl.max(tl.abs(vals), axis=0) + tl.store(block_amaxes_ptr + block_id, block_amax) - # get amax - amax = tl.zeros([1], dtype=tl.float64) - for i in range(0, n_elements, BLOCK_SIZE): - block_offs = (i * BLOCK_SIZE) + offs - block_mask = block_offs < n_elements - vals = tl.load(input_ptr + block_offs, mask=block_mask).to(input_dtype) - amax = tl.maximum(amax, tl.max(tl.abs(vals))) - import pdb - pdb.set_trace() +@triton.jit +def _fp8_scale( + block_amaxes_ptr, + scale_out_ptr, + num_elements, + fp8_dtype_max, + BLOCK_SIZE: tl.constexpr, + EPS: tl.constexpr, +): + # calculate global amax across all blocks + global_amax = tl.zeros([1], dtype=tl.float64) + num_blocks = (num_elements + BLOCK_SIZE - 1) // BLOCK_SIZE + for i in range(num_blocks): + block_max = tl.load(block_amaxes_ptr + i) + global_amax = tl.maximum(global_amax, block_max) # compute scale, must be fp32 - scale = (fp8_dtype_max / tl.clamp(amax, min=EPS, max=float("inf"))).to(tl.float32) - scale_offs = tl.arange(0, 1) - tl.store(scale_out_ptr + scale_offs, scale) + scale = (fp8_dtype_max / tl.clamp(global_amax, min=EPS, max=float("inf"))).to( + tl.float32 + ) + scale_off = tl.arange(0, 1) + tl.store(scale_out_ptr + scale_off, scale) + + +@triton.jit +def _to_fp8( + input_ptr, + scale_ptr, + out_ptr, + num_elements, + fp8_dtype_min, + fp8_dtype_max, + input_dtype: tl.constexpr, + output_dtype: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + EPS: tl.constexpr, +): + # load previously computed scale + scale = tl.load(scale_ptr) + + # load block of input tensor + block_id = tl.program_id(axis=0) + block_start = block_id * BLOCK_SIZE + block_offs = block_start + tl.arange(0, BLOCK_SIZE) + mask = block_offs < num_elements + vals = tl.load(input_ptr + block_offs, mask=mask).to(input_dtype) # perform conversion - for i in range(0, n_elements, BLOCK_SIZE): - block_offs = (i * BLOCK_SIZE) + offs - block_mask = block_offs < n_elements - vals = tl.load(input_ptr + block_offs, mask=block_mask) - vals = vals * scale - fp8_vals = tl.clamp(vals, min=fp8_dtype_min, max=fp8_dtype_max).to( - tensor_out_dtype - ) - tl.store(tensor_out_ptr + block_offs, fp8_vals, mask=block_mask) + vals = vals * scale + fp8_vals = tl.clamp(vals, min=fp8_dtype_min, max=fp8_dtype_max).to(output_dtype) + tl.store(out_ptr + block_offs, fp8_vals, mask=mask) def triton_hp_tensor_to_float8_dynamic( @@ -80,32 +110,64 @@ def triton_hp_tensor_to_float8_dynamic( gemm_input_role: GemmInputRole = GemmInputRole.INPUT, ) -> Float8Tensor: + BLOCK_SIZE = 8 + num_elements = hp_tensor.numel() + orig_shape = hp_tensor.shape + flattened_input = hp_tensor.flatten() + tl_input_dtype = FP8_DTYPE_MAP[hp_tensor.dtype] tl_output_dtype = FP8_DTYPE_MAP[fp8_dtype] - grid = lambda meta: (triton.cdiv(hp_tensor.numel(), meta["BLOCK_SIZE"]),) - - tensor_out = torch.empty_like(hp_tensor, dtype=fp8_dtype, device=hp_tensor.device) - scale_out = torch.empty((1,), dtype=torch.float32, device=hp_tensor.device) fp8_dtype_min = torch.finfo(fp8_dtype).min fp8_dtype_max = torch.finfo(fp8_dtype).max - _triton_to_fp8[grid]( - hp_tensor.flatten(), + # allocate memory for computed scale, local block maxes, and output fp8 tensor + scale_out = torch.empty((1,), dtype=torch.float32, device=hp_tensor.device) + block_amaxes = torch.zeros( + (num_elements // BLOCK_SIZE,), dtype=torch.float32, device=hp_tensor.device + ) + fp8_output = torch.empty_like( + flattened_input, dtype=fp8_dtype, device=hp_tensor.device + ) + + # compute local amax for each block + grid = lambda meta: (triton.cdiv(num_elements, meta["BLOCK_SIZE"]),) + _block_amax[grid]( + flattened_input, + block_amaxes, + num_elements, + input_dtype=tl_input_dtype, + BLOCK_SIZE=BLOCK_SIZE, + EPS=EPS, + ) + + # calculate global amax across all blocks and use it to compute scale + _fp8_scale[(1, 1, 1)]( + block_amaxes, scale_out, - tensor_out, - fp8_dtype_min, + num_elements, fp8_dtype_max, + BLOCK_SIZE=BLOCK_SIZE, + EPS=EPS, + ) + + # perform conversion + _to_fp8[grid]( + flattened_input, + scale_out, + fp8_output, num_elements, + fp8_dtype_min, + fp8_dtype_max, input_dtype=tl_input_dtype, - tensor_out_dtype=tl_output_dtype, - BLOCK_SIZE=8, # TODO: tune - EPS=1e-12, + output_dtype=tl_output_dtype, + BLOCK_SIZE=BLOCK_SIZE, + EPS=EPS, ) return Float8Tensor( - tensor_out, + fp8_output.reshape(orig_shape), scale_out, orig_dtype=hp_tensor.dtype, linear_mm_config=linear_mm_config, diff --git a/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise_test.py b/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise_test.py index 1e795e746e..979fb347a0 100644 --- a/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise_test.py +++ b/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise_test.py @@ -28,5 +28,17 @@ def test_fp8_triton_hp_tensor_to_float8_dynamic(): LinearMMConfig(), ) + def allclose_fp8(tensor1, tensor2, atol=1e-3, rtol=1e-3): + # convert fp8 tensors to a higher precision (e.g., float32) for comparison + # since torch.allclose does not support fp8 tensors + if tensor1.shape != tensor2.shape: + raise ValueError("Tensors must have the same shape for comparison.") + if tensor1.dtype != tensor2.dtype: + raise ValueError("Tensors must have the same dtype for comparison.") + + tensor1_fp32 = tensor1.to(torch.float32) + tensor2_fp32 = tensor2.to(torch.float32) + return torch.allclose(tensor1_fp32, tensor2_fp32, atol=atol, rtol=rtol) + assert torch.allclose(x_fp8._scale, y_fp8._scale, atol=1e-3, rtol=1e-3) - assert torch.allclose(x_fp8._data, y_fp8._data, atol=1e-3, rtol=1e-3) + assert allclose_fp8(x_fp8._data, y_fp8._data, atol=1e-3, rtol=1e-3) From 7c47b03001ffbdb4042cd9d14de581afb60c76a8 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Thu, 19 Dec 2024 14:21:49 -0800 Subject: [PATCH 3/3] address comments --- .../float8nocompile/kernels/fp8_dynamic_tensorwise.py | 6 ++++-- .../kernels/fp8_dynamic_tensorwise_test.py | 8 ++++++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise.py b/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise.py index f2a3ee97da..0c3929bfa0 100644 --- a/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise.py +++ b/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise.py @@ -61,7 +61,7 @@ def _fp8_scale( ): # calculate global amax across all blocks global_amax = tl.zeros([1], dtype=tl.float64) - num_blocks = (num_elements + BLOCK_SIZE - 1) // BLOCK_SIZE + num_blocks = tl.cdiv(num_elements, BLOCK_SIZE) for i in range(num_blocks): block_max = tl.load(block_amaxes_ptr + i) global_amax = tl.maximum(global_amax, block_max) @@ -110,7 +110,9 @@ def triton_hp_tensor_to_float8_dynamic( gemm_input_role: GemmInputRole = GemmInputRole.INPUT, ) -> Float8Tensor: - BLOCK_SIZE = 8 + assert hp_tensor.is_contiguous(), "tensor must be contiguous" + + BLOCK_SIZE = 8 # TODO(danielvegamyhre): tune this for perf num_elements = hp_tensor.numel() orig_shape = hp_tensor.shape diff --git a/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise_test.py b/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise_test.py index 979fb347a0..bc10f269cd 100644 --- a/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise_test.py +++ b/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise_test.py @@ -42,3 +42,11 @@ def allclose_fp8(tensor1, tensor2, atol=1e-3, rtol=1e-3): assert torch.allclose(x_fp8._scale, y_fp8._scale, atol=1e-3, rtol=1e-3) assert allclose_fp8(x_fp8._data, y_fp8._data, atol=1e-3, rtol=1e-3) + + # assert that error is raised when input tensor is not contiguous + with pytest.raises(AssertionError, match="tensor must be contiguous"): + triton_hp_tensor_to_float8_dynamic( + y_bf16.t(), # transpose so tensor memory layout is no longer contiguous + torch.float8_e4m3fn, + LinearMMConfig(), + )