From c62fcd30b2d9b4b33f4420ae4501678e3e007ad8 Mon Sep 17 00:00:00 2001 From: Mustafa Date: Wed, 6 Nov 2024 04:55:02 -0600 Subject: [PATCH 01/49] added dynamic range expansion --- torchao/prototype/low_bit_optim/subclass_fp8.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/torchao/prototype/low_bit_optim/subclass_fp8.py b/torchao/prototype/low_bit_optim/subclass_fp8.py index d95b0c2661..e745250ff7 100644 --- a/torchao/prototype/low_bit_optim/subclass_fp8.py +++ b/torchao/prototype/low_bit_optim/subclass_fp8.py @@ -20,6 +20,21 @@ def quantize_fp8(input: Tensor, block_size: int): return codes.view(shape), scale +def dynamic_range_expansion(input: Tensor, block_size: int): + + shape = input.shape + input = input.view(-1, block_size) + Rdtype = torch.finfo(DTYPE).max / torch.finfo(DTYPE).min + + min_vals = input.abs().amax(-1).clip(1e-12) + max_vals = input.abs().amin(-1).clip(1e-12) + Rx = max_vals / min_vals + k = torch.log(Rdtype) / torch.log(Rx) + + expanded_input = input.sign() * (input.abs() ** k.view(-1, 1)) + return expanded_input.view(shape), k + + # NOTE: FP8 sign bit is redundant for unsigned optim state. # we may investigate how to use it to increase range/precision for unsigned optim state. # https://arxiv.org/abs/2409.12517 uses FP8 E5M2 for 2nd Adam buffer From b11b4f6ae57d30bf50ec78f5fc7e3957c0b745fb Mon Sep 17 00:00:00 2001 From: Mustafa Date: Wed, 6 Nov 2024 04:56:57 -0600 Subject: [PATCH 02/49] created optimstate with DRE class --- torchao/prototype/low_bit_optim/subclass_fp8.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/torchao/prototype/low_bit_optim/subclass_fp8.py b/torchao/prototype/low_bit_optim/subclass_fp8.py index e745250ff7..6cce7eb68b 100644 --- a/torchao/prototype/low_bit_optim/subclass_fp8.py +++ b/torchao/prototype/low_bit_optim/subclass_fp8.py @@ -160,6 +160,23 @@ def _(func, types, args, kwargs): ) +class OptimStateFp8WithDynamicRangeExpansion(OptimStateFp8): + def __init__(self, codes, scale): + super().__init__(codes, scale) + + def dequantize(self, output_dtype=None): + + codes = super().dequantize(output_dtype) + + float_data = codes.view(-1, self.block_size) + float_data = float_data ** (1 / self.k.view(-1, 1)) + + if output_dtype is not None: + float_data = float_data.to(output_dtype) + + return float_data.view(self.codes.shape) + + if TORCH_VERSION_AT_LEAST_2_5: from torch.serialization import add_safe_globals From b887e69fe008df00da1c0a36136e81e063244d7f Mon Sep 17 00:00:00 2001 From: Mustafa Date: Wed, 6 Nov 2024 04:58:46 -0600 Subject: [PATCH 03/49] implement copy_.default for OptimStateFp8WithDynamicRangeExpansion class with DRE --- .../prototype/low_bit_optim/subclass_fp8.py | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/torchao/prototype/low_bit_optim/subclass_fp8.py b/torchao/prototype/low_bit_optim/subclass_fp8.py index 6cce7eb68b..00422719bd 100644 --- a/torchao/prototype/low_bit_optim/subclass_fp8.py +++ b/torchao/prototype/low_bit_optim/subclass_fp8.py @@ -177,6 +177,31 @@ def dequantize(self, output_dtype=None): return float_data.view(self.codes.shape) +@OptimStateFp8WithDynamicRangeExpansion.implements(aten.copy_.default) +def _(func, types, args, kwargs): + dst = args[0] + src = args[1] + + if isinstance(dst, OptimStateFp8WithDynamicRangeExpansion) and isinstance(src, OptimStateFp8WithDynamicRangeExpansion): + assert dst.block_size == src.block_size + dst.codes.copy_(src.codes) + dst.scale.copy_(src.scale) + dst.k.copy_(src.k) + + elif isinstance(dst, OptimStateFp8WithDynamicRangeExpansion): + codes, k = dynamic_range_expansion(src, dst.block_size) + + codes, scale = quantize_fp8(codes, dst.block_size) + dst.codes.copy_(codes) + dst.scale.copy_(scale) + dst.k = k + else: + dst.copy_(src.dequantize()) + + return dst + + + if TORCH_VERSION_AT_LEAST_2_5: from torch.serialization import add_safe_globals From ab026057c99568509b6e03f8840f024a5dd3488a Mon Sep 17 00:00:00 2001 From: Mustafa Date: Wed, 6 Nov 2024 04:59:13 -0600 Subject: [PATCH 04/49] implements _to_copy --- torchao/prototype/low_bit_optim/subclass_fp8.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/torchao/prototype/low_bit_optim/subclass_fp8.py b/torchao/prototype/low_bit_optim/subclass_fp8.py index 00422719bd..3f452b27be 100644 --- a/torchao/prototype/low_bit_optim/subclass_fp8.py +++ b/torchao/prototype/low_bit_optim/subclass_fp8.py @@ -201,6 +201,16 @@ def _(func, types, args, kwargs): return dst +@OptimStateFp8WithDynamicRangeExpansion.implements(aten._to_copy.default) +def _(func, types, args, kwargs): + # ignore dtype + device = kwargs.get("device", None) + out = OptimStateFp8WithDynamicRangeExpansion( + args[0].codes.to(device=device), + args[0].scale.to(device=device), + ) + return return_and_correct_aliasing(func, args, kwargs, out) + if TORCH_VERSION_AT_LEAST_2_5: from torch.serialization import add_safe_globals From 43f5c08a71d1f4b5ab8ce535c1a4688821f12c22 Mon Sep 17 00:00:00 2001 From: Mustafa Date: Wed, 6 Nov 2024 07:06:40 -0600 Subject: [PATCH 05/49] removed implemented classes --- .../prototype/low_bit_optim/subclass_fp8.py | 58 ------------------- 1 file changed, 58 deletions(-) diff --git a/torchao/prototype/low_bit_optim/subclass_fp8.py b/torchao/prototype/low_bit_optim/subclass_fp8.py index 3f452b27be..ca3073982f 100644 --- a/torchao/prototype/low_bit_optim/subclass_fp8.py +++ b/torchao/prototype/low_bit_optim/subclass_fp8.py @@ -158,61 +158,3 @@ def _(func, types, args, kwargs): func(x.codes, *args[1:], **kwargs), func(x.scale, *args[1:], **kwargs), ) - - -class OptimStateFp8WithDynamicRangeExpansion(OptimStateFp8): - def __init__(self, codes, scale): - super().__init__(codes, scale) - - def dequantize(self, output_dtype=None): - - codes = super().dequantize(output_dtype) - - float_data = codes.view(-1, self.block_size) - float_data = float_data ** (1 / self.k.view(-1, 1)) - - if output_dtype is not None: - float_data = float_data.to(output_dtype) - - return float_data.view(self.codes.shape) - - -@OptimStateFp8WithDynamicRangeExpansion.implements(aten.copy_.default) -def _(func, types, args, kwargs): - dst = args[0] - src = args[1] - - if isinstance(dst, OptimStateFp8WithDynamicRangeExpansion) and isinstance(src, OptimStateFp8WithDynamicRangeExpansion): - assert dst.block_size == src.block_size - dst.codes.copy_(src.codes) - dst.scale.copy_(src.scale) - dst.k.copy_(src.k) - - elif isinstance(dst, OptimStateFp8WithDynamicRangeExpansion): - codes, k = dynamic_range_expansion(src, dst.block_size) - - codes, scale = quantize_fp8(codes, dst.block_size) - dst.codes.copy_(codes) - dst.scale.copy_(scale) - dst.k = k - else: - dst.copy_(src.dequantize()) - - return dst - - -@OptimStateFp8WithDynamicRangeExpansion.implements(aten._to_copy.default) -def _(func, types, args, kwargs): - # ignore dtype - device = kwargs.get("device", None) - out = OptimStateFp8WithDynamicRangeExpansion( - args[0].codes.to(device=device), - args[0].scale.to(device=device), - ) - return return_and_correct_aliasing(func, args, kwargs, out) - - -if TORCH_VERSION_AT_LEAST_2_5: - from torch.serialization import add_safe_globals - - add_safe_globals([OptimStateFp8]) From 7d98b1592a2f99794f4166cc78ba87c96026965d Mon Sep 17 00:00:00 2001 From: Mustafa Date: Wed, 6 Nov 2024 07:07:22 -0600 Subject: [PATCH 06/49] dynamic_range_expansion -> apply_dynamic_range_expansion --- torchao/prototype/low_bit_optim/subclass_fp8.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/prototype/low_bit_optim/subclass_fp8.py b/torchao/prototype/low_bit_optim/subclass_fp8.py index ca3073982f..458593c8f5 100644 --- a/torchao/prototype/low_bit_optim/subclass_fp8.py +++ b/torchao/prototype/low_bit_optim/subclass_fp8.py @@ -20,7 +20,7 @@ def quantize_fp8(input: Tensor, block_size: int): return codes.view(shape), scale -def dynamic_range_expansion(input: Tensor, block_size: int): +def apply_dynamic_range_expansion(input: Tensor, block_size: int): shape = input.shape input = input.view(-1, block_size) From e03c79cd48e622c9ce436690c6ef6a735482c623 Mon Sep 17 00:00:00 2001 From: Mustafa Date: Wed, 6 Nov 2024 07:10:08 -0600 Subject: [PATCH 07/49] add DRE flags to class --- torchao/prototype/low_bit_optim/subclass_fp8.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchao/prototype/low_bit_optim/subclass_fp8.py b/torchao/prototype/low_bit_optim/subclass_fp8.py index 458593c8f5..56923d3a65 100644 --- a/torchao/prototype/low_bit_optim/subclass_fp8.py +++ b/torchao/prototype/low_bit_optim/subclass_fp8.py @@ -45,7 +45,7 @@ class OptimStateFp8(TorchAOBaseTensor): def __new__(cls, codes: Tensor, scale: Tensor): return Tensor._make_wrapper_subclass(cls, codes.shape, device=codes.device) - def __init__(self, codes: Tensor, scale: Tensor): + def __init__(self, codes: Tensor, scale: Tensor, dynamic_range_expansion: bool =False): """Create quantized FP8 optimizer state. Args @@ -61,6 +61,8 @@ def __init__(self, codes: Tensor, scale: Tensor): self.codes = codes self.scale = scale self.block_size = codes.numel() // scale.numel() + self.dynamic_range_expansion = dynamic_range_expansion + self.k = None def __tensor_flatten__(self): return self.tensor_attrs, [] From 5faa7de514c3d8cef680480a395fd330acbeb797 Mon Sep 17 00:00:00 2001 From: Mustafa Date: Wed, 6 Nov 2024 07:11:02 -0600 Subject: [PATCH 08/49] implementing contraction for dequantize --- torchao/prototype/low_bit_optim/subclass_fp8.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/torchao/prototype/low_bit_optim/subclass_fp8.py b/torchao/prototype/low_bit_optim/subclass_fp8.py index 56923d3a65..6aed8b89a7 100644 --- a/torchao/prototype/low_bit_optim/subclass_fp8.py +++ b/torchao/prototype/low_bit_optim/subclass_fp8.py @@ -79,8 +79,13 @@ def dequantize(self, output_dtype=None): float_data = self.codes.float() float_data = float_data.view(-1, self.block_size) * self.scale.view(-1, 1) + if self.dynamic_range_expansion: + float_data = float_data.view(-1, self.block_size) + float_data = float_data ** (1 / self.k.view(-1, 1)) + if output_dtype is not None: float_data = float_data.to(output_dtype) + return float_data.view(self.codes.shape) @classmethod From b43c88f1e354fca9e63a462b08558a6cc7add246 Mon Sep 17 00:00:00 2001 From: Mustafa Date: Wed, 6 Nov 2024 07:12:02 -0600 Subject: [PATCH 09/49] copy k values as well for copy method --- torchao/prototype/low_bit_optim/subclass_fp8.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/torchao/prototype/low_bit_optim/subclass_fp8.py b/torchao/prototype/low_bit_optim/subclass_fp8.py index 6aed8b89a7..98a78e6ac6 100644 --- a/torchao/prototype/low_bit_optim/subclass_fp8.py +++ b/torchao/prototype/low_bit_optim/subclass_fp8.py @@ -110,12 +110,17 @@ def _(func, types, args, kwargs): assert dst.block_size == src.block_size dst.codes.copy_(src.codes) dst.scale.copy_(src.scale) + dst.k.copy_(src.k) elif isinstance(dst, OptimStateFp8): - codes, scale = quantize_fp8(src, dst.block_size) + if src.dynamic_range_expansion: + codes, k = apply_dynamic_range_expansion(src, dst.block_size) + + codes, scale = quantize_fp8(codes, dst.block_size) + dst.codes.copy_(codes) dst.scale.copy_(scale) - + dst.k.copy_(k) else: dst.copy_(src.dequantize()) From ac4162716506dae6b8d0564fb0fe73648420423f Mon Sep 17 00:00:00 2001 From: Mustafa Date: Wed, 6 Nov 2024 04:55:02 -0600 Subject: [PATCH 10/49] added dynamic range expansion --- torchao/prototype/low_bit_optim/subclass_fp8.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/torchao/prototype/low_bit_optim/subclass_fp8.py b/torchao/prototype/low_bit_optim/subclass_fp8.py index d95b0c2661..e745250ff7 100644 --- a/torchao/prototype/low_bit_optim/subclass_fp8.py +++ b/torchao/prototype/low_bit_optim/subclass_fp8.py @@ -20,6 +20,21 @@ def quantize_fp8(input: Tensor, block_size: int): return codes.view(shape), scale +def dynamic_range_expansion(input: Tensor, block_size: int): + + shape = input.shape + input = input.view(-1, block_size) + Rdtype = torch.finfo(DTYPE).max / torch.finfo(DTYPE).min + + min_vals = input.abs().amax(-1).clip(1e-12) + max_vals = input.abs().amin(-1).clip(1e-12) + Rx = max_vals / min_vals + k = torch.log(Rdtype) / torch.log(Rx) + + expanded_input = input.sign() * (input.abs() ** k.view(-1, 1)) + return expanded_input.view(shape), k + + # NOTE: FP8 sign bit is redundant for unsigned optim state. # we may investigate how to use it to increase range/precision for unsigned optim state. # https://arxiv.org/abs/2409.12517 uses FP8 E5M2 for 2nd Adam buffer From 79c9461041b49eaebc3251bdf4bf3817d4f00701 Mon Sep 17 00:00:00 2001 From: Mustafa Date: Wed, 6 Nov 2024 04:56:57 -0600 Subject: [PATCH 11/49] created optimstate with DRE class --- torchao/prototype/low_bit_optim/subclass_fp8.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/torchao/prototype/low_bit_optim/subclass_fp8.py b/torchao/prototype/low_bit_optim/subclass_fp8.py index e745250ff7..6cce7eb68b 100644 --- a/torchao/prototype/low_bit_optim/subclass_fp8.py +++ b/torchao/prototype/low_bit_optim/subclass_fp8.py @@ -160,6 +160,23 @@ def _(func, types, args, kwargs): ) +class OptimStateFp8WithDynamicRangeExpansion(OptimStateFp8): + def __init__(self, codes, scale): + super().__init__(codes, scale) + + def dequantize(self, output_dtype=None): + + codes = super().dequantize(output_dtype) + + float_data = codes.view(-1, self.block_size) + float_data = float_data ** (1 / self.k.view(-1, 1)) + + if output_dtype is not None: + float_data = float_data.to(output_dtype) + + return float_data.view(self.codes.shape) + + if TORCH_VERSION_AT_LEAST_2_5: from torch.serialization import add_safe_globals From 47a7bb03339b085387acd3b5fb6b7439e2b79d9e Mon Sep 17 00:00:00 2001 From: Mustafa Date: Wed, 6 Nov 2024 04:58:46 -0600 Subject: [PATCH 12/49] implement copy_.default for OptimStateFp8WithDynamicRangeExpansion class with DRE --- .../prototype/low_bit_optim/subclass_fp8.py | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/torchao/prototype/low_bit_optim/subclass_fp8.py b/torchao/prototype/low_bit_optim/subclass_fp8.py index 6cce7eb68b..00422719bd 100644 --- a/torchao/prototype/low_bit_optim/subclass_fp8.py +++ b/torchao/prototype/low_bit_optim/subclass_fp8.py @@ -177,6 +177,31 @@ def dequantize(self, output_dtype=None): return float_data.view(self.codes.shape) +@OptimStateFp8WithDynamicRangeExpansion.implements(aten.copy_.default) +def _(func, types, args, kwargs): + dst = args[0] + src = args[1] + + if isinstance(dst, OptimStateFp8WithDynamicRangeExpansion) and isinstance(src, OptimStateFp8WithDynamicRangeExpansion): + assert dst.block_size == src.block_size + dst.codes.copy_(src.codes) + dst.scale.copy_(src.scale) + dst.k.copy_(src.k) + + elif isinstance(dst, OptimStateFp8WithDynamicRangeExpansion): + codes, k = dynamic_range_expansion(src, dst.block_size) + + codes, scale = quantize_fp8(codes, dst.block_size) + dst.codes.copy_(codes) + dst.scale.copy_(scale) + dst.k = k + else: + dst.copy_(src.dequantize()) + + return dst + + + if TORCH_VERSION_AT_LEAST_2_5: from torch.serialization import add_safe_globals From a162f94699b335bcfae25380aede8b6fdee19454 Mon Sep 17 00:00:00 2001 From: Mustafa Date: Wed, 6 Nov 2024 04:59:13 -0600 Subject: [PATCH 13/49] implements _to_copy --- torchao/prototype/low_bit_optim/subclass_fp8.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/torchao/prototype/low_bit_optim/subclass_fp8.py b/torchao/prototype/low_bit_optim/subclass_fp8.py index 00422719bd..3f452b27be 100644 --- a/torchao/prototype/low_bit_optim/subclass_fp8.py +++ b/torchao/prototype/low_bit_optim/subclass_fp8.py @@ -201,6 +201,16 @@ def _(func, types, args, kwargs): return dst +@OptimStateFp8WithDynamicRangeExpansion.implements(aten._to_copy.default) +def _(func, types, args, kwargs): + # ignore dtype + device = kwargs.get("device", None) + out = OptimStateFp8WithDynamicRangeExpansion( + args[0].codes.to(device=device), + args[0].scale.to(device=device), + ) + return return_and_correct_aliasing(func, args, kwargs, out) + if TORCH_VERSION_AT_LEAST_2_5: from torch.serialization import add_safe_globals From 5c1a3f495c6241750cf3cc7ff4c497fc814a4265 Mon Sep 17 00:00:00 2001 From: Mustafa Date: Wed, 6 Nov 2024 07:06:40 -0600 Subject: [PATCH 14/49] removed implemented classes --- .../prototype/low_bit_optim/subclass_fp8.py | 58 ------------------- 1 file changed, 58 deletions(-) diff --git a/torchao/prototype/low_bit_optim/subclass_fp8.py b/torchao/prototype/low_bit_optim/subclass_fp8.py index 3f452b27be..ca3073982f 100644 --- a/torchao/prototype/low_bit_optim/subclass_fp8.py +++ b/torchao/prototype/low_bit_optim/subclass_fp8.py @@ -158,61 +158,3 @@ def _(func, types, args, kwargs): func(x.codes, *args[1:], **kwargs), func(x.scale, *args[1:], **kwargs), ) - - -class OptimStateFp8WithDynamicRangeExpansion(OptimStateFp8): - def __init__(self, codes, scale): - super().__init__(codes, scale) - - def dequantize(self, output_dtype=None): - - codes = super().dequantize(output_dtype) - - float_data = codes.view(-1, self.block_size) - float_data = float_data ** (1 / self.k.view(-1, 1)) - - if output_dtype is not None: - float_data = float_data.to(output_dtype) - - return float_data.view(self.codes.shape) - - -@OptimStateFp8WithDynamicRangeExpansion.implements(aten.copy_.default) -def _(func, types, args, kwargs): - dst = args[0] - src = args[1] - - if isinstance(dst, OptimStateFp8WithDynamicRangeExpansion) and isinstance(src, OptimStateFp8WithDynamicRangeExpansion): - assert dst.block_size == src.block_size - dst.codes.copy_(src.codes) - dst.scale.copy_(src.scale) - dst.k.copy_(src.k) - - elif isinstance(dst, OptimStateFp8WithDynamicRangeExpansion): - codes, k = dynamic_range_expansion(src, dst.block_size) - - codes, scale = quantize_fp8(codes, dst.block_size) - dst.codes.copy_(codes) - dst.scale.copy_(scale) - dst.k = k - else: - dst.copy_(src.dequantize()) - - return dst - - -@OptimStateFp8WithDynamicRangeExpansion.implements(aten._to_copy.default) -def _(func, types, args, kwargs): - # ignore dtype - device = kwargs.get("device", None) - out = OptimStateFp8WithDynamicRangeExpansion( - args[0].codes.to(device=device), - args[0].scale.to(device=device), - ) - return return_and_correct_aliasing(func, args, kwargs, out) - - -if TORCH_VERSION_AT_LEAST_2_5: - from torch.serialization import add_safe_globals - - add_safe_globals([OptimStateFp8]) From 1458d658f752b7a2e3f33ef4916377fe805f1105 Mon Sep 17 00:00:00 2001 From: Mustafa Date: Wed, 6 Nov 2024 07:07:22 -0600 Subject: [PATCH 15/49] dynamic_range_expansion -> apply_dynamic_range_expansion --- torchao/prototype/low_bit_optim/subclass_fp8.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/prototype/low_bit_optim/subclass_fp8.py b/torchao/prototype/low_bit_optim/subclass_fp8.py index ca3073982f..458593c8f5 100644 --- a/torchao/prototype/low_bit_optim/subclass_fp8.py +++ b/torchao/prototype/low_bit_optim/subclass_fp8.py @@ -20,7 +20,7 @@ def quantize_fp8(input: Tensor, block_size: int): return codes.view(shape), scale -def dynamic_range_expansion(input: Tensor, block_size: int): +def apply_dynamic_range_expansion(input: Tensor, block_size: int): shape = input.shape input = input.view(-1, block_size) From 42fbb09118d98289653da9425acebfa466be674d Mon Sep 17 00:00:00 2001 From: Mustafa Date: Wed, 6 Nov 2024 07:10:08 -0600 Subject: [PATCH 16/49] add DRE flags to class --- torchao/prototype/low_bit_optim/subclass_fp8.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchao/prototype/low_bit_optim/subclass_fp8.py b/torchao/prototype/low_bit_optim/subclass_fp8.py index 458593c8f5..56923d3a65 100644 --- a/torchao/prototype/low_bit_optim/subclass_fp8.py +++ b/torchao/prototype/low_bit_optim/subclass_fp8.py @@ -45,7 +45,7 @@ class OptimStateFp8(TorchAOBaseTensor): def __new__(cls, codes: Tensor, scale: Tensor): return Tensor._make_wrapper_subclass(cls, codes.shape, device=codes.device) - def __init__(self, codes: Tensor, scale: Tensor): + def __init__(self, codes: Tensor, scale: Tensor, dynamic_range_expansion: bool =False): """Create quantized FP8 optimizer state. Args @@ -61,6 +61,8 @@ def __init__(self, codes: Tensor, scale: Tensor): self.codes = codes self.scale = scale self.block_size = codes.numel() // scale.numel() + self.dynamic_range_expansion = dynamic_range_expansion + self.k = None def __tensor_flatten__(self): return self.tensor_attrs, [] From 9d1c00c17bbbb2df140f7c6e0a2a67edf45a2a90 Mon Sep 17 00:00:00 2001 From: Mustafa Date: Wed, 6 Nov 2024 07:11:02 -0600 Subject: [PATCH 17/49] implementing contraction for dequantize --- torchao/prototype/low_bit_optim/subclass_fp8.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/torchao/prototype/low_bit_optim/subclass_fp8.py b/torchao/prototype/low_bit_optim/subclass_fp8.py index 56923d3a65..6aed8b89a7 100644 --- a/torchao/prototype/low_bit_optim/subclass_fp8.py +++ b/torchao/prototype/low_bit_optim/subclass_fp8.py @@ -79,8 +79,13 @@ def dequantize(self, output_dtype=None): float_data = self.codes.float() float_data = float_data.view(-1, self.block_size) * self.scale.view(-1, 1) + if self.dynamic_range_expansion: + float_data = float_data.view(-1, self.block_size) + float_data = float_data ** (1 / self.k.view(-1, 1)) + if output_dtype is not None: float_data = float_data.to(output_dtype) + return float_data.view(self.codes.shape) @classmethod From 7bc6ea4e060478e0b7deac641a248f35370dabc1 Mon Sep 17 00:00:00 2001 From: Mustafa Date: Wed, 6 Nov 2024 07:12:02 -0600 Subject: [PATCH 18/49] copy k values as well for copy method --- torchao/prototype/low_bit_optim/subclass_fp8.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/torchao/prototype/low_bit_optim/subclass_fp8.py b/torchao/prototype/low_bit_optim/subclass_fp8.py index 6aed8b89a7..98a78e6ac6 100644 --- a/torchao/prototype/low_bit_optim/subclass_fp8.py +++ b/torchao/prototype/low_bit_optim/subclass_fp8.py @@ -110,12 +110,17 @@ def _(func, types, args, kwargs): assert dst.block_size == src.block_size dst.codes.copy_(src.codes) dst.scale.copy_(src.scale) + dst.k.copy_(src.k) elif isinstance(dst, OptimStateFp8): - codes, scale = quantize_fp8(src, dst.block_size) + if src.dynamic_range_expansion: + codes, k = apply_dynamic_range_expansion(src, dst.block_size) + + codes, scale = quantize_fp8(codes, dst.block_size) + dst.codes.copy_(codes) dst.scale.copy_(scale) - + dst.k.copy_(k) else: dst.copy_(src.dequantize()) From 70937c87b49c647aa32f25372be08fdcf328729f Mon Sep 17 00:00:00 2001 From: Mustafa Date: Fri, 8 Nov 2024 13:22:47 -0600 Subject: [PATCH 19/49] combine range_expansion into quantize_fp8 function --- torchao/prototype/low_bit_optim/subclass_fp8.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/torchao/prototype/low_bit_optim/subclass_fp8.py b/torchao/prototype/low_bit_optim/subclass_fp8.py index 98a78e6ac6..b42756d0cd 100644 --- a/torchao/prototype/low_bit_optim/subclass_fp8.py +++ b/torchao/prototype/low_bit_optim/subclass_fp8.py @@ -11,13 +11,23 @@ DTYPE = torch.float8_e4m3fn -def quantize_fp8(input: Tensor, block_size: int): +def quantize_fp8(input: Tensor, block_size: int, apply_range_expansion: bool): + shape = input.shape input = input.view(-1, block_size) + k = aten.ones(input.shape[0], dtype=DTYPE, device=input.device) + + if apply_range_expansion: + Rdtype = torch.finfo(DTYPE).max / torch.finfo(DTYPE).min # calculate the range of the dtype + Rx = input.abs().amax(-1).clip(1e-12) / input.abs().amin(-1).clip(1e-12) # range of input max and min + k = torch.log(Rdtype) / torch.log(Rx) # k calcuated + input = input.sign() * (input.abs() ** k.view(-1, 1)) + scale = input.abs().amax(-1).clip(1e-12) / torch.finfo(DTYPE).max input = input / scale.view(-1, 1) codes = input.to(DTYPE).view(-1) - return codes.view(shape), scale + + return codes.view(shape), scale, k def apply_dynamic_range_expansion(input: Tensor, block_size: int): From 3583de77628ed592082b20129e81902cbdcaa7fb Mon Sep 17 00:00:00 2001 From: Mustafa Date: Fri, 8 Nov 2024 13:27:08 -0600 Subject: [PATCH 20/49] passing apply_range_expansion to quantize_fp8 --- torchao/prototype/low_bit_optim/subclass_fp8.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/torchao/prototype/low_bit_optim/subclass_fp8.py b/torchao/prototype/low_bit_optim/subclass_fp8.py index b42756d0cd..04fbf06aa1 100644 --- a/torchao/prototype/low_bit_optim/subclass_fp8.py +++ b/torchao/prototype/low_bit_optim/subclass_fp8.py @@ -123,10 +123,8 @@ def _(func, types, args, kwargs): dst.k.copy_(src.k) elif isinstance(dst, OptimStateFp8): - if src.dynamic_range_expansion: - codes, k = apply_dynamic_range_expansion(src, dst.block_size) - codes, scale = quantize_fp8(codes, dst.block_size) + codes, scale, k = quantize_fp8(codes, dst.block_size, False) dst.codes.copy_(codes) dst.scale.copy_(scale) From c75489334e1e14eff3f7224e2f280a868154ce3e Mon Sep 17 00:00:00 2001 From: Mustafa Date: Fri, 8 Nov 2024 13:34:22 -0600 Subject: [PATCH 21/49] remove apply_dynamic_range_expansion method --- torchao/prototype/low_bit_optim/subclass_fp8.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/torchao/prototype/low_bit_optim/subclass_fp8.py b/torchao/prototype/low_bit_optim/subclass_fp8.py index 04fbf06aa1..df5f32a9bb 100644 --- a/torchao/prototype/low_bit_optim/subclass_fp8.py +++ b/torchao/prototype/low_bit_optim/subclass_fp8.py @@ -30,21 +30,6 @@ def quantize_fp8(input: Tensor, block_size: int, apply_range_expansion: bool): return codes.view(shape), scale, k -def apply_dynamic_range_expansion(input: Tensor, block_size: int): - - shape = input.shape - input = input.view(-1, block_size) - Rdtype = torch.finfo(DTYPE).max / torch.finfo(DTYPE).min - - min_vals = input.abs().amax(-1).clip(1e-12) - max_vals = input.abs().amin(-1).clip(1e-12) - Rx = max_vals / min_vals - k = torch.log(Rdtype) / torch.log(Rx) - - expanded_input = input.sign() * (input.abs() ** k.view(-1, 1)) - return expanded_input.view(shape), k - - # NOTE: FP8 sign bit is redundant for unsigned optim state. # we may investigate how to use it to increase range/precision for unsigned optim state. # https://arxiv.org/abs/2409.12517 uses FP8 E5M2 for 2nd Adam buffer From c47b987f4daa8ef77f00c81394de51d0bd88c919 Mon Sep 17 00:00:00 2001 From: Mustafa Date: Fri, 8 Nov 2024 13:36:26 -0600 Subject: [PATCH 22/49] pass destination's dynamic range expasnsion variable to quantize fp8 --- torchao/prototype/low_bit_optim/subclass_fp8.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/prototype/low_bit_optim/subclass_fp8.py b/torchao/prototype/low_bit_optim/subclass_fp8.py index df5f32a9bb..195fe88726 100644 --- a/torchao/prototype/low_bit_optim/subclass_fp8.py +++ b/torchao/prototype/low_bit_optim/subclass_fp8.py @@ -109,7 +109,7 @@ def _(func, types, args, kwargs): elif isinstance(dst, OptimStateFp8): - codes, scale, k = quantize_fp8(codes, dst.block_size, False) + codes, scale, k = quantize_fp8(codes, dst.block_size, dst.dynamic_range_expansion) dst.codes.copy_(codes) dst.scale.copy_(scale) From 7a754ce890773a94e9baf8fd0aa6d589af317575 Mon Sep 17 00:00:00 2001 From: Mustafa Date: Fri, 8 Nov 2024 15:39:11 -0600 Subject: [PATCH 23/49] change type annotation to optional --- torchao/prototype/low_bit_optim/subclass_fp8.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchao/prototype/low_bit_optim/subclass_fp8.py b/torchao/prototype/low_bit_optim/subclass_fp8.py index 195fe88726..29a48e1099 100644 --- a/torchao/prototype/low_bit_optim/subclass_fp8.py +++ b/torchao/prototype/low_bit_optim/subclass_fp8.py @@ -1,5 +1,6 @@ import torch from torch import Tensor +from typing import Optional from torch.utils._python_dispatch import return_and_correct_aliasing from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TorchAOBaseTensor @@ -40,7 +41,7 @@ class OptimStateFp8(TorchAOBaseTensor): def __new__(cls, codes: Tensor, scale: Tensor): return Tensor._make_wrapper_subclass(cls, codes.shape, device=codes.device) - def __init__(self, codes: Tensor, scale: Tensor, dynamic_range_expansion: bool =False): + def __init__(self, codes: Tensor, scale: Tensor, k: Optional[Tensor] =None): """Create quantized FP8 optimizer state. Args @@ -55,9 +56,8 @@ def __init__(self, codes: Tensor, scale: Tensor, dynamic_range_expansion: bool = assert scale.ndim == 1 self.codes = codes self.scale = scale + self.k = k self.block_size = codes.numel() // scale.numel() - self.dynamic_range_expansion = dynamic_range_expansion - self.k = None def __tensor_flatten__(self): return self.tensor_attrs, [] From 4d37d864865cad9b4d63297395e549d956c27d73 Mon Sep 17 00:00:00 2001 From: Mustafa Date: Sat, 9 Nov 2024 01:52:27 -0600 Subject: [PATCH 24/49] k is none when dynamic range expansion is False --- torchao/prototype/low_bit_optim/subclass_fp8.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchao/prototype/low_bit_optim/subclass_fp8.py b/torchao/prototype/low_bit_optim/subclass_fp8.py index 29a48e1099..37b300215e 100644 --- a/torchao/prototype/low_bit_optim/subclass_fp8.py +++ b/torchao/prototype/low_bit_optim/subclass_fp8.py @@ -12,13 +12,13 @@ DTYPE = torch.float8_e4m3fn -def quantize_fp8(input: Tensor, block_size: int, apply_range_expansion: bool): +def quantize_fp8(input: Tensor, block_size: int, dynamic_range_expansion: bool): shape = input.shape input = input.view(-1, block_size) - k = aten.ones(input.shape[0], dtype=DTYPE, device=input.device) + k = aten.ones(input.shape[0], dtype=DTYPE, device=input.device) if dynamic_range_expansion else None - if apply_range_expansion: + if dynamic_range_expansion: Rdtype = torch.finfo(DTYPE).max / torch.finfo(DTYPE).min # calculate the range of the dtype Rx = input.abs().amax(-1).clip(1e-12) / input.abs().amin(-1).clip(1e-12) # range of input max and min k = torch.log(Rdtype) / torch.log(Rx) # k calcuated From 2d1834a06103411f923578af54639a3d15888bb2 Mon Sep 17 00:00:00 2001 From: Mustafa Date: Sat, 9 Nov 2024 01:59:21 -0600 Subject: [PATCH 25/49] referencing paper for calculation of dynamic range expansion --- torchao/prototype/low_bit_optim/subclass_fp8.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torchao/prototype/low_bit_optim/subclass_fp8.py b/torchao/prototype/low_bit_optim/subclass_fp8.py index 37b300215e..b798f76a13 100644 --- a/torchao/prototype/low_bit_optim/subclass_fp8.py +++ b/torchao/prototype/low_bit_optim/subclass_fp8.py @@ -19,9 +19,12 @@ def quantize_fp8(input: Tensor, block_size: int, dynamic_range_expansion: bool): k = aten.ones(input.shape[0], dtype=DTYPE, device=input.device) if dynamic_range_expansion else None if dynamic_range_expansion: + # NOTE: the calculation is from the paper https://arxiv.org/abs/2410.19313 + # The idea is to align optimizer state distributions more closely + # with the FP8 representation range, reducing the quantization error. Rdtype = torch.finfo(DTYPE).max / torch.finfo(DTYPE).min # calculate the range of the dtype Rx = input.abs().amax(-1).clip(1e-12) / input.abs().amin(-1).clip(1e-12) # range of input max and min - k = torch.log(Rdtype) / torch.log(Rx) # k calcuated + k = torch.log(Rdtype) / torch.log(Rx)# calculating optimal value k dynamically input = input.sign() * (input.abs() ** k.view(-1, 1)) scale = input.abs().amax(-1).clip(1e-12) / torch.finfo(DTYPE).max From 3d0d5d6fc73e2952b6c7f716e094e999a5bd7d65 Mon Sep 17 00:00:00 2001 From: Mustafa Date: Sat, 9 Nov 2024 02:01:22 -0600 Subject: [PATCH 26/49] replaced condition check using variable k --- torchao/prototype/low_bit_optim/subclass_fp8.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/prototype/low_bit_optim/subclass_fp8.py b/torchao/prototype/low_bit_optim/subclass_fp8.py index b798f76a13..657168f4ba 100644 --- a/torchao/prototype/low_bit_optim/subclass_fp8.py +++ b/torchao/prototype/low_bit_optim/subclass_fp8.py @@ -77,7 +77,7 @@ def dequantize(self, output_dtype=None): float_data = self.codes.float() float_data = float_data.view(-1, self.block_size) * self.scale.view(-1, 1) - if self.dynamic_range_expansion: + if self.k: float_data = float_data.view(-1, self.block_size) float_data = float_data ** (1 / self.k.view(-1, 1)) From c413ac4955c931661466a4438740ad25a4d35a4d Mon Sep 17 00:00:00 2001 From: Mustafa Date: Sat, 9 Nov 2024 02:04:01 -0600 Subject: [PATCH 27/49] added parameter dynamic_range_expansion --- torchao/prototype/low_bit_optim/subclass_fp8.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torchao/prototype/low_bit_optim/subclass_fp8.py b/torchao/prototype/low_bit_optim/subclass_fp8.py index 657168f4ba..206933227e 100644 --- a/torchao/prototype/low_bit_optim/subclass_fp8.py +++ b/torchao/prototype/low_bit_optim/subclass_fp8.py @@ -87,10 +87,12 @@ def dequantize(self, output_dtype=None): return float_data.view(self.codes.shape) @classmethod - def zeros(cls, shape, block_size: int = 256, device=None): + def zeros(cls, shape, block_size: int = 256, device=None, dynamic_range_expansion: bool = False): + codes = torch.zeros(shape, dtype=DTYPE, device=device) scale = torch.zeros(codes.numel() // block_size, device=device) - return cls(codes, scale) + k = torch.ones(codes.numel() // block_size, device=device) if dynamic_range_expansion else None + return cls(codes, scale, k) def __repr__(self): return ( From c3f5d2986c4d72df6adcbad3c991d5f2eac96265 Mon Sep 17 00:00:00 2001 From: Mustafa Date: Sat, 9 Nov 2024 02:12:11 -0600 Subject: [PATCH 28/49] pass bool condition for quantizing src tensor --- torchao/prototype/low_bit_optim/subclass_fp8.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/prototype/low_bit_optim/subclass_fp8.py b/torchao/prototype/low_bit_optim/subclass_fp8.py index 206933227e..b3b112697b 100644 --- a/torchao/prototype/low_bit_optim/subclass_fp8.py +++ b/torchao/prototype/low_bit_optim/subclass_fp8.py @@ -114,7 +114,7 @@ def _(func, types, args, kwargs): elif isinstance(dst, OptimStateFp8): - codes, scale, k = quantize_fp8(codes, dst.block_size, dst.dynamic_range_expansion) + codes, scale, k = quantize_fp8(src, dst.block_size, True if dst.k is not None else False) dst.codes.copy_(codes) dst.scale.copy_(scale) From 1ec93358382c85d92dce3ab48dc438c77a38d334 Mon Sep 17 00:00:00 2001 From: Mustafa Date: Sat, 9 Nov 2024 02:16:45 -0600 Subject: [PATCH 29/49] readded the torchversion safe_global exports --- torchao/prototype/low_bit_optim/subclass_fp8.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/torchao/prototype/low_bit_optim/subclass_fp8.py b/torchao/prototype/low_bit_optim/subclass_fp8.py index b3b112697b..1935094522 100644 --- a/torchao/prototype/low_bit_optim/subclass_fp8.py +++ b/torchao/prototype/low_bit_optim/subclass_fp8.py @@ -168,3 +168,9 @@ def _(func, types, args, kwargs): func(x.codes, *args[1:], **kwargs), func(x.scale, *args[1:], **kwargs), ) + + +if TORCH_VERSION_AT_LEAST_2_5: + from torch.serialization import add_safe_globals + + add_safe_globals([OptimStateFp8]) \ No newline at end of file From 122530e6946a5bf16ffe94c0768476843a9c038a Mon Sep 17 00:00:00 2001 From: Mustafa Date: Sat, 9 Nov 2024 03:07:06 -0600 Subject: [PATCH 30/49] initialize k to none and later assign value if dynamic range expansion is true --- torchao/prototype/low_bit_optim/subclass_fp8.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchao/prototype/low_bit_optim/subclass_fp8.py b/torchao/prototype/low_bit_optim/subclass_fp8.py index 1935094522..b4837fcf97 100644 --- a/torchao/prototype/low_bit_optim/subclass_fp8.py +++ b/torchao/prototype/low_bit_optim/subclass_fp8.py @@ -16,12 +16,13 @@ def quantize_fp8(input: Tensor, block_size: int, dynamic_range_expansion: bool): shape = input.shape input = input.view(-1, block_size) - k = aten.ones(input.shape[0], dtype=DTYPE, device=input.device) if dynamic_range_expansion else None + k = None if dynamic_range_expansion: # NOTE: the calculation is from the paper https://arxiv.org/abs/2410.19313 # The idea is to align optimizer state distributions more closely # with the FP8 representation range, reducing the quantization error. + k = aten.ones(input.shape[0], dtype=DTYPE, device=input.device) Rdtype = torch.finfo(DTYPE).max / torch.finfo(DTYPE).min # calculate the range of the dtype Rx = input.abs().amax(-1).clip(1e-12) / input.abs().amin(-1).clip(1e-12) # range of input max and min k = torch.log(Rdtype) / torch.log(Rx)# calculating optimal value k dynamically From 77e137145ca88a388de7be8638936bc5e7ef8926 Mon Sep 17 00:00:00 2001 From: Mustafa Date: Sat, 9 Nov 2024 03:09:53 -0600 Subject: [PATCH 31/49] conditional statement by checking if k is None instead of directly applying condition on k --- torchao/prototype/low_bit_optim/subclass_fp8.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/prototype/low_bit_optim/subclass_fp8.py b/torchao/prototype/low_bit_optim/subclass_fp8.py index b4837fcf97..08dd23b630 100644 --- a/torchao/prototype/low_bit_optim/subclass_fp8.py +++ b/torchao/prototype/low_bit_optim/subclass_fp8.py @@ -78,7 +78,7 @@ def dequantize(self, output_dtype=None): float_data = self.codes.float() float_data = float_data.view(-1, self.block_size) * self.scale.view(-1, 1) - if self.k: + if self.k is not None: float_data = float_data.view(-1, self.block_size) float_data = float_data ** (1 / self.k.view(-1, 1)) From 366743cb12031a91a9d574fa418b6541b9b22b5b Mon Sep 17 00:00:00 2001 From: Mustafa Date: Sat, 9 Nov 2024 03:12:33 -0600 Subject: [PATCH 32/49] checking if k is available in dst to copy it --- torchao/prototype/low_bit_optim/subclass_fp8.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchao/prototype/low_bit_optim/subclass_fp8.py b/torchao/prototype/low_bit_optim/subclass_fp8.py index 08dd23b630..0e883e8e21 100644 --- a/torchao/prototype/low_bit_optim/subclass_fp8.py +++ b/torchao/prototype/low_bit_optim/subclass_fp8.py @@ -119,7 +119,9 @@ def _(func, types, args, kwargs): dst.codes.copy_(codes) dst.scale.copy_(scale) - dst.k.copy_(k) + + if dst.k is not None: + dst.k.copy_(k) else: dst.copy_(src.dequantize()) From 38951aed3b0175169710575df1d740d53a3055f2 Mon Sep 17 00:00:00 2001 From: Mustafa Date: Tue, 12 Nov 2024 00:36:53 -0600 Subject: [PATCH 33/49] matching parameters counts with constructor of optimStateFp8 --- torchao/prototype/low_bit_optim/subclass_fp8.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/prototype/low_bit_optim/subclass_fp8.py b/torchao/prototype/low_bit_optim/subclass_fp8.py index 0e883e8e21..affc618513 100644 --- a/torchao/prototype/low_bit_optim/subclass_fp8.py +++ b/torchao/prototype/low_bit_optim/subclass_fp8.py @@ -42,7 +42,7 @@ class OptimStateFp8(TorchAOBaseTensor): tensor_attrs = ["codes", "scale"] @staticmethod - def __new__(cls, codes: Tensor, scale: Tensor): + def __new__(cls, codes: Tensor, scale: Tensor, k: Optional[Tensor] =None): return Tensor._make_wrapper_subclass(cls, codes.shape, device=codes.device) def __init__(self, codes: Tensor, scale: Tensor, k: Optional[Tensor] =None): From 4b3fb6bbc3b16a2a834dcd97e3879b92709c20ab Mon Sep 17 00:00:00 2001 From: Mustafa Date: Tue, 12 Nov 2024 00:39:19 -0600 Subject: [PATCH 34/49] copy to k tensor only if k is not None --- torchao/prototype/low_bit_optim/subclass_fp8.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchao/prototype/low_bit_optim/subclass_fp8.py b/torchao/prototype/low_bit_optim/subclass_fp8.py index affc618513..9a734ad199 100644 --- a/torchao/prototype/low_bit_optim/subclass_fp8.py +++ b/torchao/prototype/low_bit_optim/subclass_fp8.py @@ -111,7 +111,8 @@ def _(func, types, args, kwargs): assert dst.block_size == src.block_size dst.codes.copy_(src.codes) dst.scale.copy_(src.scale) - dst.k.copy_(src.k) + if dst.k is not None: + dst.k.copy_(src.k) elif isinstance(dst, OptimStateFp8): From 7185b0084dbed19b8d646e0341691daafeadf8ab Mon Sep 17 00:00:00 2001 From: Mustafa Date: Tue, 12 Nov 2024 00:53:00 -0600 Subject: [PATCH 35/49] passing k tensor if values are available --- torchao/prototype/low_bit_optim/subclass_fp8.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchao/prototype/low_bit_optim/subclass_fp8.py b/torchao/prototype/low_bit_optim/subclass_fp8.py index 9a734ad199..081e9e3c12 100644 --- a/torchao/prototype/low_bit_optim/subclass_fp8.py +++ b/torchao/prototype/low_bit_optim/subclass_fp8.py @@ -136,6 +136,7 @@ def _(func, types, args, kwargs): out = OptimStateFp8( args[0].codes.to(device=device), args[0].scale.to(device=device), + args[0].k.to(device=device) if args[0].k is not None else None ) return return_and_correct_aliasing(func, args, kwargs, out) @@ -150,7 +151,7 @@ def _(func, types, args, kwargs): @OptimStateFp8.implements(aten.view.default) def _(func, types, args, kwargs): x, shape = args - return OptimStateFp8(x.codes.view(shape), x.scale) + return OptimStateFp8(x.codes.view(shape), x.scale, x.k) # this is needed for DTensor.full_tensor() @@ -171,6 +172,7 @@ def _(func, types, args, kwargs): return OptimStateFp8( func(x.codes, *args[1:], **kwargs), func(x.scale, *args[1:], **kwargs), + func(x.k, *args[1:], **kwargs) if x.k else None ) From 0d7edae961eee436622822faf5744c1833b1c7d8 Mon Sep 17 00:00:00 2001 From: Mustafa Date: Tue, 12 Nov 2024 01:38:51 -0600 Subject: [PATCH 36/49] providing dynamic range expansion to the adamfloat8 class --- torchao/prototype/low_bit_optim/adam.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/torchao/prototype/low_bit_optim/adam.py b/torchao/prototype/low_bit_optim/adam.py index 1c3718972b..fa2d36a06f 100644 --- a/torchao/prototype/low_bit_optim/adam.py +++ b/torchao/prototype/low_bit_optim/adam.py @@ -253,6 +253,7 @@ def __init__( *, block_size=256, bf16_stochastic_round=False, + dynamic_range_expansion=False, ) -> None: super().__init__( params, @@ -265,11 +266,28 @@ def __init__( bf16_stochastic_round=bf16_stochastic_round, is_adamw=False, ) + self.dynamic_range_expansion = dynamic_range_expansion @staticmethod - def _subclass_zeros(p: Tensor, signed: bool, block_size: int): - return OptimStateFp8.zeros(p.shape, block_size, p.device) + def _subclass_zeros(p: Tensor, signed: bool, block_size: int, dynamic_range_expansion: bool): + return OptimStateFp8.zeros(p.shape, block_size, p.device, dynamic_range_expansion) + def _new_buffer(self, p: Tensor, signed: bool): + if p.numel() >= 4096 and p.numel() % self.block_size == 0: + if isinstance(p, DTensor): + out = DTensor.from_local( + local_tensor=self._subclass_zeros( + p.to_local(), signed, self.block_size, self.dynamic_range_expansion + ), + device_mesh=p.device_mesh, + placements=p.placements, + run_check=False, + ) + else: + out = self._subclass_zeros(p, signed, self.block_size, self.dynamic_range_expansion) + else: + out = torch.zeros_like(p) + return out class AdamW8bit(_AdamBase): def __init__( From 58ff635b67b20e23f2f70ca91e7d6b14104b2d7b Mon Sep 17 00:00:00 2001 From: Mustafa Date: Tue, 12 Nov 2024 09:52:23 -0600 Subject: [PATCH 37/49] change of _subclass_zeros from static method to normal class method --- torchao/prototype/low_bit_optim/adam.py | 54 +++++++------------------ 1 file changed, 15 insertions(+), 39 deletions(-) diff --git a/torchao/prototype/low_bit_optim/adam.py b/torchao/prototype/low_bit_optim/adam.py index fa2d36a06f..d10f15712b 100644 --- a/torchao/prototype/low_bit_optim/adam.py +++ b/torchao/prototype/low_bit_optim/adam.py @@ -51,8 +51,7 @@ def __setstate__(self, state): group.setdefault("amsgrad", False) # bring your own function to create zero-filled subclass - @staticmethod - def _subclass_zeros(p: Tensor, signed: bool, block_size: int): + def _subclass_zeros(self, p: Tensor, signed: bool): raise NotImplementedError # follow bitsandbytes, only quantize tensors >= 4096 values @@ -62,14 +61,14 @@ def _new_buffer(self, p: Tensor, signed: bool): if isinstance(p, DTensor): out = DTensor.from_local( local_tensor=self._subclass_zeros( - p.to_local(), signed, self.block_size + p.to_local(), signed ), device_mesh=p.device_mesh, placements=p.placements, run_check=False, ) else: - out = self._subclass_zeros(p, signed, self.block_size) + out = self._subclass_zeros(p, signed) else: out = torch.zeros_like(p) return out @@ -206,9 +205,8 @@ def __init__( is_adamw=False, ) - @staticmethod - def _subclass_zeros(p: Tensor, signed: bool, block_size: int): - return OptimState8bit.zeros(p.shape, signed, block_size, p.device) + def _subclass_zeros(self, p: Tensor, signed: bool): + return OptimState8bit.zeros(p.shape, signed, self.block_size, p.device) class Adam4bit(_AdamBase): @@ -236,9 +234,8 @@ def __init__( is_adamw=False, ) - @staticmethod - def _subclass_zeros(p: Tensor, signed: bool, block_size: int): - return OptimState4bit.zeros(p.shape, signed, block_size, p.device) + def _subclass_zeros(self, p: Tensor, signed: bool): + return OptimState8bit.zeros(p.shape, signed, self.block_size, p.device) class AdamFp8(_AdamBase): @@ -268,26 +265,8 @@ def __init__( ) self.dynamic_range_expansion = dynamic_range_expansion - @staticmethod - def _subclass_zeros(p: Tensor, signed: bool, block_size: int, dynamic_range_expansion: bool): - return OptimStateFp8.zeros(p.shape, block_size, p.device, dynamic_range_expansion) - - def _new_buffer(self, p: Tensor, signed: bool): - if p.numel() >= 4096 and p.numel() % self.block_size == 0: - if isinstance(p, DTensor): - out = DTensor.from_local( - local_tensor=self._subclass_zeros( - p.to_local(), signed, self.block_size, self.dynamic_range_expansion - ), - device_mesh=p.device_mesh, - placements=p.placements, - run_check=False, - ) - else: - out = self._subclass_zeros(p, signed, self.block_size, self.dynamic_range_expansion) - else: - out = torch.zeros_like(p) - return out + def _subclass_zeros(self, p: Tensor, signed: bool): + return OptimStateFp8.zeros(p.shape, self.block_size, p.device, self.dynamic_range_expansion) class AdamW8bit(_AdamBase): def __init__( @@ -314,9 +293,8 @@ def __init__( is_adamw=True, ) - @staticmethod - def _subclass_zeros(p: Tensor, signed: bool, block_size: int): - return OptimState8bit.zeros(p.shape, signed, block_size, p.device) + def _subclass_zeros(self, p: Tensor, signed: bool): + return OptimState8bit.zeros(p.shape, signed, self.block_size, p.device) class AdamW4bit(_AdamBase): @@ -344,9 +322,8 @@ def __init__( is_adamw=True, ) - @staticmethod - def _subclass_zeros(p: Tensor, signed: bool, block_size: int): - return OptimState4bit.zeros(p.shape, signed, block_size, p.device) + def _subclass_zeros(self, p: Tensor, signed: bool): + return OptimState4bit.zeros(p.shape, signed, self.block_size, p.device) class AdamWFp8(_AdamBase): @@ -374,9 +351,8 @@ def __init__( is_adamw=True, ) - @staticmethod - def _subclass_zeros(p: Tensor, signed: bool, block_size: int): - return OptimStateFp8.zeros(p.shape, block_size, p.device) + def _subclass_zeros(self, p: Tensor, signed: bool): + return OptimStateFp8.zeros(p.shape, self.block_size, p.device) class _AdamW(_AdamBase): From 6c536a92721764332c164a74e96bd2a35188d531 Mon Sep 17 00:00:00 2001 From: Mustafa Date: Tue, 12 Nov 2024 09:57:26 -0600 Subject: [PATCH 38/49] added dynamic range expansion to adamwfp8 --- torchao/prototype/low_bit_optim/adam.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torchao/prototype/low_bit_optim/adam.py b/torchao/prototype/low_bit_optim/adam.py index d10f15712b..d03969511f 100644 --- a/torchao/prototype/low_bit_optim/adam.py +++ b/torchao/prototype/low_bit_optim/adam.py @@ -338,6 +338,7 @@ def __init__( *, block_size=256, bf16_stochastic_round=False, + dynamic_range_expansion=False, ) -> None: super().__init__( params, @@ -350,9 +351,10 @@ def __init__( bf16_stochastic_round=bf16_stochastic_round, is_adamw=True, ) - + self.dynamic_range_expansion = dynamic_range_expansion + def _subclass_zeros(self, p: Tensor, signed: bool): - return OptimStateFp8.zeros(p.shape, self.block_size, p.device) + return OptimStateFp8.zeros(p.shape, self.block_size, p.device, self.dynamic_range_expansion) class _AdamW(_AdamBase): From 767ccabac452c259c635579881caab0946f922bf Mon Sep 17 00:00:00 2001 From: Mustafa Date: Tue, 12 Nov 2024 11:04:24 -0600 Subject: [PATCH 39/49] adding smoke test for additional parameters for float8 optimizers --- test/prototype/test_low_bit_optim.py | 51 ++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/test/prototype/test_low_bit_optim.py b/test/prototype/test_low_bit_optim.py index f0b608b47d..e651df86a7 100644 --- a/test/prototype/test_low_bit_optim.py +++ b/test/prototype/test_low_bit_optim.py @@ -150,6 +150,57 @@ def test_optim_smoke(self, optim_name, dtype, device): for p1, p2 in zip(model.parameters(), model2.parameters()): torch.testing.assert_close(p2, p1) + + @pytest.mark.skipif( + not TORCH_VERSION_AT_LEAST_2_4, reason="FP8 CUDA requires PyTorch >= 2.4" + ) + @pytest.mark.skipif( + torch.cuda.get_device_capability() < (8, 9), reason="FP8 CUDA requires compute capability >= 8.9" + ) + @parametrize( + "optim_name", + ["AdamFp8", "AdamWFp8"], + ) + @parametrize("dtype", [torch.float32, torch.bfloat16]) + @parametrize("device", _DEVICES) + @parametrize("optim_addon",["dynamic_range_expansion"]) + def test_optim_addons(self, optim_name, dtype, device, optim_addon): + + model = nn.Sequential(nn.Linear(32, 256), nn.ReLU(), nn.Linear(256, 32)) + model.to(device=device, dtype=dtype) + + optim_params = {optim_addon: True} + optim = getattr(low_bit_optim, optim_name)(model.parameters(), **optim_params) + + x = torch.randn(4, 32, device=device, dtype=dtype) + loss = model(x).sum() + loss.backward() + optim.step() + optim.zero_grad() + + # test serialization. also test the case CUDA optim loads CPU state dict + with tempfile.NamedTemporaryFile() as f: + torch.save(optim.state_dict(), f.name) + state_dict = torch.load(f.name, map_location="cpu") + + model2 = copy.deepcopy(model) + optim2 = getattr(low_bit_optim, optim_name)(model2.parameters()) + optim2.load_state_dict(state_dict) + + for _ in range(2): + x = torch.randn(4, 32, device=device, dtype=dtype) + + model(x).sum().backward() + optim.step() + optim.zero_grad() + + model2(x).sum().backward() + optim2.step() + optim2.zero_grad() + + for p1, p2 in zip(model.parameters(), model2.parameters()): + torch.testing.assert_close(p2, p1) + @pytest.mark.skipif(bnb is None, reason="bitsandbytes is not available") @pytest.mark.skipif( From 8fa5e3dec55218e828299bd34ecbed912eff6ae5 Mon Sep 17 00:00:00 2001 From: Mustafa Date: Tue, 12 Nov 2024 21:16:15 -0600 Subject: [PATCH 40/49] added new line --- torchao/prototype/low_bit_optim/subclass_fp8.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchao/prototype/low_bit_optim/subclass_fp8.py b/torchao/prototype/low_bit_optim/subclass_fp8.py index 081e9e3c12..c55517726d 100644 --- a/torchao/prototype/low_bit_optim/subclass_fp8.py +++ b/torchao/prototype/low_bit_optim/subclass_fp8.py @@ -179,4 +179,5 @@ def _(func, types, args, kwargs): if TORCH_VERSION_AT_LEAST_2_5: from torch.serialization import add_safe_globals - add_safe_globals([OptimStateFp8]) \ No newline at end of file + add_safe_globals([OptimStateFp8]) + \ No newline at end of file From f34bfdd6a938698a2dc9699276722d304641ae1b Mon Sep 17 00:00:00 2001 From: Mustafa Date: Tue, 12 Nov 2024 21:17:14 -0600 Subject: [PATCH 41/49] remove newline --- torchao/prototype/low_bit_optim/subclass_fp8.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchao/prototype/low_bit_optim/subclass_fp8.py b/torchao/prototype/low_bit_optim/subclass_fp8.py index c55517726d..081e9e3c12 100644 --- a/torchao/prototype/low_bit_optim/subclass_fp8.py +++ b/torchao/prototype/low_bit_optim/subclass_fp8.py @@ -179,5 +179,4 @@ def _(func, types, args, kwargs): if TORCH_VERSION_AT_LEAST_2_5: from torch.serialization import add_safe_globals - add_safe_globals([OptimStateFp8]) - \ No newline at end of file + add_safe_globals([OptimStateFp8]) \ No newline at end of file From 41598a014ca76cc27e7ee7038ef5886b66138dd3 Mon Sep 17 00:00:00 2001 From: Mustafa Date: Tue, 12 Nov 2024 21:23:08 -0600 Subject: [PATCH 42/49] removed optim_addon parameter --- test/prototype/test_low_bit_optim.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/test/prototype/test_low_bit_optim.py b/test/prototype/test_low_bit_optim.py index e651df86a7..cc1d0f6826 100644 --- a/test/prototype/test_low_bit_optim.py +++ b/test/prototype/test_low_bit_optim.py @@ -163,14 +163,12 @@ def test_optim_smoke(self, optim_name, dtype, device): ) @parametrize("dtype", [torch.float32, torch.bfloat16]) @parametrize("device", _DEVICES) - @parametrize("optim_addon",["dynamic_range_expansion"]) - def test_optim_addons(self, optim_name, dtype, device, optim_addon): + def test_optim_addons(self, optim_name, dtype, device): model = nn.Sequential(nn.Linear(32, 256), nn.ReLU(), nn.Linear(256, 32)) model.to(device=device, dtype=dtype) - optim_params = {optim_addon: True} - optim = getattr(low_bit_optim, optim_name)(model.parameters(), **optim_params) + optim = getattr(low_bit_optim, optim_name)(model.parameters(), dynamic_range_expansion=True) x = torch.randn(4, 32, device=device, dtype=dtype) loss = model(x).sum() From c189dc7e063fdd8ec0afaa4560e588282c1729a6 Mon Sep 17 00:00:00 2001 From: Mustafa Date: Tue, 12 Nov 2024 21:23:39 -0600 Subject: [PATCH 43/49] rename test_optim_addon to test_optim_fp8_coat_smoke --- test/prototype/test_low_bit_optim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/prototype/test_low_bit_optim.py b/test/prototype/test_low_bit_optim.py index cc1d0f6826..14874f0da7 100644 --- a/test/prototype/test_low_bit_optim.py +++ b/test/prototype/test_low_bit_optim.py @@ -163,7 +163,7 @@ def test_optim_smoke(self, optim_name, dtype, device): ) @parametrize("dtype", [torch.float32, torch.bfloat16]) @parametrize("device", _DEVICES) - def test_optim_addons(self, optim_name, dtype, device): + def test_optim_fp8_coat_smoke(self, optim_name, dtype, device): model = nn.Sequential(nn.Linear(32, 256), nn.ReLU(), nn.Linear(256, 32)) model.to(device=device, dtype=dtype) From 6bb49ea3d55bb7f7152387e7d3091637e50491b5 Mon Sep 17 00:00:00 2001 From: Mustafa Date: Tue, 12 Nov 2024 21:56:32 -0600 Subject: [PATCH 44/49] code formatting --- torchao/prototype/low_bit_optim/adam.py | 15 +++--- .../prototype/low_bit_optim/subclass_fp8.py | 52 ++++++++++++------- 2 files changed, 42 insertions(+), 25 deletions(-) diff --git a/torchao/prototype/low_bit_optim/adam.py b/torchao/prototype/low_bit_optim/adam.py index d03969511f..a84e0f85d7 100644 --- a/torchao/prototype/low_bit_optim/adam.py +++ b/torchao/prototype/low_bit_optim/adam.py @@ -60,9 +60,7 @@ def _new_buffer(self, p: Tensor, signed: bool): if p.numel() >= 4096 and p.numel() % self.block_size == 0: if isinstance(p, DTensor): out = DTensor.from_local( - local_tensor=self._subclass_zeros( - p.to_local(), signed - ), + local_tensor=self._subclass_zeros(p.to_local(), signed), device_mesh=p.device_mesh, placements=p.placements, run_check=False, @@ -266,7 +264,10 @@ def __init__( self.dynamic_range_expansion = dynamic_range_expansion def _subclass_zeros(self, p: Tensor, signed: bool): - return OptimStateFp8.zeros(p.shape, self.block_size, p.device, self.dynamic_range_expansion) + return OptimStateFp8.zeros( + p.shape, self.block_size, p.device, self.dynamic_range_expansion + ) + class AdamW8bit(_AdamBase): def __init__( @@ -352,9 +353,11 @@ def __init__( is_adamw=True, ) self.dynamic_range_expansion = dynamic_range_expansion - + def _subclass_zeros(self, p: Tensor, signed: bool): - return OptimStateFp8.zeros(p.shape, self.block_size, p.device, self.dynamic_range_expansion) + return OptimStateFp8.zeros( + p.shape, self.block_size, p.device, self.dynamic_range_expansion + ) class _AdamW(_AdamBase): diff --git a/torchao/prototype/low_bit_optim/subclass_fp8.py b/torchao/prototype/low_bit_optim/subclass_fp8.py index 081e9e3c12..88828162fe 100644 --- a/torchao/prototype/low_bit_optim/subclass_fp8.py +++ b/torchao/prototype/low_bit_optim/subclass_fp8.py @@ -1,6 +1,7 @@ +from typing import Optional + import torch from torch import Tensor -from typing import Optional from torch.utils._python_dispatch import return_and_correct_aliasing from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TorchAOBaseTensor @@ -13,19 +14,22 @@ def quantize_fp8(input: Tensor, block_size: int, dynamic_range_expansion: bool): - shape = input.shape input = input.view(-1, block_size) k = None if dynamic_range_expansion: # NOTE: the calculation is from the paper https://arxiv.org/abs/2410.19313 - # The idea is to align optimizer state distributions more closely + # The idea is to align optimizer state distributions more closely # with the FP8 representation range, reducing the quantization error. k = aten.ones(input.shape[0], dtype=DTYPE, device=input.device) - Rdtype = torch.finfo(DTYPE).max / torch.finfo(DTYPE).min # calculate the range of the dtype - Rx = input.abs().amax(-1).clip(1e-12) / input.abs().amin(-1).clip(1e-12) # range of input max and min - k = torch.log(Rdtype) / torch.log(Rx)# calculating optimal value k dynamically + Rdtype = ( + torch.finfo(DTYPE).max / torch.finfo(DTYPE).min + ) # calculate the range of the dtype + Rx = input.abs().amax(-1).clip(1e-12) / input.abs().amin(-1).clip( + 1e-12 + ) # range of input max and min + k = torch.log(Rdtype) / torch.log(Rx) # calculating optimal value k dynamically input = input.sign() * (input.abs() ** k.view(-1, 1)) scale = input.abs().amax(-1).clip(1e-12) / torch.finfo(DTYPE).max @@ -42,10 +46,10 @@ class OptimStateFp8(TorchAOBaseTensor): tensor_attrs = ["codes", "scale"] @staticmethod - def __new__(cls, codes: Tensor, scale: Tensor, k: Optional[Tensor] =None): + def __new__(cls, codes: Tensor, scale: Tensor, k: Optional[Tensor] = None): return Tensor._make_wrapper_subclass(cls, codes.shape, device=codes.device) - def __init__(self, codes: Tensor, scale: Tensor, k: Optional[Tensor] =None): + def __init__(self, codes: Tensor, scale: Tensor, k: Optional[Tensor] = None): """Create quantized FP8 optimizer state. Args @@ -80,19 +84,28 @@ def dequantize(self, output_dtype=None): if self.k is not None: float_data = float_data.view(-1, self.block_size) - float_data = float_data ** (1 / self.k.view(-1, 1)) - + float_data = float_data ** (1 / self.k.view(-1, 1)) + if output_dtype is not None: float_data = float_data.to(output_dtype) return float_data.view(self.codes.shape) @classmethod - def zeros(cls, shape, block_size: int = 256, device=None, dynamic_range_expansion: bool = False): - + def zeros( + cls, + shape, + block_size: int = 256, + device=None, + dynamic_range_expansion: bool = False, + ): codes = torch.zeros(shape, dtype=DTYPE, device=device) scale = torch.zeros(codes.numel() // block_size, device=device) - k = torch.ones(codes.numel() // block_size, device=device) if dynamic_range_expansion else None + k = ( + torch.ones(codes.numel() // block_size, device=device) + if dynamic_range_expansion + else None + ) return cls(codes, scale, k) def __repr__(self): @@ -115,12 +128,13 @@ def _(func, types, args, kwargs): dst.k.copy_(src.k) elif isinstance(dst, OptimStateFp8): + codes, scale, k = quantize_fp8( + src, dst.block_size, True if dst.k is not None else False + ) - codes, scale, k = quantize_fp8(src, dst.block_size, True if dst.k is not None else False) - dst.codes.copy_(codes) dst.scale.copy_(scale) - + if dst.k is not None: dst.k.copy_(k) else: @@ -136,7 +150,7 @@ def _(func, types, args, kwargs): out = OptimStateFp8( args[0].codes.to(device=device), args[0].scale.to(device=device), - args[0].k.to(device=device) if args[0].k is not None else None + args[0].k.to(device=device) if args[0].k is not None else None, ) return return_and_correct_aliasing(func, args, kwargs, out) @@ -172,11 +186,11 @@ def _(func, types, args, kwargs): return OptimStateFp8( func(x.codes, *args[1:], **kwargs), func(x.scale, *args[1:], **kwargs), - func(x.k, *args[1:], **kwargs) if x.k else None + func(x.k, *args[1:], **kwargs) if x.k else None, ) if TORCH_VERSION_AT_LEAST_2_5: from torch.serialization import add_safe_globals - add_safe_globals([OptimStateFp8]) \ No newline at end of file + add_safe_globals([OptimStateFp8]) From b1aea26551fa7ee4a05b1f6ed0f852765023105c Mon Sep 17 00:00:00 2001 From: Mustafa Date: Tue, 12 Nov 2024 22:14:43 -0600 Subject: [PATCH 45/49] Moved device compatibility check for FP8 optimizer tests from pytest skip marker to within the function. --- test/prototype/test_low_bit_optim.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/test/prototype/test_low_bit_optim.py b/test/prototype/test_low_bit_optim.py index b462967281..900d8c8249 100644 --- a/test/prototype/test_low_bit_optim.py +++ b/test/prototype/test_low_bit_optim.py @@ -149,12 +149,7 @@ def test_optim_smoke(self, optim_name, dtype, device): for p1, p2 in zip(model.parameters(), model2.parameters()): torch.testing.assert_close(p2, p1) - @pytest.mark.skipif( - not TORCH_VERSION_AT_LEAST_2_4, reason="FP8 CUDA requires PyTorch >= 2.4" - ) - @pytest.mark.skipif( - torch.cuda.get_device_capability() < (8, 9), reason="FP8 CUDA requires compute capability >= 8.9" - ) + @parametrize( "optim_name", ["AdamFp8", "AdamWFp8"], @@ -163,6 +158,11 @@ def test_optim_smoke(self, optim_name, dtype, device): @parametrize("device", _DEVICES) def test_optim_fp8_coat_smoke(self, optim_name, dtype, device): + if not TORCH_VERSION_AT_LEAST_2_4: + pytest.skip("FP8 CUDA requires PyTorch >= 2.4") + if torch.cuda.get_device_capability() < (8, 9): + pytest.skip("FP8 CUDA requires compute capability >= 8.9") + model = nn.Sequential(nn.Linear(32, 256), nn.ReLU(), nn.Linear(256, 32)) model.to(device=device, dtype=dtype) From 92ca7b274787e6356f4e4d8e822b40f6a8fa4ac8 Mon Sep 17 00:00:00 2001 From: Mustafa Date: Tue, 12 Nov 2024 22:40:02 -0600 Subject: [PATCH 46/49] formatting for `ruff check F,I` --- torchao/prototype/low_bit_optim/subclass_fp8.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchao/prototype/low_bit_optim/subclass_fp8.py b/torchao/prototype/low_bit_optim/subclass_fp8.py index 88c92535d1..4835aaaa90 100644 --- a/torchao/prototype/low_bit_optim/subclass_fp8.py +++ b/torchao/prototype/low_bit_optim/subclass_fp8.py @@ -1,4 +1,6 @@ import math +from typing import Optional + import torch from torch import Tensor from typing import Optional From 861423d9c8d7d8410a4f3c72b4aedc57f05d2b33 Mon Sep 17 00:00:00 2001 From: Mustafa Date: Tue, 12 Nov 2024 22:41:19 -0600 Subject: [PATCH 47/49] removing duplicate --- torchao/prototype/low_bit_optim/subclass_fp8.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchao/prototype/low_bit_optim/subclass_fp8.py b/torchao/prototype/low_bit_optim/subclass_fp8.py index 4835aaaa90..ef6a13d328 100644 --- a/torchao/prototype/low_bit_optim/subclass_fp8.py +++ b/torchao/prototype/low_bit_optim/subclass_fp8.py @@ -3,7 +3,6 @@ import torch from torch import Tensor -from typing import Optional from torch.utils._python_dispatch import return_and_correct_aliasing from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TorchAOBaseTensor From 7661b6171cbb5b2c19320962ab45c795f5be15cf Mon Sep 17 00:00:00 2001 From: Mustafa Date: Tue, 12 Nov 2024 23:53:10 -0600 Subject: [PATCH 48/49] checking if device is cuda before calling device capability --- test/prototype/test_low_bit_optim.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/prototype/test_low_bit_optim.py b/test/prototype/test_low_bit_optim.py index 900d8c8249..b392ca032e 100644 --- a/test/prototype/test_low_bit_optim.py +++ b/test/prototype/test_low_bit_optim.py @@ -157,11 +157,11 @@ def test_optim_smoke(self, optim_name, dtype, device): @parametrize("dtype", [torch.float32, torch.bfloat16]) @parametrize("device", _DEVICES) def test_optim_fp8_coat_smoke(self, optim_name, dtype, device): - - if not TORCH_VERSION_AT_LEAST_2_4: - pytest.skip("FP8 CUDA requires PyTorch >= 2.4") - if torch.cuda.get_device_capability() < (8, 9): - pytest.skip("FP8 CUDA requires compute capability >= 8.9") + if device == "cuda": + if not TORCH_VERSION_AT_LEAST_2_4: + pytest.skip("FP8 CUDA requires PyTorch >= 2.4") + if torch.cuda.get_device_capability() < (8, 9): + pytest.skip("FP8 CUDA requires compute capability >= 8.9") model = nn.Sequential(nn.Linear(32, 256), nn.ReLU(), nn.Linear(256, 32)) model.to(device=device, dtype=dtype) From e1fa683d4756023eec686a5f55283e459d5b88b7 Mon Sep 17 00:00:00 2001 From: Mustafa Date: Wed, 13 Nov 2024 02:20:50 -0600 Subject: [PATCH 49/49] Updating Readme with dynamic range Expansion and Reference to Paper --- torchao/prototype/README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchao/prototype/README.md b/torchao/prototype/README.md index 2e0f9725a4..011687f210 100644 --- a/torchao/prototype/README.md +++ b/torchao/prototype/README.md @@ -11,8 +11,10 @@ - `galore/docs` - implementation notes and discussion of issues faced in kernel design. - [`quant_llm`](quant_llm) - FP16 x Floatx mixed matmul kernel per [FP6-LLM](https://arxiv.org/abs/2401.14112) - [`low_bit_optim`](low_bit_optim) - re-implementation of 8-bit optimizers from [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) and 4-bit optimizers from [lpmm](https://github.com/thu-ml/low-bit-optimizers). + * `dynamic_range_expansion` - implementing additional heuristic _expand & compress_ method before quantizing and after dequantizing of Optimizer states for float8 optimizers. [COAT](https://arxiv.org/abs/2410.19313) - [`spinquant`](spinquant) - re-implementation of [SpinQuant](https://arxiv.org/abs/2405.16406) + #### Roadmap - `hqq`, `awq`, `marlin`,`QuaRot`, and other well-researched methodologies for quantized fine-tuning and inference.