Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[float8nocompile]: Triton kernels for conversion to float8 dtypes for forward pass of Float8LinearNoCompile #1445

Merged
merged 3 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torchao/prototype/float8nocompile/.gitignore
Original file line number Diff line number Diff line change
@@ -1 +1 @@
kernels/
kernels/autogen/
113 changes: 113 additions & 0 deletions torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

"""
Triton kernels for scaling high precision tensors to float8.
"""

import torch

import triton
import triton.language as tl

from torchao.float8.float8_tensor import Float8Tensor, GemmInputRole, LinearMMConfig


FP8_DTYPE_MAP = {
torch.int8: tl.int8,
torch.int16: tl.int16,
torch.int32: tl.int32,
torch.int64: tl.int64,
torch.float8_e4m3fn: tl.float8e4nv,
torch.float8_e5m2: tl.float8e5,
torch.float16: tl.float16,
torch.bfloat16: tl.bfloat16,
torch.float32: tl.float32,
torch.float64: tl.float64,
}


@triton.jit
def _triton_to_fp8(
input_ptr,
scale_out_ptr,
tensor_out_ptr,
fp8_dtype_min,
fp8_dtype_max,
n_elements,
input_dtype: tl.constexpr,
tensor_out_dtype: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
EPS: tl.constexpr,
):
offs = tl.arange(0, BLOCK_SIZE)

# get amax
amax = tl.zeros([1], dtype=tl.float64)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if I'm reading this code right, the amax value will be calculated per a single cell in the grid - I would have expected the same amax to be used for all cells in the grid, there should be a synchronization step somewhere

for i in range(0, n_elements, BLOCK_SIZE):
block_offs = (i * BLOCK_SIZE) + offs
block_mask = block_offs < n_elements
vals = tl.load(input_ptr + block_offs, mask=block_mask).to(input_dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would have expected a single tl.load per grid cell, if I'm understanding this code correctly it's both launching a grid and also iterating through every element of the tensor from each cell in the grid. I could be missing something

Copy link
Contributor Author

@danielvegamyhre danielvegamyhre Dec 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I think I see the problem, I didn't update the grid dimensions when I changed the kernel implementation strategy. In this simple strategy there should only be 1 thread block / cell, which finds the global amax via loading one block at a time into SRAM in a loop as seen here.

My original strategy was to divide the tensor into some number of blocks, compute a local amax for each block in parallel, writing those to a shared buffer of some sort, then have a second step which computes the global amax from those local amaxes. However, I decided I wanted to try this simple approach first to just get the code functionally correct to start.

Unfortunately a SEV caused me to lose access to my devgpu but I will verify the fix tomorrow then make it more parallelized.

Copy link
Contributor Author

@danielvegamyhre danielvegamyhre Dec 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vkuzo I went back to the drawing board and rewrote this using the original strategy I had in mind (described in my previous comment) and it works now. It's a bit more complicated and requires breaking up the conversion process into separate kernels, but it works and uses a higher degree of parallelism so should be reasonably performant.

However, I am wondering if it would be more efficient for each block in the final _to_fp8() kernel to locally compute the global scale from the previously computed block_amaxes, rather than have the separate kernel _scale() do it before hand and share that scale value everywhere. My reasoning is that all instances of the _to_fp8() kernel will be waiting on the _scale() kernel to finish before they can launch, and the memory access latency overhead of the _scale() kernel moving data back and forth between HBM <-> SRAM will be slower than each instance of _to_fp8() just computing scale locally.

If time allows I'd like to benchmark/profile both approaches.

amax = tl.maximum(amax, tl.max(tl.abs(vals)))
import pdb

pdb.set_trace()

# compute scale, must be fp32
scale = (fp8_dtype_max / tl.clamp(amax, min=EPS, max=float("inf"))).to(tl.float32)
scale_offs = tl.arange(0, 1)
tl.store(scale_out_ptr + scale_offs, scale)

# perform conversion
for i in range(0, n_elements, BLOCK_SIZE):
block_offs = (i * BLOCK_SIZE) + offs
block_mask = block_offs < n_elements
vals = tl.load(input_ptr + block_offs, mask=block_mask)
vals = vals * scale
fp8_vals = tl.clamp(vals, min=fp8_dtype_min, max=fp8_dtype_max).to(
tensor_out_dtype
)
tl.store(tensor_out_ptr + block_offs, fp8_vals, mask=block_mask)


def triton_hp_tensor_to_float8_dynamic(
hp_tensor: torch.Tensor,
fp8_dtype: torch.dtype,
linear_mm_config: LinearMMConfig,
gemm_input_role: GemmInputRole = GemmInputRole.INPUT,
) -> Float8Tensor:

num_elements = hp_tensor.numel()
tl_input_dtype = FP8_DTYPE_MAP[hp_tensor.dtype]
tl_output_dtype = FP8_DTYPE_MAP[fp8_dtype]
grid = lambda meta: (triton.cdiv(hp_tensor.numel(), meta["BLOCK_SIZE"]),)

tensor_out = torch.empty_like(hp_tensor, dtype=fp8_dtype, device=hp_tensor.device)
scale_out = torch.empty((1,), dtype=torch.float32, device=hp_tensor.device)

fp8_dtype_min = torch.finfo(fp8_dtype).min
fp8_dtype_max = torch.finfo(fp8_dtype).max

_triton_to_fp8[grid](
hp_tensor.flatten(),
scale_out,
tensor_out,
fp8_dtype_min,
fp8_dtype_max,
num_elements,
input_dtype=tl_input_dtype,
tensor_out_dtype=tl_output_dtype,
BLOCK_SIZE=8, # TODO: tune
EPS=1e-12,
)

return Float8Tensor(
tensor_out,
scale_out,
orig_dtype=hp_tensor.dtype,
linear_mm_config=linear_mm_config,
gemm_input_role=gemm_input_role,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import pytest
import torch
from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic
from torchao.float8.float8_tensor import LinearMMConfig
from torchao.prototype.float8nocompile.kernels.fp8_dynamic_tensorwise import (
triton_hp_tensor_to_float8_dynamic,
)


def test_fp8_triton_hp_tensor_to_float8_dynamic():
assert torch.cuda.is_available()
device = "cuda"
input_bf16 = torch.randn((4, 4), dtype=torch.bfloat16, device=device)
x_bf16 = input_bf16.clone().detach().to(device)
y_bf16 = input_bf16.clone().detach().to(device)

# production implementation
x_fp8 = hp_tensor_to_float8_dynamic(
x_bf16,
torch.float8_e4m3fn,
LinearMMConfig(),
)

# float8nocompile triton implementation
y_fp8 = triton_hp_tensor_to_float8_dynamic(
y_bf16,
torch.float8_e4m3fn,
LinearMMConfig(),
)

assert torch.allclose(x_fp8._scale, y_fp8._scale, atol=1e-3, rtol=1e-3)
assert torch.allclose(x_fp8._data, y_fp8._data, atol=1e-3, rtol=1e-3)
Loading