-
Notifications
You must be signed in to change notification settings - Fork 188
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[float8nocompile]: Triton kernels for conversion to float8 dtypes for forward pass of Float8LinearNoCompile #1445
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
kernels/ | ||
kernels/autogen/ |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD 3-Clause license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
""" | ||
Triton kernels for scaling high precision tensors to float8. | ||
""" | ||
|
||
import torch | ||
|
||
import triton | ||
import triton.language as tl | ||
|
||
from torchao.float8.float8_tensor import Float8Tensor, GemmInputRole, LinearMMConfig | ||
|
||
|
||
FP8_DTYPE_MAP = { | ||
torch.int8: tl.int8, | ||
torch.int16: tl.int16, | ||
torch.int32: tl.int32, | ||
torch.int64: tl.int64, | ||
torch.float8_e4m3fn: tl.float8e4nv, | ||
torch.float8_e5m2: tl.float8e5, | ||
torch.float16: tl.float16, | ||
torch.bfloat16: tl.bfloat16, | ||
torch.float32: tl.float32, | ||
torch.float64: tl.float64, | ||
} | ||
|
||
|
||
@triton.jit | ||
def _triton_to_fp8( | ||
input_ptr, | ||
scale_out_ptr, | ||
tensor_out_ptr, | ||
fp8_dtype_min, | ||
fp8_dtype_max, | ||
n_elements, | ||
input_dtype: tl.constexpr, | ||
tensor_out_dtype: tl.constexpr, | ||
BLOCK_SIZE: tl.constexpr, | ||
EPS: tl.constexpr, | ||
): | ||
offs = tl.arange(0, BLOCK_SIZE) | ||
|
||
# get amax | ||
amax = tl.zeros([1], dtype=tl.float64) | ||
for i in range(0, n_elements, BLOCK_SIZE): | ||
block_offs = (i * BLOCK_SIZE) + offs | ||
block_mask = block_offs < n_elements | ||
vals = tl.load(input_ptr + block_offs, mask=block_mask).to(input_dtype) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would have expected a single There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh I think I see the problem, I didn't update the grid dimensions when I changed the kernel implementation strategy. In this simple strategy there should only be 1 thread block / cell, which finds the global amax via loading one block at a time into SRAM in a loop as seen here. My original strategy was to divide the tensor into some number of blocks, compute a local amax for each block in parallel, writing those to a shared buffer of some sort, then have a second step which computes the global amax from those local amaxes. However, I decided I wanted to try this simple approach first to just get the code functionally correct to start. Unfortunately a SEV caused me to lose access to my devgpu but I will verify the fix tomorrow then make it more parallelized. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @vkuzo I went back to the drawing board and rewrote this using the original strategy I had in mind (described in my previous comment) and it works now. It's a bit more complicated and requires breaking up the conversion process into separate kernels, but it works and uses a higher degree of parallelism so should be reasonably performant. However, I am wondering if it would be more efficient for each block in the final If time allows I'd like to benchmark/profile both approaches. |
||
amax = tl.maximum(amax, tl.max(tl.abs(vals))) | ||
import pdb | ||
|
||
pdb.set_trace() | ||
|
||
# compute scale, must be fp32 | ||
scale = (fp8_dtype_max / tl.clamp(amax, min=EPS, max=float("inf"))).to(tl.float32) | ||
scale_offs = tl.arange(0, 1) | ||
tl.store(scale_out_ptr + scale_offs, scale) | ||
|
||
# perform conversion | ||
for i in range(0, n_elements, BLOCK_SIZE): | ||
block_offs = (i * BLOCK_SIZE) + offs | ||
block_mask = block_offs < n_elements | ||
vals = tl.load(input_ptr + block_offs, mask=block_mask) | ||
vals = vals * scale | ||
fp8_vals = tl.clamp(vals, min=fp8_dtype_min, max=fp8_dtype_max).to( | ||
tensor_out_dtype | ||
) | ||
tl.store(tensor_out_ptr + block_offs, fp8_vals, mask=block_mask) | ||
|
||
|
||
def triton_hp_tensor_to_float8_dynamic( | ||
hp_tensor: torch.Tensor, | ||
fp8_dtype: torch.dtype, | ||
linear_mm_config: LinearMMConfig, | ||
gemm_input_role: GemmInputRole = GemmInputRole.INPUT, | ||
) -> Float8Tensor: | ||
|
||
num_elements = hp_tensor.numel() | ||
tl_input_dtype = FP8_DTYPE_MAP[hp_tensor.dtype] | ||
tl_output_dtype = FP8_DTYPE_MAP[fp8_dtype] | ||
grid = lambda meta: (triton.cdiv(hp_tensor.numel(), meta["BLOCK_SIZE"]),) | ||
|
||
tensor_out = torch.empty_like(hp_tensor, dtype=fp8_dtype, device=hp_tensor.device) | ||
scale_out = torch.empty((1,), dtype=torch.float32, device=hp_tensor.device) | ||
|
||
fp8_dtype_min = torch.finfo(fp8_dtype).min | ||
fp8_dtype_max = torch.finfo(fp8_dtype).max | ||
|
||
_triton_to_fp8[grid]( | ||
hp_tensor.flatten(), | ||
scale_out, | ||
tensor_out, | ||
fp8_dtype_min, | ||
fp8_dtype_max, | ||
num_elements, | ||
input_dtype=tl_input_dtype, | ||
tensor_out_dtype=tl_output_dtype, | ||
BLOCK_SIZE=8, # TODO: tune | ||
EPS=1e-12, | ||
) | ||
|
||
return Float8Tensor( | ||
tensor_out, | ||
scale_out, | ||
orig_dtype=hp_tensor.dtype, | ||
linear_mm_config=linear_mm_config, | ||
gemm_input_role=gemm_input_role, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import pytest | ||
import torch | ||
from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic | ||
from torchao.float8.float8_tensor import LinearMMConfig | ||
from torchao.prototype.float8nocompile.kernels.fp8_dynamic_tensorwise import ( | ||
triton_hp_tensor_to_float8_dynamic, | ||
) | ||
|
||
|
||
def test_fp8_triton_hp_tensor_to_float8_dynamic(): | ||
assert torch.cuda.is_available() | ||
device = "cuda" | ||
input_bf16 = torch.randn((4, 4), dtype=torch.bfloat16, device=device) | ||
x_bf16 = input_bf16.clone().detach().to(device) | ||
y_bf16 = input_bf16.clone().detach().to(device) | ||
|
||
# production implementation | ||
x_fp8 = hp_tensor_to_float8_dynamic( | ||
x_bf16, | ||
torch.float8_e4m3fn, | ||
LinearMMConfig(), | ||
) | ||
|
||
# float8nocompile triton implementation | ||
y_fp8 = triton_hp_tensor_to_float8_dynamic( | ||
y_bf16, | ||
torch.float8_e4m3fn, | ||
LinearMMConfig(), | ||
) | ||
|
||
assert torch.allclose(x_fp8._scale, y_fp8._scale, atol=1e-3, rtol=1e-3) | ||
assert torch.allclose(x_fp8._data, y_fp8._data, atol=1e-3, rtol=1e-3) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if I'm reading this code right, the
amax
value will be calculated per a single cell in the grid - I would have expected the same amax to be used for all cells in the grid, there should be a synchronization step somewhere