diff --git a/ruff.toml b/ruff.toml index 773497eb5c..40a0680ae4 100644 --- a/ruff.toml +++ b/ruff.toml @@ -13,4 +13,7 @@ include = [ "test/dtypes/test_affine_quantized_float.py", "torchao/quantization/weight_tensor_linear_activation_quantization.py", "torchao/dtypes/**/*.py", + "torchao/prototype/low_bit_optim/**.py", + "test/prototype/low_bit_optim/**.py", + ] diff --git a/test/prototype/test_low_bit_optim.py b/test/prototype/test_low_bit_optim.py index 39f97896bf..f0b608b47d 100644 --- a/test/prototype/test_low_bit_optim.py +++ b/test/prototype/test_low_bit_optim.py @@ -19,7 +19,11 @@ quantize_4bit_with_qmap, _fp32_to_bf16_sr, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_6 +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_3, + TORCH_VERSION_AT_LEAST_2_4, + TORCH_VERSION_AT_LEAST_2_6, +) try: import bitsandbytes as bnb @@ -85,7 +89,9 @@ def test_bf16_stochastic_round(self, device, compile): x_rep = x.view(-1, 1).repeat(1, 100_000) if compile: - x_rep_bf16 = torch.compile(_fp32_to_bf16_sr, fullgraph=True, dynamic=False)(x_rep) + x_rep_bf16 = torch.compile(_fp32_to_bf16_sr, fullgraph=True, dynamic=False)( + x_rep + ) else: x_rep_bf16 = _fp32_to_bf16_sr(x_rep) @@ -96,8 +102,13 @@ def test_bf16_stochastic_round(self, device, compile): class TestOptim(TestCase): - @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3") - @parametrize("optim_name", ["Adam8bit", "AdamW8bit", "Adam4bit", "AdamW4bit", "AdamFp8", "AdamWFp8"]) + @pytest.mark.skipif( + not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3" + ) + @parametrize( + "optim_name", + ["Adam8bit", "AdamW8bit", "Adam4bit", "AdamW4bit", "AdamFp8", "AdamWFp8"], + ) @parametrize("dtype", [torch.float32, torch.bfloat16]) @parametrize("device", _DEVICES) def test_optim_smoke(self, optim_name, dtype, device): @@ -141,19 +152,28 @@ def test_optim_smoke(self, optim_name, dtype, device): torch.testing.assert_close(p2, p1) @pytest.mark.skipif(bnb is None, reason="bitsandbytes is not available") - @pytest.mark.skipif(not torch.cuda.is_available(), reason="bitsandbytes 8-bit Adam only works for CUDA") - @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3") + @pytest.mark.skipif( + not torch.cuda.is_available(), + reason="bitsandbytes 8-bit Adam only works for CUDA", + ) + @pytest.mark.skipif( + not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3" + ) @parametrize("optim_name", ["Adam8bit", "AdamW8bit"]) def test_optim_8bit_correctness(self, optim_name): device = "cuda" - model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device) + model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to( + device + ) model2 = copy.deepcopy(model1) # https://github.com/bitsandbytes-foundation/bitsandbytes/releases/tag/v0.44.0 block_size = 256 if Version(bnb.__version__) >= Version("0.44.0") else 2048 optim1 = getattr(bnb.optim, optim_name)(model1.parameters()) - optim2 = getattr(low_bit_optim, optim_name)(model2.parameters(), block_size=block_size) + optim2 = getattr(low_bit_optim, optim_name)( + model2.parameters(), block_size=block_size + ) for _ in range(2): x = torch.randn(4, 32, device=device) @@ -173,12 +193,18 @@ def test_optim_8bit_correctness(self, optim_name): # this will not run in CI because we can't install lpmm @pytest.mark.skipif(lpmm is None, reason="lpmm is not available") - @pytest.mark.skipif(not torch.cuda.is_available(), reason="lpmm 4-bit Adam only works for CUDA") - @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3") + @pytest.mark.skipif( + not torch.cuda.is_available(), reason="lpmm 4-bit Adam only works for CUDA" + ) + @pytest.mark.skipif( + not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3" + ) @parametrize("optim_name", ["Adam4bit", "AdamW4bit"]) def test_optim_4bit_correctness(self, optim_name): device = "cuda" - model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device) + model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to( + device + ) model2 = copy.deepcopy(model1) # lpmm doesn't have Adam. use AdamW with no weight decay instead. @@ -206,17 +232,25 @@ def test_optim_4bit_correctness(self, optim_name): for p1, p2 in zip(model1.parameters(), model2.parameters()): torch.testing.assert_close(p2, p1, rtol=1e-5, atol=1e-5) - @pytest.mark.skipif(not torch.cuda.is_available(), reason="optim CPU offload requires CUDA") + @pytest.mark.skipif( + not torch.cuda.is_available(), reason="optim CPU offload requires CUDA" + ) @parametrize("offload_grad,grad_accum", [(False, 1), (False, 2), (True, 1)]) def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum): device = "cuda" - model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device) - model1[0].requires_grad_(False) # make sure it can work in the presence of non-trainable params + model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to( + device + ) + model1[0].requires_grad_( + False + ) # make sure it can work in the presence of non-trainable params model2 = copy.deepcopy(model1) optim1 = torch.optim.AdamW(model1.parameters()) optim2 = low_bit_optim.CPUOffloadOptimizer( - model2.parameters(), torch.optim.AdamW, offload_gradients=offload_grad, + model2.parameters(), + torch.optim.AdamW, + offload_gradients=offload_grad, ) for _ in range(2): @@ -234,11 +268,17 @@ def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum): for p1, p2 in zip(model1.parameters(), model2.parameters()): torch.testing.assert_close(p2, p1) - @pytest.mark.skipif(not torch.cuda.is_available(), reason="optim CPU offload requires CUDA") + @pytest.mark.skipif( + not torch.cuda.is_available(), reason="optim CPU offload requires CUDA" + ) def test_optim_cpu_offload_save_load(self): device = "cuda" - model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device) - optim1 = low_bit_optim.CPUOffloadOptimizer(model1.parameters(), torch.optim.AdamW) + model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to( + device + ) + optim1 = low_bit_optim.CPUOffloadOptimizer( + model1.parameters(), torch.optim.AdamW + ) for _ in range(2): x = torch.randn(4, 32, device=device) @@ -253,7 +293,9 @@ def test_optim_cpu_offload_save_load(self): # resume training model2 = copy.deepcopy(model1) - optim2 = low_bit_optim.CPUOffloadOptimizer(model2.parameters(), torch.optim.AdamW) + optim2 = low_bit_optim.CPUOffloadOptimizer( + model2.parameters(), torch.optim.AdamW + ) optim2.load_state_dict(state_dict) for _ in range(2): @@ -273,13 +315,17 @@ def test_optim_cpu_offload_save_load(self): def test_optim_bf16_stochastic_round_correctness(self): device = "cuda" if torch.cuda.is_available() else "cpu" torch.manual_seed(2024) - model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device) + model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to( + device + ) model2 = copy.deepcopy(model1).bfloat16() # small LR so that weight update is small # when bf16_stochastic_round=False, the test will fail after 1 iteration optim1 = torch.optim.AdamW(model1.parameters(), lr=1e-5) - optim2 = low_bit_optim._AdamW(model2.parameters(), lr=1e-5, bf16_stochastic_round=True) + optim2 = low_bit_optim._AdamW( + model2.parameters(), lr=1e-5, bf16_stochastic_round=True + ) # overfit on this sample x = torch.randn(4, 32, device=device) @@ -299,7 +345,9 @@ def test_optim_bf16_stochastic_round_correctness(self): optim2.step() optim2.zero_grad() - torch.testing.assert_close(loss1, loss2, msg=lambda msg: f"Iteration {idx}. {msg}") + torch.testing.assert_close( + loss1, loss2, msg=lambda msg: f"Iteration {idx}. {msg}" + ) class TestFSDP2(FSDPTest): @@ -307,7 +355,9 @@ class TestFSDP2(FSDPTest): def world_size(self) -> int: return 2 - @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_6, reason="PyTorch>=2.6 is required.") + @pytest.mark.skipif( + not TORCH_VERSION_AT_LEAST_2_6, reason="PyTorch>=2.6 is required." + ) @skip_if_lt_x_gpu(2) def test_fsdp2(self): optim_classes = [low_bit_optim.AdamW8bit, low_bit_optim.AdamW4bit] @@ -363,7 +413,9 @@ def _test_fsdp2(self, optim_cls): base_loss.backward() for param in base_model.parameters(): if param.grad is not None: - torch.distributed.all_reduce(param.grad, op=torch.distributed.ReduceOp.AVG) + torch.distributed.all_reduce( + param.grad, op=torch.distributed.ReduceOp.AVG + ) base_optim.step() self.assertEqual(fsdp_loss, base_loss) diff --git a/torchao/prototype/low_bit_optim/__init__.py b/torchao/prototype/low_bit_optim/__init__.py index 4ad75d4abf..9a22507b4c 100644 --- a/torchao/prototype/low_bit_optim/__init__.py +++ b/torchao/prototype/low_bit_optim/__init__.py @@ -1,2 +1,13 @@ from .adam import Adam4bit, Adam8bit, AdamFp8, AdamW4bit, AdamW8bit, AdamWFp8, _AdamW from .cpu_offload import CPUOffloadOptimizer + +__all__ = [ + "Adam4bit", + "Adam8bit", + "AdamFp8", + "AdamW4bit", + "AdamW8bit", + "AdamWFp8", + "_AdamW", + "CPUOffloadOptimizer", +] diff --git a/torchao/prototype/low_bit_optim/adam.py b/torchao/prototype/low_bit_optim/adam.py index 19e0640334..1c3718972b 100644 --- a/torchao/prototype/low_bit_optim/adam.py +++ b/torchao/prototype/low_bit_optim/adam.py @@ -2,18 +2,28 @@ import torch from torch import Tensor -from torch.optim import Optimizer from torch.distributed._tensor import DTensor +from torch.optim import Optimizer -from .subclass_8bit import OptimState8bit +from .quant_utils import _fp32_to_bf16_sr from .subclass_4bit import OptimState4bit +from .subclass_8bit import OptimState8bit from .subclass_fp8 import OptimStateFp8 -from .quant_utils import _fp32_to_bf16_sr class _AdamBase(Optimizer): def __init__( - self, params, lr, betas, eps, weight_decay, amsgrad, *, block_size, bf16_stochastic_round, is_adamw + self, + params, + lr, + betas, + eps, + weight_decay, + amsgrad, + *, + block_size, + bf16_stochastic_round, + is_adamw, ) -> None: if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) @@ -23,7 +33,13 @@ def __init__( raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) if not 0.0 <= betas[1] < 1.0: raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) - defaults = dict(lr=torch.tensor(lr), betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad) + defaults = dict( + lr=torch.tensor(lr), + betas=betas, + eps=eps, + weight_decay=weight_decay, + amsgrad=amsgrad, + ) super().__init__(params, defaults) self.block_size = block_size self.bf16_stochastic_round = bf16_stochastic_round @@ -45,7 +61,9 @@ def _new_buffer(self, p: Tensor, signed: bool): if p.numel() >= 4096 and p.numel() % self.block_size == 0: if isinstance(p, DTensor): out = DTensor.from_local( - local_tensor=self._subclass_zeros(p.to_local(), signed, self.block_size), + local_tensor=self._subclass_zeros( + p.to_local(), signed, self.block_size + ), device_mesh=p.device_mesh, placements=p.placements, run_check=False, diff --git a/torchao/prototype/low_bit_optim/cpu_offload.py b/torchao/prototype/low_bit_optim/cpu_offload.py index 6a8671082c..c69932aa4c 100644 --- a/torchao/prototype/low_bit_optim/cpu_offload.py +++ b/torchao/prototype/low_bit_optim/cpu_offload.py @@ -25,7 +25,11 @@ def __init__( kwargs: other keyword arguments to be passed to the base optimizer e.g. `lr`, `weight_decay`. """ # default to fused CPU AdamW - if optimizer_class is torch.optim.AdamW and TORCH_VERSION_AT_LEAST_2_4 and "fused" not in kwargs: + if ( + optimizer_class is torch.optim.AdamW + and TORCH_VERSION_AT_LEAST_2_4 + and "fused" not in kwargs + ): kwargs.update(fused=True) param_groups = list(params) @@ -77,7 +81,9 @@ def backward_hook(p_cuda): self.param_cuda2cpu_map[p_cuda] = p_cpu p_cuda.register_post_accumulate_grad_hook(backward_hook) - self.optim_dict[p_cuda] = optimizer_class([{"params": p_cpu, **param_group}], **kwargs) + self.optim_dict[p_cuda] = optimizer_class( + [{"params": p_cpu, **param_group}], **kwargs + ) @torch.no_grad() def step(self, closure=None): diff --git a/torchao/prototype/low_bit_optim/quant_utils.py b/torchao/prototype/low_bit_optim/quant_utils.py index 556a2f290c..628c8a742e 100644 --- a/torchao/prototype/low_bit_optim/quant_utils.py +++ b/torchao/prototype/low_bit_optim/quant_utils.py @@ -122,14 +122,17 @@ def _fp32_to_bf16_sr(x_f32: Tensor) -> Tensor: # [a15, ..., a0] / 2^16, where the bit pattern [a15, ..., a0] is interpreted as uint16 # # we have to use int32 since most arithmetic ops are not implemented for uint32/int16/uint16 - rand_16bit = torch.randint(0, 1 << 16, x_f32.shape, device=x_f32.device, dtype=torch.int32) + rand_16bit = torch.randint( + 0, 1 << 16, x_f32.shape, device=x_f32.device, dtype=torch.int32 + ) x_f32_bits = x_f32.view(torch.int32) - x_fraction = x_f32_bits & 0xFFFF # lower 16 bits - x_bf16_towards_zero = x_f32_bits & 0xFFFF0000 # upper 16 bits + x_fraction = x_f32_bits & 0xFFFF # lower 16 bits + x_bf16_towards_zero = x_f32_bits & 0xFFFF0000 # upper 16 bits x_f32_bits = torch.where( - rand_16bit < x_fraction, # this is True with the probability of p_fraction - x_bf16_towards_zero + 0x10000, # this might overflow, which will result in UB due to signed integer + rand_16bit < x_fraction, # this is True with the probability of p_fraction + x_bf16_towards_zero + + 0x10000, # this might overflow, which will result in UB due to signed integer x_bf16_towards_zero, ) # alternative, slightly faster diff --git a/torchao/prototype/low_bit_optim/subclass_4bit.py b/torchao/prototype/low_bit_optim/subclass_4bit.py index 759d816a6e..e493b978fe 100644 --- a/torchao/prototype/low_bit_optim/subclass_4bit.py +++ b/torchao/prototype/low_bit_optim/subclass_4bit.py @@ -3,10 +3,19 @@ import torch from torch import Tensor from torch.utils._python_dispatch import return_and_correct_aliasing -from torchao.utils import TorchAOBaseTensor, TORCH_VERSION_AT_LEAST_2_4 -from .quant_utils import create_dynamic_map, scale_tensor, quantize_4bit_with_qmap, dequant_with_qmap +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_4, + TORCH_VERSION_AT_LEAST_2_5, + TorchAOBaseTensor, +) +from .quant_utils import ( + create_dynamic_map, + dequant_with_qmap, + quantize_4bit_with_qmap, + scale_tensor, +) aten = torch.ops.aten c10d_functional = torch.ops.c10d_functional @@ -55,8 +64,12 @@ def __tensor_flatten__(self): return self.tensor_attrs, [self.signed, self._shape] @classmethod - def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None): - return cls(*[tensor_data_dict[name] for name in cls.tensor_attrs], *tensor_attributes) + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None + ): + return cls( + *[tensor_data_dict[name] for name in cls.tensor_attrs], *tensor_attributes + ) def dequantize(self, output_dtype=None): codes = torch.stack([self.codes >> 4, self.codes & 0b1111], dim=-1) # unpack @@ -85,6 +98,7 @@ def __repr__(self): # in pre-2.4, calling .to(device, dtype) will not dispatch aten._to_copy.default when # dtype is the same but device is different. thus, we must override .to() method instead. if not TORCH_VERSION_AT_LEAST_2_4: + def _to(self, *args, **kwargs): # ignore other args/kwargs device = kwargs.pop("device", None) @@ -158,16 +172,20 @@ def _(func, types, args, kwargs): if len(shape) == 1 and shape[0] == -1: return OptimState4bit(x.codes, x.scale, x.qmap, x.signed, (x.numel(),)) - raise ValueError(f"{x.__class__.__name__} only supports .view() with same shape or shape=[-1]") + raise ValueError( + f"{x.__class__.__name__} only supports .view() with same shape or shape=[-1]" + ) # this is needed for DTensor.full_tensor() -@OptimState4bit.implements([ - c10d_functional.all_gather_into_tensor.default, - _c10d_functional.all_gather_into_tensor.default, - c10d_functional.wait_tensor.default, - _c10d_functional.wait_tensor.default, -]) +@OptimState4bit.implements( + [ + c10d_functional.all_gather_into_tensor.default, + _c10d_functional.all_gather_into_tensor.default, + c10d_functional.wait_tensor.default, + _c10d_functional.wait_tensor.default, + ] +) def _(func, types, args, kwargs): x = args[0] if not isinstance(x, OptimState4bit): @@ -181,3 +199,9 @@ def _(func, types, args, kwargs): # assume tensors from all ranks have the same signedness return OptimState4bit(codes, scale, x.qmap.clone(), x.signed, shape) + + +if TORCH_VERSION_AT_LEAST_2_5: + from torch.serialization import add_safe_globals + + add_safe_globals([OptimState4bit]) diff --git a/torchao/prototype/low_bit_optim/subclass_8bit.py b/torchao/prototype/low_bit_optim/subclass_8bit.py index f5374a3480..d23d159645 100644 --- a/torchao/prototype/low_bit_optim/subclass_8bit.py +++ b/torchao/prototype/low_bit_optim/subclass_8bit.py @@ -1,10 +1,19 @@ import torch from torch import Tensor from torch.utils._python_dispatch import return_and_correct_aliasing -from torchao.utils import TorchAOBaseTensor, TORCH_VERSION_AT_LEAST_2_4 -from .quant_utils import create_dynamic_map, scale_tensor, quantize_8bit_with_qmap, dequant_with_qmap +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_4, + TORCH_VERSION_AT_LEAST_2_5, + TorchAOBaseTensor, +) +from .quant_utils import ( + create_dynamic_map, + dequant_with_qmap, + quantize_8bit_with_qmap, + scale_tensor, +) aten = torch.ops.aten c10d_functional = torch.ops.c10d_functional @@ -46,8 +55,12 @@ def __tensor_flatten__(self): return self.tensor_attrs, [self.signed] @classmethod - def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None): - return cls(*[tensor_data_dict[name] for name in cls.tensor_attrs], *tensor_attributes) + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None + ): + return cls( + *[tensor_data_dict[name] for name in cls.tensor_attrs], *tensor_attributes + ) def dequantize(self, output_dtype=None): float_data = dequant_with_qmap(self.codes, self.qmap, self.scale) @@ -72,6 +85,7 @@ def __repr__(self): # in pre-2.4, calling .to(device, dtype) will not dispatch aten._to_copy.default when # dtype is the same but device is different. thus, we must override .to() method instead. if not TORCH_VERSION_AT_LEAST_2_4: + def _to(self, *args, **kwargs): # ignore other args/kwargs device = kwargs.pop("device", None) @@ -136,12 +150,14 @@ def _(func, types, args, kwargs): # this is needed for DTensor.full_tensor() -@OptimState8bit.implements([ - c10d_functional.all_gather_into_tensor.default, - _c10d_functional.all_gather_into_tensor.default, - c10d_functional.wait_tensor.default, - _c10d_functional.wait_tensor.default, -]) +@OptimState8bit.implements( + [ + c10d_functional.all_gather_into_tensor.default, + _c10d_functional.all_gather_into_tensor.default, + c10d_functional.wait_tensor.default, + _c10d_functional.wait_tensor.default, + ] +) def _(func, types, args, kwargs): x = args[0] if not isinstance(x, OptimState8bit): @@ -154,3 +170,9 @@ def _(func, types, args, kwargs): x.qmap.clone(), x.signed, ) + + +if TORCH_VERSION_AT_LEAST_2_5: + from torch.serialization import add_safe_globals + + add_safe_globals([OptimState8bit]) diff --git a/torchao/prototype/low_bit_optim/subclass_fp8.py b/torchao/prototype/low_bit_optim/subclass_fp8.py index eabe8b5051..d95b0c2661 100644 --- a/torchao/prototype/low_bit_optim/subclass_fp8.py +++ b/torchao/prototype/low_bit_optim/subclass_fp8.py @@ -1,8 +1,8 @@ import torch from torch import Tensor from torch.utils._python_dispatch import return_and_correct_aliasing -from torchao.utils import TorchAOBaseTensor +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TorchAOBaseTensor aten = torch.ops.aten c10d_functional = torch.ops.c10d_functional @@ -51,8 +51,12 @@ def __tensor_flatten__(self): return self.tensor_attrs, [] @classmethod - def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None): - return cls(*[tensor_data_dict[name] for name in cls.tensor_attrs], *tensor_attributes) + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None + ): + return cls( + *[tensor_data_dict[name] for name in cls.tensor_attrs], *tensor_attributes + ) def dequantize(self, output_dtype=None): float_data = self.codes.float() @@ -121,12 +125,14 @@ def _(func, types, args, kwargs): # this is needed for DTensor.full_tensor() -@OptimStateFp8.implements([ - c10d_functional.all_gather_into_tensor.default, - _c10d_functional.all_gather_into_tensor.default, - c10d_functional.wait_tensor.default, - _c10d_functional.wait_tensor.default, -]) +@OptimStateFp8.implements( + [ + c10d_functional.all_gather_into_tensor.default, + _c10d_functional.all_gather_into_tensor.default, + c10d_functional.wait_tensor.default, + _c10d_functional.wait_tensor.default, + ] +) def _(func, types, args, kwargs): x = args[0] if not isinstance(x, OptimStateFp8): @@ -137,3 +143,9 @@ def _(func, types, args, kwargs): func(x.codes, *args[1:], **kwargs), func(x.scale, *args[1:], **kwargs), ) + + +if TORCH_VERSION_AT_LEAST_2_5: + from torch.serialization import add_safe_globals + + add_safe_globals([OptimStateFp8])