From faa159360f334e40a78c6e8f05325f8184aedc02 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 14 Nov 2024 21:53:25 -0800 Subject: [PATCH] float8 delayed scaling: private API to fix user overriding buffers Summary: Context: https://github.com/pytorch/torchtitan/issues/654 If the user has delayed scaling and FSDP float8 all-gather on, there is a subtle bug that can happen if the user calls `model.to_empty(device="cuda")`: 1. to_empty recreates the buffers for tracking weight amax and scale 2. (1) leaves the buffers pointed to by Float8Linear.weight._amax_buffer, etc orphaned, because they don't participate in `to_empty` I couldn't think of an easy and clean way to auto-fix this since we can't expect `torch.nn.Module` to know that our logic has multiple references to the same buffer, so exposing a private API for now until we can think of something better. With the current fix, the user can then call `_maybe_fixup_delayed_scaling_buffers` manually to relink the buffers to the correct new versions. Test Plan: CI Reviewers: Subscribers: Tasks: Tags: --- test/float8/test_base.py | 23 +++++++++++++++++++++++ torchao/float8/float8_linear.py | 13 +++++++++++++ 2 files changed, 36 insertions(+) diff --git a/test/float8/test_base.py b/test/float8/test_base.py index 50ebccf60d..58a48709b3 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -530,6 +530,29 @@ def test_inference_mode(self): with torch.inference_mode(mode=True): y = m(x) + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + def test_to_empty_delayed_scaling_with_float8_all_gather(self): + with torch.device("meta"): + m_ref = nn.Sequential(nn.Linear(32, 32)) + config = Float8LinearConfig( + cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED), + cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), + cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED), + enable_fsdp_float8_all_gather=True, + ) + m_fp8 = convert_to_float8_training(m_ref, config=config) + + assert m_fp8[0].fp8_amax_weight is m_fp8[0].weight._amax_buffer + assert m_fp8[0].fp8_amax_history_weight is m_fp8[0].weight._amax_history_buffer + assert m_fp8[0].fp8_scale_weight is m_fp8[0].weight._scale_buffer + + m_fp8.to_empty(device="cuda") + m_fp8[0]._maybe_fixup_delayed_scaling_buffers() + + assert m_fp8[0].fp8_amax_weight is m_fp8[0].weight._amax_buffer + assert m_fp8[0].fp8_amax_history_weight is m_fp8[0].weight._amax_history_buffer + assert m_fp8[0].fp8_scale_weight is m_fp8[0].weight._scale_buffer + class TestScaledMM: @unittest.skipIf( diff --git a/torchao/float8/float8_linear.py b/torchao/float8/float8_linear.py index 4b16b16ba6..a12e0275d8 100644 --- a/torchao/float8/float8_linear.py +++ b/torchao/float8/float8_linear.py @@ -644,6 +644,19 @@ def extra_repr(self): s = f'{super().extra_repr()}, cast_configs={cast_config_str}"' return s + def _maybe_fixup_delayed_scaling_buffers(self): + if ( + self.config.enable_fsdp_float8_all_gather + and self.config.cast_config_weight.scaling_type is ScalingType.DELAYED + ): + # in case the module weight-related buffers got overwritten by + # the user (such as when calling `model.to_empty`), we + # re-link the weight wrapper buffers to point to the correct + # location + self.weight._amax_buffer = self.fp8_amax_weight + self.weight._amax_history_buffer = self.fp8_amax_history_weight + self.weight._scale_buffer = self.fp8_scale_weight + @classmethod def from_float( cls,