Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add fused transpose and non-transpose kernel and use it for grad output #1497

Merged
merged 33 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
25 changes: 10 additions & 15 deletions torchao/prototype/float8nocompile/float8nocompile_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
ToFP8ColumnMajor,
ToFP8ColumnMajorT,
ToFP8RowAndColumnMajor,
ToFP8RowMajor,
ToFP8RowMajorT,
ToFP8RowMajorTAndNonT,
)
from torchao.prototype.float8nocompile.kernels.fp8_dynamic_tensorwise import (
KernelAlgorithm,
Expand Down Expand Up @@ -138,12 +137,14 @@ def backward(ctx, grad_output):
input_fp8_col_major, weight_hp = ctx.saved_tensors

# cast grad output to float8_e5m2 for backward
grad_output_fp8_row_major = ToFP8RowMajor.apply(
grad_output,
ctx.config.cast_config_grad_output.target_dtype,
ctx.linear_mm_config,
GemmInputRole.GRAD_OUTPUT,
ctx.kernel_algo,
grad_output_fp8_row_major, grad_output_t_row_major = (
ToFP8RowMajorTAndNonT.apply(
grad_output,
ctx.config.cast_config_grad_output.target_dtype,
ctx.linear_mm_config,
GemmInputRole.GRAD_OUTPUT,
ctx.kernel_algo,
)
)

# grad_input = grad_output @ weight
Expand All @@ -159,12 +160,6 @@ def backward(ctx, grad_output):
# grad_weight = grad_output_t @ input
# apparently this variant is slightly faster than `grad_weight_t = input_t @ grad_output`
# source: https://github.com/pytorch/ao/blob/fe5f11b2c58b452e01ba9ec7359629928b143619/torchao/float8/float8_linear.py#L84-L85
grad_output_t_row_major = ToFP8RowMajorT.apply(
grad_output,
ctx.config.cast_config_grad_output.target_dtype,
ctx.linear_mm_config,
GemmInputRole.GRAD_OUTPUT,
ctx.kernel_algo,
)
grad_weight = torch.mm(grad_output_t_row_major, input_fp8_col_major)

return grad_input, grad_weight, None, None, None
36 changes: 32 additions & 4 deletions torchao/prototype/float8nocompile/float8nocompile_scaling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,15 @@

import torch

from torchao.float8.float8_tensor import (
GemmInputRole,
LinearMMConfig,
)
from torchao.float8.float8_tensor import GemmInputRole, LinearMMConfig
from torchao.prototype.float8nocompile.kernels.fp8_dynamic_tensorwise import (
KernelAlgorithm,
hp_to_fp8_col_major,
hp_to_fp8_col_major_t,
hp_to_fp8_row_and_col_major,
hp_to_fp8_row_major,
hp_to_fp8_row_major_t,
hp_to_fp8_row_major_t_and_non_t,
)


Expand Down Expand Up @@ -172,3 +170,33 @@ def forward(
@staticmethod
def backward(ctx, g):
return g, None, None, None, None


class ToFP8RowMajorTAndNonT(torch.autograd.Function):
"""
A differentiable conversion to fp8.
* forward: convert from high precision to float8 and produces both row-major (transposed) and row-major (non-transposed) outputs
* backward: pass the gradient without changes
"""

@staticmethod
def forward(
ctx,
tensor: torch.Tensor,
float8_dtype: torch.dtype,
linear_mm_config: LinearMMConfig,
gemm_input_role: GemmInputRole,
kernel_algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_MAX,
):
fp8_row_major, fp8_row_major_t = hp_to_fp8_row_major_t_and_non_t(
tensor,
float8_dtype,
linear_mm_config,
gemm_input_role,
algo=kernel_algo,
)
return fp8_row_major, fp8_row_major_t

@staticmethod
def backward(ctx, g):
return g, None, None, None, None
Empty file.
158 changes: 158 additions & 0 deletions torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,82 @@ def _to_fp8_row_and_col_major(
tl.store(col_major_out_ptr + col_major_offs, fp8_vals, mask=mask)


@triton.autotune(
configs=kernel_configs_2D,
key=["num_elements"],
)
@triton.jit
def _to_fp8_row_major_t_and_non_t(
input_ptr,
row_major_out_ptr,
row_major_t_out_ptr,
scale_ptr,
num_elements: int,
fp8_dtype_min: float,
fp8_dtype_max: float,
input_num_rows: int,
input_num_cols: int,
input_stride_row: int,
input_stride_col: int,
row_major_out_stride_row: int,
row_major_out_stride_col: int,
row_major_t_out_stride_row: int,
row_major_t_out_stride_col: int,
input_dtype: tl.constexpr,
output_dtype: tl.constexpr,
BLOCK_SIZE_ROWS: tl.constexpr,
BLOCK_SIZE_COLS: tl.constexpr,
EPS: tl.constexpr,
):
"""
Reads a row-major, high precision input tensor and writes 2 output tensors:
1) fp8 row major tensor (transposed)
2) fp8 row major tensor
"""
block_row_id = tl.program_id(axis=0)
block_col_id = tl.program_id(axis=1)

# load scaling factor
scale = tl.load(scale_ptr).to(tl.float32)

# load block of input tensor
block_row_start = block_row_id * BLOCK_SIZE_ROWS
block_col_start = block_col_id * BLOCK_SIZE_COLS
block_row_offs = block_row_start + tl.arange(0, BLOCK_SIZE_ROWS)
block_col_offs = block_col_start + tl.arange(0, BLOCK_SIZE_COLS)
input_offs = (
block_row_offs[:, None] * input_stride_row
+ block_col_offs[None, :] * input_stride_col
)
mask = (block_row_offs[:, None] < input_num_rows) & (
block_col_offs[None, :] < input_num_cols
)
vals = tl.load(input_ptr + input_offs, mask=mask).to(input_dtype)

# perform conversion
vals = vals * scale
fp8_vals = tl.clamp(vals, min=fp8_dtype_min, max=fp8_dtype_max).to(output_dtype)

# write row-major output
row_major_offs = (
block_row_offs[:, None] * row_major_out_stride_row
+ block_col_offs[None, :] * row_major_out_stride_col
)
tl.store(row_major_out_ptr + row_major_offs, fp8_vals, mask=mask)

# write tranposed row-major output
row_major_t_num_rows = input_num_cols
row_major_t_num_cols = input_num_rows
row_major_t_offs = (
block_col_offs[:, None] * row_major_t_out_stride_row
+ block_row_offs[None, :] * row_major_t_out_stride_col
)
mask = (block_row_offs[:, None] < row_major_t_num_rows) & (
block_col_offs[None, :] < row_major_t_num_cols
)
tl.store(row_major_t_out_ptr + row_major_t_offs, fp8_vals.trans(1, 0), mask=mask)


@triton.autotune(configs=kernel_configs_1D, key=["num_elements"])
@triton.jit
def _amax_atomic(
Expand Down Expand Up @@ -701,6 +777,88 @@ def hp_to_fp8_row_and_col_major(
return fp8_tensor_row_major, fp8_tensor_col_major


def hp_to_fp8_row_major_t_and_non_t(
hp_tensor: torch.Tensor,
fp8_dtype: torch.dtype,
linear_mm_config: LinearMMConfig,
gemm_input_role: GemmInputRole = GemmInputRole.INPUT,
algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_MAX,
) -> Float8Tensor:
assert hp_tensor.is_contiguous(), "input tensor must be contiguous"

tl_input_dtype = FP8_DTYPE_MAP[hp_tensor.dtype]
tl_output_dtype = FP8_DTYPE_MAP[fp8_dtype]

fp8_dtype_min = torch.finfo(fp8_dtype).min
fp8_dtype_max = torch.finfo(fp8_dtype).max

# compute scaling factor for tensor
scale = _hp_tensor_to_scale(
hp_tensor,
tl_input_dtype,
fp8_dtype_max,
algo,
)

# perform fp8 conversion
input_num_rows, input_num_cols = hp_tensor.shape
transposed_num_rows, transposed_num_cols = input_num_cols, input_num_rows
num_elements = hp_tensor.numel()

# preallocate necessary output tensors
fp8_output_row_major = torch.empty(
(input_num_rows, input_num_cols), dtype=fp8_dtype, device=hp_tensor.device
)
fp8_output_row_major_t = torch.empty(
(transposed_num_rows, transposed_num_cols),
dtype=fp8_dtype,
device=hp_tensor.device,
)

# launch triton kernel to perform conversion
grid = lambda meta: (
triton.cdiv(input_num_rows, meta["BLOCK_SIZE_ROWS"]),
triton.cdiv(input_num_cols, meta["BLOCK_SIZE_COLS"]),
)
_to_fp8_row_major_t_and_non_t[grid](
hp_tensor,
fp8_output_row_major,
fp8_output_row_major_t,
scale,
num_elements,
fp8_dtype_min,
fp8_dtype_max,
input_num_rows,
input_num_cols,
hp_tensor.stride(0),
hp_tensor.stride(1),
fp8_output_row_major.stride(0),
fp8_output_row_major.stride(1),
fp8_output_row_major_t.stride(0),
fp8_output_row_major_t.stride(1),
input_dtype=tl_input_dtype,
output_dtype=tl_output_dtype,
EPS=EPS,
)

# wrap outputs in Float8Tensors
fp8_tensor_row_major = Float8Tensor(
fp8_output_row_major,
scale,
orig_dtype=hp_tensor.dtype,
linear_mm_config=linear_mm_config,
gemm_input_role=gemm_input_role,
)
fp8_tensor_row_major_t = Float8Tensor(
fp8_output_row_major_t,
scale,
orig_dtype=hp_tensor.dtype,
linear_mm_config=linear_mm_config,
gemm_input_role=gemm_input_role,
)
return fp8_tensor_row_major, fp8_tensor_row_major_t


def _hp_tensor_to_scale(
hp_tensor: torch.Tensor,
tl_input_dtype: tl.core.dtype,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
hp_to_fp8_row_and_col_major,
hp_to_fp8_row_major,
hp_to_fp8_row_major_t,
hp_to_fp8_row_major_t_and_non_t,
)


Expand Down Expand Up @@ -335,3 +336,77 @@ def test_fp8_hp_to_fp8_row_and_col_major(
torch.float8_e4m3fn,
LinearMMConfig(),
)


@pytest.mark.parametrize(
"algo",
[KernelAlgorithm.REDUCTION, KernelAlgorithm.ATOMIC_MAX],
)
@pytest.mark.parametrize(
"input_shape",
[(2, 4), (32, 16), (512, 512)],
)
def test_fp8_hp_to_fp8_row_major_t_and_non_t(
input_shape: tuple[int, int], algo: KernelAlgorithm
):
assert torch.cuda.is_available()
device = "cuda"
input_bf16 = torch.randn(input_shape, dtype=torch.bfloat16, device=device)
x_bf16 = input_bf16.clone().detach().to(device)
y_bf16 = input_bf16.clone().detach().to(device)

# production implementation
x_fp8_row_major = hp_tensor_to_float8_dynamic(
x_bf16,
torch.float8_e4m3fn,
LinearMMConfig(),
)
x_fp8_row_major_t = x_fp8_row_major.t().contiguous()

# float8nocompile triton implementation
y_fp8_row_major, y_fp8_row_major_t = hp_to_fp8_row_major_t_and_non_t(
y_bf16,
torch.float8_e4m3fn,
LinearMMConfig(),
algo=algo,
)

# check scales
assert torch.eq(x_fp8_row_major._scale, y_fp8_row_major._scale)
assert torch.eq(x_fp8_row_major_t._scale, y_fp8_row_major_t._scale)

# check data
assert torch.all(torch.eq(x_fp8_row_major._data, y_fp8_row_major._data))
assert torch.all(torch.eq(x_fp8_row_major_t._data, y_fp8_row_major_t._data))

# check shapes
assert x_fp8_row_major.shape == y_fp8_row_major.shape
assert x_fp8_row_major_t.shape == y_fp8_row_major_t.shape

# check strides
assert x_fp8_row_major.stride() == y_fp8_row_major.stride()
assert x_fp8_row_major_t.stride() == y_fp8_row_major_t.stride()

# check memory layout
assert is_row_major(x_fp8_row_major.stride())
assert is_row_major(y_fp8_row_major.stride())
assert is_row_major(x_fp8_row_major_t.stride())
assert is_row_major(y_fp8_row_major_t.stride())

# check underlying memory layout
assert (
x_fp8_row_major._data.storage().tolist()
== y_fp8_row_major._data.storage().tolist()
)
assert (
x_fp8_row_major_t._data.storage().tolist()
== y_fp8_row_major_t._data.storage().tolist()
)

# assert that error is raised when input tensor is not contiguous
with pytest.raises(AssertionError, match="tensor must be contiguous"):
hp_to_fp8_row_major_t_and_non_t(
y_bf16.t(), # transpose so tensor memory layout is no longer contiguous
torch.float8_e4m3fn,
LinearMMConfig(),
)
Loading