Skip to content

Commit

Permalink
refactor to make kernel algo configurable; refactor unit tests to tes…
Browse files Browse the repository at this point in the history
…t both algos
  • Loading branch information
danielvegamyhre committed Dec 20, 2024
1 parent 40165e8 commit faf855f
Show file tree
Hide file tree
Showing 2 changed files with 168 additions and 26 deletions.
179 changes: 155 additions & 24 deletions torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""
Triton kernels for scaling high precision tensors to float8.
"""
from enum import Enum

import torch
import triton
Expand All @@ -29,16 +30,28 @@
torch.float64: tl.float64,
}


class KernelAlgorithm(Enum):
"""Enum for FP8 conversion strategy."""

# use atomic max to compute global amax between blocks
ATOMIC_MAX = "atomic_max"

# reduce shared buffer containing local block amaxes to find global amax
REDUCTION = "reduction"


kernel_configs = [
triton.Config({"BLOCK_SIZE": 128}, num_warps=1),
triton.Config({"BLOCK_SIZE": 256}, num_warps=2),
triton.Config({"BLOCK_SIZE": 512}, num_warps=4),
]


# --- atomic max version of kernel ---
@triton.autotune(configs=kernel_configs, key=["input_size"])
@triton.jit
def _block_amax(
def _block_amax_atomic(
input_ptr,
amax_ptr,
num_elements,
Expand All @@ -58,7 +71,7 @@ def _block_amax(

@triton.autotune(configs=kernel_configs, key=["input_size"])
@triton.jit
def _to_fp8(
def _to_fp8_atomic(
input_ptr,
scale_out_ptr,
amax_ptr,
Expand All @@ -76,6 +89,7 @@ def _to_fp8(
scale = (fp8_dtype_max / tl.clamp(global_amax, min=EPS, max=float("inf"))).to(
tl.float32
)

# only one program needs to store the scale
block_id = tl.program_id(axis=0)
if block_id == 0:
Expand All @@ -94,11 +108,85 @@ def _to_fp8(
tl.store(out_ptr + block_offs, fp8_vals, mask=mask)


# --- reduction version of kernel ---
@triton.jit
def _block_amax_reduction(
input_ptr,
block_amaxes_ptr,
num_elements,
input_dtype: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
EPS: tl.constexpr,
):
# compute local amax for each block
block_id = tl.program_id(axis=0)
block_start = block_id * BLOCK_SIZE
block_offs = block_start + tl.arange(0, BLOCK_SIZE)
block_mask = block_offs < num_elements
vals = tl.load(input_ptr + block_offs, mask=block_mask).to(input_dtype)
block_amax = tl.max(tl.abs(vals), axis=0)
tl.store(block_amaxes_ptr + block_id, block_amax)


@triton.jit
def _fp8_scale_reduction(
block_amaxes_ptr,
scale_out_ptr,
num_elements,
fp8_dtype_max,
BLOCK_SIZE: tl.constexpr,
EPS: tl.constexpr,
):
# calculate global amax across all blocks
global_amax = tl.zeros([1], dtype=tl.float64)
num_blocks = tl.cdiv(num_elements, BLOCK_SIZE)
for i in range(num_blocks):
block_max = tl.load(block_amaxes_ptr + i)
global_amax = tl.maximum(global_amax, block_max)

# compute scale, must be fp32
scale = (fp8_dtype_max / tl.clamp(global_amax, min=EPS, max=float("inf"))).to(
tl.float32
)
scale_off = tl.arange(0, 1)
tl.store(scale_out_ptr + scale_off, scale)


@triton.jit
def _to_fp8_reduction(
input_ptr,
scale_ptr,
out_ptr,
num_elements,
fp8_dtype_min,
fp8_dtype_max,
input_dtype: tl.constexpr,
output_dtype: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
EPS: tl.constexpr,
):
# load previously computed scale
scale = tl.load(scale_ptr)

# load block of input tensor
block_id = tl.program_id(axis=0)
block_start = block_id * BLOCK_SIZE
block_offs = block_start + tl.arange(0, BLOCK_SIZE)
mask = block_offs < num_elements
vals = tl.load(input_ptr + block_offs, mask=mask).to(input_dtype)

# perform conversion
vals = vals * scale
fp8_vals = tl.clamp(vals, min=fp8_dtype_min, max=fp8_dtype_max).to(output_dtype)
tl.store(out_ptr + block_offs, fp8_vals, mask=mask)


def triton_hp_tensor_to_float8_dynamic(
hp_tensor: torch.Tensor,
fp8_dtype: torch.dtype,
linear_mm_config: LinearMMConfig,
gemm_input_role: GemmInputRole = GemmInputRole.INPUT,
algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_MAX,
) -> Float8Tensor:
assert hp_tensor.is_contiguous(), "tensor must be contiguous"

Expand All @@ -114,35 +202,78 @@ def triton_hp_tensor_to_float8_dynamic(

# allocate memory for computed scale, local block maxes, and output fp8 tensor
scale_out = torch.empty((1,), dtype=torch.float32, device=hp_tensor.device)
global_amax = torch.zeros((1,), dtype=torch.float32, device=hp_tensor.device)

fp8_output = torch.empty_like(
flattened_input, dtype=fp8_dtype, device=hp_tensor.device
)

grid = lambda meta: (triton.cdiv(num_elements, meta["BLOCK_SIZE"]),)

# compute global amax to be used for scaling
_block_amax[grid](
flattened_input,
global_amax,
num_elements,
input_dtype=tl_input_dtype,
EPS=EPS,
)
if algo == KernelAlgorithm.ATOMIC_MAX:
global_amax = torch.zeros((1,), dtype=torch.float32, device=hp_tensor.device)
# compute global amax to be used for scaling
_block_amax_atomic[grid](
flattened_input,
global_amax,
num_elements,
input_dtype=tl_input_dtype,
EPS=EPS,
)

# perform conversion
_to_fp8[grid](
flattened_input,
scale_out,
global_amax,
fp8_output,
num_elements,
fp8_dtype_min,
fp8_dtype_max,
input_dtype=tl_input_dtype,
output_dtype=tl_output_dtype,
EPS=EPS,
)
# perform conversion and store scale for use in Float8Tensor
_to_fp8_atomic[grid](
flattened_input,
scale_out,
global_amax,
fp8_output,
num_elements,
fp8_dtype_min,
fp8_dtype_max,
input_dtype=tl_input_dtype,
output_dtype=tl_output_dtype,
EPS=EPS,
)
elif algo == KernelAlgorithm.REDUCTION:
max_block_size = 512
BLOCK_SIZE = min(max_block_size, num_elements)
block_amaxes = torch.zeros(
(num_elements // BLOCK_SIZE,), dtype=torch.float32, device=hp_tensor.device
)
# compute local amax for each block
_block_amax_reduction[grid](
flattened_input,
block_amaxes,
num_elements,
input_dtype=tl_input_dtype,
BLOCK_SIZE=BLOCK_SIZE,
EPS=EPS,
)

# calculate global amax across all blocks and use it to compute scale
_fp8_scale_reduction[(1, 1, 1)](
block_amaxes,
scale_out,
num_elements,
fp8_dtype_max,
BLOCK_SIZE=BLOCK_SIZE,
EPS=EPS,
)

# perform conversion
_to_fp8_reduction[grid](
flattened_input,
scale_out,
fp8_output,
num_elements,
fp8_dtype_min,
fp8_dtype_max,
input_dtype=tl_input_dtype,
output_dtype=tl_output_dtype,
BLOCK_SIZE=BLOCK_SIZE,
EPS=EPS,
)
else:
raise ValueError(f"Unsupported kernel algorithm: {algo}")

return Float8Tensor(
fp8_output.reshape(orig_shape),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,24 @@
from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic
from torchao.float8.float8_tensor import LinearMMConfig
from torchao.prototype.float8nocompile.kernels.fp8_dynamic_tensorwise import (
KernelAlgorithm,
triton_hp_tensor_to_float8_dynamic,
)


def test_fp8_triton_hp_tensor_to_float8_dynamic():
@pytest.mark.parametrize(
"algo", [KernelAlgorithm.ATOMIC_MAX, KernelAlgorithm.REDUCTION]
)
@pytest.mark.parametrize(
"input_shape",
[(32, 32), (512, 512), (4096, 4096)],
)
def test_fp8_triton_hp_tensor_to_float8_dynamic(
algo: KernelAlgorithm, input_shape: tuple[int, int]
):
assert torch.cuda.is_available()
device = "cuda"
input_bf16 = torch.randn((4, 4), dtype=torch.bfloat16, device=device)
input_bf16 = torch.randn(input_shape, dtype=torch.bfloat16, device=device)
x_bf16 = input_bf16.clone().detach().to(device)
y_bf16 = input_bf16.clone().detach().to(device)

Expand All @@ -26,6 +36,7 @@ def test_fp8_triton_hp_tensor_to_float8_dynamic():
y_bf16,
torch.float8_e4m3fn,
LinearMMConfig(),
algo=algo,
)

def allclose_fp8(tensor1, tensor2, atol=1e-3, rtol=1e-3):
Expand Down

0 comments on commit faf855f

Please sign in to comment.