Skip to content

Commit

Permalink
refactor float8nocompile kernel so autotune is easily usable
Browse files Browse the repository at this point in the history
  • Loading branch information
danielvegamyhre committed Dec 20, 2024
1 parent 3bac905 commit 40165e8
Showing 1 changed file with 27 additions and 51 deletions.
78 changes: 27 additions & 51 deletions torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
"""

import torch

import triton
import triton.language as tl

Expand All @@ -30,11 +29,18 @@
torch.float64: tl.float64,
}

kernel_configs = [
triton.Config({"BLOCK_SIZE": 128}, num_warps=1),
triton.Config({"BLOCK_SIZE": 256}, num_warps=2),
triton.Config({"BLOCK_SIZE": 512}, num_warps=4),
]


@triton.autotune(configs=kernel_configs, key=["input_size"])
@triton.jit
def _block_amax(
input_ptr,
block_amaxes_ptr,
amax_ptr,
num_elements,
input_dtype: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
Expand All @@ -47,37 +53,15 @@ def _block_amax(
block_mask = block_offs < num_elements
vals = tl.load(input_ptr + block_offs, mask=block_mask).to(input_dtype)
block_amax = tl.max(tl.abs(vals), axis=0)
tl.store(block_amaxes_ptr + block_id, block_amax)


@triton.jit
def _fp8_scale(
block_amaxes_ptr,
scale_out_ptr,
num_elements,
fp8_dtype_max,
BLOCK_SIZE: tl.constexpr,
EPS: tl.constexpr,
):
# calculate global amax across all blocks
global_amax = tl.zeros([1], dtype=tl.float64)
num_blocks = tl.cdiv(num_elements, BLOCK_SIZE)
for i in range(num_blocks):
block_max = tl.load(block_amaxes_ptr + i)
global_amax = tl.maximum(global_amax, block_max)

# compute scale, must be fp32
scale = (fp8_dtype_max / tl.clamp(global_amax, min=EPS, max=float("inf"))).to(
tl.float32
)
scale_off = tl.arange(0, 1)
tl.store(scale_out_ptr + scale_off, scale)
tl.atomic_max(amax_ptr, block_amax)


@triton.autotune(configs=kernel_configs, key=["input_size"])
@triton.jit
def _to_fp8(
input_ptr,
scale_ptr,
scale_out_ptr,
amax_ptr,
out_ptr,
num_elements,
fp8_dtype_min,
Expand All @@ -87,11 +71,18 @@ def _to_fp8(
BLOCK_SIZE: tl.constexpr,
EPS: tl.constexpr,
):
# load previously computed scale
scale = tl.load(scale_ptr)
# compute scale, must be fp32
global_amax = tl.load(amax_ptr)
scale = (fp8_dtype_max / tl.clamp(global_amax, min=EPS, max=float("inf"))).to(
tl.float32
)
# only one program needs to store the scale
block_id = tl.program_id(axis=0)
if block_id == 0:
scale_offs = tl.arange(0, 1)
tl.store(scale_out_ptr + scale_offs, scale)

# load block of input tensor
block_id = tl.program_id(axis=0)
block_start = block_id * BLOCK_SIZE
block_offs = block_start + tl.arange(0, BLOCK_SIZE)
mask = block_offs < num_elements
Expand All @@ -109,11 +100,8 @@ def triton_hp_tensor_to_float8_dynamic(
linear_mm_config: LinearMMConfig,
gemm_input_role: GemmInputRole = GemmInputRole.INPUT,
) -> Float8Tensor:

assert hp_tensor.is_contiguous(), "tensor must be contiguous"

BLOCK_SIZE = 8 # TODO(danielvegamyhre): tune this for perf

num_elements = hp_tensor.numel()
orig_shape = hp_tensor.shape
flattened_input = hp_tensor.flatten()
Expand All @@ -126,45 +114,33 @@ def triton_hp_tensor_to_float8_dynamic(

# allocate memory for computed scale, local block maxes, and output fp8 tensor
scale_out = torch.empty((1,), dtype=torch.float32, device=hp_tensor.device)
block_amaxes = torch.zeros(
(num_elements // BLOCK_SIZE,), dtype=torch.float32, device=hp_tensor.device
)
global_amax = torch.zeros((1,), dtype=torch.float32, device=hp_tensor.device)
fp8_output = torch.empty_like(
flattened_input, dtype=fp8_dtype, device=hp_tensor.device
)

# compute local amax for each block
grid = lambda meta: (triton.cdiv(num_elements, meta["BLOCK_SIZE"]),)

# compute global amax to be used for scaling
_block_amax[grid](
flattened_input,
block_amaxes,
global_amax,
num_elements,
input_dtype=tl_input_dtype,
BLOCK_SIZE=BLOCK_SIZE,
EPS=EPS,
)

# calculate global amax across all blocks and use it to compute scale
_fp8_scale[(1, 1, 1)](
block_amaxes,
scale_out,
num_elements,
fp8_dtype_max,
BLOCK_SIZE=BLOCK_SIZE,
EPS=EPS,
)

# perform conversion
_to_fp8[grid](
flattened_input,
scale_out,
global_amax,
fp8_output,
num_elements,
fp8_dtype_min,
fp8_dtype_max,
input_dtype=tl_input_dtype,
output_dtype=tl_output_dtype,
BLOCK_SIZE=BLOCK_SIZE,
EPS=EPS,
)

Expand Down

0 comments on commit 40165e8

Please sign in to comment.