diff --git a/benchmarks/float8/profile_linear_float8.py b/benchmarks/float8/profile_linear_float8.py index 38a8c5e87..e1acc8b0e 100644 --- a/benchmarks/float8/profile_linear_float8.py +++ b/benchmarks/float8/profile_linear_float8.py @@ -356,6 +356,7 @@ def main( ).requires_grad_() else: M, K, N = 4096, 4096, 4096 + M, K, N = 2048, 4096, 8192 m_ref = torch.nn.Sequential( torch.nn.Linear(K, N, bias=False), ) diff --git a/torchao/float8/float8_linear.py b/torchao/float8/float8_linear.py index 18aebaead..db9f9f190 100644 --- a/torchao/float8/float8_linear.py +++ b/torchao/float8/float8_linear.py @@ -39,14 +39,14 @@ def _get_weight_scale( return tensor_to_scale(weight, config.cast_config_weight.target_dtype) -def _cast_weight_to_float8_t( +def _cast_weight_to_float8( weight: torch.Tensor, config: Float8LinearConfig, linear_mm_config: LinearMMConfig, weight_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: if tensor_already_casted_to_fp8(weight): - return weight.t() + return weight weight_fp8 = hp_tensor_and_scale_to_float8( weight, weight_scale, @@ -54,7 +54,7 @@ def _cast_weight_to_float8_t( linear_mm_config, gemm_input_role=GemmInputRole.WEIGHT, ) - return weight_fp8.t() + return weight_fp8 @torch._dynamo.allow_in_graph @@ -103,16 +103,43 @@ def forward( elif c.cast_config_weight.scaling_type is ScalingType.DISABLED: weight_maybe_fp8_t = weight_hp_t else: - weight_maybe_fp8_t = hp_tensor_to_float8_dynamic( - weight_hp_t, - c.cast_config_weight.target_dtype, - linear_mm_config, - gemm_input_role=GemmInputRole.WEIGHT, - scaling_granularity=c.cast_config_weight.scaling_granularity, - axiswise_dim=get_maybe_axiswise_dim( - 0, c.cast_config_weight.scaling_granularity - ), - ) + # non-axiswise + if ( + config.cast_config_weight.scaling_granularity + is ScalingGranularity.TENSORWISE + ): + # If force_recompute_fp8_weight_in_bwd, we only recompute the fp8 weight, + # weight_scale should be saved. + weight_scale = _get_weight_scale( + weight_hp_t, config.cast_config_weight.scaling_type, config + ) + + if config.force_recompute_fp8_weight_in_bwd: + weight_maybe_fp8_t = checkpoint.checkpoint( + _cast_weight_to_float8, + weight_hp_t, + config, + linear_mm_config, + weight_scale, + ) + else: + weight_maybe_fp8_t = _cast_weight_to_float8( + weight_hp_t, + config, + linear_mm_config, + weight_scale, + ) + else: + weight_maybe_fp8_t = hp_tensor_to_float8_dynamic( + weight_hp_t, + c.cast_config_weight.target_dtype, + linear_mm_config, + gemm_input_role=GemmInputRole.WEIGHT, + scaling_granularity=c.cast_config_weight.scaling_granularity, + axiswise_dim=get_maybe_axiswise_dim( + 0, c.cast_config_weight.scaling_granularity + ), + ) # the reshapes are needed in order to make the shapes compatible with # torch.mm @@ -310,7 +337,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: # TODO(future PR): check for axiswise scaling for input, weight, # grad_output separately instead of together - if not has_any_axiswise_scaling: + if not has_any_axiswise_scaling and False: # If force_recompute_fp8_weight_in_bwd, we only recompute the fp8 weight, # weight_scale should be saved. weight_scale = _get_weight_scale(