diff --git a/torchao/prototype/float8nocompile/float8nocompile_scaling_utils.py b/torchao/prototype/float8nocompile/float8nocompile_scaling_utils.py index 5d081cdbd..8581e5014 100644 --- a/torchao/prototype/float8nocompile/float8nocompile_scaling_utils.py +++ b/torchao/prototype/float8nocompile/float8nocompile_scaling_utils.py @@ -10,16 +10,12 @@ import torch -from torchao.float8.config import ScalingGranularity -from torchao.float8.distributed_utils import tensor_already_casted_to_fp8 from torchao.float8.float8_tensor import ( - _ToFloat8ConstrFunc, Float8Tensor, GemmInputRole, LinearMMConfig, + _ToFloat8ConstrFunc, ) -from torchao.float8.float8_utils import tensor_to_scale - from torchao.prototype.float8nocompile.kernels.fp8_dynamic_tensorwise import ( triton_hp_tensor_to_float8_dynamic, ) @@ -94,7 +90,7 @@ class NoopFwToFloat8NoCompileBwDynamic(torch.autograd.Function): """ A differentiable conversion to fp8. * forward: no-op - * backward: convert to fp8_e5m2 with tensor-wise dynamic scaling + * backward: convert to float8 with tensor-wise dynamic scaling """ @staticmethod @@ -110,7 +106,6 @@ def forward( @staticmethod def backward(ctx, gradY): - # cast grad output to e5m2 in backward pass fp8_tensor = triton_hp_tensor_to_float8_dynamic( gradY, ctx.target_dtype,