Skip to content

Commit

Permalink
add handling for batch dim in float8nocompile
Browse files Browse the repository at this point in the history
ghstack-source-id: 47996349f565c237f81124afe8b500d221d1448b
ghstack-comment-id: 2574237839
Pull Request resolved: #1512
  • Loading branch information
danielvegamyhre committed Jan 8, 2025
1 parent 070345d commit 367aea3
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 5 deletions.
25 changes: 23 additions & 2 deletions torchao/prototype/float8nocompile/float8nocompile_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,14 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
return output

@classmethod
def from_float(cls, mod, kernel_algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_MAX):
def from_float(cls, mod, config: Float8LinearConfig, kernel_algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_MAX):
"""
Create an nn.Linear with fp8 compute from a regular nn.Linear
Args:
mod (torch.nn.Linear): nn.Linear to convert
config (Optional[Float8LinearConfig]): configuration for conversion to float8
"""
config = Float8LinearConfig()
with torch.device("meta"):
new_mod = cls(
mod.in_features,
Expand All @@ -107,6 +106,10 @@ def from_float(cls, mod, kernel_algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_M
class matmul_with_args_in_hp(torch.autograd.Function):
@staticmethod
def forward(ctx, input_hp, weight_hp, config, linear_mm_config, kernel_algo):
# reshape to be 2D for triton kernels
orig_input_shape = input_hp.shape
input_hp = input_hp.reshape(-1, input_hp.shape[-1])

# output = input @ weight_t
input_fp8_row_major, input_fp8_col_major = ToFP8RowAndColumnMajor.apply(
input_hp,
Expand All @@ -130,12 +133,24 @@ def forward(ctx, input_hp, weight_hp, config, linear_mm_config, kernel_algo):
ctx.linear_mm_config = linear_mm_config
ctx.kernel_algo = kernel_algo

# reshape back to expected dims
output = output.reshape(*orig_input_shape[:-1], output.shape[-1])
return output

@staticmethod
def backward(ctx, grad_output):
# grad_output may not be contiguous in cases like:
# output.sum().backward() where grad is all 1s, so the (M,N) view of the scalar "1"
# results in a non-contiguous tensor with stride (0,0).
if not grad_output.is_contiguous():
grad_output = grad_output.contiguous()

input_fp8_col_major, weight_hp = ctx.saved_tensors

# reshsape to be 2D for triton kernels
orig_grad_output_shape = grad_output.shape
grad_output = grad_output.reshape(-1, grad_output.shape[-1])

# cast grad output to float8_e5m2 for backward
grad_output_fp8_row_major, grad_output_t_row_major = (
ToFP8RowMajorTAndNonT.apply(
Expand All @@ -162,4 +177,10 @@ def backward(ctx, grad_output):
# source: https://github.com/pytorch/ao/blob/fe5f11b2c58b452e01ba9ec7359629928b143619/torchao/float8/float8_linear.py#L84-L85
grad_weight = torch.mm(grad_output_t_row_major, input_fp8_col_major)

# reshape grad input to match original shape
grad_input = grad_input.reshape(
*orig_grad_output_shape[:-1], grad_input.shape[-1]
)

# grad input shape
return grad_input, grad_weight, None, None, None
73 changes: 73 additions & 0 deletions torchao/prototype/float8nocompile/float8nocompile_linear_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import pytest

import torch
from torch.autograd.function import FunctionCtx
from torchao.float8.float8_linear import manual_float8_matmul_with_args_in_hp
from torchao.float8.config import Float8LinearConfig
from torchao.float8.float8_tensor import LinearMMConfig, ScaledMMConfig
from torch.autograd import gradcheck

from torchao.prototype.float8nocompile.float8nocompile_linear import matmul_with_args_in_hp
from torchao.prototype.float8nocompile.float8nocompile_scaling_utils import KernelAlgorithm

# unit test comparing the two implementations
@pytest.mark.parametrize(
"input_shape",
[(32, 16), (1,32,16), (2,32,16)],
)
def test_matmul_with_args_in_hp(input_shape: tuple[int, int]):
assert torch.cuda.is_available()
device = "cuda"

# high precision inputs
input_bf16 = torch.randn(input_shape, dtype=torch.bfloat16, device=device, requires_grad=True)
x_input_bf16 = input_bf16.clone().detach().to(device).requires_grad_(True)
y_input_bf16 = input_bf16.clone().detach().to(device).requires_grad_(True)

# high precision weights
# nn.Linear stores weights in transposed form
weight_bf16 = torch.randn((32, input_bf16.shape[-1]), dtype=torch.bfloat16, device=device, requires_grad=True)
x_weight_bf16 = weight_bf16.clone().detach().to(device).requires_grad_(True)
y_weight_bf16 = weight_bf16.clone().detach().to(device).requires_grad_(True)

# default configs
config = Float8LinearConfig()
emulate = False
linear_mm_config = linear_mm_config = LinearMMConfig(
# output
ScaledMMConfig(
emulate,
config.gemm_config_output.use_fast_accum,
False,
config.pad_inner_dim,
),
# grad_input
ScaledMMConfig(
emulate,
config.gemm_config_grad_input.use_fast_accum,
False,
config.pad_inner_dim,
),
# grad_weight
ScaledMMConfig(
emulate,
config.gemm_config_grad_weight.use_fast_accum,
False,
config.pad_inner_dim,
),
)

# prod forward. expects transposed weight.
out_prod = manual_float8_matmul_with_args_in_hp.apply(x_input_bf16, x_weight_bf16.t(), linear_mm_config, config)

# prototype forward. expects non-transposed weight
out_prototype = matmul_with_args_in_hp.apply(y_input_bf16, y_weight_bf16, config, linear_mm_config, KernelAlgorithm.ATOMIC_MAX)

# compare
assert torch.allclose(out_prod, out_prototype, atol=1e-3, rtol=1e-3)

out_prod.sum().backward()
out_prototype.sum().backward()

assert torch.allclose(x_input_bf16.grad, y_input_bf16.grad, atol=1e-3, rtol=1e-3)
assert torch.allclose(x_weight_bf16.grad, y_weight_bf16.grad, atol=1e-3, rtol=1e-3)
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import torch.nn as nn

from torchao.float8.config import Float8LinearConfig
from torchao.float8.float8_linear_utils import swap_linear_layers
from torchao.prototype.float8nocompile.float8nocompile_linear import (
Float8LinearNoCompile,
Expand All @@ -23,6 +24,7 @@
def convert_to_float8_nocompile_training(
module: nn.Module,
*,
config: Float8LinearConfig = None,
module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None,
kernel_algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_MAX,
) -> nn.Module:
Expand All @@ -39,7 +41,12 @@ def convert_to_float8_nocompile_training(
Returns:
nn.Module: The modified module with swapped linear layers.
"""
from_float = lambda m: Float8LinearNoCompile.from_float(m, kernel_algo=kernel_algo)
if config is None:
config = Float8LinearConfig()

from_float = lambda m: Float8LinearNoCompile.from_float(
m, config=config, kernel_algo=kernel_algo
)
return swap_linear_layers(
module,
from_float,
Expand Down
7 changes: 5 additions & 2 deletions torchao/prototype/float8nocompile/test/train_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ def model2():
return TestModel()


def test_model_weights_and_gradients(model1, model2):
@pytest.mark.parametrize(
"input_shape", [(16,32), (1,16,32), (2,16,32)]
)
def test_model_weights_and_gradients(model1, model2, input_shape: tuple[int, int]):
assert torch.cuda.is_available()
device = torch.device("cuda")

Expand All @@ -48,7 +51,7 @@ def test_model_weights_and_gradients(model1, model2):
convert_to_float8_nocompile_training(model1)

input_tensor = torch.randn(
16, 32, requires_grad=True, dtype=torch.bfloat16, device=device
*input_shape, requires_grad=True, dtype=torch.bfloat16, device=device
)
input_copy1 = input_tensor.clone().detach().requires_grad_(True)
input_copy2 = input_tensor.clone().detach().requires_grad_(True)
Expand Down

0 comments on commit 367aea3

Please sign in to comment.