diff --git a/test/prototype/test_low_bit_optim.py b/test/prototype/test_low_bit_optim.py index 4ba13db1fb..244b002c3e 100644 --- a/test/prototype/test_low_bit_optim.py +++ b/test/prototype/test_low_bit_optim.py @@ -152,6 +152,55 @@ def test_optim_smoke(self, optim_name, dtype, device): for p1, p2 in zip(model.parameters(), model2.parameters()): torch.testing.assert_close(p2, p1) + + + @parametrize( + "optim_name", + ["AdamFp8", "AdamWFp8"], + ) + @parametrize("dtype", [torch.float32, torch.bfloat16]) + @parametrize("device", _DEVICES) + def test_optim_fp8_coat_smoke(self, optim_name, dtype, device): + if device == "cuda": + if not TORCH_VERSION_AT_LEAST_2_4: + pytest.skip("FP8 CUDA requires PyTorch >= 2.4") + if torch.cuda.get_device_capability() < (8, 9): + pytest.skip("FP8 CUDA requires compute capability >= 8.9") + + model = nn.Sequential(nn.Linear(32, 256), nn.ReLU(), nn.Linear(256, 32)) + model.to(device=device, dtype=dtype) + + optim = getattr(low_bit_optim, optim_name)(model.parameters(), dynamic_range_expansion=True) + + x = torch.randn(4, 32, device=device, dtype=dtype) + loss = model(x).sum() + loss.backward() + optim.step() + optim.zero_grad() + + # test serialization. also test the case CUDA optim loads CPU state dict + with tempfile.NamedTemporaryFile() as f: + torch.save(optim.state_dict(), f.name) + state_dict = torch.load(f.name, map_location="cpu") + + model2 = copy.deepcopy(model) + optim2 = getattr(low_bit_optim, optim_name)(model2.parameters()) + optim2.load_state_dict(state_dict) + + for _ in range(2): + x = torch.randn(4, 32, device=device, dtype=dtype) + + model(x).sum().backward() + optim.step() + optim.zero_grad() + + model2(x).sum().backward() + optim2.step() + optim2.zero_grad() + + for p1, p2 in zip(model.parameters(), model2.parameters()): + torch.testing.assert_close(p2, p1) + # aten.slice is required for dcp.load() when world size changes i.e. re-sharding # however, it's cumbersome to test it directly, since we would need to run distributed diff --git a/torchao/prototype/README.md b/torchao/prototype/README.md index 2e0f9725a4..011687f210 100644 --- a/torchao/prototype/README.md +++ b/torchao/prototype/README.md @@ -11,8 +11,10 @@ - `galore/docs` - implementation notes and discussion of issues faced in kernel design. - [`quant_llm`](quant_llm) - FP16 x Floatx mixed matmul kernel per [FP6-LLM](https://arxiv.org/abs/2401.14112) - [`low_bit_optim`](low_bit_optim) - re-implementation of 8-bit optimizers from [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) and 4-bit optimizers from [lpmm](https://github.com/thu-ml/low-bit-optimizers). + * `dynamic_range_expansion` - implementing additional heuristic _expand & compress_ method before quantizing and after dequantizing of Optimizer states for float8 optimizers. [COAT](https://arxiv.org/abs/2410.19313) - [`spinquant`](spinquant) - re-implementation of [SpinQuant](https://arxiv.org/abs/2405.16406) + #### Roadmap - `hqq`, `awq`, `marlin`,`QuaRot`, and other well-researched methodologies for quantized fine-tuning and inference. diff --git a/torchao/prototype/low_bit_optim/adam.py b/torchao/prototype/low_bit_optim/adam.py index 9cad9777bf..210c1fca3c 100644 --- a/torchao/prototype/low_bit_optim/adam.py +++ b/torchao/prototype/low_bit_optim/adam.py @@ -51,8 +51,7 @@ def __setstate__(self, state): group.setdefault("amsgrad", False) # bring your own function to create zero-filled subclass - @staticmethod - def _subclass_zeros(p: Tensor, signed: bool, block_size: int): + def _subclass_zeros(self, p: Tensor, signed: bool): raise NotImplementedError def _new_buffer(self, p: Tensor, signed: bool): @@ -216,9 +215,8 @@ def __init__( is_adamw=False, ) - @staticmethod - def _subclass_zeros(p: Tensor, signed: bool, block_size: int): - return OptimState8bit.zeros(p.shape, signed, block_size, p.device) + def _subclass_zeros(self, p: Tensor, signed: bool): + return OptimState8bit.zeros(p.shape, signed, self.block_size, p.device) class Adam4bit(_AdamBase): @@ -246,9 +244,8 @@ def __init__( is_adamw=False, ) - @staticmethod - def _subclass_zeros(p: Tensor, signed: bool, block_size: int): - return OptimState4bit.zeros(p.shape, signed, block_size, p.device) + def _subclass_zeros(self, p: Tensor, signed: bool): + return OptimState8bit.zeros(p.shape, signed, self.block_size, p.device) class AdamFp8(_AdamBase): @@ -263,6 +260,7 @@ def __init__( *, block_size=256, bf16_stochastic_round=False, + dynamic_range_expansion=False, ) -> None: super().__init__( params, @@ -275,10 +273,12 @@ def __init__( bf16_stochastic_round=bf16_stochastic_round, is_adamw=False, ) + self.dynamic_range_expansion = dynamic_range_expansion - @staticmethod - def _subclass_zeros(p: Tensor, signed: bool, block_size: int): - return OptimStateFp8.zeros(p.shape, block_size, p.device) + def _subclass_zeros(self, p: Tensor, signed: bool): + return OptimStateFp8.zeros( + p.shape, self.block_size, p.device, self.dynamic_range_expansion + ) class AdamW8bit(_AdamBase): @@ -306,9 +306,8 @@ def __init__( is_adamw=True, ) - @staticmethod - def _subclass_zeros(p: Tensor, signed: bool, block_size: int): - return OptimState8bit.zeros(p.shape, signed, block_size, p.device) + def _subclass_zeros(self, p: Tensor, signed: bool): + return OptimState8bit.zeros(p.shape, signed, self.block_size, p.device) class AdamW4bit(_AdamBase): @@ -336,9 +335,8 @@ def __init__( is_adamw=True, ) - @staticmethod - def _subclass_zeros(p: Tensor, signed: bool, block_size: int): - return OptimState4bit.zeros(p.shape, signed, block_size, p.device) + def _subclass_zeros(self, p: Tensor, signed: bool): + return OptimState4bit.zeros(p.shape, signed, self.block_size, p.device) class AdamWFp8(_AdamBase): @@ -353,6 +351,7 @@ def __init__( *, block_size=256, bf16_stochastic_round=False, + dynamic_range_expansion=False, ) -> None: super().__init__( params, @@ -365,10 +364,12 @@ def __init__( bf16_stochastic_round=bf16_stochastic_round, is_adamw=True, ) + self.dynamic_range_expansion = dynamic_range_expansion - @staticmethod - def _subclass_zeros(p: Tensor, signed: bool, block_size: int): - return OptimStateFp8.zeros(p.shape, block_size, p.device) + def _subclass_zeros(self, p: Tensor, signed: bool): + return OptimStateFp8.zeros( + p.shape, self.block_size, p.device, self.dynamic_range_expansion + ) class _AdamW(_AdamBase): diff --git a/torchao/prototype/low_bit_optim/subclass_fp8.py b/torchao/prototype/low_bit_optim/subclass_fp8.py index b5c8af6c83..ef6a13d328 100644 --- a/torchao/prototype/low_bit_optim/subclass_fp8.py +++ b/torchao/prototype/low_bit_optim/subclass_fp8.py @@ -1,4 +1,5 @@ import math +from typing import Optional import torch from torch import Tensor @@ -13,13 +14,30 @@ DTYPE = torch.float8_e4m3fn -def quantize_fp8(input: Tensor, block_size: int): +def quantize_fp8(input: Tensor, block_size: int, dynamic_range_expansion: bool): shape = input.shape input = input.view(-1, block_size) + k = None + + if dynamic_range_expansion: + # NOTE: the calculation is from the paper https://arxiv.org/abs/2410.19313 + # The idea is to align optimizer state distributions more closely + # with the FP8 representation range, reducing the quantization error. + k = aten.ones(input.shape[0], dtype=DTYPE, device=input.device) + Rdtype = ( + torch.finfo(DTYPE).max / torch.finfo(DTYPE).min + ) # calculate the range of the dtype + Rx = input.abs().amax(-1).clip(1e-12) / input.abs().amin(-1).clip( + 1e-12 + ) # range of input max and min + k = torch.log(Rdtype) / torch.log(Rx) # calculating optimal value k dynamically + input = input.sign() * (input.abs() ** k.view(-1, 1)) + scale = input.abs().amax(-1).clip(1e-12) / torch.finfo(DTYPE).max input = input / scale.view(-1, 1) codes = input.to(DTYPE).view(-1) - return codes.view(shape), scale + + return codes.view(shape), scale, k # NOTE: FP8 sign bit is redundant for unsigned optim state. @@ -29,10 +47,10 @@ class OptimStateFp8(TorchAOBaseTensor): tensor_attrs = ["codes", "scale"] @staticmethod - def __new__(cls, codes: Tensor, scale: Tensor): + def __new__(cls, codes: Tensor, scale: Tensor, k: Optional[Tensor] = None): return Tensor._make_wrapper_subclass(cls, codes.shape, device=codes.device) - def __init__(self, codes: Tensor, scale: Tensor): + def __init__(self, codes: Tensor, scale: Tensor, k: Optional[Tensor] = None): """Create quantized FP8 optimizer state. Args @@ -47,6 +65,7 @@ def __init__(self, codes: Tensor, scale: Tensor): assert scale.ndim == 1 self.codes = codes self.scale = scale + self.k = k self.block_size = codes.numel() // scale.numel() def __tensor_flatten__(self): @@ -64,15 +83,31 @@ def dequantize(self, output_dtype=None): float_data = self.codes.float() float_data = float_data.view(-1, self.block_size) * self.scale.view(-1, 1) + if self.k is not None: + float_data = float_data.view(-1, self.block_size) + float_data = float_data ** (1 / self.k.view(-1, 1)) + if output_dtype is not None: float_data = float_data.to(output_dtype) + return float_data.view(self.codes.shape) @classmethod - def zeros(cls, shape, block_size: int = 256, device=None): + def zeros( + cls, + shape, + block_size: int = 256, + device=None, + dynamic_range_expansion: bool = False, + ): codes = torch.zeros(shape, dtype=DTYPE, device=device) scale = torch.zeros(codes.numel() // block_size, device=device) - return cls(codes, scale) + k = ( + torch.ones(codes.numel() // block_size, device=device) + if dynamic_range_expansion + else None + ) + return cls(codes, scale, k) def __repr__(self): return ( @@ -90,12 +125,19 @@ def _(func, types, args, kwargs): assert dst.block_size == src.block_size dst.codes.copy_(src.codes) dst.scale.copy_(src.scale) + if dst.k is not None: + dst.k.copy_(src.k) elif isinstance(dst, OptimStateFp8): - codes, scale = quantize_fp8(src, dst.block_size) + codes, scale, k = quantize_fp8( + src, dst.block_size, True if dst.k is not None else False + ) + dst.codes.copy_(codes) dst.scale.copy_(scale) + if dst.k is not None: + dst.k.copy_(k) else: dst.copy_(src.dequantize()) @@ -109,6 +151,7 @@ def _(func, types, args, kwargs): out = OptimStateFp8( args[0].codes.to(device=device), args[0].scale.to(device=device), + args[0].k.to(device=device) if args[0].k is not None else None, ) return return_and_correct_aliasing(func, args, kwargs, out) @@ -123,7 +166,7 @@ def _(func, types, args, kwargs): @OptimStateFp8.implements(aten.view.default) def _(func, types, args, kwargs): x, shape = args - return OptimStateFp8(x.codes.view(shape), x.scale) + return OptimStateFp8(x.codes.view(shape), x.scale, x.k) @OptimStateFp8.implements( @@ -146,6 +189,7 @@ def _(func, types, args, kwargs): return OptimStateFp8( func(x.codes, *args[1:], **kwargs), func(x.scale, *args[1:], **kwargs), + func(x.k, *args[1:], **kwargs) if x.k else None, )