Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[float8nocompile] Add alternate Triton kernels for FP8 conversion which use atomic_max-based algo instead of reduction-based algo #1455

Merged
merged 3 commits into from
Dec 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions torchao/prototype/float8nocompile/benchmark/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
def get_configs() -> List[ExperimentConfig]:
layer_sizes = [[4096, 4096]]
input_shapes = [(2**4, 4096), (2**8, 4096), (2**12, 4096), (2**16, 4096)]
high_precision_dtypes = [torch.float32, torch.bfloat16]
high_precision_dtypes = [torch.bfloat16]
configs = []
for layer_size, input_shape, high_precision_dtype in itertools.product(
layer_sizes, input_shapes, high_precision_dtypes
Expand Down Expand Up @@ -133,18 +133,20 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult:

def print_results(experiments: List[Experiment]):
headers = [
"input_size",
"input_shape",
"high_precision_dtype",
"eager_time",
"compiled_time",
"float8nocompile",
]
rows = []
for experiment in experiments:
input_size = experiment.config.input_shape[0] * experiment.config.input_shape[1]
input_shape = (
f"({experiment.config.input_shape[0]}, {experiment.config.input_shape[1]})"
)
rows.append(
[
f"{input_size:.2e}",
input_shape,
experiment.config.high_precision_dtype,
experiment.result.eager_time,
experiment.result.compiled_time,
Expand Down
212 changes: 170 additions & 42 deletions torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
"""
Triton kernels for scaling high precision tensors to float8.
"""
from enum import Enum

import torch

import triton
import triton.language as tl

Expand All @@ -31,8 +31,99 @@
}


class KernelAlgorithm(Enum):
"""Enum for FP8 conversion strategy."""

# use atomic max to compute global amax between blocks
ATOMIC_MAX = "atomic_max"

# reduce shared buffer containing local block amaxes to find global amax
REDUCTION = "reduction"


kernel_configs = [
triton.Config({"BLOCK_SIZE": 128}, num_warps=1),
triton.Config({"BLOCK_SIZE": 256}, num_warps=2),
triton.Config({"BLOCK_SIZE": 512}, num_warps=4),
]


# --- atomic max version of kernel ---
@triton.autotune(configs=kernel_configs, key=["input_size"])
@triton.jit
def _block_amax_atomic(
input_ptr,
amax_ptr,
num_elements,
input_dtype: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
EPS: tl.constexpr,
):
# compute local amax for each block
block_id = tl.program_id(axis=0)
block_start = block_id * BLOCK_SIZE
block_offs = block_start + tl.arange(0, BLOCK_SIZE)
block_mask = block_offs < num_elements
vals = tl.load(input_ptr + block_offs, mask=block_mask).to(input_dtype)
block_amax = tl.max(tl.abs(vals))
tl.atomic_max(amax_ptr, block_amax)


@triton.jit
def _fp8_scale_atomic(
amax_ptr,
scale_out_ptr,
fp8_dtype_max,
EPS: tl.constexpr,
):
# load previously computed global amax
global_amax = tl.load(amax_ptr)

# compute scale, must be fp32
scale = (fp8_dtype_max / tl.clamp(global_amax, min=EPS, max=float("inf"))).to(
tl.float32
)

# store scale for use in Float8Tensor constructor
scale_off = tl.arange(0, 1)
tl.store(scale_out_ptr + scale_off, scale)


@triton.autotune(configs=kernel_configs, key=["input_size"])
@triton.jit
def _block_amax(
def _to_fp8_atomic(
input_ptr,
scale_ptr,
amax_ptr,
out_ptr,
num_elements,
fp8_dtype_min,
fp8_dtype_max,
input_dtype: tl.constexpr,
output_dtype: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
EPS: tl.constexpr,
):
block_id = tl.program_id(axis=0)

# load scale
scale = tl.load(scale_ptr)

# load block of input tensor
block_start = block_id * BLOCK_SIZE
block_offs = block_start + tl.arange(0, BLOCK_SIZE)
mask = block_offs < num_elements
vals = tl.load(input_ptr + block_offs, mask=mask).to(input_dtype)

# perform conversion
vals = vals * scale
fp8_vals = tl.clamp(vals, min=fp8_dtype_min, max=fp8_dtype_max).to(output_dtype)
tl.store(out_ptr + block_offs, fp8_vals, mask=mask)


# --- reduction version of kernel ---
@triton.jit
def _block_amax_reduction(
input_ptr,
block_amaxes_ptr,
num_elements,
Expand All @@ -46,12 +137,12 @@ def _block_amax(
block_offs = block_start + tl.arange(0, BLOCK_SIZE)
block_mask = block_offs < num_elements
vals = tl.load(input_ptr + block_offs, mask=block_mask).to(input_dtype)
block_amax = tl.max(tl.abs(vals), axis=0)
block_amax = tl.max(tl.abs(vals))
tl.store(block_amaxes_ptr + block_id, block_amax)


@triton.jit
def _fp8_scale(
def _fp8_scale_reduction(
block_amaxes_ptr,
scale_out_ptr,
num_elements,
Expand All @@ -75,7 +166,7 @@ def _fp8_scale(


@triton.jit
def _to_fp8(
def _to_fp8_reduction(
input_ptr,
scale_ptr,
out_ptr,
Expand Down Expand Up @@ -108,12 +199,10 @@ def triton_hp_tensor_to_float8_dynamic(
fp8_dtype: torch.dtype,
linear_mm_config: LinearMMConfig,
gemm_input_role: GemmInputRole = GemmInputRole.INPUT,
algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_MAX,
) -> Float8Tensor:

assert hp_tensor.is_contiguous(), "tensor must be contiguous"

BLOCK_SIZE = 8 # TODO(danielvegamyhre): tune this for perf

num_elements = hp_tensor.numel()
orig_shape = hp_tensor.shape
flattened_input = hp_tensor.flatten()
Expand All @@ -126,47 +215,86 @@ def triton_hp_tensor_to_float8_dynamic(

# allocate memory for computed scale, local block maxes, and output fp8 tensor
scale_out = torch.empty((1,), dtype=torch.float32, device=hp_tensor.device)
block_amaxes = torch.zeros(
(num_elements // BLOCK_SIZE,), dtype=torch.float32, device=hp_tensor.device
)

fp8_output = torch.empty_like(
flattened_input, dtype=fp8_dtype, device=hp_tensor.device
)

# compute local amax for each block
grid = lambda meta: (triton.cdiv(num_elements, meta["BLOCK_SIZE"]),)
_block_amax[grid](
flattened_input,
block_amaxes,
num_elements,
input_dtype=tl_input_dtype,
BLOCK_SIZE=BLOCK_SIZE,
EPS=EPS,
)

# calculate global amax across all blocks and use it to compute scale
_fp8_scale[(1, 1, 1)](
block_amaxes,
scale_out,
num_elements,
fp8_dtype_max,
BLOCK_SIZE=BLOCK_SIZE,
EPS=EPS,
)
if algo == KernelAlgorithm.ATOMIC_MAX:
global_amax = torch.zeros((1,), dtype=torch.float32, device=hp_tensor.device)
# compute global amax to be used for scaling
_block_amax_atomic[grid](
flattened_input,
global_amax,
num_elements,
input_dtype=tl_input_dtype,
EPS=EPS,
)

# perform conversion
_to_fp8[grid](
flattened_input,
scale_out,
fp8_output,
num_elements,
fp8_dtype_min,
fp8_dtype_max,
input_dtype=tl_input_dtype,
output_dtype=tl_output_dtype,
BLOCK_SIZE=BLOCK_SIZE,
EPS=EPS,
)
# compute scale for fp8 conversion
_fp8_scale_atomic[1, 1, 1](
global_amax,
scale_out,
fp8_dtype_max,
EPS=EPS,
)

# perform conversion and store scale for use in Float8Tensor
_to_fp8_atomic[grid](
flattened_input,
scale_out,
global_amax,
fp8_output,
num_elements,
fp8_dtype_min,
fp8_dtype_max,
input_dtype=tl_input_dtype,
output_dtype=tl_output_dtype,
EPS=EPS,
)
elif algo == KernelAlgorithm.REDUCTION:
max_block_size = 512
BLOCK_SIZE = min(max_block_size, num_elements)
block_amaxes = torch.zeros(
(num_elements // BLOCK_SIZE,), dtype=torch.float32, device=hp_tensor.device
)
# compute local amax for each block
_block_amax_reduction[grid](
flattened_input,
block_amaxes,
num_elements,
input_dtype=tl_input_dtype,
BLOCK_SIZE=BLOCK_SIZE,
EPS=EPS,
)

# calculate global amax across all blocks and use it to compute scale
_fp8_scale_reduction[(1, 1, 1)](
block_amaxes,
scale_out,
num_elements,
fp8_dtype_max,
BLOCK_SIZE=BLOCK_SIZE,
EPS=EPS,
)

# perform conversion
_to_fp8_reduction[grid](
flattened_input,
scale_out,
fp8_output,
num_elements,
fp8_dtype_min,
fp8_dtype_max,
input_dtype=tl_input_dtype,
output_dtype=tl_output_dtype,
BLOCK_SIZE=BLOCK_SIZE,
EPS=EPS,
)
else:
raise ValueError(f"Unsupported kernel algorithm: {algo}")

return Float8Tensor(
fp8_output.reshape(orig_shape),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,24 @@
from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic
from torchao.float8.float8_tensor import LinearMMConfig
from torchao.prototype.float8nocompile.kernels.fp8_dynamic_tensorwise import (
KernelAlgorithm,
triton_hp_tensor_to_float8_dynamic,
)


def test_fp8_triton_hp_tensor_to_float8_dynamic():
@pytest.mark.parametrize(
"algo", [KernelAlgorithm.ATOMIC_MAX, KernelAlgorithm.REDUCTION]
)
@pytest.mark.parametrize(
"input_shape",
[(32, 32), (512, 512), (4096, 4096)],
)
def test_fp8_triton_hp_tensor_to_float8_dynamic(
algo: KernelAlgorithm, input_shape: tuple[int, int]
):
assert torch.cuda.is_available()
device = "cuda"
input_bf16 = torch.randn((4, 4), dtype=torch.bfloat16, device=device)
input_bf16 = torch.randn(input_shape, dtype=torch.bfloat16, device=device)
x_bf16 = input_bf16.clone().detach().to(device)
y_bf16 = input_bf16.clone().detach().to(device)

Expand All @@ -26,6 +36,7 @@ def test_fp8_triton_hp_tensor_to_float8_dynamic():
y_bf16,
torch.float8_e4m3fn,
LinearMMConfig(),
algo=algo,
)

def allclose_fp8(tensor1, tensor2, atol=1e-3, rtol=1e-3):
Expand Down
Loading