diff --git a/test/float8/test_base.py b/test/float8/test_base.py index 50ebccf60d..58a48709b3 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -530,6 +530,29 @@ def test_inference_mode(self): with torch.inference_mode(mode=True): y = m(x) + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + def test_to_empty_delayed_scaling_with_float8_all_gather(self): + with torch.device("meta"): + m_ref = nn.Sequential(nn.Linear(32, 32)) + config = Float8LinearConfig( + cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED), + cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), + cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED), + enable_fsdp_float8_all_gather=True, + ) + m_fp8 = convert_to_float8_training(m_ref, config=config) + + assert m_fp8[0].fp8_amax_weight is m_fp8[0].weight._amax_buffer + assert m_fp8[0].fp8_amax_history_weight is m_fp8[0].weight._amax_history_buffer + assert m_fp8[0].fp8_scale_weight is m_fp8[0].weight._scale_buffer + + m_fp8.to_empty(device="cuda") + m_fp8[0]._maybe_fixup_delayed_scaling_buffers() + + assert m_fp8[0].fp8_amax_weight is m_fp8[0].weight._amax_buffer + assert m_fp8[0].fp8_amax_history_weight is m_fp8[0].weight._amax_history_buffer + assert m_fp8[0].fp8_scale_weight is m_fp8[0].weight._scale_buffer + class TestScaledMM: @unittest.skipIf( diff --git a/torchao/float8/float8_linear.py b/torchao/float8/float8_linear.py index 4b16b16ba6..a12e0275d8 100644 --- a/torchao/float8/float8_linear.py +++ b/torchao/float8/float8_linear.py @@ -644,6 +644,19 @@ def extra_repr(self): s = f'{super().extra_repr()}, cast_configs={cast_config_str}"' return s + def _maybe_fixup_delayed_scaling_buffers(self): + if ( + self.config.enable_fsdp_float8_all_gather + and self.config.cast_config_weight.scaling_type is ScalingType.DELAYED + ): + # in case the module weight-related buffers got overwritten by + # the user (such as when calling `model.to_empty`), we + # re-link the weight wrapper buffers to point to the correct + # location + self.weight._amax_buffer = self.fp8_amax_weight + self.weight._amax_history_buffer = self.fp8_amax_history_weight + self.weight._scale_buffer = self.fp8_scale_weight + @classmethod def from_float( cls,