Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cleanup][3/x] unify dynamic input and grad_output casting #1480

Open
wants to merge 4 commits into
base: gh/vkuzo/14/head
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/float8/profile_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def main(
1, 2048, 4096, device=device, dtype=ref_dtype
).requires_grad_()
else:
M, K, N = 4096, 4096, 4096
M, K, N = 2048, 4096, 8192
m_ref = torch.nn.Sequential(
torch.nn.Linear(K, N, bias=False),
)
Expand Down
70 changes: 9 additions & 61 deletions torchao/float8/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from torchao.float8.config import Float8LinearConfig, ScalingGranularity, ScalingType
from torchao.float8.distributed_utils import tensor_already_casted_to_fp8
from torchao.float8.float8_scaling_utils import (
NoopFwToFloat8BwDynamic,
get_maybe_axiswise_dim,
hp_tensor_to_float8_dynamic,
)
Expand All @@ -29,33 +28,6 @@
from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor


def _cast_input_to_float8(
input: torch.Tensor,
scaling_type_input: ScalingType,
config: Float8LinearConfig,
linear_mm_config: LinearMMConfig,
) -> torch.Tensor:
# Duplicate the autocast logic for F.linear, so that the output
# of our module has the right original precision
if torch.is_autocast_enabled():
# For now, hardcode to GPU's autocast dtype
# if we need CPU support in the future, we can add it
autocast_dtype = torch.get_autocast_gpu_dtype()
input = input.to(autocast_dtype)

if tensor_already_casted_to_fp8(input):
input_fp8 = input
else:
assert scaling_type_input is ScalingType.DYNAMIC
input_fp8 = hp_tensor_to_float8_dynamic(
input,
config.cast_config_input.target_dtype,
linear_mm_config,
gemm_input_role=GemmInputRole.INPUT,
)
return input_fp8


def _get_weight_scale(
weight: torch.Tensor,
scaling_type_weight: ScalingType,
Expand Down Expand Up @@ -85,21 +57,6 @@ def _cast_weight_to_float8_t(
return weight_fp8.t()


def _cast_output_to_float8_in_bw(
output: torch.Tensor,
scaling_type_grad_output,
linear_mm_config: LinearMMConfig,
config: Float8LinearConfig,
) -> torch.Tensor:
assert scaling_type_grad_output is ScalingType.DYNAMIC
output = NoopFwToFloat8BwDynamic.apply(
output,
linear_mm_config,
config.cast_config_grad_output.target_dtype,
)
return output


@torch._dynamo.allow_in_graph
class matmul_with_hp_or_float8_args(torch.autograd.Function):
"""
Expand Down Expand Up @@ -329,6 +286,14 @@ def __init__(self, *args, **kwargs):
)

def forward(self, input: torch.Tensor) -> torch.Tensor:
# Duplicate the autocast logic for F.linear, so that the output
# of our module has the right original precision
if torch.is_autocast_enabled():
# For now, hardcode to GPU's autocast dtype
# if we need CPU support in the future, we can add it
autocast_dtype = torch.get_autocast_gpu_dtype()
input = input.to(autocast_dtype)

has_any_axiswise_scaling = any(
cc.scaling_granularity is ScalingGranularity.AXISWISE
for cc in [
Expand All @@ -341,18 +306,11 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
]
)

input_maybe_fp8 = input
weight_maybe_fp8_t = self.weight.t()

# TODO(future PR): check for axiswise scaling for input, weight,
# grad_output separately instead of together
if not has_any_axiswise_scaling:
input_fp8 = _cast_input_to_float8(
input,
self.scaling_type_input,
self.config,
self.linear_mm_config,
)
# If force_recompute_fp8_weight_in_bwd, we only recompute the fp8 weight,
# weight_scale should be saved.
weight_scale = _get_weight_scale(
Expand All @@ -375,25 +333,15 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
weight_scale,
)

input_maybe_fp8 = input_fp8
weight_maybe_fp8_t = weight_fp8_t

output = matmul_with_hp_or_float8_args.apply(
input_maybe_fp8,
input,
weight_maybe_fp8_t,
self.linear_mm_config,
self.config,
)

if not has_any_axiswise_scaling:
# Cast grad_output to float8_e5m2 during backward
output = _cast_output_to_float8_in_bw(
output,
self.scaling_type_grad_output,
self.linear_mm_config,
self.config,
)

if self.bias is not None:
output = output + self.bias.to(output.dtype)
return output
Expand Down
Loading