Skip to content

Commit

Permalink
[cleanup][4/x] unify weight casting
Browse files Browse the repository at this point in the history
Summary:

Not ready for review yet, performance regression because tensorwise
abs+max and weight casting is happening twice between fwd and bwd.
Limitation of something in PT2 stack?

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 95d7c478dcff2d6b1203dae2855a2894d8b1e3d0
ghstack-comment-id: 2568319095
Pull Request resolved: #1481
  • Loading branch information
vkuzo committed Jan 8, 2025
1 parent 97f5131 commit 02f163b
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 14 deletions.
1 change: 1 addition & 0 deletions benchmarks/float8/profile_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,7 @@ def main(
).requires_grad_()
else:
M, K, N = 4096, 4096, 4096
M, K, N = 2048, 4096, 8192
m_ref = torch.nn.Sequential(
torch.nn.Linear(K, N, bias=False),
)
Expand Down
53 changes: 39 additions & 14 deletions torchao/float8/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,22 +40,22 @@ def _get_weight_scale(
return tensor_to_scale(weight, config.cast_config_weight.target_dtype)


def _cast_weight_to_float8_t(
def _cast_weight_to_float8(
weight: torch.Tensor,
config: Float8LinearConfig,
linear_mm_config: LinearMMConfig,
weight_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if tensor_already_casted_to_fp8(weight):
return weight.t()
return weight
weight_fp8 = hp_tensor_and_scale_to_float8(
weight,
weight_scale,
config.cast_config_weight.target_dtype,
linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
)
return weight_fp8.t()
return weight_fp8


@torch._dynamo.allow_in_graph
Expand Down Expand Up @@ -104,16 +104,41 @@ def forward(
elif c.cast_config_weight.scaling_type is ScalingType.DISABLED:
weight_maybe_fp8_t = weight_hp_t
else:
weight_maybe_fp8_t = hp_tensor_to_float8_dynamic(
weight_hp_t,
c.cast_config_weight.target_dtype,
linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
scaling_granularity=c.cast_config_weight.scaling_granularity,
axiswise_dim=get_maybe_axiswise_dim(
0, c.cast_config_weight.scaling_granularity
),
)

# non-axiswise
if config.cast_config_weight.scaling_granularity is ScalingGranularity.TENSORWISE:
# If force_recompute_fp8_weight_in_bwd, we only recompute the fp8 weight,
# weight_scale should be saved.
weight_scale = _get_weight_scale(
weight_hp_t, config.cast_config_weight.scaling_type, config
)

if config.force_recompute_fp8_weight_in_bwd:
weight_maybe_fp8_t = checkpoint.checkpoint(
_cast_weight_to_float8,
weight_hp_t,
config,
linear_mm_config,
weight_scale,
)
else:
weight_maybe_fp8_t = _cast_weight_to_float8(
weight_hp_t,
config,
linear_mm_config,
weight_scale,
)
else:
weight_maybe_fp8_t = hp_tensor_to_float8_dynamic(
weight_hp_t,
c.cast_config_weight.target_dtype,
linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
scaling_granularity=c.cast_config_weight.scaling_granularity,
axiswise_dim=get_maybe_axiswise_dim(
0, c.cast_config_weight.scaling_granularity
),
)

# the reshapes are needed in order to make the shapes compatible with
# torch.mm
Expand Down Expand Up @@ -311,7 +336,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:

# TODO(future PR): check for axiswise scaling for input, weight,
# grad_output separately instead of together
if not has_any_axiswise_scaling:
if not has_any_axiswise_scaling and False:
# If force_recompute_fp8_weight_in_bwd, we only recompute the fp8 weight,
# weight_scale should be saved.
weight_scale = _get_weight_scale(
Expand Down

0 comments on commit 02f163b

Please sign in to comment.