Skip to content

Commit

Permalink
[cleanup][1/x] make hp_tensor_to_float8_dynamic only work with hp inputs
Browse files Browse the repository at this point in the history
Summary:

`hp_tensor_to_float8_dynamic` should only work with high precision
inputs, logic which checks for the input being already in float8 up
to the callsites to make it more explicit and easier to follow.

Test Plan: CI

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: c7801ee07e4da620a0a2d93454710b4e8e120ee1
ghstack-comment-id: 2560319845
Pull Request resolved: #1458
  • Loading branch information
vkuzo committed Jan 8, 2025
1 parent 457c5b1 commit 54144ed
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 22 deletions.
17 changes: 10 additions & 7 deletions torchao/float8/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,13 +312,16 @@ def cast_input_to_float8(self, input: torch.Tensor) -> torch.Tensor:
autocast_dtype = torch.get_autocast_gpu_dtype()
input = input.to(autocast_dtype)

assert self.scaling_type_input is ScalingType.DYNAMIC
input_fp8 = hp_tensor_to_float8_dynamic(
input,
self.config.cast_config_input.target_dtype,
self.linear_mm_config,
gemm_input_role=GemmInputRole.INPUT,
)
if tensor_already_casted_to_fp8(input):
input_fp8 = input
else:
assert self.scaling_type_input is ScalingType.DYNAMIC
input_fp8 = hp_tensor_to_float8_dynamic(
input,
self.config.cast_config_input.target_dtype,
self.linear_mm_config,
gemm_input_role=GemmInputRole.INPUT,
)
return input_fp8

def get_weight_scale(self, weight: torch.Tensor) -> Optional[torch.Tensor]:
Expand Down
2 changes: 0 additions & 2 deletions torchao/float8/float8_scaling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@ def hp_tensor_to_float8_dynamic(
scaling_granularity: Defines the scaling granularity
axiswise_dim: if axiswise granularity is used, defines the dim to scale across
"""
if tensor_already_casted_to_fp8(hp_tensor):
return hp_tensor
scale = tensor_to_scale(
hp_tensor,
float8_dtype,
Expand Down
27 changes: 15 additions & 12 deletions torchao/float8/float8_tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
)

from torchao.float8.config import ScalingType, e4m3_dtype
from torchao.float8.distributed_utils import tensor_already_casted_to_fp8
from torchao.float8.float8_scaling_utils import (
NoopFwToFloat8BwDynamic,
hp_tensor_to_float8_dynamic,
Expand Down Expand Up @@ -46,12 +47,13 @@ def _prepare_input_fn(
input_tensor, device_mesh, input_layouts, run_check=False
)

input_tensor = hp_tensor_to_float8_dynamic(
input_tensor,
mod.config.cast_config_input.target_dtype,
mod.linear_mm_config,
gemm_input_role=GemmInputRole.INPUT,
) # DTensor(Float8Tensor)
if not tensor_already_casted_to_fp8(input_tensor):
input_tensor = hp_tensor_to_float8_dynamic(
input_tensor,
mod.config.cast_config_input.target_dtype,
mod.linear_mm_config,
gemm_input_role=GemmInputRole.INPUT,
) # DTensor(Float8Tensor)

# transform the input layouts to the desired layouts of ColwiseParallel
if input_layouts != desired_input_layouts:
Expand Down Expand Up @@ -104,12 +106,13 @@ def _prepare_input_fn(
input_tensor, device_mesh, input_layouts, run_check=False
)

input_tensor = hp_tensor_to_float8_dynamic(
input_tensor,
mod.config.cast_config_input.target_dtype,
mod.linear_mm_config,
gemm_input_role=GemmInputRole.INPUT,
) # DTensor(Float8Tensor)
if not tensor_already_casted_to_fp8(input_tensor):
input_tensor = hp_tensor_to_float8_dynamic(
input_tensor,
mod.config.cast_config_input.target_dtype,
mod.linear_mm_config,
gemm_input_role=GemmInputRole.INPUT,
) # DTensor(Float8Tensor)

if input_layouts != desired_input_layouts:
input_tensor = input_tensor.redistribute(
Expand Down
4 changes: 3 additions & 1 deletion torchao/float8/stateful_float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,9 @@ def cast_input_to_float8(self, input: torch.Tensor) -> torch.Tensor:
autocast_dtype = torch.get_autocast_gpu_dtype()
input = input.to(autocast_dtype)

if self.scaling_type_input is ScalingType.DELAYED:
if tensor_already_casted_to_fp8(input):
input_fp8 = input
elif self.scaling_type_input is ScalingType.DELAYED:
scale_fn_name = self.config.delayed_scaling_config.scale_fn_name
_maybe_initialize_amaxes_scales_for_float8_cast(
input,
Expand Down

0 comments on commit 54144ed

Please sign in to comment.