Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cleanup][4/x] unify weight casting #1481

Open
wants to merge 6 commits into
base: gh/vkuzo/15/head
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 35 additions & 71 deletions torchao/float8/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,35 +28,6 @@
from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor


def _get_weight_scale(
weight: torch.Tensor,
scaling_type_weight: ScalingType,
config: Float8LinearConfig,
) -> Optional[torch.Tensor]:
if tensor_already_casted_to_fp8(weight):
return None
assert scaling_type_weight is ScalingType.DYNAMIC
return tensor_to_scale(weight, config.cast_config_weight.target_dtype)


def _cast_weight_to_float8_t(
weight: torch.Tensor,
config: Float8LinearConfig,
linear_mm_config: LinearMMConfig,
weight_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if tensor_already_casted_to_fp8(weight):
return weight.t()
weight_fp8 = hp_tensor_and_scale_to_float8(
weight,
weight_scale,
config.cast_config_weight.target_dtype,
linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
)
return weight_fp8.t()


@torch._dynamo.allow_in_graph
class matmul_with_hp_or_float8_args(torch.autograd.Function):
"""
Expand Down Expand Up @@ -102,6 +73,40 @@ def forward(
weight_maybe_fp8_t = weight_hp_t
elif c.cast_config_weight.scaling_type is ScalingType.DISABLED:
weight_maybe_fp8_t = weight_hp_t
elif (
config.cast_config_weight.scaling_granularity
is ScalingGranularity.TENSORWISE
):
# Special case tensorwise to allow the checkpointing of float8
# casted weight, to prevent blowing up peak memory usage in FSDP.

# inductor kernels for tensorwise max are faster when `weight` is
# contiguous.
# context: https://github.com/pytorch/pytorch/issues/144431
weight_hp_t_t = weight_hp_t.t()
assert weight_hp_t_t.is_contiguous()
weight_scale = tensor_to_scale(
weight_hp_t_t, config.cast_config_weight.target_dtype
)

if config.force_recompute_fp8_weight_in_bwd:
weight_maybe_fp8_t = checkpoint.checkpoint(
hp_tensor_and_scale_to_float8,
weight_hp_t,
weight_scale,
config.cast_config_weight.target_dtype,
linear_mm_config,
GemmInputRole.WEIGHT,
)
else:
weight_maybe_fp8_t = hp_tensor_and_scale_to_float8(
weight_hp_t,
weight_scale,
config.cast_config_weight.target_dtype,
linear_mm_config,
GemmInputRole.WEIGHT,
)

else:
weight_maybe_fp8_t = hp_tensor_to_float8_dynamic(
weight_hp_t,
Expand Down Expand Up @@ -294,50 +299,9 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
autocast_dtype = torch.get_autocast_gpu_dtype()
input = input.to(autocast_dtype)

has_any_axiswise_scaling = any(
cc.scaling_granularity is ScalingGranularity.AXISWISE
for cc in [
self.config.cast_config_input,
self.config.cast_config_weight,
self.config.cast_config_grad_output,
self.config.cast_config_input_for_grad_weight,
self.config.cast_config_weight_for_grad_input,
self.config.cast_config_grad_output_for_grad_weight,
]
)

weight_maybe_fp8_t = self.weight.t()

# TODO(future PR): check for axiswise scaling for input, weight,
# grad_output separately instead of together
if not has_any_axiswise_scaling:
# If force_recompute_fp8_weight_in_bwd, we only recompute the fp8 weight,
# weight_scale should be saved.
weight_scale = _get_weight_scale(
self.weight, self.scaling_type_weight, self.config
)

if self.config.force_recompute_fp8_weight_in_bwd:
weight_fp8_t = checkpoint.checkpoint(
_cast_weight_to_float8_t,
self.weight,
self.config,
self.linear_mm_config,
weight_scale,
)
else:
weight_fp8_t = _cast_weight_to_float8_t(
self.weight,
self.config,
self.linear_mm_config,
weight_scale,
)

weight_maybe_fp8_t = weight_fp8_t

output = matmul_with_hp_or_float8_args.apply(
input,
weight_maybe_fp8_t,
self.weight.t(),
self.linear_mm_config,
self.config,
)
Expand Down
Loading