Skip to content

Commit

Permalink
float8nocompile: wrap float8nocompile conversion kernels in autograd
Browse files Browse the repository at this point in the history
func
  • Loading branch information
Daniel Vega-Myhre committed Dec 19, 2024
1 parent e474839 commit 2daef80
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 19 deletions.
26 changes: 9 additions & 17 deletions torchao/prototype/float8nocompile/float8nocompile_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,12 @@
from torchao.float8.distributed_utils import tensor_already_casted_to_fp8
from torchao.float8.float8_linear import manual_float8_matmul_with_args_in_float8
from torchao.float8.float8_scaling_utils import NoopFwToFloat8BwDynamic
from torchao.float8.float8_tensor import (
GemmInputRole,
hp_tensor_and_scale_to_float8,
LinearMMConfig,
ScaledMMConfig,
)
from torchao.float8.float8_tensor import GemmInputRole, LinearMMConfig, ScaledMMConfig
from torchao.float8.float8_utils import tensor_to_scale

from torchao.prototype.float8nocompile.float8nocompile_scaling_utils import (
hp_tensor_to_float8nocompile_dynamic,
Float8NoCompileConversionFunc,
NoopFwToFloat8NoCompileBwDynamic,
)


Expand Down Expand Up @@ -72,7 +68,6 @@ def __init__(self, *args, **kwargs):
)

def forward(self, input: torch.Tensor) -> torch.Tensor:
# TODO(danielvegamyhre): replace conversions with triton kernels
# TODO(danielvegamyhre): support for FSDP once dependencies are implemented
input_fp8 = self.cast_input_to_float8(input)
weight_fp8_t = self.cast_weight_to_float8_t(self.weight)
Expand All @@ -92,34 +87,31 @@ def cast_input_to_float8(self, input: torch.Tensor) -> torch.Tensor:
autocast_dtype = torch.get_autocast_gpu_dtype()
input = input.to(autocast_dtype)

# TODO(danielvegamyhre): implement this fn in scaling_utils with call to triton kernel
return hp_tensor_to_float8nocompile_dynamic(
return Float8NoCompileConversionFunc.apply(
input,
self.config.cast_config_input.target_dtype,
self.linear_mm_config,
gemm_input_role=GemmInputRole.INPUT,
GemmInputRole.INPUT,
)

def cast_weight_to_float8_t(
self,
weight: torch.Tensor,
) -> torch.Tensor:
# TODO(danielvegamyhre): replace conversion with triton kernel
weight_fp8 = hp_tensor_to_float8nocompile_dynamic(
weight_fp8 = Float8NoCompileConversionFunc.apply(
weight,
self.config.cast_config_weight.target_dtype,
self.linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
GemmInputRole.WEIGHT,
)
return weight_fp8.t()

def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor:
# casts grad_output to float8_e5m2 for backward
# TODO(danielvegamyhre): replace conversion with triton kernel
return NoopFwToFloat8BwDynamic.apply(
return NoopFwToFloat8NoCompileBwDynamic.apply(
output,
self.linear_mm_config,
self.config.cast_config_grad_output.target_dtype,
self.linear_mm_config,
)

@classmethod
Expand Down
63 changes: 61 additions & 2 deletions torchao/prototype/float8nocompile/float8nocompile_scaling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
Utilities for scaling high precision tensors to float8.
"""

from typing import Optional

import torch

from torchao.float8.config import ScalingGranularity
Expand All @@ -22,6 +20,10 @@
)
from torchao.float8.float8_utils import tensor_to_scale

from torchao.prototype.float8nocompile.kernels.fp8_dynamic_tensorwise import (
triton_hp_tensor_to_float8_dynamic,
)

# avoid division by zero when calculating scale
# TODO: align this value with NVIDIA's assumptions (current value is a guess)
EPS = 1e-12
Expand Down Expand Up @@ -59,3 +61,60 @@ def hp_tensor_to_float8nocompile_dynamic(
gemm_input_role,
None,
)


class Float8NoCompileConversionFunc(torch.autograd.Function):
"""
A differentiable conversion to fp8.
* forward: convert from high precision to float8
* backward: pass the gradient without changes
"""

@staticmethod
def forward(
ctx,
tensor: torch.Tensor,
float8_dtype: torch.dtype,
linear_mm_config: LinearMMConfig,
gemm_input_role: GemmInputRole,
):
return triton_hp_tensor_to_float8_dynamic(
tensor,
float8_dtype,
linear_mm_config,
gemm_input_role,
)

@staticmethod
def backward(ctx, g):
return g, None, None, None, None, None


class NoopFwToFloat8NoCompileBwDynamic(torch.autograd.Function):
"""
A differentiable conversion to fp8.
* forward: no-op
* backward: convert to fp8_e5m2 with tensor-wise dynamic scaling
"""

@staticmethod
def forward(
ctx,
tensor: torch.Tensor,
float8_dtype: torch.dtype,
linear_mm_config: LinearMMConfig,
):
ctx.linear_mm_config = linear_mm_config
ctx.target_dtype = float8_dtype
return tensor

@staticmethod
def backward(ctx, gradY):
# cast grad output to e5m2 in backward pass
fp8_tensor = triton_hp_tensor_to_float8_dynamic(
gradY,
ctx.target_dtype,
ctx.linear_mm_config,
GemmInputRole.GRAD_OUTPUT,
)
return fp8_tensor, None, None

0 comments on commit 2daef80

Please sign in to comment.