Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix bug in tl.store mask for kernel _to_fp8_row_major_t_and_non_t #1516

Merged
merged 59 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from 58 commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
f85618f
Update
danielvegamyhre Jan 3, 2025
c603139
Update
danielvegamyhre Jan 3, 2025
9b42e69
Update
danielvegamyhre Jan 3, 2025
fc301fd
Update
danielvegamyhre Jan 3, 2025
a69fc66
Update
danielvegamyhre Jan 3, 2025
1d2ee55
Update
danielvegamyhre Jan 3, 2025
5870160
Update
danielvegamyhre Jan 3, 2025
36d8d17
Update
danielvegamyhre Jan 3, 2025
7e526fd
Update
danielvegamyhre Jan 7, 2025
2bbdf88
Update
danielvegamyhre Jan 7, 2025
58be437
Update
danielvegamyhre Jan 7, 2025
25298cb
Update
danielvegamyhre Jan 7, 2025
89c6b53
Update
danielvegamyhre Jan 7, 2025
0808acf
Update
danielvegamyhre Jan 7, 2025
3cc35df
Update
danielvegamyhre Jan 7, 2025
ff6dad0
Update
danielvegamyhre Jan 7, 2025
ddf1efc
Update
danielvegamyhre Jan 7, 2025
0536cb8
Update
danielvegamyhre Jan 7, 2025
10830d8
Update
danielvegamyhre Jan 7, 2025
5a47687
Update
danielvegamyhre Jan 7, 2025
8d52227
Update
danielvegamyhre Jan 7, 2025
f485529
Update
danielvegamyhre Jan 7, 2025
ad6d97b
Update
danielvegamyhre Jan 7, 2025
a6becf8
Update
danielvegamyhre Jan 7, 2025
5714e99
Update
danielvegamyhre Jan 7, 2025
7ccdd26
Update
danielvegamyhre Jan 7, 2025
6c97b63
Update
danielvegamyhre Jan 7, 2025
1cb1fec
Update
danielvegamyhre Jan 7, 2025
a8a8f3c
Update
danielvegamyhre Jan 7, 2025
23266fb
Update
danielvegamyhre Jan 7, 2025
99fab5a
Update
danielvegamyhre Jan 7, 2025
879d61f
Update
danielvegamyhre Jan 7, 2025
8860f93
Update
danielvegamyhre Jan 7, 2025
d8b2451
Update
danielvegamyhre Jan 7, 2025
4aadedf
Update
danielvegamyhre Jan 8, 2025
1e9a150
Update
danielvegamyhre Jan 8, 2025
6db778a
Update
danielvegamyhre Jan 8, 2025
f585e44
Update
danielvegamyhre Jan 8, 2025
e11918d
Update
danielvegamyhre Jan 8, 2025
f65e981
Update
danielvegamyhre Jan 8, 2025
c0da780
Update
danielvegamyhre Jan 8, 2025
754c6bf
Update
danielvegamyhre Jan 8, 2025
49373f1
Update
danielvegamyhre Jan 8, 2025
e459d25
Update
danielvegamyhre Jan 8, 2025
ff6b91e
Update
danielvegamyhre Jan 8, 2025
3eb406f
Update
danielvegamyhre Jan 8, 2025
c78a574
Update
danielvegamyhre Jan 8, 2025
e5c69e7
Update
danielvegamyhre Jan 8, 2025
01eedbf
Update
danielvegamyhre Jan 8, 2025
74286fe
Update
danielvegamyhre Jan 8, 2025
a356ac5
Update
danielvegamyhre Jan 8, 2025
7ee060a
Update
danielvegamyhre Jan 8, 2025
84cc74b
Update
danielvegamyhre Jan 8, 2025
2600ee4
Update
danielvegamyhre Jan 8, 2025
d5666b2
Update
danielvegamyhre Jan 8, 2025
7a44bd9
Update
danielvegamyhre Jan 8, 2025
2e13197
Update
danielvegamyhre Jan 8, 2025
e665139
Update
danielvegamyhre Jan 8, 2025
96ee5ee
Update
danielvegamyhre Jan 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 30 additions & 3 deletions torchao/prototype/float8nocompile/float8nocompile_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,20 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
return output

