diff --git a/torchao/float8/float8_linear.py b/torchao/float8/float8_linear.py index b7a344927..4b3f271e2 100644 --- a/torchao/float8/float8_linear.py +++ b/torchao/float8/float8_linear.py @@ -312,13 +312,16 @@ def cast_input_to_float8(self, input: torch.Tensor) -> torch.Tensor: autocast_dtype = torch.get_autocast_gpu_dtype() input = input.to(autocast_dtype) - assert self.scaling_type_input is ScalingType.DYNAMIC - input_fp8 = hp_tensor_to_float8_dynamic( - input, - self.config.cast_config_input.target_dtype, - self.linear_mm_config, - gemm_input_role=GemmInputRole.INPUT, - ) + if tensor_already_casted_to_fp8(input): + input_fp8 = input + else: + assert self.scaling_type_input is ScalingType.DYNAMIC + input_fp8 = hp_tensor_to_float8_dynamic( + input, + self.config.cast_config_input.target_dtype, + self.linear_mm_config, + gemm_input_role=GemmInputRole.INPUT, + ) return input_fp8 def get_weight_scale(self, weight: torch.Tensor) -> Optional[torch.Tensor]: diff --git a/torchao/float8/float8_scaling_utils.py b/torchao/float8/float8_scaling_utils.py index 3a9841e62..0c27e4f3f 100644 --- a/torchao/float8/float8_scaling_utils.py +++ b/torchao/float8/float8_scaling_utils.py @@ -52,8 +52,6 @@ def hp_tensor_to_float8_dynamic( scaling_granularity: Defines the scaling granularity axiswise_dim: if axiswise granularity is used, defines the dim to scale across """ - if tensor_already_casted_to_fp8(hp_tensor): - return hp_tensor scale = tensor_to_scale( hp_tensor, float8_dtype, diff --git a/torchao/float8/float8_tensor_parallel.py b/torchao/float8/float8_tensor_parallel.py index 37cb67c7e..9d45196cf 100644 --- a/torchao/float8/float8_tensor_parallel.py +++ b/torchao/float8/float8_tensor_parallel.py @@ -9,6 +9,7 @@ ) from torchao.float8.config import ScalingType, e4m3_dtype +from torchao.float8.distributed_utils import tensor_already_casted_to_fp8 from torchao.float8.float8_scaling_utils import ( NoopFwToFloat8BwDynamic, hp_tensor_to_float8_dynamic, @@ -46,12 +47,13 @@ def _prepare_input_fn( input_tensor, device_mesh, input_layouts, run_check=False ) - input_tensor = hp_tensor_to_float8_dynamic( - input_tensor, - mod.config.cast_config_input.target_dtype, - mod.linear_mm_config, - gemm_input_role=GemmInputRole.INPUT, - ) # DTensor(Float8Tensor) + if not tensor_already_casted_to_fp8(input_tensor): + input_tensor = hp_tensor_to_float8_dynamic( + input_tensor, + mod.config.cast_config_input.target_dtype, + mod.linear_mm_config, + gemm_input_role=GemmInputRole.INPUT, + ) # DTensor(Float8Tensor) # transform the input layouts to the desired layouts of ColwiseParallel if input_layouts != desired_input_layouts: @@ -104,12 +106,13 @@ def _prepare_input_fn( input_tensor, device_mesh, input_layouts, run_check=False ) - input_tensor = hp_tensor_to_float8_dynamic( - input_tensor, - mod.config.cast_config_input.target_dtype, - mod.linear_mm_config, - gemm_input_role=GemmInputRole.INPUT, - ) # DTensor(Float8Tensor) + if not tensor_already_casted_to_fp8(input_tensor): + input_tensor = hp_tensor_to_float8_dynamic( + input_tensor, + mod.config.cast_config_input.target_dtype, + mod.linear_mm_config, + gemm_input_role=GemmInputRole.INPUT, + ) # DTensor(Float8Tensor) if input_layouts != desired_input_layouts: input_tensor = input_tensor.redistribute( diff --git a/torchao/float8/stateful_float8_linear.py b/torchao/float8/stateful_float8_linear.py index 94851511b..7db72b993 100644 --- a/torchao/float8/stateful_float8_linear.py +++ b/torchao/float8/stateful_float8_linear.py @@ -153,7 +153,9 @@ def cast_input_to_float8(self, input: torch.Tensor) -> torch.Tensor: autocast_dtype = torch.get_autocast_gpu_dtype() input = input.to(autocast_dtype) - if self.scaling_type_input is ScalingType.DELAYED: + if tensor_already_casted_to_fp8(input): + input_fp8 = input + elif self.scaling_type_input is ScalingType.DELAYED: scale_fn_name = self.config.delayed_scaling_config.scale_fn_name _maybe_initialize_amaxes_scales_for_float8_cast( input,