Skip to content

Commit

Permalink
float8nocompile: two pass fp8 conversion kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
Daniel Vega-Myhre committed Dec 19, 2024
1 parent 18de9e0 commit 5002a80
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 42 deletions.
144 changes: 103 additions & 41 deletions torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from torchao.float8.float8_tensor import Float8Tensor, GemmInputRole, LinearMMConfig

EPS = 1e-12

FP8_DTYPE_MAP = {
torch.int8: tl.int8,
Expand All @@ -31,46 +32,75 @@


@triton.jit
def _triton_to_fp8(
def _block_amax(
input_ptr,
scale_out_ptr,
tensor_out_ptr,
fp8_dtype_min,
fp8_dtype_max,
n_elements,
block_amaxes_ptr,
num_elements,
input_dtype: tl.constexpr,
tensor_out_dtype: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
EPS: tl.constexpr,
):
offs = tl.arange(0, BLOCK_SIZE)
# compute local amax for each block
block_id = tl.program_id(axis=0)
block_start = block_id * BLOCK_SIZE
block_offs = block_start + tl.arange(0, BLOCK_SIZE)
block_mask = block_offs < num_elements
vals = tl.load(input_ptr + block_offs, mask=block_mask).to(input_dtype)
block_amax = tl.max(tl.abs(vals), axis=0)
tl.store(block_amaxes_ptr + block_id, block_amax)

# get amax
amax = tl.zeros([1], dtype=tl.float64)
for i in range(0, n_elements, BLOCK_SIZE):
block_offs = (i * BLOCK_SIZE) + offs
block_mask = block_offs < n_elements
vals = tl.load(input_ptr + block_offs, mask=block_mask).to(input_dtype)
amax = tl.maximum(amax, tl.max(tl.abs(vals)))
import pdb

pdb.set_trace()
@triton.jit
def _fp8_scale(
block_amaxes_ptr,
scale_out_ptr,
num_elements,
fp8_dtype_max,
BLOCK_SIZE: tl.constexpr,
EPS: tl.constexpr,
):
# calculate global amax across all blocks
global_amax = tl.zeros([1], dtype=tl.float64)
num_blocks = (num_elements + BLOCK_SIZE - 1) // BLOCK_SIZE
for i in range(num_blocks):
block_max = tl.load(block_amaxes_ptr + i)
global_amax = tl.maximum(global_amax, block_max)

# compute scale, must be fp32
scale = (fp8_dtype_max / tl.clamp(amax, min=EPS, max=float("inf"))).to(tl.float32)
scale_offs = tl.arange(0, 1)
tl.store(scale_out_ptr + scale_offs, scale)
scale = (fp8_dtype_max / tl.clamp(global_amax, min=EPS, max=float("inf"))).to(
tl.float32
)
scale_off = tl.arange(0, 1)
tl.store(scale_out_ptr + scale_off, scale)


@triton.jit
def _to_fp8(
input_ptr,
scale_ptr,
out_ptr,
num_elements,
fp8_dtype_min,
fp8_dtype_max,
input_dtype: tl.constexpr,
output_dtype: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
EPS: tl.constexpr,
):
# load previously computed scale
scale = tl.load(scale_ptr)

# load block of input tensor
block_id = tl.program_id(axis=0)
block_start = block_id * BLOCK_SIZE
block_offs = block_start + tl.arange(0, BLOCK_SIZE)
mask = block_offs < num_elements
vals = tl.load(input_ptr + block_offs, mask=mask).to(input_dtype)

# perform conversion
for i in range(0, n_elements, BLOCK_SIZE):
block_offs = (i * BLOCK_SIZE) + offs
block_mask = block_offs < n_elements
vals = tl.load(input_ptr + block_offs, mask=block_mask)
vals = vals * scale
fp8_vals = tl.clamp(vals, min=fp8_dtype_min, max=fp8_dtype_max).to(
tensor_out_dtype
)
tl.store(tensor_out_ptr + block_offs, fp8_vals, mask=block_mask)
vals = vals * scale
fp8_vals = tl.clamp(vals, min=fp8_dtype_min, max=fp8_dtype_max).to(output_dtype)
tl.store(out_ptr + block_offs, fp8_vals, mask=mask)


def triton_hp_tensor_to_float8_dynamic(
Expand All @@ -80,32 +110,64 @@ def triton_hp_tensor_to_float8_dynamic(
gemm_input_role: GemmInputRole = GemmInputRole.INPUT,
) -> Float8Tensor:

BLOCK_SIZE = 8

num_elements = hp_tensor.numel()
orig_shape = hp_tensor.shape
flattened_input = hp_tensor.flatten()

tl_input_dtype = FP8_DTYPE_MAP[hp_tensor.dtype]
tl_output_dtype = FP8_DTYPE_MAP[fp8_dtype]
grid = lambda meta: (triton.cdiv(hp_tensor.numel(), meta["BLOCK_SIZE"]),)

tensor_out = torch.empty_like(hp_tensor, dtype=fp8_dtype, device=hp_tensor.device)
scale_out = torch.empty((1,), dtype=torch.float32, device=hp_tensor.device)

fp8_dtype_min = torch.finfo(fp8_dtype).min
fp8_dtype_max = torch.finfo(fp8_dtype).max

_triton_to_fp8[grid](
hp_tensor.flatten(),
# allocate memory for computed scale, local block maxes, and output fp8 tensor
scale_out = torch.empty((1,), dtype=torch.float32, device=hp_tensor.device)
block_amaxes = torch.zeros(
(num_elements // BLOCK_SIZE,), dtype=torch.float32, device=hp_tensor.device
)
fp8_output = torch.empty_like(
flattened_input, dtype=fp8_dtype, device=hp_tensor.device
)

# compute local amax for each block
grid = lambda meta: (triton.cdiv(num_elements, meta["BLOCK_SIZE"]),)
_block_amax[grid](
flattened_input,
block_amaxes,
num_elements,
input_dtype=tl_input_dtype,
BLOCK_SIZE=BLOCK_SIZE,
EPS=EPS,
)

# calculate global amax across all blocks and use it to compute scale
_fp8_scale[(1, 1, 1)](
block_amaxes,
scale_out,
tensor_out,
fp8_dtype_min,
num_elements,
fp8_dtype_max,
BLOCK_SIZE=BLOCK_SIZE,
EPS=EPS,
)

# perform conversion
_to_fp8[grid](
flattened_input,
scale_out,
fp8_output,
num_elements,
fp8_dtype_min,
fp8_dtype_max,
input_dtype=tl_input_dtype,
tensor_out_dtype=tl_output_dtype,
BLOCK_SIZE=8, # TODO: tune
EPS=1e-12,
output_dtype=tl_output_dtype,
BLOCK_SIZE=BLOCK_SIZE,
EPS=EPS,
)

return Float8Tensor(
tensor_out,
fp8_output.reshape(orig_shape),
scale_out,
orig_dtype=hp_tensor.dtype,
linear_mm_config=linear_mm_config,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,17 @@ def test_fp8_triton_hp_tensor_to_float8_dynamic():
LinearMMConfig(),
)

def allclose_fp8(tensor1, tensor2, atol=1e-3, rtol=1e-3):
if tensor1.shape != tensor2.shape:
raise ValueError("Tensors must have the same shape for comparison.")
if tensor1.dtype != tensor2.dtype:
raise ValueError("Tensors must have the same dtype for comparison.")

# convert fp8 tensors to a higher precision (e.g., float32) for comparison
# since fp8 ops necessary for allclose are not supported
tensor1_fp32 = tensor1.to(torch.float32)
tensor2_fp32 = tensor2.to(torch.float32)
return torch.allclose(tensor1_fp32, tensor2_fp32, atol=atol, rtol=rtol)

assert torch.allclose(x_fp8._scale, y_fp8._scale, atol=1e-3, rtol=1e-3)
assert torch.allclose(x_fp8._data, y_fp8._data, atol=1e-3, rtol=1e-3)
assert allclose_fp8(x_fp8._data, y_fp8._data, atol=1e-3, rtol=1e-3)

0 comments on commit 5002a80

Please sign in to comment.