Skip to content

Commit

Permalink
fix linter errors
Browse files Browse the repository at this point in the history
ghstack-source-id: 619ae59ae76acfda5492b56241f74878a4d1b04f
ghstack-comment-id: 2576566853
Pull Request resolved: #1525
  • Loading branch information
danielvegamyhre committed Jan 8, 2025
1 parent 3cf3f1c commit 70e147d
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 70 deletions.
59 changes: 33 additions & 26 deletions torchao/prototype/float8nocompile/float8nocompile_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
ToFP8ColumnMajor,
ToFP8ColumnMajorT,
ToFP8RowAndColumnMajor,
ToFP8RowMajor,
ToFP8RowMajorTAndNonT,
)
from torchao.prototype.float8nocompile.kernels.fp8_dynamic_tensorwise import (
Expand All @@ -39,7 +40,9 @@ def __init__(self, *args, **kwargs):
"""
self.config = kwargs.pop("config")
self.kernel_algo = kwargs.pop("kernel_algo")
self.use_activation_checkpointing = kwargs.pop("use_activation_checkpointing", False)
self.use_activation_checkpointing = kwargs.pop(
"use_activation_checkpointing", False
)
super().__init__(*args, **kwargs)

self.linear_mm_config = LinearMMConfig(
Expand Down Expand Up @@ -69,7 +72,7 @@ def __init__(self, *args, **kwargs):
def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.use_activation_checkpointing:
output = checkpoint(
matmul_with_args_in_hp.apply,
matmul_with_args_in_hp.apply,
input,
self.weight,
self.config,
Expand All @@ -90,9 +93,9 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:

@classmethod
def from_float(
cls,
mod,
config: Float8LinearConfig,
cls,
mod,
config: Float8LinearConfig,
kernel_algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_MAX,
use_activation_checkpointing: bool = False,
):
Expand Down Expand Up @@ -122,18 +125,22 @@ def from_float(
class matmul_with_args_in_hp(torch.autograd.Function):
@staticmethod
def forward(
ctx,
input_hp: torch.Tensor,
weight_hp: torch.Tensor,
config: Float8LinearConfig,
linear_mm_config: LinearMMConfig,
kernel_algo: KernelAlgorithm,
ctx,
input_hp: torch.Tensor,
weight_hp: torch.Tensor,
config: Float8LinearConfig,
linear_mm_config: LinearMMConfig,
kernel_algo: KernelAlgorithm,
use_activation_checkpointing: bool,
):
if use_activation_checkpointing:
return matmul_with_args_in_hp._forward_with_ac(ctx, input_hp, weight_hp, config, linear_mm_config, kernel_algo)
return matmul_with_args_in_hp._forward_with_ac(
ctx, input_hp, weight_hp, config, linear_mm_config, kernel_algo
)
else:
return matmul_with_args_in_hp._forward_no_ac(ctx, input_hp, weight_hp, config, linear_mm_config, kernel_algo)
return matmul_with_args_in_hp._forward_no_ac(
ctx, input_hp, weight_hp, config, linear_mm_config, kernel_algo
)

@staticmethod
def backward(ctx, grad_output):
Expand All @@ -144,12 +151,12 @@ def backward(ctx, grad_output):

@staticmethod
def _forward_no_ac(
ctx,
input_hp: torch.Tensor,
weight_hp: torch.Tensor,
config: Float8LinearConfig,
linear_mm_config: LinearMMConfig,
kernel_algo: KernelAlgorithm,
ctx,
input_hp: torch.Tensor,
weight_hp: torch.Tensor,
config: Float8LinearConfig,
linear_mm_config: LinearMMConfig,
kernel_algo: KernelAlgorithm,
):
# reshape to be 2D for triton kernels
orig_input_shape = input_hp.shape
Expand Down Expand Up @@ -233,12 +240,12 @@ def _backward_no_ac(ctx, grad_output):

@staticmethod
def _forward_with_ac(
ctx,
input_hp: torch.Tensor,
weight_hp: torch.Tensor,
config: Float8LinearConfig,
linear_mm_config: LinearMMConfig,
kernel_algo: KernelAlgorithm,
ctx,
input_hp: torch.Tensor,
weight_hp: torch.Tensor,
config: Float8LinearConfig,
linear_mm_config: LinearMMConfig,
kernel_algo: KernelAlgorithm,
):
# reshape to be 2D for triton kernels
orig_input_shape = input_hp.shape
Expand Down Expand Up @@ -271,7 +278,7 @@ def _forward_with_ac(

# reshape back to expected dims
output = output.reshape(*orig_input_shape[:-1], output.shape[-1])
return output
return output

@staticmethod
def _backward_with_ac(ctx, grad_output):
Expand Down
83 changes: 50 additions & 33 deletions torchao/prototype/float8nocompile/float8nocompile_linear_test.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,84 @@
import pytest

import torch
from torch.autograd.function import FunctionCtx
from torchao.float8.float8_linear import manual_float8_matmul_with_args_in_hp

from torchao.float8.config import Float8LinearConfig
from torchao.float8.float8_linear import manual_float8_matmul_with_args_in_hp
from torchao.float8.float8_tensor import LinearMMConfig, ScaledMMConfig
from torch.autograd import gradcheck
from torchao.prototype.float8nocompile.float8nocompile_linear import (
matmul_with_args_in_hp,
)
from torchao.prototype.float8nocompile.float8nocompile_scaling_utils import (
KernelAlgorithm,
)

from torchao.prototype.float8nocompile.float8nocompile_linear import matmul_with_args_in_hp
from torchao.prototype.float8nocompile.float8nocompile_scaling_utils import KernelAlgorithm

# unit test comparing the two implementations
@pytest.mark.parametrize(
"input_shape",
[(32, 16), (1,32,16), (2,32,16)],
[(32, 16), (1, 32, 16), (2, 32, 16)],
)
def test_matmul_with_args_in_hp(input_shape: tuple[int, int]):
assert torch.cuda.is_available()
device = "cuda"

# high precision inputs
input_bf16 = torch.randn(input_shape, dtype=torch.bfloat16, device=device, requires_grad=True)
input_bf16 = torch.randn(
input_shape, dtype=torch.bfloat16, device=device, requires_grad=True
)
x_input_bf16 = input_bf16.clone().detach().to(device).requires_grad_(True)
y_input_bf16 = input_bf16.clone().detach().to(device).requires_grad_(True)

# high precision weights
# nn.Linear stores weights in transposed form
weight_bf16 = torch.randn((32, input_bf16.shape[-1]), dtype=torch.bfloat16, device=device, requires_grad=True)
weight_bf16 = torch.randn(
(32, input_bf16.shape[-1]),
dtype=torch.bfloat16,
device=device,
requires_grad=True,
)
x_weight_bf16 = weight_bf16.clone().detach().to(device).requires_grad_(True)
y_weight_bf16 = weight_bf16.clone().detach().to(device).requires_grad_(True)

# default configs
config = Float8LinearConfig()
emulate = False
linear_mm_config = linear_mm_config = LinearMMConfig(
# output
ScaledMMConfig(
emulate,
config.gemm_config_output.use_fast_accum,
False,
config.pad_inner_dim,
),
# grad_input
ScaledMMConfig(
emulate,
config.gemm_config_grad_input.use_fast_accum,
False,
config.pad_inner_dim,
),
# grad_weight
ScaledMMConfig(
emulate,
config.gemm_config_grad_weight.use_fast_accum,
False,
config.pad_inner_dim,
),
)
# output
ScaledMMConfig(
emulate,
config.gemm_config_output.use_fast_accum,
False,
config.pad_inner_dim,
),
# grad_input
ScaledMMConfig(
emulate,
config.gemm_config_grad_input.use_fast_accum,
False,
config.pad_inner_dim,
),
# grad_weight
ScaledMMConfig(
emulate,
config.gemm_config_grad_weight.use_fast_accum,
False,
config.pad_inner_dim,
),
)

# prod forward. expects transposed weight.
out_prod = manual_float8_matmul_with_args_in_hp.apply(x_input_bf16, x_weight_bf16.t(), linear_mm_config, config)
out_prod = manual_float8_matmul_with_args_in_hp.apply(
x_input_bf16, x_weight_bf16.t(), linear_mm_config, config
)

# prototype forward. expects non-transposed weight
out_prototype = matmul_with_args_in_hp.apply(y_input_bf16, y_weight_bf16, config, linear_mm_config, KernelAlgorithm.ATOMIC_MAX)
out_prototype = matmul_with_args_in_hp.apply(
y_input_bf16,
y_weight_bf16,
config,
linear_mm_config,
KernelAlgorithm.ATOMIC_MAX,
)

# compare
assert torch.allclose(out_prod, out_prototype, atol=1e-3, rtol=1e-3)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import torch.nn as nn

from torchao.float8.config import Float8LinearConfig
from torchao.float8.float8_linear_utils import swap_linear_layers
from torchao.prototype.float8nocompile.float8nocompile_linear import (
Float8LinearNoCompile,
Expand Down Expand Up @@ -45,9 +46,9 @@ def convert_to_float8_nocompile_training(
config = Float8LinearConfig()

from_float = lambda m: Float8LinearNoCompile.from_float(
m,
config=config,
kernel_algo=kernel_algo,
m,
config=config,
kernel_algo=kernel_algo,
use_activation_checkpointing=use_activation_checkpointing,
)
return swap_linear_layers(
Expand Down
16 changes: 8 additions & 8 deletions torchao/prototype/float8nocompile/test/train_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,11 @@ def model2():
return TestModel()


@pytest.mark.parametrize(
"input_shape", [(16,32), (1,16,32), (2,16,32)]
)
@pytest.mark.parametrize(
"use_activation_checkpointing", [True, False]
)
def test_model_weights_and_gradients(model1, model2, input_shape: tuple[int, int], use_activation_checkpointing: bool):
@pytest.mark.parametrize("input_shape", [(16, 32), (1, 16, 32), (2, 16, 32)])
@pytest.mark.parametrize("use_activation_checkpointing", [True, False])
def test_model_weights_and_gradients(
model1, model2, input_shape: tuple[int, int], use_activation_checkpointing: bool
):
assert torch.cuda.is_available()
device = torch.device("cuda")

Expand All @@ -51,7 +49,9 @@ def test_model_weights_and_gradients(model1, model2, input_shape: tuple[int, int

# compare production float8 linear conversion with no-compile version
convert_to_float8_training(model2)
convert_to_float8_nocompile_training(model1, use_activation_checkpointing=use_activation_checkpointing)
convert_to_float8_nocompile_training(
model1, use_activation_checkpointing=use_activation_checkpointing
)

input_tensor = torch.randn(
*input_shape, requires_grad=True, dtype=torch.bfloat16, device=device
Expand Down

0 comments on commit 70e147d

Please sign in to comment.