@classmethod
def from_float(cls, mod, kernel_algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_MAX):
def from_float(
cls,
mod,
config: Float8LinearConfig, # only default config is supported, non-defaults silently ignored
kernel_algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_MAX,
):
"""
Create an nn.Linear with fp8 compute from a regular nn.Linear

Args:
mod (torch.nn.Linear): nn.Linear to convert
config (Optional[Float8LinearConfig]): configuration for conversion to float8
config (Optional[Float8LinearConfig]): configuration for conversion to float8 (note: only
default config is supported, non-defaults silently ignored)
"""
config = Float8LinearConfig()
with torch.device("meta"):
new_mod = cls(
mod.in_features,
Expand All @@ -107,6 +112,10 @@ def from_float(cls, mod, kernel_algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_M
class matmul_with_args_in_hp(torch.autograd.Function):
@staticmethod
def forward(ctx, input_hp, weight_hp, config, linear_mm_config, kernel_algo):
# reshape to be 2D for triton kernels
orig_input_shape = input_hp.shape
input_hp = input_hp.reshape(-1, input_hp.shape[-1])

# output = input @ weight_t
input_fp8_row_major, input_fp8_col_major = ToFP8RowAndColumnMajor.apply(
input_hp,
Expand All @@ -130,12 +139,24 @@ def forward(ctx, input_hp, weight_hp, config, linear_mm_config, kernel_algo):
ctx.linear_mm_config = linear_mm_config
ctx.kernel_algo = kernel_algo

# reshape back to expected dims
output = output.reshape(*orig_input_shape[:-1], output.shape[-1])
return output

@staticmethod
def backward(ctx, grad_output):
# grad_output may not be contiguous in cases like:
# output.sum().backward() where grad is all 1s, so the (M,N) view of the scalar "1"
# results in a non-contiguous tensor with stride (0,0).
if not grad_output.is_contiguous():
grad_output = grad_output.contiguous()

input_fp8_col_major, weight_hp = ctx.saved_tensors

# reshsape to be 2D for triton kernels
orig_grad_output_shape = grad_output.shape
grad_output = grad_output.reshape(-1, grad_output.shape[-1])

# cast grad output to float8_e5m2 for backward
grad_output_fp8_row_major, grad_output_t_row_major = (
ToFP8RowMajorTAndNonT.apply(
Expand All @@ -162,4 +183,10 @@ def backward(ctx, grad_output):
# source: https://github.com/pytorch/ao/blob/fe5f11b2c58b452e01ba9ec7359629928b143619/torchao/float8/float8_linear.py#L84-L85
grad_weight = torch.mm(grad_output_t_row_major, input_fp8_col_major)

# reshape grad input to match original shape
grad_input = grad_input.reshape(
*orig_grad_output_shape[:-1], grad_input.shape[-1]
)

# grad input shape
return grad_input, grad_weight, None, None, None
97 changes: 97 additions & 0 deletions torchao/prototype/float8nocompile/float8nocompile_linear_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import pytest
import torch

from torchao.float8.config import Float8LinearConfig
from torchao.float8.float8_linear import manual_float8_matmul_with_args_in_hp
from torchao.float8.float8_tensor import LinearMMConfig, ScaledMMConfig
from torchao.prototype.float8nocompile.float8nocompile_linear import (
matmul_with_args_in_hp,
)
from torchao.prototype.float8nocompile.float8nocompile_scaling_utils import (
KernelAlgorithm,
)


# unit test comparing the two implementations
@pytest.mark.parametrize(
"input_shape",
[(32, 16), (1, 32, 16), (2, 32, 16)],
)
def test_matmul_with_args_in_hp(input_shape: tuple[int, int]):
assert torch.cuda.is_available()
device = "cuda"

# high precision inputs
input_bf16 = torch.randn(
input_shape, dtype=torch.bfloat16, device=device, requires_grad=True
)
prod_input_bf16 = input_bf16.clone().detach().to(device).requires_grad_(True)
prototype_input_bf16 = input_bf16.clone().detach().to(device).requires_grad_(True)

# high precision weights
# nn.Linear stores weights in transposed form
weight_bf16 = torch.randn(
(32, input_bf16.shape[-1]),
dtype=torch.bfloat16,
device=device,
requires_grad=True,
)
prod_weight_bf16 = weight_bf16.clone().detach().to(device).requires_grad_(True)
prototype_weight_bf16 = weight_bf16.clone().detach().to(device).requires_grad_(True)

# default configs
config = Float8LinearConfig()
emulate = False
linear_mm_config = linear_mm_config = LinearMMConfig(
# output
ScaledMMConfig(
emulate,
config.gemm_config_output.use_fast_accum,
False,
config.pad_inner_dim,
),
# grad_input
ScaledMMConfig(
emulate,
config.gemm_config_grad_input.use_fast_accum,
False,
config.pad_inner_dim,
),
# grad_weight
ScaledMMConfig(
emulate,
config.gemm_config_grad_weight.use_fast_accum,
False,
config.pad_inner_dim,
),
)

# prod forward. expects transposed weight.
out_prod = manual_float8_matmul_with_args_in_hp.apply(
prod_input_bf16, prod_weight_bf16.t(), linear_mm_config, config
)

# prototype forward. expects non-transposed weight
out_prototype = matmul_with_args_in_hp.apply(
prototype_input_bf16,
prototype_weight_bf16,
config,
linear_mm_config,
KernelAlgorithm.ATOMIC_MAX,
)

# compare model outputs
assert torch.allclose(out_prod, out_prototype, atol=0, rtol=0)

out_prod.sum().backward()
out_prototype.sum().backward()

# compare input gradients
assert torch.allclose(
prod_input_bf16.grad, prototype_input_bf16.grad, atol=0, rtol=0
)

# compare weight gradients
assert torch.allclose(
prod_weight_bf16.grad, prototype_weight_bf16.grad, atol=0, rtol=0
)
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import torch.nn as nn

from torchao.float8.config import Float8LinearConfig
from torchao.float8.float8_linear_utils import swap_linear_layers
from torchao.prototype.float8nocompile.float8nocompile_linear import (
Float8LinearNoCompile,
Expand All @@ -23,6 +24,7 @@
def convert_to_float8_nocompile_training(
module: nn.Module,
*,
config: Float8LinearConfig = None,
module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None,
kernel_algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_MAX,
) -> nn.Module:
Expand All @@ -39,7 +41,12 @@ def convert_to_float8_nocompile_training(
Returns:
nn.Module: The modified module with swapped linear layers.
"""
from_float = lambda m: Float8LinearNoCompile.from_float(m, kernel_algo=kernel_algo)
if config is None:
config = Float8LinearConfig()

from_float = lambda m: Float8LinearNoCompile.from_float(
m, config=config, kernel_algo=kernel_algo
)
return swap_linear_layers(
module,
from_float,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -375,8 +375,8 @@ def _to_fp8_row_major_t_and_non_t(
block_col_offs[:, None] * row_major_t_out_stride_row
+ block_row_offs[None, :] * row_major_t_out_stride_col
)
mask = (block_row_offs[:, None] < row_major_t_num_rows) & (
block_col_offs[None, :] < row_major_t_num_cols
mask = (block_col_offs[:, None] < row_major_t_num_rows) & (
block_row_offs[None, :] < row_major_t_num_cols
)
tl.store(row_major_t_out_ptr + row_major_t_offs, fp8_vals.trans(1, 0), mask=mask)

Expand Down
7 changes: 5 additions & 2 deletions torchao/prototype/float8nocompile/test/train_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ def model2():
return TestModel()


def test_model_weights_and_gradients(model1, model2):
@pytest.mark.parametrize(
"input_shape", [(16, 32), (1, 16, 32), (2, 16, 32), (128, 8192, 32)]
)
def test_model_weights_and_gradients(model1, model2, input_shape: tuple[int, int]):
assert torch.cuda.is_available()
device = torch.device("cuda")

Expand All @@ -48,7 +51,7 @@ def test_model_weights_and_gradients(model1, model2):
convert_to_float8_nocompile_training(model1)

input_tensor = torch.randn(
16, 32, requires_grad=True, dtype=torch.bfloat16, device=device
*input_shape, requires_grad=True, dtype=torch.bfloat16, device=device
)
input_copy1 = input_tensor.clone().detach().requires_grad_(True)
input_copy2 = input_tensor.clone().detach().requires_grad_(True)
Expand Down
Loading