Skip to content

Commit

Permalink
support for activation checkpointing in float8nocompile
Browse files Browse the repository at this point in the history
ghstack-source-id: f5d90a779adc52d4529356e22a32a39467df16fa
ghstack-comment-id: 2576321936
Pull Request resolved: #1517
  • Loading branch information
danielvegamyhre committed Jan 8, 2025
1 parent 878c886 commit 715b728
Show file tree
Hide file tree
Showing 4 changed files with 234 additions and 60 deletions.
192 changes: 171 additions & 21 deletions torchao/prototype/float8nocompile/float8nocompile_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@
"""

import torch
from torch.utils.checkpoint import checkpoint

from torchao.float8.config import Float8LinearConfig
from torchao.float8.float8_tensor import GemmInputRole, LinearMMConfig, ScaledMMConfig
from torchao.prototype.float8nocompile.float8nocompile_scaling_utils import (
ToFP8ColumnMajor,
ToFP8ColumnMajorT,
ToFP8RowAndColumnMajor,
ToFP8RowMajor,
ToFP8RowMajorTAndNonT,
)
from torchao.prototype.float8nocompile.kernels.fp8_dynamic_tensorwise import (
Expand All @@ -36,51 +38,67 @@ def __init__(self, *args, **kwargs):
Additional arguments on top of `torch.nn.Linear`'s arguments:
* `config`: Float8LinearConfig
"""
config = kwargs.pop("config")
kernel_algo = kwargs.pop("kernel_algo")
emulate = config.emulate
self.config = kwargs.pop("config")
self.kernel_algo = kwargs.pop("kernel_algo")
self.use_activation_checkpointing = kwargs.pop(
"use_activation_checkpointing", False
)
super().__init__(*args, **kwargs)

self.config = config
self.kernel_algo = kernel_algo

self.linear_mm_config = LinearMMConfig(
# output
ScaledMMConfig(
emulate,
self.config.emulate,
self.config.gemm_config_output.use_fast_accum,
False,
self.config.pad_inner_dim,
),
# grad_input
ScaledMMConfig(
emulate,
self.config.emulate,
self.config.gemm_config_grad_input.use_fast_accum,
False,
self.config.pad_inner_dim,
),
# grad_weight
ScaledMMConfig(
emulate,
self.config.emulate,
self.config.gemm_config_grad_weight.use_fast_accum,
False,
self.config.pad_inner_dim,
),
)

def forward(self, input: torch.Tensor) -> torch.Tensor:
# TODO(danielvegamyhre): support for FSDP once dependencies are implemented
output = matmul_with_args_in_hp.apply(
input,
self.weight,
self.config,
self.linear_mm_config,
self.kernel_algo,
)
if self.use_activation_checkpointing:
output = checkpoint(
matmul_with_args_in_hp.apply,
input,
self.weight,
self.config,
self.linear_mm_config,
self.kernel_algo,
self.use_activation_checkpointing,
)
else:
output = matmul_with_args_in_hp.apply(
input,
self.weight,
self.config,
self.linear_mm_config,
self.kernel_algo,
self.use_activation_checkpointing,
)
return output

