Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
danielvegamyhre committed Dec 20, 2024
1 parent aaf7ed8 commit b8e1c6a
Showing 1 changed file with 2 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,12 @@

import torch

from torchao.float8.config import ScalingGranularity
from torchao.float8.distributed_utils import tensor_already_casted_to_fp8
from torchao.float8.float8_tensor import (
_ToFloat8ConstrFunc,
Float8Tensor,
GemmInputRole,
LinearMMConfig,
_ToFloat8ConstrFunc,
)
from torchao.float8.float8_utils import tensor_to_scale

from torchao.prototype.float8nocompile.kernels.fp8_dynamic_tensorwise import (
triton_hp_tensor_to_float8_dynamic,
)
Expand Down Expand Up @@ -94,7 +90,7 @@ class NoopFwToFloat8NoCompileBwDynamic(torch.autograd.Function):
"""
A differentiable conversion to fp8.
* forward: no-op
* backward: convert to fp8_e5m2 with tensor-wise dynamic scaling
* backward: convert to float8 with tensor-wise dynamic scaling
"""

@staticmethod
Expand All @@ -110,7 +106,6 @@ def forward(

@staticmethod
def backward(ctx, gradY):
# cast grad output to e5m2 in backward pass
fp8_tensor = triton_hp_tensor_to_float8_dynamic(
gradY,
ctx.target_dtype,
Expand Down

0 comments on commit b8e1c6a

Please sign in to comment.