Skip to content

Commit

Permalink
[cleanup][3/x] unify dynamic input and grad_output casting
Browse files Browse the repository at this point in the history
Summary:

As titled, removes redundant logic for (input|grad_output) + dynamic scaling

Test Plan:

```
./test/float8/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 74536a831b0f7425f342b44e6be6f2136a235db0
ghstack-comment-id: 2568319054
Pull Request resolved: #1480
  • Loading branch information
vkuzo committed Jan 8, 2025
1 parent 7c71469 commit f99bd4b
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 62 deletions.
2 changes: 1 addition & 1 deletion benchmarks/float8/profile_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def main(
1, 2048, 4096, device=device, dtype=ref_dtype
).requires_grad_()
else:
M, K, N = 4096, 4096, 4096
M, K, N = 2048, 4096, 8192
m_ref = torch.nn.Sequential(
torch.nn.Linear(K, N, bias=False),
)
Expand Down
70 changes: 9 additions & 61 deletions torchao/float8/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from torchao.float8.config import Float8LinearConfig, ScalingGranularity, ScalingType
from torchao.float8.distributed_utils import tensor_already_casted_to_fp8
from torchao.float8.float8_scaling_utils import (
NoopFwToFloat8BwDynamic,
get_maybe_axiswise_dim,
hp_tensor_to_float8_dynamic,
)
Expand All @@ -29,33 +28,6 @@
from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor


def _cast_input_to_float8(
input: torch.Tensor,
scaling_type_input: ScalingType,
config: Float8LinearConfig,
linear_mm_config: LinearMMConfig,
) -> torch.Tensor:
# Duplicate the autocast logic for F.linear, so that the output
# of our module has the right original precision
if torch.is_autocast_enabled():
# For now, hardcode to GPU's autocast dtype
# if we need CPU support in the future, we can add it
autocast_dtype = torch.get_autocast_gpu_dtype()
input = input.to(autocast_dtype)

if tensor_already_casted_to_fp8(input):
input_fp8 = input
else:
assert scaling_type_input is ScalingType.DYNAMIC
input_fp8 = hp_tensor_to_float8_dynamic(
input,
config.cast_config_input.target_dtype,
linear_mm_config,
gemm_input_role=GemmInputRole.INPUT,
)
return input_fp8


def _get_weight_scale(
weight: torch.Tensor,
scaling_type_weight: ScalingType,
Expand Down Expand Up @@ -85,21 +57,6 @@ def _cast_weight_to_float8_t(
return weight_fp8.t()


def _cast_output_to_float8_in_bw(
output: torch.Tensor,
scaling_type_grad_output,
linear_mm_config: LinearMMConfig,
config: Float8LinearConfig,
) -> torch.Tensor:
assert scaling_type_grad_output is ScalingType.DYNAMIC
output = NoopFwToFloat8BwDynamic.apply(
output,
linear_mm_config,
config.cast_config_grad_output.target_dtype,
)
return output


@torch._dynamo.allow_in_graph
class matmul_with_hp_or_float8_args(torch.autograd.Function):
"""
Expand Down Expand Up @@ -329,6 +286,14 @@ def __init__(self, *args, **kwargs):
)

def forward(self, input: torch.Tensor) -> torch.Tensor:
# Duplicate the autocast logic for F.linear, so that the output
# of our module has the right original precision
if torch.is_autocast_enabled():
# For now, hardcode to GPU's autocast dtype
# if we need CPU support in the future, we can add it
autocast_dtype = torch.get_autocast_gpu_dtype()
input = input.to(autocast_dtype)

has_any_axiswise_scaling = any(
cc.scaling_granularity is ScalingGranularity.AXISWISE
for cc in [
Expand All @@ -341,18 +306,11 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
]
)

input_maybe_fp8 = input
weight_maybe_fp8_t = self.weight.t()

# TODO(future PR): check for axiswise scaling for input, weight,
# grad_output separately instead of together
if not has_any_axiswise_scaling:
input_fp8 = _cast_input_to_float8(
input,
self.scaling_type_input,
self.config,
self.linear_mm_config,
)
# If force_recompute_fp8_weight_in_bwd, we only recompute the fp8 weight,
# weight_scale should be saved.
weight_scale = _get_weight_scale(
Expand All @@ -375,25 +333,15 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
weight_scale,
)

input_maybe_fp8 = input_fp8
weight_maybe_fp8_t = weight_fp8_t

output = matmul_with_hp_or_float8_args.apply(
input_maybe_fp8,
input,
weight_maybe_fp8_t,
self.linear_mm_config,
self.config,
)

if not has_any_axiswise_scaling:
# Cast grad_output to float8_e5m2 during backward
output = _cast_output_to_float8_in_bw(
output,
self.scaling_type_grad_output,
self.linear_mm_config,
self.config,
)

if self.bias is not None:
output = output + self.bias.to(output.dtype)
return output
Expand Down

0 comments on commit f99bd4b

Please sign in to comment.