@classmethod
def from_float(cls, mod, config: Float8LinearConfig, kernel_algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_MAX):
def from_float(
cls,
mod,
config: Float8LinearConfig,
kernel_algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_MAX,
use_activation_checkpointing: bool = False,
):
"""
Create an nn.Linear with fp8 compute from a regular nn.Linear
Expand All @@ -95,6 +113,7 @@ def from_float(cls, mod, config: Float8LinearConfig, kernel_algo: KernelAlgorith
bias=False,
config=config,
kernel_algo=kernel_algo,
use_activation_checkpointing=use_activation_checkpointing,
)
new_mod.weight = mod.weight
new_mod.bias = mod.bias
Expand All @@ -105,7 +124,40 @@ def from_float(cls, mod, config: Float8LinearConfig, kernel_algo: KernelAlgorith

class matmul_with_args_in_hp(torch.autograd.Function):
@staticmethod
def forward(ctx, input_hp, weight_hp, config, linear_mm_config, kernel_algo):
def forward(
ctx,
input_hp: torch.Tensor,
weight_hp: torch.Tensor,
config: Float8LinearConfig,
linear_mm_config: LinearMMConfig,
kernel_algo: KernelAlgorithm,
use_activation_checkpointing: bool,
):
if use_activation_checkpointing:
return matmul_with_args_in_hp._forward_with_ac(
ctx, input_hp, weight_hp, config, linear_mm_config, kernel_algo
)
else:
return matmul_with_args_in_hp._forward_no_ac(
ctx, input_hp, weight_hp, config, linear_mm_config, kernel_algo
)

@staticmethod
def backward(ctx, grad_output):
if ctx.use_activation_checkpointing:
return matmul_with_args_in_hp._backward_with_ac(ctx, grad_output)
else:
return matmul_with_args_in_hp._backward_no_ac(ctx, grad_output)

@staticmethod
def _forward_no_ac(
ctx,
input_hp: torch.Tensor,
weight_hp: torch.Tensor,
config: Float8LinearConfig,
linear_mm_config: LinearMMConfig,
kernel_algo: KernelAlgorithm,
):
# reshape to be 2D for triton kernels
orig_input_shape = input_hp.shape
input_hp = input_hp.reshape(-1, input_hp.shape[-1])
Expand All @@ -132,13 +184,14 @@ def forward(ctx, input_hp, weight_hp, config, linear_mm_config, kernel_algo):
ctx.config = config
ctx.linear_mm_config = linear_mm_config
ctx.kernel_algo = kernel_algo
ctx.use_activation_checkpointing = False

# reshape back to expected dims
output = output.reshape(*orig_input_shape[:-1], output.shape[-1])
return output

@staticmethod
def backward(ctx, grad_output):
def _backward_no_ac(ctx, grad_output):
# grad_output may not be contiguous in cases like:
# output.sum().backward() where grad is all 1s, so the (M,N) view of the scalar "1"
# results in a non-contiguous tensor with stride (0,0).
Expand Down Expand Up @@ -172,15 +225,112 @@ def backward(ctx, grad_output):
)
grad_input = torch.mm(grad_output_fp8_row_major, weight_fp8_col_major)

# reshape grad input to match original shape
grad_input = grad_input.reshape(
*orig_grad_output_shape[:-1], grad_input.shape[-1]
)

# grad_weight = grad_output_t @ input
# apparently this variant is slightly faster than `grad_weight_t = input_t @ grad_output`
# source: https://github.com/pytorch/ao/blob/fe5f11b2c58b452e01ba9ec7359629928b143619/torchao/float8/float8_linear.py#L84-L85
grad_weight = torch.mm(grad_output_t_row_major, input_fp8_col_major)

# grad input shape
return grad_input, grad_weight, None, None, None, None

@staticmethod
def _forward_with_ac(
ctx,
input_hp: torch.Tensor,
weight_hp: torch.Tensor,
config: Float8LinearConfig,
linear_mm_config: LinearMMConfig,
kernel_algo: KernelAlgorithm,
):
# reshape to be 2D for triton kernels
orig_input_shape = input_hp.shape
input_hp = input_hp.reshape(-1, input_hp.shape[-1])

# output = input @ weight_t
input_fp8_row_major = ToFP8RowMajor.apply(
input_hp,
config.cast_config_input.target_dtype,
linear_mm_config,
GemmInputRole.INPUT,
kernel_algo,
)
weight_t_fp8_col_major = ToFP8ColumnMajorT.apply(
weight_hp,
config.cast_config_weight.target_dtype,
linear_mm_config,
GemmInputRole.WEIGHT,
kernel_algo,
)
output = torch.mm(input_fp8_row_major, weight_t_fp8_col_major)

# with AC we only will save the original hp input tensor and weight for backward,
# and do the necessary fp8 conversions during the backward pass.
ctx.save_for_backward(input_hp, weight_hp)
ctx.config = config
ctx.linear_mm_config = linear_mm_config
ctx.kernel_algo = kernel_algo
ctx.use_activation_checkpointing = True

# reshape back to expected dims
output = output.reshape(*orig_input_shape[:-1], output.shape[-1])
return output

@staticmethod
def _backward_with_ac(ctx, grad_output):
# grad_output may not be contiguous in cases like:
# output.sum().backward() where grad is all 1s, so the (M,N) view of the scalar "1"
# results in a non-contiguous tensor with stride (0,0).
if not grad_output.is_contiguous():
grad_output = grad_output.contiguous()

input_hp, weight_hp = ctx.saved_tensors

# reshsape to be 2D for triton kernels
orig_grad_output_shape = grad_output.shape
grad_output = grad_output.reshape(-1, grad_output.shape[-1])

# cast grad output to float8_e5m2 for backward
grad_output_fp8_row_major, grad_output_t_row_major = (
ToFP8RowMajorTAndNonT.apply(
grad_output,
ctx.config.cast_config_grad_output.target_dtype,
ctx.linear_mm_config,
GemmInputRole.GRAD_OUTPUT,
ctx.kernel_algo,
)
)

# grad_input = grad_output @ weight
weight_fp8_col_major = ToFP8ColumnMajor.apply(
weight_hp,
ctx.config.cast_config_weight.target_dtype,
ctx.linear_mm_config,
GemmInputRole.WEIGHT,
ctx.kernel_algo,
)
grad_input = torch.mm(grad_output_fp8_row_major, weight_fp8_col_major)

# reshape grad input to match original shape
grad_input = grad_input.reshape(
*orig_grad_output_shape[:-1], grad_input.shape[-1]
)

# grad_weight = grad_output_t @ input
# apparently this variant is slightly faster than `grad_weight_t = input_t @ grad_output`
# source: https://github.com/pytorch/ao/blob/fe5f11b2c58b452e01ba9ec7359629928b143619/torchao/float8/float8_linear.py#L84-L85
input_fp8_col_major = ToFP8ColumnMajor.apply(
input_hp,
ctx.config.cast_config_input.target_dtype,
ctx.linear_mm_config,
GemmInputRole.INPUT,
ctx.kernel_algo,
)
grad_weight = torch.mm(grad_output_t_row_major, input_fp8_col_major)

# grad input shape
return grad_input, grad_weight, None, None, None
return grad_input, grad_weight, None, None, None, None
83 changes: 50 additions & 33 deletions torchao/prototype/float8nocompile/float8nocompile_linear_test.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,84 @@
import pytest

import torch
from torch.autograd.function import FunctionCtx
from torchao.float8.float8_linear import manual_float8_matmul_with_args_in_hp

from torchao.float8.config import Float8LinearConfig
from torchao.float8.float8_linear import manual_float8_matmul_with_args_in_hp
from torchao.float8.float8_tensor import LinearMMConfig, ScaledMMConfig
from torch.autograd import gradcheck
from torchao.prototype.float8nocompile.float8nocompile_linear import (
matmul_with_args_in_hp,
)
from torchao.prototype.float8nocompile.float8nocompile_scaling_utils import (
KernelAlgorithm,
)

from torchao.prototype.float8nocompile.float8nocompile_linear import matmul_with_args_in_hp
from torchao.prototype.float8nocompile.float8nocompile_scaling_utils import KernelAlgorithm

# unit test comparing the two implementations
@pytest.mark.parametrize(
"input_shape",
[(32, 16), (1,32,16), (2,32,16)],
[(32, 16), (1, 32, 16), (2, 32, 16)],
)
def test_matmul_with_args_in_hp(input_shape: tuple[int, int]):
assert torch.cuda.is_available()
device = "cuda"

# high precision inputs
input_bf16 = torch.randn(input_shape, dtype=torch.bfloat16, device=device, requires_grad=True)
input_bf16 = torch.randn(
input_shape, dtype=torch.bfloat16, device=device, requires_grad=True
)
x_input_bf16 = input_bf16.clone().detach().to(device).requires_grad_(True)
y_input_bf16 = input_bf16.clone().detach().to(device).requires_grad_(True)

# high precision weights
# nn.Linear stores weights in transposed form
weight_bf16 = torch.randn((32, input_bf16.shape[-1]), dtype=torch.bfloat16, device=device, requires_grad=True)
weight_bf16 = torch.randn(
(32, input_bf16.shape[-1]),
dtype=torch.bfloat16,
device=device,
requires_grad=True,
)
x_weight_bf16 = weight_bf16.clone().detach().to(device).requires_grad_(True)
y_weight_bf16 = weight_bf16.clone().detach().to(device).requires_grad_(True)

# default configs
config = Float8LinearConfig()
emulate = False
linear_mm_config = linear_mm_config = LinearMMConfig(
# output
ScaledMMConfig(
emulate,
config.gemm_config_output.use_fast_accum,
False,
config.pad_inner_dim,
),
# grad_input
ScaledMMConfig(
emulate,
config.gemm_config_grad_input.use_fast_accum,
False,
config.pad_inner_dim,
),
# grad_weight
ScaledMMConfig(
emulate,
config.gemm_config_grad_weight.use_fast_accum,
False,
config.pad_inner_dim,
),
)
# output
ScaledMMConfig(
emulate,
config.gemm_config_output.use_fast_accum,
False,
config.pad_inner_dim,
),
# grad_input
ScaledMMConfig(
emulate,
config.gemm_config_grad_input.use_fast_accum,
False,
config.pad_inner_dim,
),
# grad_weight
ScaledMMConfig(
emulate,
config.gemm_config_grad_weight.use_fast_accum,
False,
config.pad_inner_dim,
),
)

# prod forward. expects transposed weight.
out_prod = manual_float8_matmul_with_args_in_hp.apply(x_input_bf16, x_weight_bf16.t(), linear_mm_config, config)
out_prod = manual_float8_matmul_with_args_in_hp.apply(
x_input_bf16, x_weight_bf16.t(), linear_mm_config, config
)

# prototype forward. expects non-transposed weight
out_prototype = matmul_with_args_in_hp.apply(y_input_bf16, y_weight_bf16, config, linear_mm_config, KernelAlgorithm.ATOMIC_MAX)
out_prototype = matmul_with_args_in_hp.apply(
y_input_bf16,
y_weight_bf16,
config,
linear_mm_config,
KernelAlgorithm.ATOMIC_MAX,
)

# compare
assert torch.allclose(out_prod, out_prototype, atol=1e-3, rtol=1e-3)
Expand Down
Loading

0 comments on commit 715b728

Please sign in to comment.