From 6bf0f5c83e6c316e57671b4a395b01e268d8f0ea Mon Sep 17 00:00:00 2001 From: Wei Feng Date: Wed, 18 Sep 2024 19:29:21 -0700 Subject: [PATCH 01/41] [float8] improve eager numerics for dynamic scales --- test/float8/test_base.py | 34 ++++++++++++++++++++++++++++++++- torchao/float8/float8_tensor.py | 5 ++++- torchao/float8/float8_utils.py | 5 ++++- 3 files changed, 41 insertions(+), 3 deletions(-) diff --git a/test/float8/test_base.py b/test/float8/test_base.py index 2a875c44d6..ac82553b7a 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -15,6 +15,9 @@ import torch import torch.nn as nn +from torchao.float8.float8_scaling_utils import ( + hp_tensor_to_float8_dynamic, +) from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 @@ -53,7 +56,7 @@ is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) def bitwise_identical(a: Float8Tensor, b: Float8Tensor) -> bool: - assert torch.all(a._data == b._data).item(), "scales are not identical" + assert torch.all(a._scale == b._scale).item(), "scales are not identical" assert torch.all(a._data == b._data).item(), "data is not identical" return True @@ -604,6 +607,35 @@ def test_small_amax_float16(self, float8_dtype): x = torch.tensor([target_amax], dtype=torch.float16, device="cuda") scale = tensor_to_scale(x, float8_dtype) assert not torch.any(torch.isinf(scale)) + + @pytest.mark.parametrize( + "dtype", + [ + torch.float32, + torch.bfloat16, + torch.float16, + ], + ) + def test_dynamic_scale_parity(self, dtype: torch.dtype): + scaling_type_weight = ScalingType.DYNAMIC + torch.manual_seed(42) + hp_tensor = torch.randn(768, 32, device="cuda", dtype=dtype) + float8_config = Float8LinearConfig( + cast_config_weight=CastConfig(scaling_type=scaling_type_weight), + ) + float8_eager = hp_tensor_to_float8_dynamic( + hp_tensor, + torch.float8_e4m3fn, + float8_config, + gemm_input_role=GemmInputRole.WEIGHT, + ) + float8_compile = torch.compile(hp_tensor_to_float8_dynamic)( + hp_tensor, + torch.float8_e4m3fn, + float8_config, + gemm_input_role=GemmInputRole.WEIGHT, + ) + assert bitwise_identical(float8_eager, float8_compile) class TestFloat8LinearUtils(unittest.TestCase): diff --git a/torchao/float8/float8_tensor.py b/torchao/float8/float8_tensor.py index 63110101a5..81a69a25e2 100644 --- a/torchao/float8/float8_tensor.py +++ b/torchao/float8/float8_tensor.py @@ -163,7 +163,10 @@ def forward( DTensor Invariant: DTensor must always be the outer most tensor subclass """ - tensor_scaled = tensor * scale + # Required by scaled_mm, scale is always float32. + # Cast tensor to float32 to improve numerics and + # get on-par with torch.compile. + tensor_scaled = tensor.to(torch.float32) * scale bits_fp8 = to_fp8_saturated(tensor_scaled, float8_dtype) if isinstance(bits_fp8, DTensor): diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index 54613e5b05..cfa1fa43cc 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -42,6 +42,9 @@ def amax_to_scale( float8_dtype: The float8 dtype. orig_dtype: The original dtype of the tensor. """ + # Preserve precision in amax-to-scale conversion + # and ensure on-par numerics with torch.compile + amax = amax.to(torch.float64) if float8_dtype in FP8_TYPES: res = torch.finfo(float8_dtype).max / torch.clamp(amax, min=EPS) else: @@ -99,7 +102,7 @@ def amax_history_to_scale_stack( @torch.no_grad() def tensor_to_amax(x: torch.Tensor, reduce_amax: bool = False) -> torch.Tensor: - amax = torch.max(torch.abs(x)) + amax = torch.linalg.vector_norm(x, ord=float("inf")) # If the user asked for distributed reduction, do it. # If the user did not ask for it, assume that it will From 553687f37a3c49bb9a60283e492bc4e32e8390f5 Mon Sep 17 00:00:00 2001 From: Wei Feng Date: Wed, 18 Sep 2024 19:43:22 -0700 Subject: [PATCH 02/41] leave torch.linalg.vector_norm for another PR Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchao/float8/float8_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index cfa1fa43cc..de2f0bc500 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -102,7 +102,7 @@ def amax_history_to_scale_stack( @torch.no_grad() def tensor_to_amax(x: torch.Tensor, reduce_amax: bool = False) -> torch.Tensor: - amax = torch.linalg.vector_norm(x, ord=float("inf")) + amax = torch.max(torch.abs(x)) # If the user asked for distributed reduction, do it. # If the user did not ask for it, assume that it will From 19a592d493f7e91706ec061d2f3140701ceb44f0 Mon Sep 17 00:00:00 2001 From: Wei Feng Date: Wed, 18 Sep 2024 20:00:27 -0700 Subject: [PATCH 03/41] cuda Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/float8/test_base.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/float8/test_base.py b/test/float8/test_base.py index ac82553b7a..5607efbfe6 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -608,6 +608,10 @@ def test_small_amax_float16(self, float8_dtype): scale = tensor_to_scale(x, float8_dtype) assert not torch.any(torch.isinf(scale)) + @unittest.skipIf( + not is_cuda_8_9, + "CUDA not available", + ) @pytest.mark.parametrize( "dtype", [ From 218290eb00927649755b409bfee494e3131068d3 Mon Sep 17 00:00:00 2001 From: Wei Feng Date: Wed, 18 Sep 2024 20:25:48 -0700 Subject: [PATCH 04/41] remove _data and investigate Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/float8/test_base.py | 4 ++-- torchao/float8/float8_tensor.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/test/float8/test_base.py b/test/float8/test_base.py index 5607efbfe6..fc0a8bd12d 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -56,7 +56,7 @@ is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) def bitwise_identical(a: Float8Tensor, b: Float8Tensor) -> bool: - assert torch.all(a._scale == b._scale).item(), "scales are not identical" + assert torch.all(a._data == b._data).item(), "scales are not identical" assert torch.all(a._data == b._data).item(), "data is not identical" return True @@ -639,7 +639,7 @@ def test_dynamic_scale_parity(self, dtype: torch.dtype): float8_config, gemm_input_role=GemmInputRole.WEIGHT, ) - assert bitwise_identical(float8_eager, float8_compile) + assert torch.equal(float8_eager._scale, float8_compile._scale) class TestFloat8LinearUtils(unittest.TestCase): diff --git a/torchao/float8/float8_tensor.py b/torchao/float8/float8_tensor.py index 81a69a25e2..8bb26f04ec 100644 --- a/torchao/float8/float8_tensor.py +++ b/torchao/float8/float8_tensor.py @@ -166,7 +166,7 @@ def forward( # Required by scaled_mm, scale is always float32. # Cast tensor to float32 to improve numerics and # get on-par with torch.compile. - tensor_scaled = tensor.to(torch.float32) * scale + tensor_scaled = tensor * scale bits_fp8 = to_fp8_saturated(tensor_scaled, float8_dtype) if isinstance(bits_fp8, DTensor): From 24ec9140249c169418783de3f7021acb8936a5d6 Mon Sep 17 00:00:00 2001 From: Wei Feng Date: Wed, 18 Sep 2024 20:26:41 -0700 Subject: [PATCH 05/41] remove _data comment Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchao/float8/float8_tensor.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/torchao/float8/float8_tensor.py b/torchao/float8/float8_tensor.py index 8bb26f04ec..63110101a5 100644 --- a/torchao/float8/float8_tensor.py +++ b/torchao/float8/float8_tensor.py @@ -163,9 +163,6 @@ def forward( DTensor Invariant: DTensor must always be the outer most tensor subclass """ - # Required by scaled_mm, scale is always float32. - # Cast tensor to float32 to improve numerics and - # get on-par with torch.compile. tensor_scaled = tensor * scale bits_fp8 = to_fp8_saturated(tensor_scaled, float8_dtype) From c099486b0b1bbc427cc169af0a6e7b0a28629ea5 Mon Sep 17 00:00:00 2001 From: Wei Feng Date: Sat, 21 Sep 2024 13:02:03 -0700 Subject: [PATCH 06/41] upcast to float32 is enough Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/float8/test_base.py | 2 +- torchao/float8/float8_utils.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/test/float8/test_base.py b/test/float8/test_base.py index fc0a8bd12d..36271fdf8a 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -622,7 +622,7 @@ def test_small_amax_float16(self, float8_dtype): ) def test_dynamic_scale_parity(self, dtype: torch.dtype): scaling_type_weight = ScalingType.DYNAMIC - torch.manual_seed(42) + torch.manual_seed(0) hp_tensor = torch.randn(768, 32, device="cuda", dtype=dtype) float8_config = Float8LinearConfig( cast_config_weight=CastConfig(scaling_type=scaling_type_weight), diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index de2f0bc500..53e6bf59c2 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -44,7 +44,7 @@ def amax_to_scale( """ # Preserve precision in amax-to-scale conversion # and ensure on-par numerics with torch.compile - amax = amax.to(torch.float64) + amax = amax.to(torch.float32) if float8_dtype in FP8_TYPES: res = torch.finfo(float8_dtype).max / torch.clamp(amax, min=EPS) else: @@ -55,7 +55,7 @@ def amax_to_scale( # to care about this for float32/bfloat16. if orig_dtype is torch.float16: res = torch.clamp(res, max=torch.finfo(torch.float16).max) - return res.to(torch.float32) + return res @torch.no_grad() From b93ffc8f615ff59f036cc1add63d8a0a8ff6302a Mon Sep 17 00:00:00 2001 From: Wei Feng Date: Sat, 21 Sep 2024 13:22:08 -0700 Subject: [PATCH 07/41] explain why float32 Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchao/float8/float8_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index 53e6bf59c2..1b8c216677 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -42,8 +42,7 @@ def amax_to_scale( float8_dtype: The float8 dtype. orig_dtype: The original dtype of the tensor. """ - # Preserve precision in amax-to-scale conversion - # and ensure on-par numerics with torch.compile + # _scaled_mm requires float32 scale amax = amax.to(torch.float32) if float8_dtype in FP8_TYPES: res = torch.finfo(float8_dtype).max / torch.clamp(amax, min=EPS) From ebff416e8199afc979c6d910ae4e752c68dbc5e6 Mon Sep 17 00:00:00 2001 From: Wei Feng Date: Sat, 21 Sep 2024 14:33:17 -0700 Subject: [PATCH 08/41] _data parity Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/float8/test_base.py | 1 + torchao/float8/float8_tensor.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/test/float8/test_base.py b/test/float8/test_base.py index 36271fdf8a..d7d39a960c 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -640,6 +640,7 @@ def test_dynamic_scale_parity(self, dtype: torch.dtype): gemm_input_role=GemmInputRole.WEIGHT, ) assert torch.equal(float8_eager._scale, float8_compile._scale) + assert torch.equal(float8_eager._data, float8_compile._data) class TestFloat8LinearUtils(unittest.TestCase): diff --git a/torchao/float8/float8_tensor.py b/torchao/float8/float8_tensor.py index 63110101a5..a584166107 100644 --- a/torchao/float8/float8_tensor.py +++ b/torchao/float8/float8_tensor.py @@ -163,7 +163,8 @@ def forward( DTensor Invariant: DTensor must always be the outer most tensor subclass """ - tensor_scaled = tensor * scale + # scale is float32 thus upcasting tensor to match + tensor_scaled = tensor.to(torch.float32) * scale bits_fp8 = to_fp8_saturated(tensor_scaled, float8_dtype) if isinstance(bits_fp8, DTensor): From 8978ab2dc299a57f2f5e65f727899f0f86e1892f Mon Sep 17 00:00:00 2001 From: Wei Feng Date: Sat, 21 Sep 2024 15:17:45 -0700 Subject: [PATCH 09/41] handle sm8.9 Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/float8/test_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/float8/test_base.py b/test/float8/test_base.py index d7d39a960c..9586967cee 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -640,7 +640,7 @@ def test_dynamic_scale_parity(self, dtype: torch.dtype): gemm_input_role=GemmInputRole.WEIGHT, ) assert torch.equal(float8_eager._scale, float8_compile._scale) - assert torch.equal(float8_eager._data, float8_compile._data) + assert torch.testing.assert_close(float8_eager._data, float8_compile._data) class TestFloat8LinearUtils(unittest.TestCase): From f17dc121e275b9368178c2180b2b03d9c45e4222 Mon Sep 17 00:00:00 2001 From: Wei Feng Date: Sat, 21 Sep 2024 23:45:19 -0700 Subject: [PATCH 10/41] fix transformer unit test Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/float8/test_fsdp2/fsdp2_common.py | 5 +---- torchao/float8/float8_utils.py | 4 ++-- torchao/float8/fsdp_utils.py | 6 +++--- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/test/float8/test_fsdp2/fsdp2_common.py b/test/float8/test_fsdp2/fsdp2_common.py index 333206ba41..0a0a12eb46 100644 --- a/test/float8/test_fsdp2/fsdp2_common.py +++ b/test/float8/test_fsdp2/fsdp2_common.py @@ -48,10 +48,7 @@ def check_parity_no_mp( ): precompute_float8_dynamic_scale_for_fsdp(model) - if compile_transformer_block: - test_cls.assertEqual(losses[0], losses[1], atol=1e-4, rtol=1e-4) - else: - test_cls.assertEqual(losses[0], losses[1]) + test_cls.assertEqual(losses[0], losses[1]) def check_parity_bf16_mp( diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index 1b8c216677..03636b7370 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -43,7 +43,7 @@ def amax_to_scale( orig_dtype: The original dtype of the tensor. """ # _scaled_mm requires float32 scale - amax = amax.to(torch.float32) + amax = amax.to(torch.float64) if float8_dtype in FP8_TYPES: res = torch.finfo(float8_dtype).max / torch.clamp(amax, min=EPS) else: @@ -54,7 +54,7 @@ def amax_to_scale( # to care about this for float32/bfloat16. if orig_dtype is torch.float16: res = torch.clamp(res, max=torch.finfo(torch.float16).max) - return res + return res.to(torch.float32) @torch.no_grad() diff --git a/torchao/float8/fsdp_utils.py b/torchao/float8/fsdp_utils.py index 7ec60c795b..19386d932b 100644 --- a/torchao/float8/fsdp_utils.py +++ b/torchao/float8/fsdp_utils.py @@ -59,7 +59,7 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None: return # inf-norm is equivalent to max(abs(w)) - max_weights = torch._foreach_norm(weights, ord=math.inf) # Partial + max_weights = torch._foreach_norm(weights, ord=math.inf, dtype=torch.float64) # Partial amax_tensor = torch.stack(max_weights) # Partial # clamp is dispatched through DTensor # it will issue a single all-reduce @@ -67,9 +67,9 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None: scale_tensor = torch.finfo(torch.float8_e4m3fn).max / amax_tensor # Replicate if amax_tensor.dtype is torch.float16: scale_tensor = torch.clamp(scale_tensor, max=torch.finfo(torch.float16).max) - local_scale_tensor = scale_tensor.to_local() + local_scale_tensor = scale_tensor.to_local().to(torch.float32) for i, float8_linear in enumerate(float8_linears): - float8_linear.weight._local_tensor._precomputed_scale = local_scale_tensor[i].to(torch.float32) + float8_linear.weight._local_tensor._precomputed_scale = local_scale_tensor[i] # FSDP pads its local tensor on dim-0. The subclass should be preserved such From 511c751ef2b53a01b8fabb91a97d1e5926c3919e Mon Sep 17 00:00:00 2001 From: Wei Feng Date: Thu, 26 Sep 2024 10:57:23 -0700 Subject: [PATCH 11/41] print if error Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/float8/test_base.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/float8/test_base.py b/test/float8/test_base.py index 9586967cee..1a81662424 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -623,7 +623,7 @@ def test_small_amax_float16(self, float8_dtype): def test_dynamic_scale_parity(self, dtype: torch.dtype): scaling_type_weight = ScalingType.DYNAMIC torch.manual_seed(0) - hp_tensor = torch.randn(768, 32, device="cuda", dtype=dtype) + hp_tensor = torch.randn(32, 32, device="cuda", dtype=dtype) float8_config = Float8LinearConfig( cast_config_weight=CastConfig(scaling_type=scaling_type_weight), ) @@ -639,8 +639,9 @@ def test_dynamic_scale_parity(self, dtype: torch.dtype): float8_config, gemm_input_role=GemmInputRole.WEIGHT, ) + torch.set_printoptions(precision=10, threshold=2000) assert torch.equal(float8_eager._scale, float8_compile._scale) - assert torch.testing.assert_close(float8_eager._data, float8_compile._data) + assert torch.equal(float8_eager._data, float8_compile._data), f"{float8_eager._data=} vs {float8_compile._data=}" class TestFloat8LinearUtils(unittest.TestCase): From 9becda113166ff74a4a2b6ce354eb033fef368ec Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Fri, 20 Sep 2024 12:27:11 -0400 Subject: [PATCH 12/41] Add tutorial for trainable tensor subclass (#908) Summary: The new tutorial provides an example of how to implement a trainable tensor subclass that wraps quantized data. This extends the existing `MyDTypeTensor` with a few necessary steps to ensure proper gradient updates, namely: 1. Define a differentiable constructor 2. Define backward pass for ops of interest (e.g. torch.nn.functional.linear) 3. Handle special ops used by the optimizer (e.g. aten.add, aten.add_) Test Plan: python tutorials/developer_api_guide/my_trainable_tensor_subclass.py --- tutorials/developer_api_guide/__init__.py | 0 .../my_dtype_tensor_subclass.py | 80 +++---- .../my_trainable_tensor_subclass.py | 200 ++++++++++++++++++ 3 files changed, 244 insertions(+), 36 deletions(-) create mode 100644 tutorials/developer_api_guide/__init__.py create mode 100644 tutorials/developer_api_guide/my_trainable_tensor_subclass.py diff --git a/tutorials/developer_api_guide/__init__.py b/tutorials/developer_api_guide/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tutorials/developer_api_guide/my_dtype_tensor_subclass.py b/tutorials/developer_api_guide/my_dtype_tensor_subclass.py index 5044393803..6baa6dfc64 100644 --- a/tutorials/developer_api_guide/my_dtype_tensor_subclass.py +++ b/tutorials/developer_api_guide/my_dtype_tensor_subclass.py @@ -77,6 +77,7 @@ def __new__( layout_tensor: MyDTypeLayout, shape: torch.Size, dtype: Optional[torch.dtype] = None, + requires_grad: bool = False, ): kwargs = {} kwargs["device"] = layout_tensor.device @@ -86,7 +87,7 @@ def __new__( else layout_tensor.layout ) kwargs["dtype"] = dtype - kwargs["requires_grad"] = False + kwargs["requires_grad"] = requires_grad return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] def __init__( @@ -94,6 +95,7 @@ def __init__( layout_tensor: MyDTypeLayout, shape: torch.Size, dtype: Optional[torch.dtype] = None, + requires_grad: bool = False, ): self.layout_tensor = layout_tensor @@ -108,7 +110,7 @@ def __tensor_flatten__(self): The first one contains any tensor fields such as int_data and scale as keys to a dictionary The second one contains all other non tensor type fields as values of a list """ - return ["layout_tensor"], [self.shape, self.dtype] + return ["layout_tensor"], [self.shape, self.dtype, self.requires_grad] @classmethod def __tensor_unflatten__( @@ -120,11 +122,12 @@ def __tensor_unflatten__( tensor_attributes contains all other non tensor type fields """ layout_tensor = tensor_data_dict["layout_tensor"] - shape, dtype = tensor_attributes + shape, dtype, requires_grad = tensor_attributes return cls( layout_tensor, shape if outer_size is None else outer_size, dtype=dtype, + requires_grad=requires_grad, ) """classmethod that converts from a floating point Tensor (fp32/fp16/bf16) to the current dtype @@ -330,37 +333,42 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ######## # Test # ######## -from torchao.utils import benchmark_model - -m = M() -example_inputs = (100 * torch.randn(1024, 1024),) -NUM_WARMUPS = 10 -NUM_RUNS = 100 - -for _ in range(NUM_WARMUPS): - m(*example_inputs) -print("before quantization:", benchmark_model(m, NUM_RUNS, example_inputs)) - -compiled = torch.compile(m, mode="max-autotune") -for _ in range(NUM_WARMUPS): - compiled(*example_inputs) -print("after compile:", benchmark_model(compiled, NUM_RUNS, example_inputs)) - -# convert weights to quantized weights -m.linear.weight = torch.nn.Parameter( - to_my_dtype(m.linear.weight), requires_grad=False -) -for _ in range(NUM_WARMUPS): - m(*example_inputs) - -print("after quantization:", benchmark_model(m, NUM_RUNS, example_inputs)) - -m = torch.compile(m, mode="max-autotune") - -for _ in range(NUM_WARMUPS): - m(*example_inputs) - -# NOTE: currently there is no speedup because we just dequantize the weight in the _quantized_linear op -# we plan to add custom op example in the future and that will help us to get speedup -print("after quantization and compile:", benchmark_model(m, NUM_RUNS, example_inputs)) +def test(): + from torchao.utils import benchmark_model + + m = M() + example_inputs = (100 * torch.randn(1024, 1024),) + NUM_WARMUPS = 10 + NUM_RUNS = 100 + + for _ in range(NUM_WARMUPS): + m(*example_inputs) + print("before quantization:", benchmark_model(m, NUM_RUNS, example_inputs)) + + compiled = torch.compile(m, mode="max-autotune") + for _ in range(NUM_WARMUPS): + compiled(*example_inputs) + print("after compile:", benchmark_model(compiled, NUM_RUNS, example_inputs)) + + # convert weights to quantized weights + m.linear.weight = torch.nn.Parameter( + to_my_dtype(m.linear.weight), requires_grad=False + ) + + for _ in range(NUM_WARMUPS): + m(*example_inputs) + + print("after quantization:", benchmark_model(m, NUM_RUNS, example_inputs)) + + m = torch.compile(m, mode="max-autotune") + + for _ in range(NUM_WARMUPS): + m(*example_inputs) + + # NOTE: currently there is no speedup because we just dequantize the weight in the _quantized_linear op + # we plan to add custom op example in the future and that will help us to get speedup + print("after quantization and compile:", benchmark_model(m, NUM_RUNS, example_inputs)) + +if __name__ == "__main__": + test() diff --git a/tutorials/developer_api_guide/my_trainable_tensor_subclass.py b/tutorials/developer_api_guide/my_trainable_tensor_subclass.py new file mode 100644 index 0000000000..a3fc0af8d5 --- /dev/null +++ b/tutorials/developer_api_guide/my_trainable_tensor_subclass.py @@ -0,0 +1,200 @@ +""" +This is an example for a tensor subclass representing a simple dtype +that can be used in training. + +We extend our previous example of `MyDTypeTensor` with a few extra steps +needed to ensure proper gradient updates during training: + + 1. Define a differentiable constructor + 2. Define backward pass for ops of interest (e.g. torch.nn.functional.linear) + 3. Handle special ops used by the optimizer (e.g. aten.add, aten.add_) +""" + +import torch + +from torch.utils._python_dispatch import return_and_correct_aliasing +from torchao.quantization.quant_primitives import choose_qparams_affine, MappingType +from torchao.dtypes.utils import LayoutType, PlainLayoutType +from my_dtype_tensor_subclass import MyDTypeLayout, MyDTypeTensor + +aten = torch.ops.aten + + +############################## +# Tensor Subclass Definition # +############################## + +class MyTrainableDTypeTensor(MyDTypeTensor): + """ + Example tensor subclass that extends `MyDTypeTensor` to support training. + """ + + @classmethod + def _quantize( + cls, + input_float: torch.Tensor, + layout_type: LayoutType, + ) -> MyDTypeLayout: + """ + Convert from a floating point tensor (fp32/fp16/bf16) to the desired dtype. + """ + mapping_type = MappingType.SYMMETRIC + block_size = input_float.shape + dtype = torch.int16 + scale, _ = choose_qparams_affine(input_float, mapping_type, block_size, dtype) + int_data = (input_float / scale).to(torch.int8) + layout_tensor_ctr = cls.get_layout_tensor_constructor(type(layout_type)) + return layout_tensor_ctr(int_data, scale, layout_type) + + @classmethod + def from_float( + cls, + input_float: torch.Tensor, + layout_type: LayoutType = PlainLayoutType(), + ) -> "MyTrainableDTypeTensor": + """ + Main entry point for creating a `MyTrainableDTypeTensor`. + + This instantiates the tensor subclass in a differentiable constructor + to ensure gradients are passed to the tensor subclass properly during training. + """ + return _ToMyTrainableDTypeTensor.apply(input_float, layout_type) + +class _ToMyTrainableDTypeTensor(torch.autograd.Function): + """ + Differentiable constructor for `MyTrainableDTypeTensor`. + """ + + @staticmethod + def forward( + ctx: torch.autograd.function.FunctionCtx, + input_float: torch.Tensor, + layout_type: LayoutType, + ) -> "MyTrainableDTypeTensor": + layout_tensor = MyTrainableDTypeTensor._quantize(input_float, layout_type) + return MyTrainableDTypeTensor( + layout_tensor, + input_float.shape, + requires_grad=True, + ) + + @staticmethod + def backward(ctx, gy): + return gy, None + +to_my_trainable_dtype = MyTrainableDTypeTensor.from_float + + +##################################################### +# torch functional and aten operator implementation # +##################################################### + +implements = MyTrainableDTypeTensor.implements + +class _QuantizedLinearOp(torch.autograd.Function): + """ + Forward and backward definition for linear with quantized weights. + Weights are dequantized during both the forward and the backward passes. + """ + + @staticmethod + def forward( + ctx: torch.autograd.function.FunctionCtx, + input_tensor: torch.Tensor, + weight_tensor: torch.Tensor, + ) -> torch.Tensor: + assert isinstance(weight_tensor, MyTrainableDTypeTensor) + ctx.save_for_backward(input_tensor, weight_tensor) + weight_tensor = weight_tensor.dequantize() + return torch.nn.functional.linear(input_tensor, weight_tensor) + + @staticmethod + def backward(ctx, grad_output): + input_tensor, weight_tensor = ctx.saved_tensors + grad_input = torch.matmul(grad_output, weight_tensor.dequantize()) + grad_weight = torch.matmul( + grad_output.view(-1, weight_tensor.shape[0]).T, + input_tensor.view(-1, weight_tensor.shape[1]), + ) + return grad_input, grad_weight + +@implements(torch.nn.functional.linear) +def _(func, types, args, kwargs): + """ + Handle the linear op with quantized weights. + For simplicity, we run both the forward and backward passes entirely in float. + """ + assert isinstance(args[1], MyTrainableDTypeTensor) + if len(args) > 2 and args[2] is not None: + raise NotImplementedError("linear bias not yet supported") + return _QuantizedLinearOp.apply(args[0], args[1]) + +@implements(aten.add_.Tensor) +def _(func, types, args, kwargs): + """ + Handle the in-place add op, called by the optimizer to update + the quantized weight during training. + """ + assert len(args) == 2 + assert isinstance(args[0], MyTrainableDTypeTensor) + assert args[0].layout_tensor.int_data.dtype == torch.int8 + float0 = args[0].dequantize() + float1 = args[1].dequantize() if isinstance(args[1], MyTrainableDTypeTensor) else args[1] + new_value = torch.add(float0, float1, **kwargs) + new_layout_tensor = MyTrainableDTypeTensor._quantize( + new_value, + args[0].layout_tensor.get_layout_type(), + ) + args[0].layout_tensor = new_layout_tensor + return return_and_correct_aliasing(func, args, kwargs, args[0]) + +@implements(aten.add.Tensor) +def _(func, types, args, kwargs): + """Handle the add op, called by the optimizer during training.""" + assert len(args) == 2 + assert not isinstance(args[0], MyTrainableDTypeTensor) + assert isinstance(args[1], MyTrainableDTypeTensor) + out = torch.add(args[0], args[1].dequantize(), **kwargs) + return return_and_correct_aliasing(func, args, kwargs, out) + + +######## +# Test # +######## + +class M(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.linear = torch.nn.Linear(512, 1024, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + +def main(): + m = M().cuda() + NUM_TRAIN_STEPS = 10 + VERBOSE = True + + # Convert weights to quantized weights + m.linear.weight = torch.nn.Parameter( + to_my_trainable_dtype(m.linear.weight), requires_grad=True, + ) + + # Dummy training loop + optimizer = torch.optim.SGD(m.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-5) + loss_fn = torch.nn.CrossEntropyLoss() + for i in range(NUM_TRAIN_STEPS): + example_inputs = (torch.randn(512).cuda(),) + target = torch.randn(1024).cuda() + output = m(*example_inputs) + loss = loss_fn(output, target) + loss.backward() + if VERBOSE: + weight = m.linear.weight.layout_tensor.int_data.flatten()[:3] + weight_grad = m.linear.weight.grad.flatten()[:3] + print(" * step %s: weight grad = %s, weight value = %s" % (i, weight_grad, weight)) + optimizer.step() + optimizer.zero_grad() + +if __name__ == "__main__": + main() From e4fdca992a72ec1657983defe85315d79cc7f088 Mon Sep 17 00:00:00 2001 From: Vaishnavi Gupta Date: Fri, 20 Sep 2024 11:32:05 -0700 Subject: [PATCH 13/41] Introducing 1-bit quantization for Llama in torchchat (#910) Differential Revision: D63052325 Pull Request resolved: https://github.com/pytorch/ao/pull/911 --- .../benchmarks/benchmark_bitpacking.cpp | 179 +++++++++++++++++- .../aarch64/benchmarks/benchmark_linear.cpp | 6 + .../kernels/cpu/aarch64/bitpacking/bitpack.h | 82 ++++++-- .../kernels/cpu/aarch64/bitpacking/uint1.h | 142 ++++++++++++++ .../kernels/cpu/aarch64/bitpacking/uint5.h | 2 +- .../cpu/aarch64/tests/test_bitpacking.cpp | 114 +++++++++++ .../kernels/cpu/aarch64/tests/test_utils.h | 2 +- 7 files changed, 505 insertions(+), 22 deletions(-) create mode 100644 torchao/experimental/kernels/cpu/aarch64/bitpacking/uint1.h diff --git a/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_bitpacking.cpp b/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_bitpacking.cpp index 88bf8988a8..16096a6c4d 100644 --- a/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_bitpacking.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_bitpacking.cpp @@ -8,6 +8,7 @@ #include #include +#include #include #include #include @@ -16,6 +17,128 @@ namespace { +// Benchmark utility to compare variants of uint1 packing +void pack_uint1_values( + uint8_t* packed, + uint8_t* unpacked, + int packed_size, + int unpacked_size, + int variant) { + constexpr int nbit = 1; + constexpr int bitsPerByte = 8; + assert(unpacked_size * nbit / bitsPerByte == packed_size); + assert(packed_size % variant == 0); + + uint8x16_t unpacked0; + uint8x16_t unpacked1; + uint8x16_t unpacked2; + uint8x16_t unpacked3; + uint8x16_t unpacked4; + uint8x16_t unpacked5; + uint8x16_t unpacked6; + uint8x16_t unpacked7; + + switch (variant) { + case 8: + for (int i = 0; i < unpacked_size; i += 8) { + torchao::bitpacking::internal::pack_8_uint1_values( + packed + ((i * nbit) / bitsPerByte), unpacked + i); + } + break; + case 64: + for (int i = 0; i < unpacked_size; i += 64) { + torchao::bitpacking::internal::vec_load_64_uint8_values( + unpacked0, unpacked1, unpacked2, unpacked3, unpacked + i); + torchao::bitpacking::internal::vec_pack_64_uint1_values( + packed + ((i * nbit) / bitsPerByte), + unpacked0, + unpacked1, + unpacked2, + unpacked3); + } + break; + case 128: + for (int i = 0; i < unpacked_size; i += 128) { + torchao::bitpacking::internal::vec_load_64_uint8_values( + unpacked0, unpacked1, unpacked2, unpacked3, unpacked + i); + torchao::bitpacking::internal::vec_load_64_uint8_values( + unpacked4, unpacked5, unpacked6, unpacked7, unpacked + i + 64); + torchao::bitpacking::internal::vec_pack_128_uint1_values( + packed + ((i * nbit) / bitsPerByte), + unpacked0, + unpacked1, + unpacked2, + unpacked3, + unpacked4, + unpacked5, + unpacked6, + unpacked7); + } + break; + } +} + +// Benchmark utility to compare variants of uint1 packing +void unpack_uint1_values( + uint8_t* unpacked, + uint8_t* packed, + int unpacked_size, + int packed_size, + int variant) { + constexpr int nbit = 1; + constexpr int bitsPerByte = 8; + assert(unpacked_size * nbit / bitsPerByte == packed_size); + assert(packed_size % variant == 0); + + uint8x16_t unpacked0; + uint8x16_t unpacked1; + uint8x16_t unpacked2; + uint8x16_t unpacked3; + uint8x16_t unpacked4; + uint8x16_t unpacked5; + uint8x16_t unpacked6; + uint8x16_t unpacked7; + + switch (variant) { + case 8: + for (int i = 0; i < unpacked_size; i += 8) { + torchao::bitpacking::internal::unpack_8_uint1_values( + unpacked + i, packed + ((i * nbit) / bitsPerByte)); + } + break; + case 64: + for (int i = 0; i < unpacked_size; i += 64) { + torchao::bitpacking::internal::vec_unpack_64_uint1_values( + unpacked0, + unpacked1, + unpacked2, + unpacked3, + packed + ((i * nbit) / bitsPerByte)); + torchao::bitpacking::internal::vec_store_64_uint8_values( + unpacked + i, unpacked0, unpacked1, unpacked2, unpacked3); + } + break; + case 128: + for (int i = 0; i < unpacked_size; i += 128) { + torchao::bitpacking::internal::vec_unpack_128_uint1_values( + unpacked0, + unpacked1, + unpacked2, + unpacked3, + unpacked4, + unpacked5, + unpacked6, + unpacked7, + packed + ((i * nbit) / bitsPerByte)); + torchao::bitpacking::internal::vec_store_64_uint8_values( + unpacked + i, unpacked0, unpacked1, unpacked2, unpacked3); + torchao::bitpacking::internal::vec_store_64_uint8_values( + unpacked + i + 64, unpacked4, unpacked5, unpacked6, unpacked7); + } + break; + } +} + // Benchmark utility to compare variants of uint2 packing void pack_uint2_values( uint8_t* packed, @@ -470,6 +593,44 @@ void unpack_uint5_values( } // namespace +static void benchmark_pack_uint1_values(benchmark::State& state) { + int unpacked_size = state.range(0); + int variant = state.range(1); + int nbit = 1; + + assert(unpacked_size % 8 == 0); + int packed_size = (unpacked_size / 8) * nbit; + + auto packed = std::vector(packed_size, 0); + auto unpacked = torchao::get_random_lowbit_vector(unpacked_size, nbit); + + for (auto _ : state) { + pack_uint1_values( + packed.data(), unpacked.data(), packed_size, unpacked_size, variant); + } +} + +static void benchmark_unpack_uint1_values(benchmark::State& state) { + int unpacked_size = state.range(0); + int variant = state.range(1); + int nbit = 1; + + assert(unpacked_size % 8 == 0); + int packed_size = (unpacked_size / 8) * nbit; + + auto packed = torchao::get_random_lowbit_vector(packed_size, 8); + auto unpacked = std::vector(unpacked_size, 0); + + for (auto _ : state) { + unpack_uint1_values( + unpacked.data(), + packed.data(), + unpacked.size(), + packed.size(), + variant); + } +} + static void benchmark_pack_uint2_values(benchmark::State& state) { int unpacked_size = state.range(0); int variant = state.range(1); @@ -478,8 +639,8 @@ static void benchmark_pack_uint2_values(benchmark::State& state) { assert(unpacked_size % 8 == 0); int packed_size = (unpacked_size / 8) * nbit; - auto packed = std::vector(unpacked_size, 0); - auto unpacked = torchao::get_random_lowbit_vector(packed_size, 8); + auto packed = std::vector(packed_size, 0); + auto unpacked = torchao::get_random_lowbit_vector(unpacked_size, nbit); for (auto _ : state) { pack_uint2_values( @@ -516,8 +677,8 @@ static void benchmark_pack_uint3_values(benchmark::State& state) { assert(unpacked_size % 8 == 0); int packed_size = (unpacked_size / 8) * nbit; - auto packed = std::vector(unpacked_size, 0); - auto unpacked = torchao::get_random_lowbit_vector(packed_size, 8); + auto packed = std::vector(packed_size, 0); + auto unpacked = torchao::get_random_lowbit_vector(unpacked_size, nbit); for (auto _ : state) { pack_uint3_values( @@ -554,8 +715,8 @@ static void benchmark_pack_uint4_values(benchmark::State& state) { assert(unpacked_size % 8 == 0); int packed_size = (unpacked_size / 8) * nbit; - auto packed = std::vector(unpacked_size, 0); - auto unpacked = torchao::get_random_lowbit_vector(packed_size, 8); + auto packed = std::vector(packed_size, 0); + auto unpacked = torchao::get_random_lowbit_vector(unpacked_size, nbit); for (auto _ : state) { pack_uint4_values( @@ -592,8 +753,8 @@ static void benchmark_pack_uint5_values(benchmark::State& state) { assert(unpacked_size % 8 == 0); int packed_size = (unpacked_size / 8) * nbit; - auto packed = std::vector(unpacked_size, 0); - auto unpacked = torchao::get_random_lowbit_vector(packed_size, 8); + auto packed = std::vector(packed_size, 0); + auto unpacked = torchao::get_random_lowbit_vector(unpacked_size, nbit); for (auto _ : state) { pack_uint5_values( @@ -622,6 +783,8 @@ static void benchmark_unpack_uint5_values(benchmark::State& state) { } } +BENCHMARK(benchmark_pack_uint1_values)->ArgsProduct({{128}, {8, 64, 128}}); +BENCHMARK(benchmark_unpack_uint1_values)->ArgsProduct({{128}, {8, 64, 128}}); BENCHMARK(benchmark_pack_uint2_values)->ArgsProduct({{128}, {4, 32, 64}}); BENCHMARK(benchmark_unpack_uint2_values)->ArgsProduct({{128}, {4, 32, 64}}); BENCHMARK(benchmark_pack_uint3_values)->ArgsProduct({{128}, {8, 64, 128}}); diff --git a/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_linear.cpp b/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_linear.cpp index 0d21bc5e5b..02a8d7ac98 100644 --- a/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_linear.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_linear.cpp @@ -228,6 +228,8 @@ channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot( false>) \ ->ArgsProduct(BENCHMARK_PARAMS) +BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x1x32_F32_NEONDOT( + 1); BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x1x32_F32_NEONDOT( 2); BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x1x32_F32_NEONDOT( @@ -236,6 +238,8 @@ BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x1x32_F32_NEONDOT 4); BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x1x32_F32_NEONDOT( 5); +BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT( + 1); BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT( 2); BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT( @@ -244,6 +248,8 @@ BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT 4); BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT( 5); +BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT( + 1); BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x8x16_F32_NEONDOT( 2); BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x8x16_F32_NEONDOT( diff --git a/torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h b/torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h index 37db7926ac..ae5a716a54 100644 --- a/torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h +++ b/torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h @@ -7,6 +7,7 @@ #pragma once #include #include +#include #include #include #include @@ -72,10 +73,10 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_32_lowbit_values( const int8x16_t& unpacked0, const int8x16_t& unpacked1) { static_assert(nbit < 8); - static_assert(nbit >= 2); + static_assert(nbit >= 1); // Currently supported values - static_assert(nbit >= 2); + static_assert(nbit >= 1); static_assert(nbit <= 5); // Shift unpacked values to nonnegative range @@ -84,6 +85,19 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_32_lowbit_values( uint8x16_t shifted1 = vreinterpretq_u8_s8(vaddq_s8(unpacked1, shift)); switch (nbit) { + case 1: + uint8_t buffer1[32]; + vst1q_u8(buffer1, shifted0); + vst1q_u8(buffer1 + 16, shifted1); + + torchao::bitpacking::internal::pack_8_uint1_values(packed, buffer1); + torchao::bitpacking::internal::pack_8_uint1_values( + packed + 1, buffer1 + 8); + torchao::bitpacking::internal::pack_8_uint1_values( + packed + 2, buffer1 + 16); + torchao::bitpacking::internal::pack_8_uint1_values( + packed + 3, buffer1 + 24); + break; case 2: torchao::bitpacking::internal::vec_pack_32_uint2_values( packed, @@ -132,16 +146,28 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_lowbit_values( int8x16_t& unpacked1, uint8_t* packed) { static_assert(nbit < 8); - static_assert(nbit >= 2); + static_assert(nbit >= 1); // Currently supported values - static_assert(nbit >= 2); + static_assert(nbit >= 1); static_assert(nbit <= 5); uint8x16_t shifted0; uint8x16_t shifted1; switch (nbit) { + case 1: + uint8_t buffer1[32]; + torchao::bitpacking::internal::unpack_8_uint1_values(buffer1, packed); + torchao::bitpacking::internal::unpack_8_uint1_values( + buffer1 + 8, packed + 1); + torchao::bitpacking::internal::unpack_8_uint1_values( + buffer1 + 16, packed + 2); + torchao::bitpacking::internal::unpack_8_uint1_values( + buffer1 + 24, packed + 3); + shifted0 = vld1q_u8(buffer1); + shifted1 = vld1q_u8(buffer1 + 16); + break; case 2: uint8x8_t shifted0_low; uint8x8_t shifted0_high; @@ -197,10 +223,10 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_64_lowbit_values( const int8x16_t& unpacked2, const int8x16_t& unpacked3) { static_assert(nbit < 8); - static_assert(nbit >= 2); + static_assert(nbit >= 1); // Currently supported values - static_assert(nbit >= 2); + static_assert(nbit >= 1); static_assert(nbit <= 5); // Shift unpacked values to nonnegative range @@ -211,6 +237,10 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_64_lowbit_values( uint8x16_t shifted3 = vreinterpretq_u8_s8(vaddq_s8(unpacked3, shift)); switch (nbit) { + case 1: + torchao::bitpacking::internal::vec_pack_64_uint1_values( + packed, shifted0, shifted1, shifted2, shifted3); + break; case 2: torchao::bitpacking::internal::vec_pack_64_uint2_values( packed, shifted0, shifted1, shifted2, shifted3); @@ -242,10 +272,10 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_lowbit_values( int8x16_t& unpacked3, uint8_t* packed) { static_assert(nbit < 8); - static_assert(nbit >= 2); + static_assert(nbit >= 1); // Currently supported values - static_assert(nbit >= 2); + static_assert(nbit >= 1); static_assert(nbit <= 5); uint8x16_t shifted0; @@ -254,6 +284,10 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_lowbit_values( uint8x16_t shifted3; switch (nbit) { + case 1: + torchao::bitpacking::internal::vec_unpack_64_uint1_values( + shifted0, shifted1, shifted2, shifted3, packed); + break; case 2: torchao::bitpacking::internal::vec_unpack_64_uint2_values( shifted0, shifted1, shifted2, shifted3, packed); @@ -296,10 +330,10 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_128_lowbit_values( const int8x16_t& unpacked6, const int8x16_t& unpacked7) { static_assert(nbit < 8); - static_assert(nbit >= 2); + static_assert(nbit >= 1); // Currently supported values - static_assert(nbit >= 2); + static_assert(nbit >= 1); static_assert(nbit <= 5); // Shift unpacked values to nonnegative range @@ -314,6 +348,18 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_128_lowbit_values( uint8x16_t shifted7 = vreinterpretq_u8_s8(vaddq_s8(unpacked7, shift)); switch (nbit) { + case 1: + torchao::bitpacking::internal::vec_pack_128_uint1_values( + packed, + shifted0, + shifted1, + shifted2, + shifted3, + shifted4, + shifted5, + shifted6, + shifted7); + break; case 2: torchao::bitpacking::internal::vec_pack_64_uint2_values( packed, shifted0, shifted1, shifted2, shifted3); @@ -371,10 +417,10 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_128_lowbit_values( int8x16_t& unpacked7, uint8_t* packed) { static_assert(nbit < 8); - static_assert(nbit >= 2); + static_assert(nbit >= 1); // Currently supported values - static_assert(nbit >= 2); + static_assert(nbit >= 1); static_assert(nbit <= 5); uint8x16_t shifted0; @@ -387,6 +433,18 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_128_lowbit_values( uint8x16_t shifted7; switch (nbit) { + case 1: + torchao::bitpacking::internal::vec_unpack_128_uint1_values( + shifted0, + shifted1, + shifted2, + shifted3, + shifted4, + shifted5, + shifted6, + shifted7, + packed); + break; case 2: torchao::bitpacking::internal::vec_unpack_64_uint2_values( shifted0, shifted1, shifted2, shifted3, packed); diff --git a/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint1.h b/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint1.h new file mode 100644 index 0000000000..0a16c7398a --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint1.h @@ -0,0 +1,142 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once +#include +#include + +// This file contains bitpacking and unpacking methods for uint1. +// These are not inteded to be used outside of bitpacking directory. +// See bitpack.h for the interface. + +namespace torchao { +namespace bitpacking { +namespace internal { + +TORCHAO_ALWAYS_INLINE inline void pack_8_uint1_values( + uint8_t* packed, + const uint8_t* unpacked) { + // Input is 8 bytes + // Output is 1 bytes + packed[0] = 0; + for (int i = 0; i < 8; i++) { + packed[0] |= (unpacked[i] << (7 - i)); + } +} + +TORCHAO_ALWAYS_INLINE inline void unpack_8_uint1_values( + uint8_t* unpacked, + const uint8_t* packed) { + // Unpacks data packed by pack_8_uint1_values + // + // Input is 8 bits = 1 byte + // Output is 8 bytes + for (int i = 0; i < 8; i++) { + unpacked[i] = (packed[0] >> (7 - i)) & 1; + } +} + +// This function is a vectorized version of pack_8_uint1_values +// To understand it, please see pack_8_uint1_values first. +// +// Input is 64 bytes +// Output is 64 bits = 8 bytes +TORCHAO_ALWAYS_INLINE inline void vec_pack_64_uint1_values( + uint8_t* packed, + const uint8x16_t& unpacked0, + const uint8x16_t& unpacked1, + const uint8x16_t& unpacked2, + const uint8x16_t& unpacked3) { + uint8x16_t vec_packed; + uint8x8_t vec_packed_low; + uint8x8_t vec_packed_high; + vec_packed = vshlq_n_u8(unpacked0, 3); + vec_packed = vorrq_u8(vec_packed, vshlq_n_u8(unpacked1, 2)); + vec_packed = vorrq_u8(vec_packed, vshlq_n_u8(unpacked2, 1)); + vec_packed = vorrq_u8(vec_packed, unpacked3); + + vec_packed_low = vget_low_u8(vec_packed); + vec_packed_high = vget_high_u8(vec_packed); + + vst1_u8(packed, vsli_n_u8(vec_packed_low, vec_packed_high, 4)); +} + +TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_uint1_values( + uint8x16_t& unpacked0, + uint8x16_t& unpacked1, + uint8x16_t& unpacked2, + uint8x16_t& unpacked3, + const uint8_t* packed) { + uint8x8_t vec_packed; + vec_packed = vld1_u8(packed); + + uint8x8_t vec_packed_low; + uint8x8_t vec_packed_high; + vec_packed_low = vand_u8(vec_packed, vdup_n_u8(0xF)); + vec_packed_high = vshr_n_u8(vec_packed, 4); + + uint8x16_t combined = vcombine_u8(vec_packed_low, vec_packed_high); + unpacked0 = vshrq_n_u8(vandq_u8(combined, vdupq_n_u8(8)), 3); + unpacked1 = vshrq_n_u8(vandq_u8(combined, vdupq_n_u8(4)), 2); + unpacked2 = vshrq_n_u8(vandq_u8(combined, vdupq_n_u8(2)), 1); + unpacked3 = vandq_u8(combined, vdupq_n_u8(1)); +} + +// This function is a vectorized version of pack_8_uint1_values +// To understand it, please see `pack_8_uint1_values` first. +// +// Input is 128 bytes +// Output is 128 bytes * 1 bit/8bits = 16 bytes +TORCHAO_ALWAYS_INLINE inline void vec_pack_128_uint1_values( + uint8_t* packed, + const uint8x16_t& unpacked0, + const uint8x16_t& unpacked1, + const uint8x16_t& unpacked2, + const uint8x16_t& unpacked3, + const uint8x16_t& unpacked4, + const uint8x16_t& unpacked5, + const uint8x16_t& unpacked6, + const uint8x16_t& unpacked7) { + uint8x16_t vec_packed; + + vec_packed = vshlq_n_u8(unpacked0, 7); + vec_packed = vorrq_u8(vec_packed, vshlq_n_u8(unpacked1, 6)); + vec_packed = vorrq_u8(vec_packed, vshlq_n_u8(unpacked2, 5)); + vec_packed = vorrq_u8(vec_packed, vshlq_n_u8(unpacked3, 4)); + vec_packed = vorrq_u8(vec_packed, vshlq_n_u8(unpacked4, 3)); + vec_packed = vorrq_u8(vec_packed, vshlq_n_u8(unpacked5, 2)); + vec_packed = vorrq_u8(vec_packed, vshlq_n_u8(unpacked6, 1)); + vec_packed = vorrq_u8(vec_packed, unpacked7); + + vst1q_u8(packed, vec_packed); +} + +TORCHAO_ALWAYS_INLINE inline void vec_unpack_128_uint1_values( + uint8x16_t& unpacked0, + uint8x16_t& unpacked1, + uint8x16_t& unpacked2, + uint8x16_t& unpacked3, + uint8x16_t& unpacked4, + uint8x16_t& unpacked5, + uint8x16_t& unpacked6, + uint8x16_t& unpacked7, + const uint8_t* packed) { + uint8x16_t vec_packed; + vec_packed = vld1q_u8(packed); + + unpacked0 = vandq_u8(vshrq_n_u8(vec_packed, 7), vdupq_n_u8(1)); + unpacked1 = vandq_u8(vshrq_n_u8(vec_packed, 6), vdupq_n_u8(1)); + unpacked2 = vandq_u8(vshrq_n_u8(vec_packed, 5), vdupq_n_u8(1)); + unpacked3 = vandq_u8(vshrq_n_u8(vec_packed, 4), vdupq_n_u8(1)); + unpacked4 = vandq_u8(vshrq_n_u8(vec_packed, 3), vdupq_n_u8(1)); + unpacked5 = vandq_u8(vshrq_n_u8(vec_packed, 2), vdupq_n_u8(1)); + unpacked6 = vandq_u8(vshrq_n_u8(vec_packed, 1), vdupq_n_u8(1)); + unpacked7 = vandq_u8(vec_packed, vdupq_n_u8(1)); +} + +} // namespace internal +} // namespace bitpacking +} // namespace torchao diff --git a/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint5.h b/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint5.h index 0c6bd8f221..0e8e101ea6 100644 --- a/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint5.h +++ b/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint5.h @@ -19,7 +19,7 @@ namespace internal { TORCHAO_ALWAYS_INLINE inline void pack_8_uint5_values( uint8_t* packed, const uint8_t* unpacked) { - // Given 8 unpacked uint3 values: 0abcd, 1efgh, 2ijkl, 3mnop, 4qrst, 5uvwx, + // Given 8 unpacked uint5 values: 0abcd, 1efgh, 2ijkl, 3mnop, 4qrst, 5uvwx, // 6yzAB, 7CDEF, this function packs them as: // b4: 7|6|5|4|3|2|1|0 (upper bits for all values) // b3210_0: efgh|abcd (lower 4 bits for first 2 values) diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_bitpacking.cpp b/torchao/experimental/kernels/cpu/aarch64/tests/test_bitpacking.cpp index 4c53ec28d6..581c3b3e37 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_bitpacking.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_bitpacking.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -14,6 +15,116 @@ #include #include +TEST(test_bitpacking_8_uint1_values, PackUnpackAreSame) { + int unpacked_bytes = 8; + int packed_bytes = 1; + auto input = torchao::get_random_lowbit_vector(unpacked_bytes, 1); + std::vector packed(packed_bytes, 0); + std::vector unpacked(unpacked_bytes, 0); + + torchao::bitpacking::internal::pack_8_uint1_values( + packed.data(), input.data()); + torchao::bitpacking::internal::unpack_8_uint1_values( + unpacked.data(), packed.data()); + for (int i = 0; i < unpacked_bytes; ++i) { + EXPECT_EQ(input[i], unpacked[i]); + } +} + +TEST(test_bitpacking_64_uint1_values, PackUnpackAreSame) { + int unpacked_bytes = 64; + int packed_bytes = 8; + auto input = torchao::get_random_lowbit_vector(unpacked_bytes, 1); + std::vector packed(packed_bytes, 0); + + uint8x16_t input0; + uint8x16_t input1; + uint8x16_t input2; + uint8x16_t input3; + + uint8x16_t unpacked0; + uint8x16_t unpacked1; + uint8x16_t unpacked2; + uint8x16_t unpacked3; + + torchao::bitpacking::internal::vec_load_64_uint8_values( + input0, input1, input2, input3, input.data()); + + torchao::bitpacking::internal::vec_pack_64_uint1_values( + packed.data(), input0, input1, input2, input3); + + torchao::bitpacking::internal::vec_unpack_64_uint1_values( + unpacked0, unpacked1, unpacked2, unpacked3, packed.data()); + + for (int i = 0; i < 16; ++i) { + EXPECT_EQ(input0[i], unpacked0[i]); + EXPECT_EQ(input1[i], unpacked1[i]); + EXPECT_EQ(input2[i], unpacked2[i]); + EXPECT_EQ(input3[i], unpacked3[i]); + } +} + +TEST(test_bitpacking_128_uint1_values, PackUnpackAreSame) { + int unpacked_bytes = 128; + int packed_bytes = 16; + auto input = torchao::get_random_lowbit_vector(unpacked_bytes, 1); + std::vector packed(packed_bytes, 0); + + uint8x16_t input0; + uint8x16_t input1; + uint8x16_t input2; + uint8x16_t input3; + uint8x16_t input4; + uint8x16_t input5; + uint8x16_t input6; + uint8x16_t input7; + + uint8x16_t unpacked0; + uint8x16_t unpacked1; + uint8x16_t unpacked2; + uint8x16_t unpacked3; + uint8x16_t unpacked4; + uint8x16_t unpacked5; + uint8x16_t unpacked6; + uint8x16_t unpacked7; + + torchao::bitpacking::internal::vec_load_64_uint8_values( + input0, input1, input2, input3, input.data()); + torchao::bitpacking::internal::vec_load_64_uint8_values( + input4, input5, input6, input7, input.data() + 64); + torchao::bitpacking::internal::vec_pack_128_uint1_values( + packed.data(), + input0, + input1, + input2, + input3, + input4, + input5, + input6, + input7); + torchao::bitpacking::internal::vec_unpack_128_uint1_values( + unpacked0, + unpacked1, + unpacked2, + unpacked3, + unpacked4, + unpacked5, + unpacked6, + unpacked7, + packed.data()); + + for (int i = 0; i < 16; ++i) { + EXPECT_EQ(input0[i], unpacked0[i]); + EXPECT_EQ(input1[i], unpacked1[i]); + EXPECT_EQ(input2[i], unpacked2[i]); + EXPECT_EQ(input3[i], unpacked3[i]); + EXPECT_EQ(input4[i], unpacked4[i]); + EXPECT_EQ(input5[i], unpacked5[i]); + EXPECT_EQ(input6[i], unpacked6[i]); + EXPECT_EQ(input7[i], unpacked7[i]); + } +} + TEST(test_bitpacking_4_uint2_values, PackUnpackAreSame) { int unpacked_bytes = 4; int packed_bytes = 1; @@ -534,16 +645,19 @@ void test_bitpacking_128_lowbit_values() { test_bitpacking_128_lowbit_values(); \ } +TEST_BITPACKING_32_LOWBIT_VALUES(1); TEST_BITPACKING_32_LOWBIT_VALUES(2); TEST_BITPACKING_32_LOWBIT_VALUES(3); TEST_BITPACKING_32_LOWBIT_VALUES(4); TEST_BITPACKING_32_LOWBIT_VALUES(5); +TEST_BITPACKING_64_LOWBIT_VALUES(1); TEST_BITPACKING_64_LOWBIT_VALUES(2); TEST_BITPACKING_64_LOWBIT_VALUES(3); TEST_BITPACKING_64_LOWBIT_VALUES(4); TEST_BITPACKING_64_LOWBIT_VALUES(5); +TEST_BITPACKING_128_LOWBIT_VALUES(1); TEST_BITPACKING_128_LOWBIT_VALUES(2); TEST_BITPACKING_128_LOWBIT_VALUES(3); TEST_BITPACKING_128_LOWBIT_VALUES(4); diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h b/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h index 4e5083d9ef..b9b03c7771 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h @@ -25,7 +25,7 @@ get_random_vector(int size, float min = -1.0, float max = 1.0) { } inline std::vector get_random_lowbit_vector(int size, int nbit) { - assert(nbit >= 2); + assert(nbit >= 1); assert(nbit <= 8); int min = 0; From 0cd4d37077e2dbe978efa61a0e747dbdda9ab2c7 Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Fri, 20 Sep 2024 13:17:23 -0700 Subject: [PATCH 14/41] Rename Floating point to fp8 (#909) --- torchao/dtypes/affine_quantized_tensor.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 3c1c4b52f3..ecc8aa10d7 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -1360,7 +1360,7 @@ def _linear_f16_act_floatx_weight_impl(input_tensor, weight_tensor, bias): return out.view(*act.shape[:-1], out_dim).to(act.dtype) -def _linear_fp_act_fp8_weight_check( +def _linear_fp8_act_fp8_weight_check( input_tensor: Union[torch.Tensor, AffineQuantizedTensor], weight_tensor: Union[torch.Tensor, AffineQuantizedTensor], bias: Optional[torch.Tensor], @@ -1384,7 +1384,7 @@ def preprocess_scale(input_scale: torch.Tensor, input_shape: Tuple[int]): return input_scale -def _linear_fp_act_fp8_weight_impl( +def _linear_fp8_act_fp8_weight_impl( input_tensor: AffineQuantizedTensor, weight_tensor: AffineQuantizedTensor, bias: Optional[torch.Tensor], @@ -1473,7 +1473,7 @@ def _register_aqt_quantized_linear_dispatches(): for dispatch_condition, impl in [ (_linear_int8_act_int8_weight_check, _linear_int8_act_int8_weight_impl), (_linear_int8_act_int8_weight_semi_structured_sparse_check, _linear_int8_act_int8_weight_semi_structured_sparse_impl), - (_linear_fp_act_fp8_weight_check, _linear_fp_act_fp8_weight_impl), + (_linear_fp8_act_fp8_weight_check, _linear_fp8_act_fp8_weight_impl), (_linear_bf16_act_uint4_weight_check, _linear_bf16_act_uint4_weight_impl), (_linear_fp_act_int8_weight_check, _linear_fp_act_int8_weight_impl), (_linear_f16_act_floatx_weight_check, _linear_f16_act_floatx_weight_impl), From 014558d07082e30fb6713ed40ab6cc4319e004aa Mon Sep 17 00:00:00 2001 From: "Wei (Will) Feng" <134637289+weifengpy@users.noreply.github.com> Date: Mon, 23 Sep 2024 09:59:27 -0700 Subject: [PATCH 15/41] [float8] fix typo in bitwise_identical unit test (#918) Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/float8/test_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/float8/test_base.py b/test/float8/test_base.py index 1a81662424..784eca593e 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -56,7 +56,7 @@ is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) def bitwise_identical(a: Float8Tensor, b: Float8Tensor) -> bool: - assert torch.all(a._data == b._data).item(), "scales are not identical" + assert torch.all(a._scale == b._scale).item(), "scales are not identical" assert torch.all(a._data == b._data).item(), "data is not identical" return True From 3267402c15e5a066f9961a57bbe91cfafec65d8a Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Mon, 23 Sep 2024 13:28:31 -0700 Subject: [PATCH 16/41] Adding example for quantized tensor + tensor parallelism (#785) * [WIP] Adding example for quantized tensor + tensor parallelism Summary: This PR adds an example of how quantized tensor subclass can work with DTensor: https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/README.md End goal is to rewrite https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/llama2.py with normal llama2 implementation and show case with DTensor + AffineQuantizedTensor + torch.compile we can get on par performance with the custom tensor parallel implementation Test Plan: torchrun --standalone --nnodes=1 --nproc-per-node=4 tutorials/developer_api_guide/tensor_parallel.py Reviewers: Subscribers: Tasks: Tags: * tensor parallel file * Use DTensor.from instead of distribute_tensor * implementing aten.slice.Tensor (WIP) * working * some shape fix and use more quant primitive ops * Add rowwise test * make rowwise sharding work * compile still not working yet * fake tensor didn't pick up shape changes from transpose * backend='eager' * change transpose to non-inplace op * add error message * works now with torch nightly * remove print * ruff * Clean up * Fix device id --------- Co-authored-by: Ke Wen --- .../my_dtype_tensor_subclass.py | 132 +++++++++--- .../my_trainable_tensor_subclass.py | 6 +- .../developer_api_guide/tensor_parallel.py | 191 ++++++++++++++++++ 3 files changed, 294 insertions(+), 35 deletions(-) create mode 100644 tutorials/developer_api_guide/tensor_parallel.py diff --git a/tutorials/developer_api_guide/my_dtype_tensor_subclass.py b/tutorials/developer_api_guide/my_dtype_tensor_subclass.py index 6baa6dfc64..03b0d31590 100644 --- a/tutorials/developer_api_guide/my_dtype_tensor_subclass.py +++ b/tutorials/developer_api_guide/my_dtype_tensor_subclass.py @@ -15,7 +15,12 @@ import torch from torch.utils._python_dispatch import return_and_correct_aliasing -from torchao.quantization.quant_primitives import choose_qparams_affine, MappingType +from torchao.quantization.quant_primitives import ( + choose_qparams_affine, + MappingType, + quantize_affine, + dequantize_affine, +) from torchao.dtypes.utils import ( LayoutType, PlainLayoutType, @@ -24,6 +29,32 @@ aten = torch.ops.aten +# TODO: move to torchao/utils.py +def fill_defaults(args, n, defaults_tail): + """ + __torch_dispatch__ doesn't guarantee the number of arguments you are + passed (e.g., defaulted arguments are not passed); but usually it is + convenient to pad out the arguments list with defaults. This function + helps you do that. + Args: + args: the list of positional arguments passed to __torch_dispatch__ + n: the number of arguments you are expecting to get + defaults_tail: default values for the arguments, starting from the + end of the list + Example: + >>> fill_defaults([1, 2, 3], 5, [3, 4, 5]) + [1, 2, 3, 4, 5] + >>> fill_defaults([1, 2, 3], 5, [None, None, None]) + [1, 2, 3, None, None]] + """ + if n - len(defaults_tail) > len(args): + raise RuntimeError("not enough defaults to fill arguments") + r = list(args) + for i in range(len(args), n): + r.append(defaults_tail[i - n + len(defaults_tail)]) + return r + + ############################### # Base Layout Tensor Subclass # ############################### @@ -140,10 +171,10 @@ def from_float( layout_type: LayoutType = PlainLayoutType(), ): mapping_type = MappingType.SYMMETRIC - block_size = input_float.shape + block_size = (1, input_float.shape[-1]) dtype = torch.int16 - scale, _ = choose_qparams_affine(input_float, mapping_type, block_size, dtype) - int_data = (input_float / scale).to(torch.int8) + scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, dtype) + int_data = quantize_affine(input_float, block_size, scale, zero_point, dtype) layout_tensor_ctr = get_layout_tensor_constructor(type(layout_type)) layout_tensor = layout_tensor_ctr(int_data, scale, layout_type) return cls(layout_tensor, input_float.shape) @@ -160,7 +191,14 @@ def dequantize(self, output_dtype=None): if output_dtype is None: output_dtype = torch.get_default_dtype() int_data, scale = self.layout_tensor.get_plain() - return int_data.to(output_dtype) * scale + transposed = False + block_size = (1, int_data.shape[-1]) + if hasattr(self.layout_tensor, "transposed") and self.layout_tensor.transposed: + transposed = True + res = dequantize_affine(int_data, block_size, scale, None, int_data.dtype, output_dtype=output_dtype) + if transposed: + res = res.t() + return res def __repr__(self): return ( @@ -203,6 +241,7 @@ def __new__( cls, int_data: torch.Tensor, scale: torch.Tensor, + transposed: bool, layout_type: LayoutType, ): kwargs = {} @@ -219,22 +258,24 @@ def __init__( self, int_data: torch.Tensor, scale: torch.Tensor, + transposed: bool, layout_type: LayoutType, ): self.int_data = int_data self.scale = scale + self.transposed = transposed self.layout_type = layout_type def __tensor_flatten__(self): - return ["int_data", "scale"], [self.layout_type] + return ["int_data", "scale"], [self.transposed, self.layout_type] @classmethod def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): int_data, scale = tensor_data_dict["int_data"], tensor_data_dict["scale"] - layout_type, = tensor_attributes - return cls(int_data, scale, layout_type) + transposed, layout_type, = tensor_attributes + return cls(int_data, scale, transposed, layout_type) @classmethod def from_plain( @@ -247,12 +288,13 @@ def from_plain( extra metadata for packing etc. """ assert isinstance(layout_type, PlainLayoutType) - return cls(int_data, scale, layout_type) + return cls(int_data, scale, False, layout_type) def _apply_fn_to_data(self, fn): return self.__class__( fn(self.int_data), fn(self.scale), + self.transposed, self.layout_type, ) @@ -265,8 +307,36 @@ def __torch_dispatch__(cls, func, types, args, kwargs): func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) ) + # Tensor parallel support START + elif func in [aten._to_copy.default, aten.clone.default]: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + elif func is aten.split.Tensor: + int_data_list = func(args[0].int_data, *args[1:], **kwargs) + scale_list = func(args[0].scale, *args[1:], **kwargs) + out = [PlainMyDTypeLayout(int_data, scale, args[0].transposed, args[0].layout_type) for int_data, scale in zip(int_data_list, scale_list)] + return out + elif func is aten.empty_like.default: + int_data_empty_like = func(args[0].int_data, *args[1:], **kwargs) + return PlainMyDTypeLayout(int_data_empty_like, args[0].scale, args[0].transposed, args[0].layout_type) + elif func is aten.slice.Tensor: + self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) + if dim == 0: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(lambda x: aten.slice.Tensor(x, dim, start, end, step)) + ) + elif dim == 1: + return PlainMyDTypeLayout(aten.slice.Tensor(self.int_data, dim, start, end, step), self.scale.view(-1, 1), self.transposed, self.layout_type) + else: + raise NotImplementedError(f"PlainMyDTypeLayout dispatch: attempting to run {func}, with dim={dim}, that is not supported") + elif func is aten.t.default: + return return_and_correct_aliasing(func, args, kwargs, PlainMyDTypeLayout(args[0].int_data, args[0].scale, not args[0].transposed, args[0].layout_type)) + + # Tensor parallel support END + raise NotImplementedError( - f"MyDTypeLayout dispatch: attempting to run {func}, this is not supported" + f"PlainMyDTypeLayout dispatch: attempting to run {func}, this is not supported" ) ##################################################### @@ -315,15 +385,6 @@ def _(func, types, args, kwargs): func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) ) - -class M(torch.nn.Module): - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.linear = torch.nn.Linear(1024, 1024) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.linear(x) - ##################### # Factory functions # ##################### @@ -333,42 +394,49 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ######## # Test # ######## - -def test(): +def main(): from torchao.utils import benchmark_model - + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(1024, 128) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + m = M() - example_inputs = (100 * torch.randn(1024, 1024),) + example_inputs = (100 * torch.randn(512, 1024),) NUM_WARMUPS = 10 NUM_RUNS = 100 - + for _ in range(NUM_WARMUPS): m(*example_inputs) print("before quantization:", benchmark_model(m, NUM_RUNS, example_inputs)) - + compiled = torch.compile(m, mode="max-autotune") for _ in range(NUM_WARMUPS): compiled(*example_inputs) print("after compile:", benchmark_model(compiled, NUM_RUNS, example_inputs)) - + # convert weights to quantized weights m.linear.weight = torch.nn.Parameter( to_my_dtype(m.linear.weight), requires_grad=False ) - + for _ in range(NUM_WARMUPS): m(*example_inputs) - + print("after quantization:", benchmark_model(m, NUM_RUNS, example_inputs)) - + m = torch.compile(m, mode="max-autotune") - + for _ in range(NUM_WARMUPS): m(*example_inputs) - + # NOTE: currently there is no speedup because we just dequantize the weight in the _quantized_linear op # we plan to add custom op example in the future and that will help us to get speedup print("after quantization and compile:", benchmark_model(m, NUM_RUNS, example_inputs)) if __name__ == "__main__": - test() + main() diff --git a/tutorials/developer_api_guide/my_trainable_tensor_subclass.py b/tutorials/developer_api_guide/my_trainable_tensor_subclass.py index a3fc0af8d5..b702ac4f91 100644 --- a/tutorials/developer_api_guide/my_trainable_tensor_subclass.py +++ b/tutorials/developer_api_guide/my_trainable_tensor_subclass.py @@ -61,7 +61,7 @@ def from_float( return _ToMyTrainableDTypeTensor.apply(input_float, layout_type) class _ToMyTrainableDTypeTensor(torch.autograd.Function): - """ + """ Differentiable constructor for `MyTrainableDTypeTensor`. """ @@ -163,8 +163,8 @@ def _(func, types, args, kwargs): ######## class M(torch.nn.Module): - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) + def __init__(self) -> None: + super().__init__() self.linear = torch.nn.Linear(512, 1024, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: diff --git a/tutorials/developer_api_guide/tensor_parallel.py b/tutorials/developer_api_guide/tensor_parallel.py new file mode 100644 index 0000000000..a94d84fe05 --- /dev/null +++ b/tutorials/developer_api_guide/tensor_parallel.py @@ -0,0 +1,191 @@ +import os +import torch +import torch.distributed as dist +from torch.distributed import DeviceMesh +from torch.distributed._tensor import DTensor, Replicate, Shard +from torch.utils._python_dispatch import return_and_correct_aliasing +from my_dtype_tensor_subclass import MyDTypeTensor, fill_defaults + +# a tensor subclass that supports tensor parallelism with DTensor +class MyDTypeTensorTP(MyDTypeTensor): + pass + +implements = MyDTypeTensorTP.implements + +aten = torch.ops.aten + +@implements([aten._to_copy.default, aten.clone.default]) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + +@implements([aten.split.Tensor]) +def _(func, types, args, kwargs): + layout_tensor_list = func(args[0].layout_tensor, *args[1:], **kwargs) + out = [MyDTypeTensorTP(layout_tensor, layout_tensor.shape) for layout_tensor in layout_tensor_list] + return out + +@implements([aten.empty_like.default]) +def _(func, types, args, kwargs): + empty_like_layout_tensor = func(args[0].layout_tensor, *args[1:], **kwargs) + return MyDTypeTensorTP(empty_like_layout_tensor, empty_like_layout_tensor.shape) + +@implements(aten.slice.Tensor) +def _(func, types, args, kwargs): + self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) + assert step == 1 + if end >= self.shape[dim]: + end = self.shape[dim] + shape = list(self.shape) + shape[dim] = end - start + return self.__class__(aten.slice.Tensor(self.layout_tensor, dim, start, end, step), shape, self.dtype) + +# this is needed for DTensor.from_local() and for flattening tensor +@implements(aten.view.default) +def _(func, types, args, kwargs): + x, shape = args + + if tuple(x.shape) == tuple(shape): + return x.__class__(x.layout_tensor, x.shape, x.dtype) + + if len(shape) == 1 and shape[0] == -1: + return x.__class__(x.layout_tensor, (x.numel(),), x.dtype) + + raise ValueError(f"{x.__class__.__name__} only supports .view() with same shape or shape=[-1]") + +@implements(aten.t.default) +def _(func, types, args, kwargs): + tensor = args[0] + shape = tensor.shape[::-1] + new = tensor.__class__(tensor.layout_tensor.t(), shape, tensor.dtype) + return return_and_correct_aliasing(func, args, kwargs, new) + +@implements(aten.addmm.default) +def _(func, types, args, kwargs): + input_tensor, weight_tensor, bias = ( + args[1], + args[2], + args[0], + ) + weight_tensor = weight_tensor.dequantize() + return aten.addmm(input_tensor, weight_tensor, bias) + +@implements(aten.mm.default) +def _(func, types, args, kwargs): + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + None + ) + weight_tensor = weight_tensor.dequantize() + return aten.mm(input_tensor, weight_tensor) + + +class M(torch.nn.Module): + def __init__(self, in_features, out_features, **kwargs) -> None: + super().__init__(**kwargs) + self.linear = torch.nn.Linear(in_features, out_features, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + +to_my_dtype_tp = MyDTypeTensorTP.from_float + +def quantize(m: torch.nn.Module) -> torch.nn.Module: + """ + Quantize the model + """ + m.linear.weight = torch.nn.Parameter( + to_my_dtype_tp(m.linear.weight), requires_grad=False + ) + return m + +def colwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module: + """ + Shard linear layer of the model in column-wise fashion + """ + # Column-wise is wrt to A^T, so for A it is row-wise. + # Number of rows per rank + orig_weight = m.linear.weight + n_local_rows = orig_weight.size(0) // mesh.size() + rank = mesh.get_local_rank() + local_shard = orig_weight[rank * n_local_rows : (rank + 1) * n_local_rows, :] + # Construct DTensor from local shard + dtensor = DTensor.from_local(local_shard, mesh, [Shard(0)]) + # Replace parameter in module + m.linear.weight = torch.nn.Parameter( + dtensor, requires_grad=False + ) + return m + +def rowwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module: + """ + Shard linear layer of the model in row-wise fashion + """ + # Row-wise is wrt to A^T, so for A it is column-wise. + # Number of rows per rank + orig_weight = m.linear.weight + n_local_cols = orig_weight.size(1) // mesh.size() + rank = mesh.get_local_rank() + local_shard = orig_weight[:, rank * n_local_cols : (rank + 1) * n_local_cols] + # Construct DTensor from local shard + dtensor = DTensor.from_local(local_shard, mesh, [Shard(1)]) + # Replace parameter in module + m.linear.weight = torch.nn.Parameter( + dtensor, requires_grad=False + ) + return m + +######## +# Test # +######## +def main(): + # To make sure different ranks create the same module + torch.manual_seed(5) + + # Get rank and device + world_size = int(os.environ["WORLD_SIZE"]) + rank = int(os.environ["RANK"]) + device = torch.device(f"cuda:{rank % torch.cuda.device_count()}") + + # Original model + proj_up = M(1024, 2048).to(device) + proj_dn = M(2048, 1024).to(device) + example_input = 100 * torch.randn(128, 1024, device=device) + y = proj_dn(proj_up(example_input)) + + # Quantize the model + up_quant = quantize(proj_up) + dn_quant = quantize(proj_dn) + y_q = dn_quant(up_quant(example_input)) + print("Quantization works!") + + # Create a device mesh + dist.init_process_group(backend="nccl") + mesh = dist.init_device_mesh("cuda", (world_size,)) + + # Shard the models + up_dist = colwise_shard(up_quant, mesh) + dn_dist = rowwise_shard(dn_quant, mesh) + + # We need to turn inputs into DTensor form as well -- just a format change + input_dtensor = DTensor.from_local( + example_input, mesh, [Replicate()] + ) + + y_d = dn_dist(up_dist(input_dtensor)) + print("Distributed result:", y_d) + print("Distributed works!") + + up_compiled = torch.compile(up_dist) + y_up = up_compiled(input_dtensor) + dn_compiled = torch.compile(dn_dist) + y_dn = dn_compiled(y_up) + print("compiled result:", y_dn) + print("torch.compile works!") + + dist.destroy_process_group() + +if __name__ == "__main__": + main() From 1e07effbb8a6972d635255e63bae5d68ef12fddb Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Mon, 23 Sep 2024 18:04:25 -0700 Subject: [PATCH 17/41] rename cuda mode -> gpu mode (#925) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index b28ff522ba..2388433e97 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # torchao: PyTorch Architecture Optimization -[![](https://dcbadge.vercel.app/api/server/cudamode?style=flat)](https://discord.gg/cudamode) +[![](https://dcbadge.vercel.app/api/server/gpumode?style=flat)](https://discord.gg/gpumode) [Introduction](#introduction) | [Inference](#inference) | [Training](#training) | [Composability](#composability) | [Custom Kernels](#custom-kernels) | [Alpha Features](#alpha-features) | [Installation](#installation) | [Integrations](#integrations) | [Videos](#videos) | [License](#license) From ebdeed04b30016a5e2f0cd53ac921346bd93dd8f Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 24 Sep 2024 09:09:08 -0700 Subject: [PATCH 18/41] Add workaround to recover the perf for quantized vit in torch.compile (#926) Add temporary workaround to recover the perf for quantized vit under torch.compile Summary: Recently we found a perf drop in quantized vit due to https://github.com/pytorch/ao/issues/898#issuecomment-2364540055 This PR add a temp fix until we figure out the longer term fix. I think ideally we should figure out why the tensor subclass check failed in torch.compile (https://github.com/pytorch/pytorch/blob/e4d294221b140fdbb49a64f297bc60c9fcc2f80e/torch/nn/modules/activation.py#L1286) and fix that Test Plan: python tutorials/quantize_vit/run_vit_b_quant.py Reviewers: Subscribers: Tasks: Tags: --- tutorials/quantize_vit/run_vit_b_quant.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tutorials/quantize_vit/run_vit_b_quant.py b/tutorials/quantize_vit/run_vit_b_quant.py index 5c30762099..06113bcd68 100644 --- a/tutorials/quantize_vit/run_vit_b_quant.py +++ b/tutorials/quantize_vit/run_vit_b_quant.py @@ -36,6 +36,9 @@ if not TORCH_VERSION_AT_LEAST_2_5: unwrap_tensor_subclass(model) +# temporary workaround to recover the perf with quantized model under torch.compile +torch.backends.mha.set_fastpath_enabled(False) + model = torch.compile(model, mode='max-autotune') # Must run with no_grad when optimizing for inference From 09ffa227adc775a74b19029e9e9226d89b08c142 Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Tue, 24 Sep 2024 10:12:07 -0700 Subject: [PATCH 19/41] clean up device checks in float8 unit test files (#923) Summary: While working on rowwise scaling I noticed that some of the CUDA device capability checks we had in the test files did not make sense, cleaning this up. Test Plan: tests pass on my H100 CI, it should skip less tests now since CI only has CUDA capability 8, 9 Reviewers: Subscribers: Tasks: Tags: --- test/float8/test_base.py | 23 ----------------------- test/float8/test_compile.py | 3 ++- 2 files changed, 2 insertions(+), 24 deletions(-) diff --git a/test/float8/test_base.py b/test/float8/test_base.py index 784eca593e..60d8ffa57d 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -234,15 +234,6 @@ def test_linear( linear_dtype: torch.dtype, linear_bias: bool, ): - if not emulate: - if not torch.cuda.is_available(): - warnings.warn("CUDA not available") - pytest.skip() - elif torch.cuda.get_device_capability() < (9, 0): - warnings.warn( - f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)" - ) - pytest.skip() x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype) m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype) @@ -290,16 +281,6 @@ def test_autocast_outputs( emulate: bool, linear_dtype: torch.dtype, ): - if not emulate: - if not torch.cuda.is_available(): - warnings.warn("CUDA not available") - pytest.skip() - elif torch.cuda.get_device_capability() < (9, 0): - warnings.warn( - f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)" - ) - pytest.skip() - m_ref = nn.Linear(32, 16, device="cuda", dtype=linear_dtype) config = Float8LinearConfig( cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED), @@ -337,10 +318,6 @@ def test_autocast_outputs( @pytest.mark.parametrize("emulate", [True, False] if is_cuda_8_9 else [True]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_type_cast(self, linear_dtype: torch.dtype, emulate: bool): - emulate = ( - not torch.cuda.is_available() or torch.cuda.get_device_capability() < (9, 0) - ) - m = nn.Linear(32, 16, device="cuda", dtype=linear_dtype) config = Float8LinearConfig(emulate=emulate) m = Float8Linear.from_float(copy.deepcopy(m), config) diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index 8a0458bec3..bae62bf77d 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -224,7 +224,8 @@ def forward(self, x): return x_hp return x_fp8 - @unittest.skipIf(not torch.cuda.is_available() or not is_H100, "CUDA with float8 support not available") + # TODO(future): figure out why the test below fails on CUDA capability 8.9 + @unittest.skipIf(not torch.cuda.is_available() or not is_H100, "CUDA with capability 9.0 or greater not available") def test_float8_with_graph_break_in_the_middle(self): """Test that having Float8Tensor object at the boundary of a subgraph""" cnts = CompileCounterWithBackend("inductor") From 0b8dd85d6738bb9454b3e8493ac16a568aaa0d38 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 25 Sep 2024 04:41:48 +0800 Subject: [PATCH 20/41] [low-bit optim] Change 8-bit and FP8 optim block size from 2048 to 256 to match new bnb v0.44 (#927) --- test/prototype/test_low_bit_optim.py | 6 +++++- torchao/prototype/low_bit_optim/README.md | 2 +- torchao/prototype/low_bit_optim/adam.py | 8 ++++---- torchao/prototype/low_bit_optim/subclass_8bit.py | 2 +- torchao/prototype/low_bit_optim/subclass_fp8.py | 2 +- 5 files changed, 12 insertions(+), 8 deletions(-) diff --git a/test/prototype/test_low_bit_optim.py b/test/prototype/test_low_bit_optim.py index ccf925a3fd..496fa3659f 100644 --- a/test/prototype/test_low_bit_optim.py +++ b/test/prototype/test_low_bit_optim.py @@ -3,6 +3,7 @@ import pytest import torch +from packaging.version import Version from torch import nn from torch.testing._internal.common_utils import ( TestCase, @@ -105,8 +106,11 @@ def test_optim_8bit_correctness(self, optim_name): model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device) model2 = copy.deepcopy(model1) + # https://github.com/bitsandbytes-foundation/bitsandbytes/releases/tag/v0.44.0 + block_size = 256 if Version(bnb.__version__) >= Version("0.44.0") else 2048 + optim1 = getattr(bnb.optim, optim_name)(model1.parameters()) - optim2 = getattr(low_bit_optim, optim_name)(model2.parameters()) + optim2 = getattr(low_bit_optim, optim_name)(model2.parameters(), block_size=block_size) for _ in range(2): x = torch.randn(4, 32, device=device) diff --git a/torchao/prototype/low_bit_optim/README.md b/torchao/prototype/low_bit_optim/README.md index 7781386bdd..64cb536ac1 100644 --- a/torchao/prototype/low_bit_optim/README.md +++ b/torchao/prototype/low_bit_optim/README.md @@ -19,7 +19,7 @@ model = ... optim = Adam8bit(model.parameters()) ``` -To use 4-bit Adam, replace the above with `Adam4bit`. Similarly for `AdamFp8`. You can also change quantization block size by passing `block_size=value` to the optimizer. By default, block size is 2048 for 8-bit and FP8 optimizers, and 128 for 4-bit optimizers. +To use 4-bit Adam, replace the above with `Adam4bit`. Similarly for `AdamFp8`. You can also change quantization block size by passing `block_size=value` to the optimizer. By default, block size is 256 for 8-bit and FP8 optimizers, and 128 for 4-bit optimizers. **Other optimizers**: AdamW is also available as `AdamW8bit`, `AdamW4bit`, and `AdamWFp8`. Other optimizers can be added based on demand. diff --git a/torchao/prototype/low_bit_optim/adam.py b/torchao/prototype/low_bit_optim/adam.py index 7f0d47854b..6c3c6996b9 100644 --- a/torchao/prototype/low_bit_optim/adam.py +++ b/torchao/prototype/low_bit_optim/adam.py @@ -161,7 +161,7 @@ def __init__( weight_decay=0, amsgrad=False, *, - block_size=2048, + block_size=256, ) -> None: super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size, is_adamw=False) @@ -199,7 +199,7 @@ def __init__( weight_decay=0, amsgrad=False, *, - block_size=2048, + block_size=256, ) -> None: super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size, is_adamw=False) @@ -218,7 +218,7 @@ def __init__( weight_decay=1e-2, amsgrad=False, *, - block_size=2048, + block_size=256, ) -> None: super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size, is_adamw=True) @@ -256,7 +256,7 @@ def __init__( weight_decay=1e-2, amsgrad=False, *, - block_size=2048, + block_size=256, ) -> None: super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size, is_adamw=True) diff --git a/torchao/prototype/low_bit_optim/subclass_8bit.py b/torchao/prototype/low_bit_optim/subclass_8bit.py index 865498a57e..9c6e641e6d 100644 --- a/torchao/prototype/low_bit_optim/subclass_8bit.py +++ b/torchao/prototype/low_bit_optim/subclass_8bit.py @@ -53,7 +53,7 @@ def dequantize(self, output_dtype=None): return dequant_with_qmap(self.codes, self.qmap, self.scale).to(dtype) @classmethod - def zeros(cls, shape, signed: bool = True, block_size: int = 2048, device=None): + def zeros(cls, shape, signed: bool = True, block_size: int = 256, device=None): codes = torch.zeros(shape, dtype=torch.uint8, device=device) scale = torch.zeros(codes.numel() // block_size, device=device) qmap = torch.tensor(QMAP_SIGNED if signed else QMAP_UNSIGNED, device=device) diff --git a/torchao/prototype/low_bit_optim/subclass_fp8.py b/torchao/prototype/low_bit_optim/subclass_fp8.py index 805c516f4e..146023c9f5 100644 --- a/torchao/prototype/low_bit_optim/subclass_fp8.py +++ b/torchao/prototype/low_bit_optim/subclass_fp8.py @@ -60,7 +60,7 @@ def dequantize(self, output_dtype=None): return float_data.view(self.codes.shape).to(dtype) @classmethod - def zeros(cls, shape, block_size: int = 2048, device=None): + def zeros(cls, shape, block_size: int = 256, device=None): codes = torch.zeros(shape, dtype=DTYPE, device=device) scale = torch.zeros(codes.numel() // block_size, device=device) return cls(codes, scale) From 87faf04a6a8b3ba632bbbf6ba5a5d2c98a8a15eb Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Tue, 24 Sep 2024 15:50:46 -0700 Subject: [PATCH 21/41] Float8 autoquant weight only (#866) --- test/integration/test_integration.py | 11 +++++++++- test/kernel/test_autotuner.py | 20 ++++++++++++++++++ torchao/dtypes/affine_quantized_tensor.py | 2 +- torchao/kernel/intmm.py | 7 ++++++- torchao/quantization/autoquant.py | 25 ++++++++++++++++++++++- 5 files changed, 61 insertions(+), 4 deletions(-) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 6a5ea8ef9d..8e047985c5 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -72,7 +72,7 @@ AQInt8WeightOnlyQuantizedLinearWeight2, AQInt8WeightOnlyQuantizedLinearWeight3, AutoQuantizableLinearWeight, - + AQFloat8WeightOnlyQuantizedLinearWeight, ) from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx import os @@ -98,6 +98,7 @@ COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16] COMMON_DEVICE_DTYPE = list(itertools.product(COMMON_DEVICES, COMMON_DTYPES)).copy() +is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) def _int8wo_api(mod): if TORCH_VERSION_AT_LEAST_2_4: @@ -744,6 +745,14 @@ def test_aq_int8_weight_only_quant_3_subclass(self, device, dtype): AQInt8WeightOnlyQuantizedLinearWeight3.from_float, device, 35, test_dtype=dtype ) + @parameterized.expand(COMMON_DEVICE_DTYPE) + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch") + @unittest.skipIf(not is_H100, "Need H100 to run") + def test_aq_float8_weight_only_quant_subclass(self, device, dtype): + self._test_lin_weight_subclass_impl( + AQFloat8WeightOnlyQuantizedLinearWeight.from_float, device, 30, test_dtype=dtype + ) + @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") diff --git a/test/kernel/test_autotuner.py b/test/kernel/test_autotuner.py index 82fb117363..4ed0974172 100644 --- a/test/kernel/test_autotuner.py +++ b/test/kernel/test_autotuner.py @@ -16,6 +16,7 @@ logging.basicConfig(level=logging.INFO) +is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) class TestQuantFlow(unittest.TestCase): @@ -49,6 +50,25 @@ def test_int_mm(self, device, dtype): assert out32_2.dtype == out32_1.dtype torch.testing.assert_allclose(out32_1, out32_2) + @parameterized.expand( + [ + ("cuda", torch.bfloat16), + ("cuda", torch.float16), + ] + ) + @unittest.skipIf(not is_H100, "Needs H100") + def test_int_mm_float8(self, device, dtype): + from torchao.kernel import intmm + + dtype = torch.bfloat16 + m, k, n = (128, 64, 16) + x = torch.randn(m, k, dtype=dtype, device=device) + w = torch.randn(n, k, dtype=dtype, device=device).t() + x_float8 = x.to(dtype=torch.float8_e4m3fn) + w_float8 = w.to(dtype=torch.float8_e4m3fn) + out32_1 = intmm.safe_int_mm(x_float8, w_float8) + assert out32_1.dtype == torch.int32 + @parameterized.expand( [ ("cuda", torch.bfloat16), diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index ecc8aa10d7..e00576263f 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -335,8 +335,8 @@ def from_hp_to_floatx( input_float: torch.Tensor, block_size: Tuple[int, ...], target_dtype: torch.dtype, - scale_dtype: Optional[torch.dtype], layout_type: LayoutType, + scale_dtype: Optional[torch.dtype] = None, ): if target_dtype in FP8_TYPES: diff --git a/torchao/kernel/intmm.py b/torchao/kernel/intmm.py index 3005cb16a9..81e7b19b15 100644 --- a/torchao/kernel/intmm.py +++ b/torchao/kernel/intmm.py @@ -69,7 +69,12 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: input = ( input.contiguous() ) # (it seems the transpose makes cublas check the above j constraint on i) - return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2) + try: + return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2) + except Exception: + # fallback path, would run on H100 for float8 dtypes + # Exception on H100 float8 dtype : "addmm_cuda" not implemented for 'Float8_e4m3fn' + return torch.matmul(input.to(torch.float32), mat2.to(torch.float32)).to(torch.int32) else: def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: """ diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index 39482caf84..089add1d87 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -9,7 +9,7 @@ Int8WeightOnlyQuantizedLinearWeight, QuantizedLinearWeightBase, ) -from torchao.dtypes import AffineQuantizedTensor, PlainLayoutType, TensorCoreTiledLayoutType +from torchao.dtypes import AffineQuantizedTensor, PlainLayoutType, TensorCoreTiledLayoutType, Float8LayoutType from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor from torch.utils._python_dispatch import return_and_correct_aliasing from .quant_primitives import ( @@ -477,6 +477,22 @@ def _quantized_linear_op(act_mat, w_qtensor, bias): def from_float(cls, weight): return weight +class AQFloat8WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin): + """ + AutoQuantizable version of Float8WeightOnlyQuantizedLinearWeight for target_dtype=torch.float8_e4m3fn + """ + target_dtype: torch.dtype = torch.float8_e4m3fn + + @staticmethod + def _quantized_linear_op(act_mat, w_qtensor, bias): + return torch.nn.functional.linear(act_mat, w_qtensor.dequantize(), bias) + + @classmethod + def from_float(cls, weight): + block_size = (1, weight.shape[1]) + return super(AQFloat8WeightOnlyQuantizedLinearWeight, cls).from_hp_to_floatx(weight, block_size, target_dtype=cls.target_dtype, layout_type=Float8LayoutType()) + + # here we don't include int4 quantization in since int8 tends to be a better apples to apples comparison DEFAULT_AUTOQUANT_CLASS_LIST = [ AQFloatLinearWeight, @@ -493,6 +509,11 @@ def from_float(cls, weight): AQInt4G64WeightOnlyQuantizedLinearWeight ] +OTHER_AUTOQUANT_CLASS_LIST = [ + AQFloat8WeightOnlyQuantizedLinearWeight, +] + + def _change_linears_to_autoquantizable(model, **kwargs): """ Converts all linear weight tensors to the @@ -617,6 +638,8 @@ def autoquant( if set_inductor_config: torchao.quantization.utils.recommended_inductor_config_setter() + if qtensor_class_list in OTHER_AUTOQUANT_CLASS_LIST: + assert torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9), "float8 requires CUDA arch >= 8.9" # perform initial swap from linear weights # to AutoQuantizableLinearWeight From 3a9fdb0274aee3e7902beac41fde39eb01545c32 Mon Sep 17 00:00:00 2001 From: Tobias van der Werff <33268192+tobiasvanderwerff@users.noreply.github.com> Date: Wed, 25 Sep 2024 02:07:59 +0000 Subject: [PATCH 22/41] Fix failing FP6 benchmark (#931) --- benchmarks/benchmark_fp6.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmarks/benchmark_fp6.py b/benchmarks/benchmark_fp6.py index e9f9d21398..9b8dcf3387 100644 --- a/benchmarks/benchmark_fp6.py +++ b/benchmarks/benchmark_fp6.py @@ -1,7 +1,7 @@ import torch import pandas as pd import torch.nn.functional as F -from torchao.dtypes import to_affine_quantized_floatx +from torchao.dtypes import to_affine_quantized_fpx from torchao.dtypes.floatx import FloatxTensorCoreAQTLayout, FloatxTensorCoreLayoutType from torchao.utils import benchmark_torch_function_in_microseconds from tqdm import tqdm @@ -9,7 +9,7 @@ def benchmark(m: int, k: int, n: int): float_data = torch.randn(n, k, dtype=torch.half, device="cuda") - fp6_weight = to_affine_quantized_floatx(float_data, FloatxTensorCoreLayoutType(3, 2)) + fp6_weight = to_affine_quantized_fpx(float_data, FloatxTensorCoreLayoutType(3, 2)) fp16_weight = fp6_weight.dequantize(torch.half) fp16_act = torch.randn(m, k, dtype=torch.half, device="cuda") From fc6c393004736d9636e89d23e024f874cbe02ce2 Mon Sep 17 00:00:00 2001 From: Shuqi Yang Date: Wed, 25 Sep 2024 19:00:00 +0000 Subject: [PATCH 23/41] Remove two if statements in fp8 padding (#935) Reviewed By: vkuzo Differential Revision: D63051205 Pull Request resolved: https://github.com/pytorch/ao/pull/935 Approved by: https://github.com/vkuzo --- torchao/float8/float8_utils.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index 03636b7370..535c870890 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -198,9 +198,7 @@ def _get_min_alignment(size: int, alignment_value: int) -> int: 16 ``` """ - if size % alignment_value == 0: - return size - return (1 + (size // alignment_value)) * alignment_value + return (1 + ((size - 1) // alignment_value)) * alignment_value def pad_tensor_for_matmul( @@ -236,10 +234,6 @@ def pad_tensor_for_matmul( dim1_aligned = _get_min_alignment(dim1, 16) if 0 in dims else dim1 dim2_aligned = _get_min_alignment(dim2, 16) if 1 in dims else dim2 - # Check if padding is needed for either dimension - if dim1 == dim1_aligned and dim2 == dim2_aligned: - return tensor - # Calculate padding values for both dimensions pad_dim1 = dim1_aligned - dim1 pad_dim2 = dim2_aligned - dim2 From 0043ace48085f3fa60f0fa5262bbed0ec9f9c43a Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Thu, 26 Sep 2024 03:40:27 +0800 Subject: [PATCH 24/41] [Distributed] Improve sharding example (#937) * [Distributed] Improve sharding example * Add comment --- .../developer_api_guide/tensor_parallel.py | 41 ++++++++++++++----- 1 file changed, 30 insertions(+), 11 deletions(-) diff --git a/tutorials/developer_api_guide/tensor_parallel.py b/tutorials/developer_api_guide/tensor_parallel.py index a94d84fe05..db610a71fa 100644 --- a/tutorials/developer_api_guide/tensor_parallel.py +++ b/tutorials/developer_api_guide/tensor_parallel.py @@ -1,8 +1,9 @@ import os import torch import torch.distributed as dist +from typing import Sequence from torch.distributed import DeviceMesh -from torch.distributed._tensor import DTensor, Replicate, Shard +from torch.distributed.tensor import DTensor, Replicate, Shard, Placement from torch.utils._python_dispatch import return_and_correct_aliasing from my_dtype_tensor_subclass import MyDTypeTensor, fill_defaults @@ -101,18 +102,40 @@ def quantize(m: torch.nn.Module) -> torch.nn.Module: ) return m +def shard( + full_tensor: torch.Tensor, + device_mesh: DeviceMesh, + placements: Sequence[Placement], +) -> DTensor: + """ + Add a shard function to simplify both colwise_shard and rowwise_shard. The + shard function accepts a full tensor, and returns a DTensor based on + indicated placements. Goal is to move the shard function as a static method + of DTensor, e.g. + dtensor = DTensor.shard(full_tensor, device_mesh, placement) + """ + from torch.distributed.tensor._utils import compute_local_shape_and_global_offset + + shape, offset = compute_local_shape_and_global_offset( + full_tensor.shape, device_mesh, placements + ) + slices = [ + slice(cur_offset, cur_offset + cur_shape) + for cur_shape, cur_offset in zip(shape, offset) + ] + local_tensor = full_tensor[slices] + return DTensor.from_local( + local_tensor, device_mesh, placements + ) + def colwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module: """ Shard linear layer of the model in column-wise fashion """ # Column-wise is wrt to A^T, so for A it is row-wise. - # Number of rows per rank orig_weight = m.linear.weight - n_local_rows = orig_weight.size(0) // mesh.size() - rank = mesh.get_local_rank() - local_shard = orig_weight[rank * n_local_rows : (rank + 1) * n_local_rows, :] # Construct DTensor from local shard - dtensor = DTensor.from_local(local_shard, mesh, [Shard(0)]) + dtensor = shard(orig_weight, mesh, [Shard(0)]) # Replace parameter in module m.linear.weight = torch.nn.Parameter( dtensor, requires_grad=False @@ -124,13 +147,9 @@ def rowwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module: Shard linear layer of the model in row-wise fashion """ # Row-wise is wrt to A^T, so for A it is column-wise. - # Number of rows per rank orig_weight = m.linear.weight - n_local_cols = orig_weight.size(1) // mesh.size() - rank = mesh.get_local_rank() - local_shard = orig_weight[:, rank * n_local_cols : (rank + 1) * n_local_cols] # Construct DTensor from local shard - dtensor = DTensor.from_local(local_shard, mesh, [Shard(1)]) + dtensor = shard(orig_weight, mesh, [Shard(1)]) # Replace parameter in module m.linear.weight = torch.nn.Parameter( dtensor, requires_grad=False From ab3435c904217a0514b2648feba473ff3897fe4a Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Wed, 25 Sep 2024 18:34:27 -0400 Subject: [PATCH 25/41] Add composable QAT quantizer (#938) Summary: This is a utility for users who wish to apply multiple QAT quantizers to their models. In the near future, we expect to add an embedding QAT quantizer that composes with the existing linear QAT quantizers. Test Plan: python test/quantization/test_qat.py -k test_composable_qat_quantizer --- test/quantization/test_qat.py | 42 +++++++++++++++++ .../quantization/prototype/qat/__init__.py | 2 + torchao/quantization/prototype/qat/api.py | 46 +++++++++++++++++-- torchao/quantization/unified.py | 6 +-- 4 files changed, 89 insertions(+), 7 deletions(-) diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 397283a59b..457e3a060f 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -15,6 +15,9 @@ from torchao.dtypes import ( TensorCoreTiledLayoutType, ) +from torchao.quantization.prototype.qat.api import ( + ComposableQATQuantizer, +) from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import ( AffineFakeQuantizedTensor, ) @@ -34,6 +37,9 @@ MappingType, ZeroPointDomain, ) +from torchao.quantization.unified import ( + TwoStepQuantizer, +) from torchao.quantization.utils import ( get_group_qparams_symmetric, get_groupwise_affine_qparams, @@ -626,6 +632,42 @@ def test_qat_4w_quantizer_module_swap(self): module_swap_out = module_swap_model(*x2) torch.testing.assert_close(subclass_out, module_swap_out, atol=0, rtol=0) + class _MyQATQuantizer(TwoStepQuantizer): + """ + Dummy quantizer that attaches a certain value to each nn.Linear's + `_temp_quantizer_values` attribute. + """ + ATTR_NAME = "_temp_quantizer_values" + + def __init__(self, value: str): + self.value = value + + def _attach_value(self, module: torch.nn.Module): + if isinstance(module, torch.nn.Linear): + if not hasattr(module, self.ATTR_NAME): + setattr(module, self.ATTR_NAME, []) + getattr(module, self.ATTR_NAME).append(self.value) + + def prepare(self, model: torch.nn.Module): + model.apply(self._attach_value) + return model + + def convert(self, model: torch.nn.Module): + model.apply(self._attach_value) + return model + + def test_composable_qat_quantizer(self): + quantizer1 = self._MyQATQuantizer("quantizer1") + quantizer2 = self._MyQATQuantizer("quantizer2") + composable_quantizer = ComposableQATQuantizer([quantizer1, quantizer2]) + model = M() + model = composable_quantizer.prepare(model) + self.assertTrue(hasattr(model.linear1, self._MyQATQuantizer.ATTR_NAME)) + values_list = getattr(model.linear1, self._MyQATQuantizer.ATTR_NAME) + self.assertEqual(values_list, ["quantizer1", "quantizer2"]) + composable_quantizer.convert(model) + values_list = getattr(model.linear1, self._MyQATQuantizer.ATTR_NAME) + self.assertEqual(values_list, ["quantizer1", "quantizer2", "quantizer1", "quantizer2"]) if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/prototype/qat/__init__.py b/torchao/quantization/prototype/qat/__init__.py index c16b3ead44..9f8dd74e44 100644 --- a/torchao/quantization/prototype/qat/__init__.py +++ b/torchao/quantization/prototype/qat/__init__.py @@ -5,6 +5,7 @@ enable_8da4w_fake_quant, int4_weight_only_fake_quantize, int8_dynamic_activation_int4_weight_fake_quantize, + ComposableQATQuantizer, Int4WeightOnlyQATQuantizer, Int8DynActInt4WeightQATQuantizer, ) @@ -20,6 +21,7 @@ "enable_8da4w_fake_quant", "int4_weight_only_fake_quantize", "int8_dynamic_activation_int4_weight_fake_quantize", + "ComposableQATQuantizer", "Int4WeightOnlyQATQuantizer", "Int8DynActInt4WeightQATQuantizer", "Int8DynActInt4WeightQATLinear", diff --git a/torchao/quantization/prototype/qat/api.py b/torchao/quantization/prototype/qat/api.py index 2f3368ff1c..e1c5221e1e 100644 --- a/torchao/quantization/prototype/qat/api.py +++ b/torchao/quantization/prototype/qat/api.py @@ -4,7 +4,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Optional +from typing import Any, List, Optional import torch import torch.nn.functional as F @@ -34,6 +34,44 @@ ) +class ComposableQATQuantizer(TwoStepQuantizer): + """ + Composable quantizer that users can use to apply multiple QAT quantizers easily. + Quantizers will be applied in the order they are specified in the constructor. + + Note: the quantizers provided must apply to different modules in the model, + e.g. nn.Linear and nn.Embedding, otherwise the behavior will be undefined. + + Example usage:: + + my_quantizer = ComposableQATQuantizer([ + QATQuantizer1(), + QATQuantizer2(), + QATQuantizer3(), + ]) + model = my_quantizer.prepare(model) + train(model) + model = my_quantizer.convert(model) + """ + + def __init__(self, quantizers: List[TwoStepQuantizer]): + self.quantizers = quantizers + + def prepare( + self, model: torch.nn.Module, *args: Any, **kwargs: Any + ) -> torch.nn.Module: + for quantizer in self.quantizers: + model = quantizer.prepare(model) + return model + + def convert( + self, model: torch.nn.Module, *args: Any, **kwargs: Any + ) -> torch.nn.Module: + for quantizer in self.quantizers: + model = quantizer.convert(model) + return model + + # ================= # | 8da4w QAT | # ================= @@ -44,7 +82,8 @@ def int8_dynamic_activation_int4_weight_fake_quantize(group_size=32): int4 per group weight symmetric fake quantization to linear. Please see :func:`~torchao.quantization.int8_dynamic_activation_int4_weight` for more details. - Example usage: + Example usage:: + from torchao.quantization import quantize_ quantize_(model, int8_dynamic_activation_int4_weight_fake_quantize(group_size=32)) """ @@ -151,7 +190,8 @@ def int4_weight_only_fake_quantize(group_size=128): Applies uint4 weight-only asymmetric per-group fake quantization to linear layers. Please see :func:`~torchao.quantization.int4_weight_only` for more details. - Example usage: + Example usage:: + from torchao.quantization import quantize_ quantize_(model, int4_weight_only_fake_quantize(group_size=32)) """ diff --git a/torchao/quantization/unified.py b/torchao/quantization/unified.py index 7da915dec7..1bd62b8979 100644 --- a/torchao/quantization/unified.py +++ b/torchao/quantization/unified.py @@ -1,5 +1,5 @@ import torch -from typing import Any +from typing import Any, List from abc import ABC, abstractmethod """ @@ -17,7 +17,6 @@ class Quantizer(ABC): def quantize( self, model: torch.nn.Module, *args: Any, **kwargs: Any ) -> torch.nn.Module: - pass @@ -27,11 +26,10 @@ class TwoStepQuantizer: def prepare( self, model: torch.nn.Module, *args: Any, **kwargs: Any ) -> torch.nn.Module: - pass + @abstractmethod def convert( self, model: torch.nn.Module, *args: Any, **kwargs: Any ) -> torch.nn.Module: - pass From a05a40f8a018c293bf60e2dc2bb3c9bc8add0f4f Mon Sep 17 00:00:00 2001 From: Wei Feng Date: Thu, 26 Sep 2024 11:27:37 -0700 Subject: [PATCH 26/41] resolve conflict with latest main Differential Revision: D63048850 Pull Request resolved: https://github.com/pytorch/ao/pull/912 --- test/float8/test_fsdp2/test_fsdp2.py | 2 +- torchao/testing/__init__.py | 0 torchao/testing/float8/__init__.py | 0 .../fsdp2_common.py => torchao/testing/float8/fsdp2_utils.py | 4 ++-- 4 files changed, 3 insertions(+), 3 deletions(-) create mode 100644 torchao/testing/__init__.py create mode 100644 torchao/testing/float8/__init__.py rename test/float8/test_fsdp2/fsdp2_common.py => torchao/testing/float8/fsdp2_utils.py (92%) diff --git a/test/float8/test_fsdp2/test_fsdp2.py b/test/float8/test_fsdp2/test_fsdp2.py index e2e7097f6b..ecde051e36 100644 --- a/test/float8/test_fsdp2/test_fsdp2.py +++ b/test/float8/test_fsdp2/test_fsdp2.py @@ -18,7 +18,7 @@ from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType from torchao.float8.float8_linear_utils import convert_to_float8_training from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor -from fsdp2_common import check_parity_bf16_mp, check_parity_no_mp +from torchao.testing.float8.fsdp2_utils import check_parity_bf16_mp, check_parity_no_mp from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy from torch.distributed._tensor import DTensor from torch.testing._internal.common_cuda import TEST_CUDA diff --git a/torchao/testing/__init__.py b/torchao/testing/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/torchao/testing/float8/__init__.py b/torchao/testing/float8/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/test/float8/test_fsdp2/fsdp2_common.py b/torchao/testing/float8/fsdp2_utils.py similarity index 92% rename from test/float8/test_fsdp2/fsdp2_common.py rename to torchao/testing/float8/fsdp2_utils.py index 0a0a12eb46..62a571e156 100644 --- a/test/float8/test_fsdp2/fsdp2_common.py +++ b/torchao/testing/float8/fsdp2_utils.py @@ -48,7 +48,7 @@ def check_parity_no_mp( ): precompute_float8_dynamic_scale_for_fsdp(model) - test_cls.assertEqual(losses[0], losses[1]) + test_cls.assertEqual(losses[0], losses[1], msg = f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}") def check_parity_bf16_mp( @@ -83,4 +83,4 @@ def check_parity_bf16_mp( ref_model.parameters(), ref_model_bf16.parameters() ): param_bf16.detach().copy_(param_fp32) - test_cls.assertEqual(losses[0], losses[1]) + test_cls.assertEqual(losses[0], losses[1], msg = f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}") From 334891bf5ad14e15d7c1573c65c3f85b665b56a6 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Wed, 25 Sep 2024 16:31:27 -0700 Subject: [PATCH 27/41] Add torchchat quantizer Differential Revision: D62394341 Pull Request resolved: https://github.com/pytorch/ao/pull/897 --- .../kernels/cpu/aarch64/CMakeLists.txt | 8 +- .../examples/torch_custom_op/CMakeLists.txt | 8 +- .../torch_custom_op/build_custom_op.sh | 8 +- .../examples/torch_custom_op/run_custom_op.py | 72 ++-- .../torch_custom_op/test_custom_op.py | 56 --- ...test_int8_dyn_act_intx_weight_quantizer.py | 79 +++++ .../torch_custom_op/torch_custom_op.py | 231 ------------- torchao/experimental/quant_api.py | 321 ++++++++++++++++++ 8 files changed, 432 insertions(+), 351 deletions(-) delete mode 100644 torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/test_custom_op.py create mode 100644 torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/test_int8_dyn_act_intx_weight_quantizer.py delete mode 100644 torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/torch_custom_op.py create mode 100644 torchao/experimental/quant_api.py diff --git a/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt b/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt index ecffb579c1..a13737d874 100644 --- a/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt +++ b/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt @@ -6,8 +6,8 @@ add_library( kernel_aarch64 - ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp - ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp - ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp - ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp + ${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp + ${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp + ${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp + ${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp ) diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/CMakeLists.txt b/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/CMakeLists.txt index 55bcdfbc23..10e44a79a8 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/CMakeLists.txt +++ b/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/CMakeLists.txt @@ -13,12 +13,12 @@ set(CMAKE_BUILD_TYPE Release) add_compile_options("-Wall" "-Werror") include(CMakePrintHelpers) -message("TORCHAO_LIBRARIES: ${TORCHAO_LIBRARIES}") -include_directories(${TORCHAO_LIBRARIES}) +message("TORCHAO_INCLUDE_DIRS: ${TORCHAO_INCLUDE_DIRS}") +include_directories(${TORCHAO_INCLUDE_DIRS}) -add_subdirectory(${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/kernel_aarch64) +add_subdirectory(${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/kernel_aarch64) -include(${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/Utils.cmake) +include(${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/Utils.cmake) set(PLATFORM "ATEN" CACHE STRING "Choose platform surface: ATEN, EXECUTORCH") string(TOUPPER ${PLATFORM} PLATFORM_TO_UPPER) diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/build_custom_op.sh b/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/build_custom_op.sh index 94cb9587c6..c657857fcc 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/build_custom_op.sh +++ b/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/build_custom_op.sh @@ -6,14 +6,14 @@ # LICENSE file in the root directory of this source tree. SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd) -export TORCHAO_LIBRARIES=${SCRIPT_DIR}/../../../../../../.. +export TORCHAO_INCLUDE_DIRS=${SCRIPT_DIR}/../../../../../../.. export CMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)')" echo "CMAKE_PREFIX_PATH: ${CMAKE_PREFIX_PATH}" -export CMAKE_OUT=/tmp/cmake-out/torch_ao/examples/torch_custom_op -cmake -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} \ +export CMAKE_OUT=/tmp/cmake-out/torchao +cmake -DTORCHAO_INCLUDE_DIRS=${TORCHAO_INCLUDE_DIRS} \ -DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH} \ -DPLATFORM="ATEN" \ - -S ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op \ + -S ${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op \ -B ${CMAKE_OUT} cmake --build ${CMAKE_OUT} diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/run_custom_op.py b/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/run_custom_op.py index 0b85583f76..e3d96df63c 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/run_custom_op.py +++ b/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/run_custom_op.py @@ -5,12 +5,21 @@ # LICENSE file in the root directory of this source tree. import copy +import glob +import os + +import sys import torch -from torch_custom_op import ( - linear_a8sz_w_lowbit_reference_impl, - replace_linear_with_quantized_linear, + +sys.path.insert( + 0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../..")) ) +from quant_api import Int8DynActIntxWeightQuantizer + +libs = glob.glob("/tmp/cmake-out/torchao/liblowbit_op_aten.*") +libs = list(filter(lambda l: (l.endswith("so") or l.endswith("dylib")), libs)) +torch.ops.load_library(libs[0]) group_size = 256 m = 1 @@ -27,15 +36,15 @@ print("Quantizing random model") quantized_model = copy.deepcopy(model) -quantized_model = quantized_model.eval() -replace_linear_with_quantized_linear( - quantized_model, - kwargs={ - "group_size": group_size, - "nbit": nbit, - "has_weight_zeros": has_weight_zeros, - }, +quantizer = Int8DynActIntxWeightQuantizer( + device="cpu", + precision=torch.float32, + bitwidth=nbit, + groupsize=group_size, + has_weight_zeros=has_weight_zeros, ) +quantized_model = quantizer.quantize(quantized_model) +quantized_model = quantized_model.eval() print("Creating random activations") activations = torch.randn(m, k, dtype=torch.float32) @@ -58,44 +67,3 @@ print("Running AOTI") fn = torch._export.aot_load("/tmp/torch_custom_op_example_model.so", "cpu") fn(activations) - - -print("\nChecking correctness on layer 0") -linear = model[0] -quantized_linear = quantized_model[0] - -with torch.no_grad(): - result = quantized_linear(activations) - expected_result = linear_a8sz_w_lowbit_reference_impl( - linear.weight, activations, group_size, nbit, has_weight_zeros - ) - non_quantized_result = linear(activations) - - -# Check that entries in result match entries in expected_result -num_mismatch_at_low_tol = 0 -num_total = result.reshape(-1).shape[0] -for i in range(num_total): - actual_val = result.reshape(-1)[i] - expected_val = expected_result.reshape(-1)[i] - if not torch.allclose(actual_val, expected_val): - num_mismatch_at_low_tol += 1 - - # If results are not close at a relaxed tolerance, exit with failure - if not torch.allclose(actual_val, expected_val, atol=1e-6): - assert False, "Correctness check failed" - -# Assert at most 5% of entries are not close at a low tolerance -assert num_mismatch_at_low_tol / num_total <= 0.05, "Correctness check failed" -print( - "Correctness check passed. All results are close, and ", - (num_total - num_mismatch_at_low_tol), - "/", - num_total, - " entries are close at a low tolerance.", -) -print("Quantization errors:") -print("\tL1 error: ", torch.mean(torch.abs(result - non_quantized_result)).item()) -print("\tL2 error: ", torch.mean((result - non_quantized_result) ** 2).item()) -print("\tquantized_result[0:5]: ", result[0][0:5]) -print("\tnon_quantized_result[0:5]: ", non_quantized_result[0][0:5]) diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/test_custom_op.py b/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/test_custom_op.py deleted file mode 100644 index e4e108b901..0000000000 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/test_custom_op.py +++ /dev/null @@ -1,56 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import unittest - -import torch -from torch_custom_op import ( - linear_a8sz_w_lowbit_reference_impl, - replace_linear_with_quantized_linear, -) -import copy - -class TestTorchCustomOp(unittest.TestCase): - def test_accuracy(self): - group_size = 128 - m = 1 - n = 1071 - k = 4096 - activations = torch.randn(m, k, dtype=torch.float32) - model = torch.nn.Sequential(*[torch.nn.Linear(k, n, bias=False)]) - - for nbit in [2, 3, 4, 5]: - for has_weight_zeros in [False, True]: - quantized_model = copy.deepcopy(model) - replace_linear_with_quantized_linear( - quantized_model, - kwargs={ - "group_size": group_size, - "nbit": nbit, - "has_weight_zeros": has_weight_zeros, - }, - ) - - with torch.no_grad(): - result = quantized_model(activations) - expected_result = linear_a8sz_w_lowbit_reference_impl( - model[0].weight, activations, group_size, nbit, has_weight_zeros - ) - - num_mismatch_at_low_tol = 0 - num_total = result.reshape(-1).shape[0] - for i in range(num_total): - actual_val = result.reshape(-1)[i] - expected_val = expected_result.reshape(-1)[i] - self.assertTrue(torch.allclose(actual_val, expected_val, atol=1e-6)) - if not torch.allclose(actual_val, expected_val): - num_mismatch_at_low_tol += 1 - - # Assert at most 5% of entries are not close at a low tolerance - self.assertTrue(num_mismatch_at_low_tol / num_total <= 0.05) - -if __name__ == '__main__': - unittest.main() diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/test_int8_dyn_act_intx_weight_quantizer.py b/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/test_int8_dyn_act_intx_weight_quantizer.py new file mode 100644 index 0000000000..513088d2f0 --- /dev/null +++ b/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/test_int8_dyn_act_intx_weight_quantizer.py @@ -0,0 +1,79 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import copy + +import glob +import os + +import sys +import unittest + +import torch + +sys.path.insert( + 0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../..")) +) +from quant_api import ( + _Int8DynActIntxWeightQuantizedLinearFallback, + Int8DynActIntxWeightQuantizer, +) + +libs = glob.glob("/tmp/cmake-out/torchao/liblowbit_op_aten.*") +libs = list(filter(lambda l: (l.endswith("so") or l.endswith("dylib")), libs)) +if len(libs) == 0: + print( + "Could not find library lowbit_op_aten; please run `sh build_custom_op.sh` to build the library. A slow fallback kernel will be used instaed." + ) +else: + torch.ops.load_library(libs[0]) + + +class TestInt8DynActIntxWeightQuantizer(unittest.TestCase): + def test_accuracy(self): + group_size = 128 + m = 1 + n = 1071 + k = 4096 + activations = torch.randn(m, k, dtype=torch.float32) + model = torch.nn.Sequential(*[torch.nn.Linear(k, n, bias=False)]) + + for nbit in [1, 2, 3, 4, 5, 6, 7]: + for has_weight_zeros in [True, False]: + print(f"Testing nbit={nbit}, has_weight_zeros={has_weight_zeros}") + quantized_model = copy.deepcopy(model) + quantizer = Int8DynActIntxWeightQuantizer( + device="cpu", + precision=torch.float32, + bitwidth=nbit, + groupsize=group_size, + has_weight_zeros=has_weight_zeros, + ) + quantized_model = quantizer.quantize(quantized_model) + + with torch.no_grad(): + result = quantized_model(activations) + reference_impl = _Int8DynActIntxWeightQuantizedLinearFallback() + reference_impl.quantize_and_pack_weights( + model[0].weight, nbit, group_size, has_weight_zeros + ) + expected_result = reference_impl(activations) + + num_mismatch_at_low_tol = 0 + num_total = result.reshape(-1).shape[0] + for i in range(num_total): + actual_val = result.reshape(-1)[i] + expected_val = expected_result.reshape(-1)[i] + self.assertTrue(torch.allclose(actual_val, expected_val, atol=1e-6)) + if not torch.allclose(actual_val, expected_val): + num_mismatch_at_low_tol += 1 + + # Assert at most 5% of entries are not close at a low tolerance + self.assertTrue(num_mismatch_at_low_tol / num_total <= 0.05) + + +if __name__ == "__main__": + unittest.main() diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/torch_custom_op.py b/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/torch_custom_op.py deleted file mode 100644 index 46117db15a..0000000000 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/torch_custom_op.py +++ /dev/null @@ -1,231 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -from typing import Optional - -import torch -import torch.nn as nn - -import glob -libs = glob.glob("/tmp/cmake-out/torch_ao/examples/torch_custom_op/libtorch_custom_op.*") -libs = list(filter(lambda l:(l.endswith("so") or l.endswith("dylib")), libs)) -torch.ops.load_library(libs[0]) - -def quantize(vals: torch.Tensor, group_size: int, nbit: int, scale_only: bool): - assert nbit >= 2 and nbit <= 8 - qmin = -(1 << (nbit - 1)) - qmax = (1 << (nbit - 1)) - 1 - - n, k = vals.shape - vals = vals.reshape(-1, group_size) - vmins, _ = torch.min(vals, axis=1) - vmaxs, _ = torch.max(vals, axis=1) - group_scales = (vmaxs - vmins) / (qmax - qmin) - - if scale_only: - group_qvals = torch.round(vals / group_scales.reshape(-1, 1)) - else: - group_zeros = qmin - torch.round(vmins / group_scales) - group_qvals = torch.round( - group_zeros.reshape(-1, 1) + vals / group_scales.reshape(-1, 1) - ) - - group_qvals = torch.clip(group_qvals, qmin, qmax).reshape(n, k).to(torch.int8) - - if scale_only: - return group_qvals, group_scales - return group_qvals, group_scales, group_zeros - - -def linear_a8sz_w_lowbit_reference_impl( - weights, activations, group_size, nbit, has_weight_zeros -): - n, k = weights.shape - m, k = activations.shape - assert m == 1 - assert k % group_size == 0 - - if has_weight_zeros: - weight_qvals, weight_scales, weight_zeros = quantize( - weights, group_size, nbit, scale_only=False - ) - weights_dequantized = ( - weight_scales.reshape(-1, 1) - * (weight_qvals.reshape(-1, group_size) - weight_zeros.reshape(-1, 1)) - ).reshape(n, k) - else: - weight_qvals, weight_scales = quantize( - weights, group_size, nbit, scale_only=True - ) - weights_dequantized = ( - weight_scales.reshape(-1, 1) * (weight_qvals.reshape(-1, group_size)) - ).reshape(n, k) - - activation_qvals, activations_scales, activations_zeros = quantize( - activations, k, 8, False - ) - activations_dequantized = activations_scales * ( - activation_qvals - activations_zeros - ).reshape(m, k) - return torch.matmul(activations_dequantized, weights_dequantized.transpose(1, 0)) - - -class _quantized_linear(nn.Module): - def __init__( - self, - nbit, - has_weight_zeros, - pack_weight_op, - linear_op, - squeeze_unsqueeze_dim0=False, - ): - super().__init__() - self.squeeze_unsqueeze_dim0 = squeeze_unsqueeze_dim0 - self.nbit = nbit - - self._has_weight_zeros = has_weight_zeros - self._pack_weights_op = pack_weight_op - self._linear_op = linear_op - - def pack_weights(self, weight_qvals, weight_scales_and_zeros, group_size): - n, k = weight_qvals.shape - - # TODO(T200095131): convert self.n, self.k, self.group_size to - # int when supported by AOTI - self.n = torch.empty(n) - self.k = torch.empty(k) - self.group_size = torch.empty(group_size) - - if self._has_weight_zeros: - weight_scales, weight_zeros = weight_scales_and_zeros - self.packed_weights = self._pack_weights_op( - weight_qvals, weight_scales, weight_zeros, self.group_size - ) - else: - weight_scales = weight_scales_and_zeros - self.packed_weights = self._pack_weights_op( - weight_qvals, weight_scales, self.group_size - ) - - def forward(self, x): - if self.squeeze_unsqueeze_dim0: - x = x.squeeze(0) - - res = self._linear_op(self.packed_weights, self.n, self.k, self.group_size, x) - - if self.squeeze_unsqueeze_dim0: - res = res.unsqueeze(0) - return res - - -def replace_linear_with_quantized_linear(module: nn.Module, kwargs={}): - group_size = kwargs["group_size"] - nbit = kwargs["nbit"] - has_weight_zeros = kwargs["has_weight_zeros"] - squeeze_unsqueeze_dim0 = ( - kwargs["squeeze_unsqueeze_dim0"] - if "squeeze_unsqueeze_dim0" in kwargs - else False - ) - - for name, child in module.named_children(): - if isinstance(child, nn.Linear): - assert child.bias is None - - if not has_weight_zeros: - weight_qvals, weight_scales = quantize( - child.weight, group_size=group_size, nbit=nbit, scale_only=True - ) - weight_scales_and_zeros = weight_scales - else: - weight_qvals, weight_scales, weight_zeros = quantize( - child.weight, group_size=group_size, nbit=nbit, scale_only=False - ) - weight_scales_and_zeros = (weight_scales, weight_zeros.to(torch.int8)) - - qlinear = None - if nbit == 2: - if has_weight_zeros: - qlinear = _quantized_linear( - nbit=nbit, - has_weight_zeros=has_weight_zeros, - pack_weight_op=torch.ops.torchao._pack_weights_a8sz_w2sz, - linear_op=torch.ops.torchao._linear_a8sz_w2sz, - squeeze_unsqueeze_dim0=squeeze_unsqueeze_dim0, - ) - else: - qlinear = _quantized_linear( - nbit=nbit, - has_weight_zeros=has_weight_zeros, - pack_weight_op=torch.ops.torchao._pack_weights_a8sz_w2s, - linear_op=torch.ops.torchao._linear_a8sz_w2s, - squeeze_unsqueeze_dim0=squeeze_unsqueeze_dim0, - ) - elif nbit == 3: - if has_weight_zeros: - qlinear = _quantized_linear( - nbit=nbit, - has_weight_zeros=has_weight_zeros, - pack_weight_op=torch.ops.torchao._pack_weights_a8sz_w3sz, - linear_op=torch.ops.torchao._linear_a8sz_w3sz, - squeeze_unsqueeze_dim0=squeeze_unsqueeze_dim0, - ) - else: - qlinear = _quantized_linear( - nbit=nbit, - has_weight_zeros=has_weight_zeros, - pack_weight_op=torch.ops.torchao._pack_weights_a8sz_w3s, - linear_op=torch.ops.torchao._linear_a8sz_w3s, - squeeze_unsqueeze_dim0=squeeze_unsqueeze_dim0, - ) - elif nbit == 4: - if has_weight_zeros: - qlinear = _quantized_linear( - nbit=nbit, - has_weight_zeros=has_weight_zeros, - pack_weight_op=torch.ops.torchao._pack_weights_a8sz_w4sz, - linear_op=torch.ops.torchao._linear_a8sz_w4sz, - squeeze_unsqueeze_dim0=squeeze_unsqueeze_dim0, - ) - else: - qlinear = _quantized_linear( - nbit=nbit, - has_weight_zeros=has_weight_zeros, - pack_weight_op=torch.ops.torchao._pack_weights_a8sz_w4s, - linear_op=torch.ops.torchao._linear_a8sz_w4s, - squeeze_unsqueeze_dim0=squeeze_unsqueeze_dim0, - ) - elif nbit == 5: - if has_weight_zeros: - qlinear = _quantized_linear( - nbit=nbit, - has_weight_zeros=has_weight_zeros, - pack_weight_op=torch.ops.torchao._pack_weights_a8sz_w5sz, - linear_op=torch.ops.torchao._linear_a8sz_w5sz, - squeeze_unsqueeze_dim0=squeeze_unsqueeze_dim0, - ) - else: - qlinear = _quantized_linear( - nbit=nbit, - has_weight_zeros=has_weight_zeros, - pack_weight_op=torch.ops.torchao._pack_weights_a8sz_w5s, - linear_op=torch.ops.torchao._linear_a8sz_w5s, - squeeze_unsqueeze_dim0=squeeze_unsqueeze_dim0, - ) - else: - raise ValueError( - f"Unsupported nbit ({nbit}) and has_weight_zeros ({has_weight_zeros}) combination" - ) - - assert qlinear is not None - setattr(module, name, qlinear) - getattr(module, name).pack_weights( - weight_qvals, - weight_scales_and_zeros, - group_size, - ) - else: - replace_linear_with_quantized_linear(child, kwargs) diff --git a/torchao/experimental/quant_api.py b/torchao/experimental/quant_api.py new file mode 100644 index 0000000000..26797bdb1c --- /dev/null +++ b/torchao/experimental/quant_api.py @@ -0,0 +1,321 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import Optional + +import torch +import torch.nn as nn +from torch.ao.quantization.fx._decomposed import ( + dequantize_per_channel_group, + quantize_per_channel_group, +) + +logger = logging.getLogger(__name__) +logger.setLevel(logging.WARNING) + +import sys + +handler = logging.StreamHandler(sys.stdout) +formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") +handler.setFormatter(formatter) +logger.addHandler(handler) + + +def _quantize(vals: torch.Tensor, group_size: int, nbit: int, has_weight_zeros: bool): + assert nbit >= 1 and nbit <= 8 + qmin = -(1 << (nbit - 1)) + qmax = (1 << (nbit - 1)) - 1 + + n, k = vals.shape + vals = vals.reshape(-1, group_size) + vmins, _ = torch.min(vals, axis=1) + vmaxs, _ = torch.max(vals, axis=1) + group_scales = (vmaxs - vmins) / (qmax - qmin) + + if not has_weight_zeros: + group_zeros = torch.zeros_like(group_scales) + else: + group_zeros = qmin - torch.round(vmins / group_scales) + + vals = vals.reshape(n, k) + group_scales = group_scales.reshape(n, -1) + group_zeros = group_zeros.reshape(n, -1) + + group_qvals = quantize_per_channel_group( + input=vals, + scales=group_scales, + zero_points=group_zeros, + quant_min=qmin, + quant_max=qmax, + dtype=torch.int8, + group_size=group_size, + ) + + if not has_weight_zeros: + group_zeros = None + + return group_qvals, group_scales, group_zeros + + +class _Int8DynActIntxWeightQuantizedLinearNative(nn.Module): + def __init__( + self, + pack_weight_op, + linear_op, + ): + super().__init__() + self._pack_weights_op = pack_weight_op + self._linear_op = linear_op + + def quantize_and_pack_weights(self, weights, nbit, group_size, has_weight_zeros): + self.nbit = nbit + self.group_size = group_size + self.has_weight_zeros = has_weight_zeros + + n, k = weights.shape + + # TODO(T200095131): convert self.n, self.k, self.group_size to + # int when supported by AOTI + self._n = torch.empty(n, dtype=torch.int8) + self._k = torch.empty(k, dtype=torch.int8) + self._group_size = torch.empty(self.group_size, dtype=torch.int8) + + weight_qvals, weight_scales, weight_zeros = _quantize( + weights, self.group_size, self.nbit, self.has_weight_zeros + ) + if self.has_weight_zeros: + self.packed_weights = self._pack_weights_op( + weight_qvals, + weight_scales.reshape(-1), + weight_zeros.to(torch.int8).reshape(-1), + self._group_size, + ) + else: + self.packed_weights = self._pack_weights_op( + weight_qvals, weight_scales.reshape(-1), self._group_size + ) + + def forward(self, x): + assert x.dim() >= 2 + if x.dim() == 2: + return self._linear_op( + self.packed_weights, self._n, self._k, self._group_size, x + ) + + assert x.dim() >= 3 + lead_shape = x.shape[0:-2] + m, k = x.shape[-2], x.shape[-1] + n = self._n.shape[0] + x = x.reshape(-1, m, k) + + res = [ + self._linear_op( + self.packed_weights, self._n, self._k, self._group_size, x[i, :, :] + ) + for i in range(x.shape[0]) + ] + res = torch.stack(res) + res = res.reshape(*lead_shape, m, n) + return res + + +# Python-based reference implementation of Int8DynActLowbitWeightQuantizedLinear +# It is arithmetically equivalent to Int8DynActLowbitWeightQuantizedLinear +# This is used to test Int8DynActLowbitWeightQuantizedLinear, and as a fallback when +# Int8DynActLowbitWeightQuantizedLinear is not available +class _Int8DynActIntxWeightQuantizedLinearFallback(nn.Module): + def __init__(self): + super().__init__() + + def quantize_and_pack_weights(self, weights, nbit, group_size, has_weight_zeros): + self.nbit = nbit + self.group_size = group_size + self.has_weight_zeros = has_weight_zeros + + self._n, self._k = weights.shape + assert self._k % group_size == 0, "group_size must divide k" + + self.weight_qvals, self.weight_scales, self.weight_zeros = _quantize( + weights, self.group_size, self.nbit, self.has_weight_zeros + ) + + def _forward_2d(self, x): + assert x.dim() == 2 + + n, k = self._n, self._k + m, k_ = x.shape + assert k_ == k + + weights_dequantized = dequantize_per_channel_group( + w_int8=self.weight_qvals, + scales=self.weight_scales, + zero_points=( + self.weight_zeros + if self.has_weight_zeros + else torch.zeros_like(self.weight_scales) + ), + quant_min=None, # TODO: why is this an arg for this function + quant_max=None, # TODO: why is this an arg for this function + dtype=None, # TODO: why is this an arg for this function + group_size=self.group_size, + output_dtype=torch.float32, + ) + + activation_qvals, activation_scales, activation_zeros = _quantize( + x, group_size=k, nbit=8, has_weight_zeros=True + ) + activations_dequantized = dequantize_per_channel_group( + w_int8=activation_qvals, + scales=activation_scales, + zero_points=activation_zeros, + quant_min=None, # TODO: why is this an arg for this function + quant_max=None, # TODO: why is this an arg for this function + dtype=None, # TODO: why is this an arg for this function + group_size=k, + output_dtype=torch.float32, + ) + + res = torch.matmul(activations_dequantized, weights_dequantized.transpose(1, 0)) + return res + + def forward(self, x): + assert x.dim() >= 2 + if x.dim() == 2: + return self._forward_2d(x) + + assert x.dim() >= 3 + lead_shape = x.shape[0:-2] + m, k = x.shape[-2], x.shape[-1] + n = self._n + x = x.reshape(-1, m, k) + + res = [self._forward_2d(x[i, :, :]) for i in range(x.shape[0])] + res = torch.stack(res) + res = res.reshape(*lead_shape, m, n) + return res + + +def _maybe_get_quantized_linear_native(nbit, has_weight_zeros): + try: + if nbit in [2, 3, 4, 5]: + wzp_suffix = "z" if has_weight_zeros else "" + return _Int8DynActIntxWeightQuantizedLinearNative( + pack_weight_op=getattr( + torch.ops.torchao, f"_pack_weights_a8sz_w{nbit}s{wzp_suffix}" + ), + linear_op=getattr( + torch.ops.torchao, f"_linear_a8sz_w{nbit}s{wzp_suffix}" + ), + ) + else: + logger.warning( + f"_Int8DynActIntxWeightQuantizedLinearNative does not support: nbit={nbit}, has_weight_zeros={has_weight_zeros}." + ) + except Exception as e: + logger.warning( + f"_Int8DynActIntxWeightQuantizedLinearNative raised an exception during initialization: {e}" + ) + + logger.warning( + "Falling back to **slow** implementation _Int8DynActIntxWeightQuantizedLinearFallback." + ) + return _Int8DynActIntxWeightQuantizedLinearFallback() + + +def _replace_linear_with_quantized_linear(module: nn.Module, kwargs={}): + group_size = kwargs["group_size"] + nbit = kwargs["nbit"] + has_weight_zeros = kwargs["has_weight_zeros"] + + assert not isinstance(module, nn.Linear) + assert nbit >= 1 and nbit <= 7 + + for name, child in module.named_children(): + if not isinstance(child, nn.Linear): + _replace_linear_with_quantized_linear(child, kwargs) + else: + assert child.bias is None + qlinear = _maybe_get_quantized_linear_native( + nbit=nbit, has_weight_zeros=has_weight_zeros + ) + try: + # The packing function may raise some error from the C++ layer (e.g., if group_size is unsupported) + # so calling quantize_and_pack_weights can fail. In this case, we still switch to fallback + # implementation + setattr(module, name, qlinear) + getattr(module, name).quantize_and_pack_weights( + child.weight, nbit, group_size, has_weight_zeros + ) + except Exception as e: + if not isinstance(qlinear, _Int8DynActIntxWeightQuantizedLinearNative): + raise e + logger.warning( + "_Int8DynActIntxWeightQuantizedLinearNative raised an exception during quantize_and_pack_weights: {e}\n" + + "Falling back to **slow** implementation _Int8DynActIntxWeightQuantizedLinearFallback." + ) + qlinear = _Int8DynActIntxWeightQuantizedLinearFallback() + setattr(module, name, qlinear) + getattr(module, name).quantize_and_pack_weights( + child.weight, nbit, group_size, has_weight_zeros + ) + + +class Int8DynActIntxWeightQuantizer: + def __init__( + self, + device, + precision, + *, + bitwidth: Optional[int] = None, + groupsize: Optional[int] = None, + has_weight_zeros: Optional[bool] = None, + ): + if device != "cpu": + raise NotImplementedError( + "Only device=cpu is currently supported in Int8DynActLowbitWeightQuantizer" + ) + else: + self.device = device + + if precision != torch.float32: + raise NotImplementedError( + "Only precision=torch.float32 is currently supported in Int8DynActLowbitWeightQuantizer" + ) + else: + self.precision = precision + + if bitwidth is None: + self.bitwidth = 4 + logger.warning(f"bitwidth not specified, defaulting to {self.bitwidth}.") + else: + self.bitwidth = bitwidth + + if groupsize is None: + self.groupsize = 128 + logger.warning(f"groupsize not specified, defaulting to {self.groupsize}.") + else: + self.groupsize = groupsize + + if has_weight_zeros is None: + self.has_weight_zeros = False + logger.warning( + f"has_weight_zeros not specified, defaulting to {self.has_weight_zeros}." + ) + else: + self.has_weight_zeros = has_weight_zeros + + def quantize(self, model: nn.Module) -> nn.Module: + model = model.to(self.device).to(self.precision) + _replace_linear_with_quantized_linear( + model, + kwargs={ + "group_size": self.groupsize, + "nbit": self.bitwidth, + "has_weight_zeros": self.has_weight_zeros, + }, + ) + return model From c706139d5efb4d736c210554143f07db03822631 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 25 Sep 2024 18:47:49 -0700 Subject: [PATCH 28/41] Add compile tests to test suite (#906) * Add compile tests to test suite Summary: This is a follow up PR addressing https://github.com/pytorch/ao/pull/839#discussion_r1750720771 We can add more compiler related tests in the future. Next * refactor a bit to use quantize_ API directly * use the test suite in existing API tests Test Plan: python torchao/testing/utils.py Reviewers: Subscribers: Tasks: Tags: * rename * add result check --- torchao/testing/utils.py | 65 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 63 insertions(+), 2 deletions(-) diff --git a/torchao/testing/utils.py b/torchao/testing/utils.py index a6c5bf7e0a..48a171a75f 100644 --- a/torchao/testing/utils.py +++ b/torchao/testing/utils.py @@ -69,8 +69,6 @@ def new_test(self, value=value): class TorchAOBasicTestCase(common_utils.TestCase): - """Basic test case for tensor subclasses - """ COMMON_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16] @@ -142,6 +140,66 @@ def test_linear(self, device, dtype): lp_res = torch.nn.functional.linear(hp_act_tensor, lp_tensor) self.assertGreater(torchao.quantization.utils.compute_error(hp_res, lp_res), self.LINEAR_MIN_SQNR) + +class TorchAOCompileTestCase(common_utils.TestCase): + COMMON_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) + COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16] + + TENSOR_SUBCLASS = AffineQuantizedTensor + FACTORY_FN = to_affine_quantized_intx + kwargs = { + "mapping_type": MappingType.ASYMMETRIC, + "block_size": (1, 32), + "target_dtype": torch.uint8, + } + # minimum sqnr for linear operation when the weight is quantized to low precision + # with the above setting + LINEAR_MIN_SQNR = 40 + COMPILE_MIN_SQNR = 50 + + @common_utils.parametrize("device", COMMON_DEVICES) + @common_utils.parametrize("dtype", COMMON_DTYPES) + def test_input_output_tensor_subclass(self, device, dtype): + hp_tensor = torch.randn(4, 128, device=device, dtype=dtype) + lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs) + def f(tensor): + return tensor + + ref = f(lp_tensor) + f = torch.compile(f) + compiled = f(lp_tensor) + self.assertTrue(isinstance(f(lp_tensor), self.TENSOR_SUBCLASS)) + self.assertEqual(ref.dequantize(), compiled.dequantize()) + + @common_utils.parametrize("device", COMMON_DEVICES) + @common_utils.parametrize("dtype", COMMON_DTYPES) + def test_input_tensor_subclass(self, device, dtype): + hp_tensor = torch.randn(4, 128, device=device, dtype=dtype) + lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs) + def f(tensor): + return tensor.dequantize() + + ref = f(lp_tensor) + f = torch.compile(f) + compiled = f(lp_tensor) + self.assertFalse(isinstance(f(lp_tensor), self.TENSOR_SUBCLASS)) + self.assertEqual(ref, compiled) + + @common_utils.parametrize("device", COMMON_DEVICES) + @common_utils.parametrize("dtype", COMMON_DTYPES) + def test_output_tensor_subclass(self, device, dtype): + hp_tensor = torch.randn(4, 128, device=device, dtype=dtype) + def f(hp_tensor): + return self.FACTORY_FN(hp_tensor, **self.kwargs) + + ref = f(hp_tensor) + f = torch.compile(f) + compiled = f(hp_tensor) + self.assertTrue(isinstance(f(hp_tensor), self.TENSOR_SUBCLASS)) + # bfloat16 seems to result in much larger numerical differences + if dtype != torch.bfloat16: + self.assertGreater(torchao.quantization.utils.compute_error(ref.dequantize(), compiled.dequantize()), self.COMPILE_MIN_SQNR) + @common_utils.parametrize("device", COMMON_DEVICES) @common_utils.parametrize("dtype", COMMON_DTYPES) def test_linear_compile(self, device, dtype): @@ -155,7 +213,10 @@ def test_linear_compile(self, device, dtype): lp_res = torch.compile(l)(hp_act_tensor) self.assertGreater(torchao.quantization.utils.compute_error(hp_res, lp_res), self.LINEAR_MIN_SQNR) + + common_utils.instantiate_parametrized_tests(TorchAOBasicTestCase) +common_utils.instantiate_parametrized_tests(TorchAOCompileTestCase) if __name__ == "__main__": unittest.main() From 93554c0275fbc1569ba1c8ad19c96e26131fc5cd Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Wed, 25 Sep 2024 19:33:03 -0700 Subject: [PATCH 29/41] Fix up CMakeLists and reorganize some code locations Differential Revision: D62711903 Pull Request resolved: https://github.com/pytorch/ao/pull/948 --- torchao/experimental/CMakeLists.txt | 50 ++++++++++++++ .../{kernels/cpu => }/Utils.cmake | 0 ...uild_custom_op.sh => build_torchao_ops.sh} | 15 ++-- .../kernels/cpu/aarch64/CMakeLists.txt | 16 +++-- .../cpu/linear/benchmarks/CMakeLists.txt | 57 --------------- .../cpu/linear/examples/CMakeLists.txt | 38 ---------- .../examples/torch_custom_op/CMakeLists.txt | 58 ---------------- .../examples/torch_custom_op/run_custom_op.py | 69 ------------------- .../kernels/cpu/linear/tests/CMakeLists.txt | 41 ----------- .../cpu/linear/tests/build_and_run_tests.sh | 12 ---- .../experimental/ops/linear/CMakeLists.txt | 12 ++++ .../ops/linear/benchmarks/CMakeLists.txt | 40 +++++++++++ .../benchmarks/benchmark_linear_operator.cpp | 40 +++++++++-- .../benchmarks/build_and_run_benchmarks.sh | 6 +- ...it_activation_groupwise_lowbit_weight.cpp} | 57 +++------------ ..._8bit_activation_groupwise_lowbit_weight.h | 21 +----- .../ops/linear/examples/CMakeLists.txt | 42 +++++++++++ ...ationGroupwiseLowbitWeightLinearOperator.h | 20 +++--- .../linear/examples/build_and_run_examples.sh | 10 +-- .../examples/separate_function_wrappers.cpp | 36 ++++++++-- .../examples/stateful_class_wrapper.cpp | 32 ++++++++- .../linear/linear_a8wxdq_op/CMakeLists.txt | 45 ++++++++++++ .../linear_a8wxdq_op/linear_a8wxdq-impl.h} | 40 +++++++++-- .../linear_a8wxdq_op/linear_a8wxdq_aten.cpp} | 2 +- .../linear_a8wxdq_executorch}/w2s.cpp | 2 +- .../linear_a8wxdq_executorch}/w2sz.cpp | 2 +- .../linear_a8wxdq_executorch}/w3s.cpp | 2 +- .../linear_a8wxdq_executorch}/w3sz.cpp | 2 +- .../linear_a8wxdq_executorch}/w4s.cpp | 2 +- .../linear_a8wxdq_executorch}/w4sz.cpp | 2 +- .../linear_a8wxdq_executorch}/w5s.cpp | 2 +- .../linear_a8wxdq_executorch}/w5sz.cpp | 2 +- .../ops/linear/tests/CMakeLists.txt | 43 ++++++++++++ .../ops/linear/tests/build_and_run_tests.sh | 14 ++++ .../linear/tests/test_linear_operator.cpp | 51 ++++++++++---- .../experimental/{kernels/cpu => ops}/macro.h | 0 .../{kernels/cpu => ops}/memory.h | 0 .../{kernels/cpu => ops}/parallel-aten-impl.h | 0 .../cpu => ops}/parallel-openmp-impl.h | 0 .../cpu => ops}/parallel-pthreadpool-impl.h | 0 .../parallel-single_threaded-impl.h | 0 .../cpu => ops}/parallel-test_dummy-impl.h | 0 .../{kernels/cpu => ops}/parallel.h | 22 +++--- ...test_int8_dyn_act_intx_weight_quantizer.py | 47 ++++++++++++- 44 files changed, 521 insertions(+), 431 deletions(-) create mode 100644 torchao/experimental/CMakeLists.txt rename torchao/experimental/{kernels/cpu => }/Utils.cmake (100%) rename torchao/experimental/{kernels/cpu/linear/examples/torch_custom_op/build_custom_op.sh => build_torchao_ops.sh} (51%) delete mode 100644 torchao/experimental/kernels/cpu/linear/benchmarks/CMakeLists.txt delete mode 100644 torchao/experimental/kernels/cpu/linear/examples/CMakeLists.txt delete mode 100644 torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/CMakeLists.txt delete mode 100644 torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/run_custom_op.py delete mode 100644 torchao/experimental/kernels/cpu/linear/tests/CMakeLists.txt delete mode 100644 torchao/experimental/kernels/cpu/linear/tests/build_and_run_tests.sh create mode 100644 torchao/experimental/ops/linear/CMakeLists.txt create mode 100644 torchao/experimental/ops/linear/benchmarks/CMakeLists.txt rename torchao/experimental/{kernels/cpu => ops}/linear/benchmarks/benchmark_linear_operator.cpp (77%) rename torchao/experimental/{kernels/cpu => ops}/linear/benchmarks/build_and_run_benchmarks.sh (70%) rename torchao/experimental/{kernels/cpu/linear/channelwise_8bit_activation_groupwise_lowbit_weight-impl.h => ops/linear/channelwise_8bit_activation_groupwise_lowbit_weight.cpp} (83%) rename torchao/experimental/{kernels/cpu => ops}/linear/channelwise_8bit_activation_groupwise_lowbit_weight.h (81%) create mode 100644 torchao/experimental/ops/linear/examples/CMakeLists.txt rename torchao/experimental/{kernels/cpu => ops}/linear/examples/Channelwise8BitActivationGroupwiseLowbitWeightLinearOperator.h (91%) rename torchao/experimental/{kernels/cpu => ops}/linear/examples/build_and_run_examples.sh (67%) rename torchao/experimental/{kernels/cpu => ops}/linear/examples/separate_function_wrappers.cpp (80%) rename torchao/experimental/{kernels/cpu => ops}/linear/examples/stateful_class_wrapper.cpp (71%) create mode 100644 torchao/experimental/ops/linear/linear_a8wxdq_op/CMakeLists.txt rename torchao/experimental/{kernels/cpu/linear/examples/torch_custom_op/lowbit_op-impl.h => ops/linear/linear_a8wxdq_op/linear_a8wxdq-impl.h} (87%) rename torchao/experimental/{kernels/cpu/linear/examples/torch_custom_op/lowbit_op_aten.cpp => ops/linear/linear_a8wxdq_op/linear_a8wxdq_aten.cpp} (98%) rename torchao/experimental/{kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch => ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch}/w2s.cpp (90%) rename torchao/experimental/{kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch => ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch}/w2sz.cpp (90%) rename torchao/experimental/{kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch => ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch}/w3s.cpp (90%) rename torchao/experimental/{kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch => ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch}/w3sz.cpp (90%) rename torchao/experimental/{kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch => ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch}/w4s.cpp (90%) rename torchao/experimental/{kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch => ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch}/w4sz.cpp (90%) rename torchao/experimental/{kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch => ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch}/w5s.cpp (90%) rename torchao/experimental/{kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch => ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch}/w5sz.cpp (90%) create mode 100644 torchao/experimental/ops/linear/tests/CMakeLists.txt create mode 100644 torchao/experimental/ops/linear/tests/build_and_run_tests.sh rename torchao/experimental/{kernels/cpu => ops}/linear/tests/test_linear_operator.cpp (78%) rename torchao/experimental/{kernels/cpu => ops}/macro.h (100%) rename torchao/experimental/{kernels/cpu => ops}/memory.h (100%) rename torchao/experimental/{kernels/cpu => ops}/parallel-aten-impl.h (100%) rename torchao/experimental/{kernels/cpu => ops}/parallel-openmp-impl.h (100%) rename torchao/experimental/{kernels/cpu => ops}/parallel-pthreadpool-impl.h (100%) rename torchao/experimental/{kernels/cpu => ops}/parallel-single_threaded-impl.h (100%) rename torchao/experimental/{kernels/cpu => ops}/parallel-test_dummy-impl.h (100%) rename torchao/experimental/{kernels/cpu => ops}/parallel.h (73%) rename torchao/experimental/{kernels/cpu/linear/examples/torch_custom_op => tests}/test_int8_dyn_act_intx_weight_quantizer.py (63%) diff --git a/torchao/experimental/CMakeLists.txt b/torchao/experimental/CMakeLists.txt new file mode 100644 index 0000000000..198e9ebd44 --- /dev/null +++ b/torchao/experimental/CMakeLists.txt @@ -0,0 +1,50 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +project(torchao) + +cmake_minimum_required(VERSION 3.19) + +set(CMAKE_CXX_STANDARD 17) + +if (NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE Release) +endif() + + +# Source root directory for torchao/experimental +if(NOT TORCHAO_ROOT) + set(TORCHAO_ROOT ${CMAKE_CURRENT_SOURCE_DIR}) +endif() + +if(NOT TORCHAO_INCLUDE_DIRS) + set(TORCHAO_INCLUDE_DIRS ${TORCHAO_ROOT}/../..) +endif() + +if (NOT TORCHAO_PARALLEL_BACKEND) + if (TORCHAO_OP_TARGET STREQUAL "ATEN") + set(TORCHAO_PARALLEL_BACKEND "ATEN_OPENMP") + elseif(TORCHAO_OP_TARGET STREQUAL "EXECUTORCH") + set(TORCHAO_PARALLEL_BACKEND "PTHREADPOOL") + else() + message(TORCHAO_PARALLEL_BACKEND "TORCHAO_PARALLEL_BACKEND is not set. Please set it directly or set TORCHAO_OP_TARGET to get a default.") + endif() +endif() + +include(CMakePrintHelpers) + +add_compile_options("-Wall" "-Werror") + +include(CMakePrintHelpers) +message("TORCHAO_INCLUDE_DIRS: ${TORCHAO_INCLUDE_DIRS}") +include_directories(${TORCHAO_INCLUDE_DIRS}) + +if (CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") + # Defines target torchao_kernels_aarch64 + add_subdirectory(${TORCHAO_ROOT}/kernels/cpu/aarch64) + add_subdirectory(${TORCHAO_ROOT}/ops/linear) + add_subdirectory(${TORCHAO_ROOT}/ops/linear/linear_a8wxdq_op) +endif() diff --git a/torchao/experimental/kernels/cpu/Utils.cmake b/torchao/experimental/Utils.cmake similarity index 100% rename from torchao/experimental/kernels/cpu/Utils.cmake rename to torchao/experimental/Utils.cmake diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/build_custom_op.sh b/torchao/experimental/build_torchao_ops.sh similarity index 51% rename from torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/build_custom_op.sh rename to torchao/experimental/build_torchao_ops.sh index c657857fcc..de6d8e17d8 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/build_custom_op.sh +++ b/torchao/experimental/build_torchao_ops.sh @@ -5,15 +5,14 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd) -export TORCHAO_INCLUDE_DIRS=${SCRIPT_DIR}/../../../../../../.. - export CMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)')" echo "CMAKE_PREFIX_PATH: ${CMAKE_PREFIX_PATH}" export CMAKE_OUT=/tmp/cmake-out/torchao -cmake -DTORCHAO_INCLUDE_DIRS=${TORCHAO_INCLUDE_DIRS} \ - -DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH} \ - -DPLATFORM="ATEN" \ - -S ${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op \ +cmake -DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH} \ + -DCMAKE_INSTALL_PREFIX=${CMAKE_OUT} \ + -DTORCHAO_OP_TARGET="$1" \ + -DEXECUTORCH_LIBRARIES=${EXECUTORCH_LIBRARIES} \ + -DEXECUTORCH_INCLUDE_DIRS=${EXECUTORCH_INCLUDE_DIRS} \ + -S . \ -B ${CMAKE_OUT} -cmake --build ${CMAKE_OUT} +cmake --build ${CMAKE_OUT} --target install --config Release diff --git a/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt b/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt index a13737d874..ec497a1871 100644 --- a/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt +++ b/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt @@ -4,10 +4,12 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -add_library( - kernel_aarch64 - ${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp - ${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp - ${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp - ${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp -) +if (CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") + add_library( + torchao_kernels_aarch64 + ${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp + ${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp + ${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp + ${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp + ) +endif() diff --git a/torchao/experimental/kernels/cpu/linear/benchmarks/CMakeLists.txt b/torchao/experimental/kernels/cpu/linear/benchmarks/CMakeLists.txt deleted file mode 100644 index 61e5eeae27..0000000000 --- a/torchao/experimental/kernels/cpu/linear/benchmarks/CMakeLists.txt +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -cmake_minimum_required(VERSION 3.19) -project(benchmarks) -set(CMAKE_CXX_STANDARD 17) -set(CMAKE_BUILD_TYPE Release) - -include(FetchContent) -FetchContent_Declare(googlebenchmark - GIT_REPOSITORY https://github.com/google/benchmark.git - GIT_TAG main) # need main for benchmark::benchmark - -set(BENCHMARK_ENABLE_TESTING OFF) -FetchContent_MakeAvailable( - googlebenchmark) - -add_compile_options("-Wall" "-Werror") - -include(CMakePrintHelpers) -message("TORCHAO_LIBRARIES: ${TORCHAO_LIBRARIES}") -include_directories(${TORCHAO_LIBRARIES}) - -add_library( - dep - ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp - ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp - ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp - ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp -) - -add_executable(benchmark_linear_operator benchmark_linear_operator.cpp) -target_link_libraries( - benchmark_linear_operator - PRIVATE - benchmark::benchmark - dep -) - -option(TORCHAO_PARALLEL_OMP "" OFF) -option(TORCHAO_PARALLEL_SINGLE_THREADED "" ON) - -if (TORCHAO_PARALLEL_OMP) - message("OpenMP_ROOT: ${OpenMP_ROOT}") - add_definitions(-DTORCHAO_PARALLEL_OMP=1) - find_package(OpenMP REQUIRED) - if(OpenMP_CXX_FOUND) - target_link_libraries(benchmark_linear_operator PUBLIC OpenMP::OpenMP_CXX) - endif() -endif() - -if (TORCHAO_PARALLEL_SINGLE_THREADED) - add_definitions(-DTORCHAO_PARALLEL_SINGLE_THREADED=1) -endif() diff --git a/torchao/experimental/kernels/cpu/linear/examples/CMakeLists.txt b/torchao/experimental/kernels/cpu/linear/examples/CMakeLists.txt deleted file mode 100644 index 4489dc7c36..0000000000 --- a/torchao/experimental/kernels/cpu/linear/examples/CMakeLists.txt +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -project(examples) - -cmake_minimum_required(VERSION 3.19) -set(CMAKE_CXX_STANDARD 17) -set(CMAKE_BUILD_TYPE Release) - -add_compile_options("-Wall" "-Werror") - -include(CMakePrintHelpers) -message("TORCHAO_LIBRARIES: ${TORCHAO_LIBRARIES}") -include_directories(${TORCHAO_LIBRARIES}) - -add_subdirectory(${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/kernel_aarch64) - -add_executable(separate_function_wrappers separate_function_wrappers.cpp) -target_link_libraries( - separate_function_wrappers - PRIVATE - kernel_aarch64 -) - -add_executable(stateful_class_wrapper stateful_class_wrapper.cpp) -target_link_libraries( - stateful_class_wrapper - PRIVATE - kernel_aarch64 -) - -include(${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/Utils.cmake) - -target_link_torchao_parallel_backend(stateful_class_wrapper "openmp") -target_link_torchao_parallel_backend(separate_function_wrappers "openmp") diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/CMakeLists.txt b/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/CMakeLists.txt deleted file mode 100644 index 10e44a79a8..0000000000 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/CMakeLists.txt +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -project(torch_custom_op) - -cmake_minimum_required(VERSION 3.19) -set(CMAKE_CXX_STANDARD 17) -set(CMAKE_BUILD_TYPE Release) - -add_compile_options("-Wall" "-Werror") - -include(CMakePrintHelpers) -message("TORCHAO_INCLUDE_DIRS: ${TORCHAO_INCLUDE_DIRS}") -include_directories(${TORCHAO_INCLUDE_DIRS}) - -add_subdirectory(${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/kernel_aarch64) - -include(${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/Utils.cmake) - -set(PLATFORM "ATEN" CACHE STRING "Choose platform surface: ATEN, EXECUTORCH") -string(TOUPPER ${PLATFORM} PLATFORM_TO_UPPER) - -if(PLATFORM_TO_UPPER STREQUAL "ATEN") -message(STATUS "Building with PLATFORM=ATEN") - -find_package(Torch REQUIRED) -add_library(lowbit_op_aten SHARED lowbit_op_aten.cpp) -target_link_libraries(lowbit_op_aten PRIVATE kernel_aarch64) -target_include_directories(lowbit_op_aten PRIVATE "${TORCH_INCLUDE_DIRS}") -target_link_libraries(lowbit_op_aten PRIVATE "${TORCH_LIBRARIES}") -target_compile_definitions(lowbit_op_aten PRIVATE USE_ATEN=1) -target_link_torchao_parallel_backend(lowbit_op_aten "ATEN_OPENMP") - -elseif(PLATFORM_TO_UPPER STREQUAL "EXECUTORCH") -message(STATUS "Building with PLATFORM=EXECUTORCH") - -add_library(lowbit_op_executorch SHARED - lowbit_op_executorch/w2s.cpp - lowbit_op_executorch/w2sz.cpp - lowbit_op_executorch/w3s.cpp - lowbit_op_executorch/w3sz.cpp - lowbit_op_executorch/w4s.cpp - lowbit_op_executorch/w4sz.cpp - lowbit_op_executorch/w5s.cpp - lowbit_op_executorch/w5sz.cpp -) -target_include_directories(lowbit_op_executorch PRIVATE ${EXECUTORCH_INCLUDE_DIRS}) -target_compile_definitions(lowbit_op_executorch PRIVATE USE_EXECUTORCH=1) -target_link_torchao_parallel_backend(lowbit_op_executorch "SINGLE_THREADED") -target_link_libraries(lowbit_op_executorch PRIVATE ${EXECUTORCH_LIBRARIES}) -target_link_libraries(lowbit_op_executorch PRIVATE kernel_aarch64) - -else() -message(FATAL_ERROR "Unknown PLATFORM: ${PLATFORM}. Please choose one of: ATEN, EXECUTORCH.") -endif() diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/run_custom_op.py b/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/run_custom_op.py deleted file mode 100644 index e3d96df63c..0000000000 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/run_custom_op.py +++ /dev/null @@ -1,69 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import copy -import glob -import os - -import sys - -import torch - -sys.path.insert( - 0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../..")) -) -from quant_api import Int8DynActIntxWeightQuantizer - -libs = glob.glob("/tmp/cmake-out/torchao/liblowbit_op_aten.*") -libs = list(filter(lambda l: (l.endswith("so") or l.endswith("dylib")), libs)) -torch.ops.load_library(libs[0]) - -group_size = 256 -m = 1 -n = 4096 -k = 4096 -nbit = 4 -has_weight_zeros = False -n_layers = 5 - -print("Creating random model") -layers = [torch.nn.Linear(k, n, bias=False) for _ in range(n_layers)] -model = torch.nn.Sequential(*layers) -model = model.eval() - -print("Quantizing random model") -quantized_model = copy.deepcopy(model) -quantizer = Int8DynActIntxWeightQuantizer( - device="cpu", - precision=torch.float32, - bitwidth=nbit, - groupsize=group_size, - has_weight_zeros=has_weight_zeros, -) -quantized_model = quantizer.quantize(quantized_model) -quantized_model = quantized_model.eval() - -print("Creating random activations") -activations = torch.randn(m, k, dtype=torch.float32) - -print("Exporting quantized model") -exported = torch.export.export(quantized_model, (activations,)) - -print("Using torch.compile on quantized model") -quantized_model_compiled = torch.compile(quantized_model) -with torch.no_grad(): - quantized_model_compiled(activations) - -print("Compiling quantized model with AOTI") -torch._export.aot_compile( - quantized_model, - (activations,), - options={"aot_inductor.output_path": "/tmp/torch_custom_op_example_model.so"}, -) - -print("Running AOTI") -fn = torch._export.aot_load("/tmp/torch_custom_op_example_model.so", "cpu") -fn(activations) diff --git a/torchao/experimental/kernels/cpu/linear/tests/CMakeLists.txt b/torchao/experimental/kernels/cpu/linear/tests/CMakeLists.txt deleted file mode 100644 index 3a415d8edd..0000000000 --- a/torchao/experimental/kernels/cpu/linear/tests/CMakeLists.txt +++ /dev/null @@ -1,41 +0,0 @@ -# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -cmake_minimum_required(VERSION 3.19) -project(tests) -set(CMAKE_CXX_STANDARD 17) -set(CMAKE_BUILD_TYPE Debug) - -include(FetchContent) -FetchContent_Declare( - googletest - URL https://github.com/google/googletest/archive/03597a01ee50ed33e9dfd640b249b4be3799d395.zip -) -FetchContent_MakeAvailable(googletest) - -add_compile_options("-Wall" "-Werror") - -include(CMakePrintHelpers) -message("TORCHAO_LIBRARIES: ${TORCHAO_LIBRARIES}") -include_directories(${TORCHAO_LIBRARIES}) - -add_library( - dep - ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp - ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp - ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp - ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp -) - -enable_testing() - -add_definitions(-DTORCHAO_PARALLEL_TEST_DUMMY=1) -add_executable(test_linear_operator test_linear_operator.cpp) -target_link_libraries( - test_linear_operator - PRIVATE - GTest::gtest_main - dep -) - -include(GoogleTest) -gtest_discover_tests(test_linear_operator) diff --git a/torchao/experimental/kernels/cpu/linear/tests/build_and_run_tests.sh b/torchao/experimental/kernels/cpu/linear/tests/build_and_run_tests.sh deleted file mode 100644 index ad9a855084..0000000000 --- a/torchao/experimental/kernels/cpu/linear/tests/build_and_run_tests.sh +++ /dev/null @@ -1,12 +0,0 @@ -#!/bin/bash -# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd) -export TORCHAO_LIBRARIES=${SCRIPT_DIR}/../../../../../.. -export CMAKE_OUT=/tmp/cmake-out/torch_ao/tests -cmake -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} -S ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/linear/tests -B ${CMAKE_OUT} - -cmake --build ${CMAKE_OUT} - -# Run -${CMAKE_OUT}/test_linear_operator diff --git a/torchao/experimental/ops/linear/CMakeLists.txt b/torchao/experimental/ops/linear/CMakeLists.txt new file mode 100644 index 0000000000..2f7b91bbf9 --- /dev/null +++ b/torchao/experimental/ops/linear/CMakeLists.txt @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +cmake_minimum_required(VERSION 3.19) + +include(${TORCHAO_ROOT}/Utils.cmake) + +add_library(torchao_ops_linear_${TORCHAO_PARALLEL_BACKEND} STATIC channelwise_8bit_activation_groupwise_lowbit_weight.cpp) +target_link_torchao_parallel_backend(torchao_ops_linear_${TORCHAO_PARALLEL_BACKEND} "${TORCHAO_PARALLEL_BACKEND}") diff --git a/torchao/experimental/ops/linear/benchmarks/CMakeLists.txt b/torchao/experimental/ops/linear/benchmarks/CMakeLists.txt new file mode 100644 index 0000000000..70d6bf2cba --- /dev/null +++ b/torchao/experimental/ops/linear/benchmarks/CMakeLists.txt @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +cmake_minimum_required(VERSION 3.19) +project(benchmarks) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_BUILD_TYPE Release) +add_compile_options("-Wall" "-Werror") + +set(TORCHAO_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../..) +set(TORCHAO_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../../../../..) + +include(FetchContent) +FetchContent_Declare(googlebenchmark + GIT_REPOSITORY https://github.com/google/benchmark.git + GIT_TAG main) # need main for benchmark::benchmark + +set(BENCHMARK_ENABLE_TESTING OFF) +FetchContent_MakeAvailable( + googlebenchmark) + +include_directories(${TORCHAO_INCLUDE_DIRS}) + +set(TORCHAO_PARALLEL_BACKEND "OPENMP") +add_subdirectory(${TORCHAO_ROOT}/ops/linear ${CMAKE_CURRENT_BINARY_DIR}/torchao_ops_linear_${TORCHAO_PARALLEL_BACKEND}) +add_subdirectory(${TORCHAO_ROOT}/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/torchao_kernels_aarch64) + +add_executable(benchmark_linear_operator benchmark_linear_operator.cpp) +target_link_libraries( + benchmark_linear_operator + PRIVATE + benchmark::benchmark + torchao_kernels_aarch64 + torchao_ops_linear_${TORCHAO_PARALLEL_BACKEND} +) +target_link_torchao_parallel_backend(benchmark_linear_operator "${TORCHAO_PARALLEL_BACKEND}") diff --git a/torchao/experimental/kernels/cpu/linear/benchmarks/benchmark_linear_operator.cpp b/torchao/experimental/ops/linear/benchmarks/benchmark_linear_operator.cpp similarity index 77% rename from torchao/experimental/kernels/cpu/linear/benchmarks/benchmark_linear_operator.cpp rename to torchao/experimental/ops/linear/benchmarks/benchmark_linear_operator.cpp index ad6563eabe..8d7cd4a908 100644 --- a/torchao/experimental/kernels/cpu/linear/benchmarks/benchmark_linear_operator.cpp +++ b/torchao/experimental/ops/linear/benchmarks/benchmark_linear_operator.cpp @@ -5,11 +5,40 @@ // LICENSE file in the root directory of this source tree. #include +#include #include -#include -#include +#include +#include +#include #include +using namespace torchao::ops::linear:: + channelwise_8bit_activation_groupwise_lowbit_weight; + +template +UKernelConfig get_ukernel_config() { + UKernelConfig config; + + namespace ukernel = torchao::kernels::cpu::aarch64::linear:: + channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot; + config.mr = 1; + config.nr = 8; + config.activation_data_size_fn = + &ukernel::activation_data_size; + config.activation_data_alignment = 16; // size of neon register + config.prepare_activation_data_fn = + &ukernel::prepare_activation_data; + config.weight_data_size_fn = + &ukernel::weight_data_size; + config.weight_data_alignment = 16; // size of neon register + config.prepare_weight_data_fn = + &ukernel::prepare_weight_data; + config.kernel_fn = + &ukernel::kernel; + + return config; +} + template static void channelwise_8bit_activation_groupwise_lowbit_weight( benchmark::State& state) { @@ -24,9 +53,6 @@ static void channelwise_8bit_activation_groupwise_lowbit_weight( int num_test_cases = state.range(5); // Initialize config and tiling params - using namespace torchao::operators::cpu::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight; - auto ukernel_config = get_ukernel_config(); auto pack_weight_data_tiling_params = @@ -66,7 +92,7 @@ static void channelwise_8bit_activation_groupwise_lowbit_weight( std::vector> packed_weight_data; for (int i = 0; i < test_cases.size(); i++) { - packed_weight_data.emplace_back(torchao::make_aligned_byte_array_unique_ptr( + packed_weight_data.emplace_back(torchao::make_aligned_byte_ptr( packed_weight_data_alignment, packed_weight_data_size)); pack_weight_data_operator( ukernel_config, @@ -91,7 +117,7 @@ static void channelwise_8bit_activation_groupwise_lowbit_weight( size_t activation_data_buffer_alignment = get_activation_data_buffer_alignment(ukernel_config); - auto activation_data_buffer = torchao::make_aligned_byte_array_unique_ptr( + auto activation_data_buffer = torchao::make_aligned_byte_ptr( activation_data_buffer_alignment, activation_data_buffer_size); auto output = std::vector(m * n); diff --git a/torchao/experimental/kernels/cpu/linear/benchmarks/build_and_run_benchmarks.sh b/torchao/experimental/ops/linear/benchmarks/build_and_run_benchmarks.sh similarity index 70% rename from torchao/experimental/kernels/cpu/linear/benchmarks/build_and_run_benchmarks.sh rename to torchao/experimental/ops/linear/benchmarks/build_and_run_benchmarks.sh index 18da0e992d..ed80d34e2f 100644 --- a/torchao/experimental/kernels/cpu/linear/benchmarks/build_and_run_benchmarks.sh +++ b/torchao/experimental/ops/linear/benchmarks/build_and_run_benchmarks.sh @@ -7,11 +7,9 @@ # Call script with sh build_and_run_benchmarks.sh {BENCHAMRK} -SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd) -export TORCHAO_LIBRARIES=${SCRIPT_DIR}/../../../../../.. -export CMAKE_OUT=/tmp/cmake-out/torch_ao/benchmarks +export CMAKE_OUT=/tmp/cmake-out/torchao/benchmarks cmake -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} \ - -S ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/linear/benchmarks \ + -S . \ -B ${CMAKE_OUT} \ -DOpenMP_ROOT=$(brew --prefix libomp) \ -DTORCHAO_PARALLEL_OMP=ON diff --git a/torchao/experimental/kernels/cpu/linear/channelwise_8bit_activation_groupwise_lowbit_weight-impl.h b/torchao/experimental/ops/linear/channelwise_8bit_activation_groupwise_lowbit_weight.cpp similarity index 83% rename from torchao/experimental/kernels/cpu/linear/channelwise_8bit_activation_groupwise_lowbit_weight-impl.h rename to torchao/experimental/ops/linear/channelwise_8bit_activation_groupwise_lowbit_weight.cpp index 37ad74b0f0..ae611d3ccc 100644 --- a/torchao/experimental/kernels/cpu/linear/channelwise_8bit_activation_groupwise_lowbit_weight-impl.h +++ b/torchao/experimental/ops/linear/channelwise_8bit_activation_groupwise_lowbit_weight.cpp @@ -4,18 +4,18 @@ // This source code is licensed under the license found in the // LICENSE file in the root directory of this source tree. -#pragma once #include -#include -#include +#include +#include +#include #include #include #include -namespace torchao::operators::cpu::linear:: +namespace torchao::ops::linear:: channelwise_8bit_activation_groupwise_lowbit_weight { -inline PackWeightDataTilingParams get_default_pack_weight_data_tiling_params( +PackWeightDataTilingParams get_default_pack_weight_data_tiling_params( const UKernelConfig& ukernel_config, int n, int target_panels_per_thread) { @@ -40,7 +40,7 @@ inline PackWeightDataTilingParams get_default_pack_weight_data_tiling_params( return tiling_params; } -inline void pack_weight_data_operator( +void pack_weight_data_operator( const UKernelConfig& ukernel_config, const PackWeightDataTilingParams& tiling_params, // Outputs @@ -81,7 +81,7 @@ inline void pack_weight_data_operator( } // This default mimics XNNPACK behavior if target_tiles_per_thread = 5 -inline LinearTilingParams get_default_linear_tiling_params( +LinearTilingParams get_default_linear_tiling_params( const UKernelConfig& ukernel_config, int m, int n, @@ -118,8 +118,7 @@ inline LinearTilingParams get_default_linear_tiling_params( namespace internal { -inline int -get_activation_data_buffer_size_with_tile_schedule_policy_single_mc_parallel_nc( +inline int get_activation_data_buffer_size_with_tile_schedule_policy_single_mc_parallel_nc( const UKernelConfig& ukernel_config, const LinearTilingParams& tiling_params, int m, @@ -273,7 +272,7 @@ inline void linear_operator_with_tile_schedule_policy_parallel_mc_parallel_nc( } } // namespace internal -inline void linear_operator( +void linear_operator( const UKernelConfig& ukernel_config, const LinearTilingParams& tiling_params, LinearTileSchedulingPolicy scheduling_policy, @@ -333,7 +332,7 @@ inline void linear_operator( } } -inline int get_activation_data_buffer_size( +int get_activation_data_buffer_size( const UKernelConfig& ukernel_config, const LinearTilingParams& tiling_params, LinearTileSchedulingPolicy scheduling_policy, @@ -355,38 +354,4 @@ inline int get_activation_data_buffer_size( } } // namespace - // torchao::operators::cpu::linear::channelwise_8bit_activation_groupwise_lowbit_weight - -// TODO: may move to different fil or namespace. This method is not part of the -// high-level interface, but specific to the universal kernels we wrote in -// torchao -#include -namespace torchao::operators::cpu::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight { -template - -inline UKernelConfig get_ukernel_config() { - UKernelConfig config; - - namespace ukernel = torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot; - config.mr = 1; - config.nr = 8; - config.activation_data_size_fn = - &ukernel::activation_data_size; - config.activation_data_alignment = 16; // size of neon register - config.prepare_activation_data_fn = - &ukernel::prepare_activation_data; - config.weight_data_size_fn = - &ukernel::weight_data_size; - config.weight_data_alignment = 16; // size of neon register - config.prepare_weight_data_fn = - &ukernel::prepare_weight_data; - config.kernel_fn = - &ukernel::kernel; - - return config; -} -} // namespace - // torchao::operators::cpu::linear::channelwise_8bit_activation_groupwise_lowbit_weight - // torchao::kernels::cpu::linear::channelwise_8bit_activation_groupwise_lowbit_weight + // torchao::ops::linear::channelwise_8bit_activation_groupwise_lowbit_weight diff --git a/torchao/experimental/kernels/cpu/linear/channelwise_8bit_activation_groupwise_lowbit_weight.h b/torchao/experimental/ops/linear/channelwise_8bit_activation_groupwise_lowbit_weight.h similarity index 81% rename from torchao/experimental/kernels/cpu/linear/channelwise_8bit_activation_groupwise_lowbit_weight.h rename to torchao/experimental/ops/linear/channelwise_8bit_activation_groupwise_lowbit_weight.h index 5d8f11b821..c92c94acfb 100644 --- a/torchao/experimental/kernels/cpu/linear/channelwise_8bit_activation_groupwise_lowbit_weight.h +++ b/torchao/experimental/ops/linear/channelwise_8bit_activation_groupwise_lowbit_weight.h @@ -7,8 +7,7 @@ #pragma once #include -// TODO: maybe move to operator directory -namespace torchao::operators::cpu::linear:: +namespace torchao::ops::linear:: channelwise_8bit_activation_groupwise_lowbit_weight { struct UKernelConfig { @@ -147,20 +146,4 @@ void linear_operator( float clamp_max); } // namespace - // torchao::operators::cpu::linear::channelwise_8bit_activation_groupwise_lowbit_weight - -// TODO: may move to different file or namespace -// It is not part of the high-level interface, but specific to the universal -// kernels in torchao. -// Kleidi will need to implement their own get_ukernel_config -// In future, we may build a high-level get_ukernel_config with CPU-runtime -// selection -namespace torchao::operators::cpu::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight { -template -UKernelConfig get_ukernel_config(); - -} // namespace - // torchao::operators::cpu::linear::channelwise_8bit_activation_groupwise_lowbit_weight - -#include + // torchao::ops::linear::channelwise_8bit_activation_groupwise_lowbit_weight diff --git a/torchao/experimental/ops/linear/examples/CMakeLists.txt b/torchao/experimental/ops/linear/examples/CMakeLists.txt new file mode 100644 index 0000000000..2b69adb3d8 --- /dev/null +++ b/torchao/experimental/ops/linear/examples/CMakeLists.txt @@ -0,0 +1,42 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +project(examples) + +cmake_minimum_required(VERSION 3.19) +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_BUILD_TYPE Release) + +include(CMakePrintHelpers) + +set(TORCHAO_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../..) +set(TORCHAO_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../../../../..) + +include_directories(${TORCHAO_INCLUDE_DIRS}) + +set(TORCHAO_PARALLEL_BACKEND "OPENMP") +add_subdirectory(${TORCHAO_ROOT}/ops/linear ${CMAKE_CURRENT_BINARY_DIR}/torchao_ops_linear_${TORCHAO_PARALLEL_BACKEND}) +add_subdirectory(${TORCHAO_ROOT}/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/torchao_kernels_aarch64) + +include(${TORCHAO_ROOT}/Utils.cmake) + +add_executable(separate_function_wrappers separate_function_wrappers.cpp) +target_link_libraries( + separate_function_wrappers + PRIVATE + torchao_kernels_aarch64 + torchao_ops_linear_${TORCHAO_PARALLEL_BACKEND} +) +target_link_torchao_parallel_backend(separate_function_wrappers "${TORCHAO_PARALLEL_BACKEND}") + +add_executable(stateful_class_wrapper stateful_class_wrapper.cpp) +target_link_libraries( + stateful_class_wrapper + PRIVATE + torchao_kernels_aarch64 + torchao_ops_linear_${TORCHAO_PARALLEL_BACKEND} +) +target_link_torchao_parallel_backend(stateful_class_wrapper "${TORCHAO_PARALLEL_BACKEND}") diff --git a/torchao/experimental/kernels/cpu/linear/examples/Channelwise8BitActivationGroupwiseLowbitWeightLinearOperator.h b/torchao/experimental/ops/linear/examples/Channelwise8BitActivationGroupwiseLowbitWeightLinearOperator.h similarity index 91% rename from torchao/experimental/kernels/cpu/linear/examples/Channelwise8BitActivationGroupwiseLowbitWeightLinearOperator.h rename to torchao/experimental/ops/linear/examples/Channelwise8BitActivationGroupwiseLowbitWeightLinearOperator.h index 575093f21b..a7755dadf4 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/Channelwise8BitActivationGroupwiseLowbitWeightLinearOperator.h +++ b/torchao/experimental/ops/linear/examples/Channelwise8BitActivationGroupwiseLowbitWeightLinearOperator.h @@ -5,26 +5,22 @@ // LICENSE file in the root directory of this source tree. #pragma once -#include -#include -#include +#include +#include +#include #include #include -namespace torchao::operators::cpu::linear:: +namespace torchao::ops::linear:: channelwise_8bit_activation_groupwise_lowbit_weight { class Channelwise8BitActivationGroupwiseLowbitWeightLinearOperator { private: - torchao::aligned_byte_ptr packed_weight_data_{ - nullptr, - nullptr}; + torchao::aligned_byte_ptr packed_weight_data_{nullptr, nullptr}; int packed_weight_data_size_{0}; int packed_weight_data_alignment_{0}; - torchao::aligned_byte_ptr activation_data_buffer_{ - nullptr, - nullptr}; + torchao::aligned_byte_ptr activation_data_buffer_{nullptr, nullptr}; int m_{0}; int n_{0}; @@ -114,7 +110,7 @@ class Channelwise8BitActivationGroupwiseLowbitWeightLinearOperator { get_packed_weight_data_size(ukernel_config_, n_, k_, group_size_); auto packed_weight_data_alignment = get_packed_weight_data_alignment(ukernel_config_); - + packed_weight_data_size_ = packed_weight_data_size; packed_weight_data_alignment_ = packed_weight_data_alignment; packed_weight_data_ = torchao::make_aligned_byte_ptr( @@ -199,4 +195,4 @@ class Channelwise8BitActivationGroupwiseLowbitWeightLinearOperator { } }; } // namespace - // torchao::operators::cpu::linear::channelwise_8bit_activation_groupwise_lowbit_weight + // torchao::ops::linear::channelwise_8bit_activation_groupwise_lowbit_weight diff --git a/torchao/experimental/kernels/cpu/linear/examples/build_and_run_examples.sh b/torchao/experimental/ops/linear/examples/build_and_run_examples.sh similarity index 67% rename from torchao/experimental/kernels/cpu/linear/examples/build_and_run_examples.sh rename to torchao/experimental/ops/linear/examples/build_and_run_examples.sh index 9c244e54cc..01185fdd3f 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/build_and_run_examples.sh +++ b/torchao/experimental/ops/linear/examples/build_and_run_examples.sh @@ -5,15 +5,11 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd) -export TORCHAO_LIBRARIES=${SCRIPT_DIR}/../../../../../.. - export CMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)')" echo "CMAKE_PREFIX_PATH: ${CMAKE_PREFIX_PATH}" -export CMAKE_OUT=/tmp/cmake-out/torch_ao/examples -cmake -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} \ - -DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH} \ - -S ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/linear/examples \ +export CMAKE_OUT=/tmp/cmake-out/torchao/examples +cmake -DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH} \ + -S . \ -B ${CMAKE_OUT} \ -DOpenMP_ROOT=$(brew --prefix libomp) cmake --build ${CMAKE_OUT} diff --git a/torchao/experimental/kernels/cpu/linear/examples/separate_function_wrappers.cpp b/torchao/experimental/ops/linear/examples/separate_function_wrappers.cpp similarity index 80% rename from torchao/experimental/kernels/cpu/linear/examples/separate_function_wrappers.cpp rename to torchao/experimental/ops/linear/examples/separate_function_wrappers.cpp index ba3e5b29b3..144fe5c08d 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/separate_function_wrappers.cpp +++ b/torchao/experimental/ops/linear/examples/separate_function_wrappers.cpp @@ -4,9 +4,11 @@ // This source code is licensed under the license found in the // LICENSE file in the root directory of this source tree. +#include #include -#include -#include +#include +#include +#include #include // This file contains an example of wrapping the torchao weight packing and // linear operators into two operators: one for weight packing and another @@ -20,9 +22,33 @@ // one stateful class, but not all surfaces support this (see // examples/stateful_class_wrapper.cpp for an example of this). -namespace torchao::operators::cpu::linear:: +namespace torchao::ops::linear:: channelwise_8bit_activation_groupwise_lowbit_weight { +template +UKernelConfig get_ukernel_config() { + UKernelConfig config; + + namespace ukernel = torchao::kernels::cpu::aarch64::linear:: + channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot; + config.mr = 1; + config.nr = 8; + config.activation_data_size_fn = + &ukernel::activation_data_size; + config.activation_data_alignment = 16; // size of neon register + config.prepare_activation_data_fn = + &ukernel::prepare_activation_data; + config.weight_data_size_fn = + &ukernel::weight_data_size; + config.weight_data_alignment = 16; // size of neon register + config.prepare_weight_data_fn = + &ukernel::prepare_weight_data; + config.kernel_fn = + &ukernel::kernel; + + return config; +} + torchao::aligned_byte_ptr pack_weight_data_operator( UKernelConfig ukernel_config, int n, @@ -115,10 +141,10 @@ void linear_operator( } } // namespace - // torchao::operators::cpu::linear::channelwise_8bit_activation_groupwise_lowbit_weight + // torchao::ops::linear::channelwise_8bit_activation_groupwise_lowbit_weight int main() { - using namespace torchao::operators::cpu::linear:: + using namespace torchao::ops::linear:: channelwise_8bit_activation_groupwise_lowbit_weight; torchao::set_num_threads(8); diff --git a/torchao/experimental/kernels/cpu/linear/examples/stateful_class_wrapper.cpp b/torchao/experimental/ops/linear/examples/stateful_class_wrapper.cpp similarity index 71% rename from torchao/experimental/kernels/cpu/linear/examples/stateful_class_wrapper.cpp rename to torchao/experimental/ops/linear/examples/stateful_class_wrapper.cpp index 5fb24c683d..c1cd2d110b 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/stateful_class_wrapper.cpp +++ b/torchao/experimental/ops/linear/examples/stateful_class_wrapper.cpp @@ -4,9 +4,10 @@ // This source code is licensed under the license found in the // LICENSE file in the root directory of this source tree. +#include #include -#include -#include +#include +#include #include #include @@ -21,9 +22,33 @@ // examples/separate_function_wrappers.cpp for an example of how to split the // operations into two steps. -using namespace torchao::operators::cpu::linear:: +using namespace torchao::ops::linear:: channelwise_8bit_activation_groupwise_lowbit_weight; +template +UKernelConfig get_ukernel_config() { + UKernelConfig config; + + namespace ukernel = torchao::kernels::cpu::aarch64::linear:: + channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot; + config.mr = 1; + config.nr = 8; + config.activation_data_size_fn = + &ukernel::activation_data_size; + config.activation_data_alignment = 16; // size of neon register + config.prepare_activation_data_fn = + &ukernel::prepare_activation_data; + config.weight_data_size_fn = + &ukernel::weight_data_size; + config.weight_data_alignment = 16; // size of neon register + config.prepare_weight_data_fn = + &ukernel::prepare_weight_data; + config.kernel_fn = + &ukernel::kernel; + + return config; +} + int main() { int m = 13; int n = 4096 + 1; @@ -54,6 +79,7 @@ int main() { std::cout << "Initializing linear_operator." << std::endl; auto ukernel_config = get_ukernel_config(); + auto linear_operator = Channelwise8BitActivationGroupwiseLowbitWeightLinearOperator( ukernel_config, diff --git a/torchao/experimental/ops/linear/linear_a8wxdq_op/CMakeLists.txt b/torchao/experimental/ops/linear/linear_a8wxdq_op/CMakeLists.txt new file mode 100644 index 0000000000..f69d884cd8 --- /dev/null +++ b/torchao/experimental/ops/linear/linear_a8wxdq_op/CMakeLists.txt @@ -0,0 +1,45 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +cmake_minimum_required(VERSION 3.19) + +include(${TORCHAO_ROOT}/Utils.cmake) + +if(TORCHAO_OP_TARGET STREQUAL "ATEN") + message(STATUS "Building with TORCHAO_OP_TARGET=ATEN") + find_package(Torch REQUIRED) + add_library(linear_a8wxdq_${TORCHAO_OP_TARGET} SHARED linear_a8wxdq_aten.cpp) + target_link_libraries(linear_a8wxdq_${TORCHAO_OP_TARGET} PRIVATE torchao_kernels_aarch64) + target_link_libraries(linear_a8wxdq_${TORCHAO_OP_TARGET} PRIVATE torchao_ops_linear_${TORCHAO_PARALLEL_BACKEND}) + target_include_directories(linear_a8wxdq_${TORCHAO_OP_TARGET} PRIVATE "${TORCH_INCLUDE_DIRS}") + target_link_libraries(linear_a8wxdq_${TORCHAO_OP_TARGET} PRIVATE "${TORCH_LIBRARIES}") + target_compile_definitions(linear_a8wxdq_${TORCHAO_OP_TARGET} PRIVATE USE_ATEN=1) +elseif(TORCHAO_OP_TARGET STREQUAL "EXECUTORCH") + message(STATUS "Building with TORCHAO_OP_TARGET=EXECUTORCH") + add_library(linear_a8wxdq_${TORCHAO_OP_TARGET} SHARED + linear_a8wxdq_executorch/w2s.cpp + linear_a8wxdq_executorch/w2sz.cpp + linear_a8wxdq_executorch/w3s.cpp + linear_a8wxdq_executorch/w3sz.cpp + linear_a8wxdq_executorch/w4s.cpp + linear_a8wxdq_executorch/w4sz.cpp + linear_a8wxdq_executorch/w5s.cpp + linear_a8wxdq_executorch/w5sz.cpp + ) + target_include_directories(linear_a8wxdq_${TORCHAO_OP_TARGET} PRIVATE ${EXECUTORCH_INCLUDE_DIRS}) + target_compile_definitions(linear_a8wxdq_${TORCHAO_OP_TARGET} PRIVATE USE_EXECUTORCH=1) + target_link_libraries(linear_a8wxdq_${TORCHAO_OP_TARGET} PRIVATE ${EXECUTORCH_LIBRARIES}) + target_link_libraries(linear_a8wxdq_${TORCHAO_OP_TARGET} PRIVATE torchao_kernels_aarch64) + target_link_libraries(linear_a8wxdq_${TORCHAO_OP_TARGET} PRIVATE torchao_ops_linear_${TORCHAO_PARALLEL_BACKEND}) +else() + message(FATAL_ERROR "Unknown TORCHAO_OP_TARGET: ${TORCHAO_OP_TARGET}. Please choose one of: ATEN, EXECUTORCH.") +endif() + + +install( + TARGETS linear_a8wxdq_${TORCHAO_OP_TARGET} + DESTINATION lib +) diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op-impl.h b/torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq-impl.h similarity index 87% rename from torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op-impl.h rename to torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq-impl.h index 01b1836981..eee51eafc6 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op-impl.h +++ b/torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq-impl.h @@ -5,7 +5,8 @@ // LICENSE file in the root directory of this source tree. #pragma once -#include +#include +#include #include #include @@ -28,6 +29,35 @@ using RuntimeContext = torch::executor::KernelRuntimeContext; #error "Must define either USE_ATEN or USE_EXECUTORCH" #endif +namespace { + +template +inline torchao::ops::linear:: + channelwise_8bit_activation_groupwise_lowbit_weight::UKernelConfig + get_ukernel_config() { + torchao::ops::linear::channelwise_8bit_activation_groupwise_lowbit_weight:: + UKernelConfig config; + + namespace ukernel = torchao::kernels::cpu::aarch64::linear:: + channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot; + config.mr = 1; + config.nr = 8; + config.activation_data_size_fn = + &ukernel::activation_data_size; + config.activation_data_alignment = 16; // size of neon register + config.prepare_activation_data_fn = + &ukernel::prepare_activation_data; + config.weight_data_size_fn = + &ukernel::weight_data_size; + config.weight_data_alignment = 16; // size of neon register + config.prepare_weight_data_fn = + &ukernel::prepare_weight_data; + config.kernel_fn = + &ukernel::kernel; + + return config; +} + #ifdef USE_ATEN template Tensor pack_weights_cpu( @@ -69,7 +99,7 @@ Tensor pack_weights_cpu( weight_zeros_ptr = weight_zeros.value().const_data_ptr(); } - using namespace torchao::operators::cpu::linear:: + using namespace torchao::ops::linear:: channelwise_8bit_activation_groupwise_lowbit_weight; auto ukernel_config = get_ukernel_config< @@ -137,7 +167,7 @@ Tensor pack_weights_meta( int n = weight_qvals.size(0); int k = weight_qvals.size(1); - using namespace torchao::operators::cpu::linear:: + using namespace torchao::ops::linear:: channelwise_8bit_activation_groupwise_lowbit_weight; auto ukernel_config = get_ukernel_config< @@ -221,7 +251,7 @@ Tensor linear_out_cpu( CHECK_MSG(out.size(1) == n, "out shape is incorrect"); #endif // USE_EXECUTORCH - using namespace torchao::operators::cpu::linear:: + using namespace torchao::ops::linear:: channelwise_8bit_activation_groupwise_lowbit_weight; auto ukernel_config = get_ukernel_config< @@ -311,3 +341,5 @@ Tensor linear_meta( return torch::empty({m, n}).to("meta"); } #endif // USE_ATEN + +} // namespace diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_aten.cpp b/torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_aten.cpp similarity index 98% rename from torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_aten.cpp rename to torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_aten.cpp index 626b3e769f..b1d464e5b5 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_aten.cpp +++ b/torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_aten.cpp @@ -4,7 +4,7 @@ // This source code is licensed under the license found in the // LICENSE file in the root directory of this source tree. -#include +#include #define DEFINE_OP(weight_nbit) \ m.def( \ diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w2s.cpp b/torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w2s.cpp similarity index 90% rename from torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w2s.cpp rename to torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w2s.cpp index 592a0190a9..c6ef089995 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w2s.cpp +++ b/torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w2s.cpp @@ -8,7 +8,7 @@ // EXECUTORCH_LIBRARY per cpp file due to a name redefinition error, so a new // file is needed for each variant -#include +#include namespace { Tensor _op_out( diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w2sz.cpp b/torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w2sz.cpp similarity index 90% rename from torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w2sz.cpp rename to torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w2sz.cpp index d2683b36ce..e569e05812 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w2sz.cpp +++ b/torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w2sz.cpp @@ -8,7 +8,7 @@ // EXECUTORCH_LIBRARY per cpp file due to a name redefinition error, so a new // file is needed for each variant -#include +#include namespace { Tensor _op_out( diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w3s.cpp b/torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w3s.cpp similarity index 90% rename from torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w3s.cpp rename to torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w3s.cpp index d59db3e1c7..9f236bd7b3 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w3s.cpp +++ b/torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w3s.cpp @@ -8,7 +8,7 @@ // EXECUTORCH_LIBRARY per cpp file due to a name redefinition error, so a new // file is needed for each variant -#include +#include namespace { Tensor _op_out( diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w3sz.cpp b/torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w3sz.cpp similarity index 90% rename from torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w3sz.cpp rename to torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w3sz.cpp index 7458311b91..24a381fdcc 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w3sz.cpp +++ b/torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w3sz.cpp @@ -8,7 +8,7 @@ // EXECUTORCH_LIBRARY per cpp file due to a name redefinition error, so a new // file is needed for each variant -#include +#include namespace { Tensor _op_out( diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w4s.cpp b/torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w4s.cpp similarity index 90% rename from torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w4s.cpp rename to torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w4s.cpp index 75143050fa..67263d209d 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w4s.cpp +++ b/torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w4s.cpp @@ -8,7 +8,7 @@ // EXECUTORCH_LIBRARY per cpp file due to a name redefinition error, so a new // file is needed for each variant -#include +#include namespace { Tensor _op_out( diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w4sz.cpp b/torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w4sz.cpp similarity index 90% rename from torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w4sz.cpp rename to torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w4sz.cpp index 714192a19b..530ff44370 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w4sz.cpp +++ b/torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w4sz.cpp @@ -8,7 +8,7 @@ // EXECUTORCH_LIBRARY per cpp file due to a name redefinition error, so a new // file is needed for each variant -#include +#include namespace { Tensor _op_out( diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w5s.cpp b/torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w5s.cpp similarity index 90% rename from torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w5s.cpp rename to torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w5s.cpp index 08c2d42ee8..de04a09f6a 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w5s.cpp +++ b/torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w5s.cpp @@ -8,7 +8,7 @@ // EXECUTORCH_LIBRARY per cpp file due to a name redefinition error, so a new // file is needed for each variant -#include +#include namespace { Tensor _op_out( diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w5sz.cpp b/torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w5sz.cpp similarity index 90% rename from torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w5sz.cpp rename to torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w5sz.cpp index c1e3e953d3..91c5a16312 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w5sz.cpp +++ b/torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w5sz.cpp @@ -8,7 +8,7 @@ // EXECUTORCH_LIBRARY per cpp file due to a name redefinition error, so a new // file is needed for each variant -#include +#include namespace { Tensor _op_out( diff --git a/torchao/experimental/ops/linear/tests/CMakeLists.txt b/torchao/experimental/ops/linear/tests/CMakeLists.txt new file mode 100644 index 0000000000..866d832ccd --- /dev/null +++ b/torchao/experimental/ops/linear/tests/CMakeLists.txt @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +cmake_minimum_required(VERSION 3.19) +project(tests) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_BUILD_TYPE Debug) +add_compile_options("-Wall" "-Werror") + +set(TORCHAO_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../..) +set(TORCHAO_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../../../../..) + +include(FetchContent) +FetchContent_Declare( + googletest + URL https://github.com/google/googletest/archive/03597a01ee50ed33e9dfd640b249b4be3799d395.zip +) +FetchContent_MakeAvailable(googletest) +enable_testing() + +include_directories(${TORCHAO_INCLUDE_DIRS}) + +set(TORCHAO_PARALLEL_BACKEND "TEST_DUMMY") +add_subdirectory(${TORCHAO_ROOT}/ops/linear ${CMAKE_CURRENT_BINARY_DIR}/torchao_ops_linear_${TORCHAO_PARALLEL_BACKEND}) +add_subdirectory(${TORCHAO_ROOT}/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/torchao_kernels_aarch64) + +include(${TORCHAO_ROOT}/Utils.cmake) +add_executable(test_linear_operator test_linear_operator.cpp) +target_link_libraries( + test_linear_operator + PRIVATE + GTest::gtest_main + torchao_kernels_aarch64 + torchao_ops_linear_${TORCHAO_PARALLEL_BACKEND} +) +target_link_torchao_parallel_backend(test_linear_operator "${TORCHAO_PARALLEL_BACKEND}") + +include(GoogleTest) +gtest_discover_tests(test_linear_operator) diff --git a/torchao/experimental/ops/linear/tests/build_and_run_tests.sh b/torchao/experimental/ops/linear/tests/build_and_run_tests.sh new file mode 100644 index 0000000000..3fbe78c172 --- /dev/null +++ b/torchao/experimental/ops/linear/tests/build_and_run_tests.sh @@ -0,0 +1,14 @@ +#!/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +export CMAKE_OUT=/tmp/cmake-out/torchao/tests +cmake -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} -S . -B ${CMAKE_OUT} + +cmake --build ${CMAKE_OUT} + +# Run +${CMAKE_OUT}/test_linear_operator diff --git a/torchao/experimental/kernels/cpu/linear/tests/test_linear_operator.cpp b/torchao/experimental/ops/linear/tests/test_linear_operator.cpp similarity index 78% rename from torchao/experimental/kernels/cpu/linear/tests/test_linear_operator.cpp rename to torchao/experimental/ops/linear/tests/test_linear_operator.cpp index 5408e426bf..6d563111cc 100644 --- a/torchao/experimental/kernels/cpu/linear/tests/test_linear_operator.cpp +++ b/torchao/experimental/ops/linear/tests/test_linear_operator.cpp @@ -1,22 +1,52 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. #include // TODO: move test_utils.h out of aarch64 +#include #include -#include -#include -#include +#include +#include +#include const float kTol = 1.0e-5; +using namespace torchao::ops::linear:: + channelwise_8bit_activation_groupwise_lowbit_weight; + +template +UKernelConfig get_ukernel_config() { + UKernelConfig config; + + namespace ukernel = torchao::kernels::cpu::aarch64::linear:: + channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot; + config.mr = 1; + config.nr = 8; + config.activation_data_size_fn = + &ukernel::activation_data_size; + config.activation_data_alignment = 16; // size of neon register + config.prepare_activation_data_fn = + &ukernel::prepare_activation_data; + config.weight_data_size_fn = + &ukernel::weight_data_size; + config.weight_data_alignment = 16; // size of neon register + config.prepare_weight_data_fn = + &ukernel::prepare_weight_data; + config.kernel_fn = + &ukernel::kernel; + + return config; +} + template void test_channelwise_8bit_activation_groupwise_lowbit_weight( int m, int n, int k, int group_size) { - using namespace torchao::operators::cpu::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight; auto ukernel_config = get_ukernel_config(); @@ -47,7 +77,7 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight( get_packed_weight_data_size(ukernel_config, n, k, group_size); auto packed_weight_data_alignment = get_packed_weight_data_alignment(ukernel_config); - auto packed_weight_data = torchao::make_aligned_byte_array_unique_ptr( + auto packed_weight_data = torchao::make_aligned_byte_ptr( packed_weight_data_alignment, packed_weight_data_size); pack_weight_data_operator( @@ -74,7 +104,7 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight( group_size); auto activation_data_buffer_alignment = get_activation_data_buffer_alignment(ukernel_config); - auto activation_data_buffer = torchao::make_aligned_byte_array_unique_ptr( + auto activation_data_buffer = torchao::make_aligned_byte_ptr( activation_data_buffer_alignment, activation_data_buffer_size); // Run linear @@ -153,9 +183,6 @@ TEST( int n = 1; int k = 16 + 1; int group_size = 16; - - using namespace torchao::operators::cpu::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight; auto ukernel_config = get_ukernel_config< 3 /*weight_nbit*/, true /*has_weight_zeros*/, @@ -187,8 +214,6 @@ TEST( int k = 20; int group_size = 10; - using namespace torchao::operators::cpu::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight; auto ukernel_config = get_ukernel_config< 3 /*weight_nbit*/, true /*has_weight_zeros*/, diff --git a/torchao/experimental/kernels/cpu/macro.h b/torchao/experimental/ops/macro.h similarity index 100% rename from torchao/experimental/kernels/cpu/macro.h rename to torchao/experimental/ops/macro.h diff --git a/torchao/experimental/kernels/cpu/memory.h b/torchao/experimental/ops/memory.h similarity index 100% rename from torchao/experimental/kernels/cpu/memory.h rename to torchao/experimental/ops/memory.h diff --git a/torchao/experimental/kernels/cpu/parallel-aten-impl.h b/torchao/experimental/ops/parallel-aten-impl.h similarity index 100% rename from torchao/experimental/kernels/cpu/parallel-aten-impl.h rename to torchao/experimental/ops/parallel-aten-impl.h diff --git a/torchao/experimental/kernels/cpu/parallel-openmp-impl.h b/torchao/experimental/ops/parallel-openmp-impl.h similarity index 100% rename from torchao/experimental/kernels/cpu/parallel-openmp-impl.h rename to torchao/experimental/ops/parallel-openmp-impl.h diff --git a/torchao/experimental/kernels/cpu/parallel-pthreadpool-impl.h b/torchao/experimental/ops/parallel-pthreadpool-impl.h similarity index 100% rename from torchao/experimental/kernels/cpu/parallel-pthreadpool-impl.h rename to torchao/experimental/ops/parallel-pthreadpool-impl.h diff --git a/torchao/experimental/kernels/cpu/parallel-single_threaded-impl.h b/torchao/experimental/ops/parallel-single_threaded-impl.h similarity index 100% rename from torchao/experimental/kernels/cpu/parallel-single_threaded-impl.h rename to torchao/experimental/ops/parallel-single_threaded-impl.h diff --git a/torchao/experimental/kernels/cpu/parallel-test_dummy-impl.h b/torchao/experimental/ops/parallel-test_dummy-impl.h similarity index 100% rename from torchao/experimental/kernels/cpu/parallel-test_dummy-impl.h rename to torchao/experimental/ops/parallel-test_dummy-impl.h diff --git a/torchao/experimental/kernels/cpu/parallel.h b/torchao/experimental/ops/parallel.h similarity index 73% rename from torchao/experimental/kernels/cpu/parallel.h rename to torchao/experimental/ops/parallel.h index 0d12c3acf9..e3949b8551 100644 --- a/torchao/experimental/kernels/cpu/parallel.h +++ b/torchao/experimental/ops/parallel.h @@ -10,7 +10,7 @@ namespace torchao { // F has signature [&](int64_t idx) template -void parallel_1d(const int64_t begin, const int64_t end, const F& f); +void parallel_1d(const int64_t begin, const int64_t end, const F& f); void set_num_threads(int num_threads); @@ -18,16 +18,17 @@ int get_num_threads(); } // namespace torchao - #ifdef TORCHAO_PARALLEL_ATEN #pragma message("TORCHAO_PARALLEL_ATEN is set. Using ATen parallel backend.") #ifndef INTRA_OP_PARALLEL - #pragma message("INTRA_OP_PARALLEL is not set; TORCHAO_PARALLEL_ATEN may be single-threaded.") +#pragma message( \ + "INTRA_OP_PARALLEL is not set; TORCHAO_PARALLEL_ATEN may be single-threaded.") #endif #ifndef AT_PARALLEL_OPENMP - #pragma message("AT_PARALLEL_OPENMP is not set; TORCHAO_PARALLEL_ATEN may be single-threaded.") +#pragma message( \ + "AT_PARALLEL_OPENMP is not set; TORCHAO_PARALLEL_ATEN may be single-threaded.") #endif -#include +#include #else #ifdef TORCHAO_PARALLEL_EXECUTORCH @@ -40,24 +41,25 @@ int get_num_threads(); #ifdef TORCHAO_PARALLEL_PTHREADPOOL #pragma message( \ "TORCHAO_PARALLEL_PTHREADPOOL is set. Using pthreadpool parallel backend.") -#include +#include #else #ifdef TORCHAO_PARALLEL_OPENMP -#pragma message("TORCHAO_PARALLEL_OPENMP is set. Using OPENMP parallel backend.") -#include +#pragma message( \ + "TORCHAO_PARALLEL_OPENMP is set. Using OPENMP parallel backend.") +#include #else #if defined TORCHAO_PARALLEL_SINGLE_THREADED #pragma message( \ "TORCHAO_PARALLEL_SINGLE_THREADED is set. Using single-threaded parallel backend.") -#include +#include #else #if defined TORCHAO_PARALLEL_TEST_DUMMY #pragma message( \ "TORCHAO_PARALLEL_TEST_DUMMY is set. Using test dummy parallel backend.") -#include +#include #else #error \ diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/test_int8_dyn_act_intx_weight_quantizer.py b/torchao/experimental/tests/test_int8_dyn_act_intx_weight_quantizer.py similarity index 63% rename from torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/test_int8_dyn_act_intx_weight_quantizer.py rename to torchao/experimental/tests/test_int8_dyn_act_intx_weight_quantizer.py index 513088d2f0..d431d26939 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/test_int8_dyn_act_intx_weight_quantizer.py +++ b/torchao/experimental/tests/test_int8_dyn_act_intx_weight_quantizer.py @@ -11,18 +11,19 @@ import sys import unittest +import tempfile import torch sys.path.insert( - 0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../..")) + 0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) ) from quant_api import ( _Int8DynActIntxWeightQuantizedLinearFallback, Int8DynActIntxWeightQuantizer, ) -libs = glob.glob("/tmp/cmake-out/torchao/liblowbit_op_aten.*") +libs = glob.glob("/tmp/cmake-out/torchao/lib/liblinear_a8wxdq_ATEN.*") libs = list(filter(lambda l: (l.endswith("so") or l.endswith("dylib")), libs)) if len(libs) == 0: print( @@ -73,7 +74,49 @@ def test_accuracy(self): # Assert at most 5% of entries are not close at a low tolerance self.assertTrue(num_mismatch_at_low_tol / num_total <= 0.05) + + def test_export_compile_aoti(self): + group_size = 32 + m = 1 + n = 256 + k = 256 + nbit = 4 + has_weight_zeros = False + n_layers = 3 + layers = [torch.nn.Linear(k, n, bias=False) for _ in range(n_layers)] + model = torch.nn.Sequential(*layers) + + activations = torch.randn(m, k, dtype=torch.float32) + + print("Quantizing model") + quantizer = Int8DynActIntxWeightQuantizer( + device="cpu", + precision=torch.float32, + bitwidth=nbit, + groupsize=group_size, + has_weight_zeros=has_weight_zeros, + ) + quantized_model = quantizer.quantize(model) + + print("Exporting quantized model") + exported = torch.export.export(quantized_model, (activations,)) + + print("Compiling quantized model") + quantized_model_compiled = torch.compile(quantized_model) + with torch.no_grad(): + quantized_model_compiled(activations) + + with tempfile.TemporaryDirectory() as tmpdirname: + print("Exporting quantized model with AOTI") + torch._export.aot_compile( + quantized_model, + (activations,), + options={"aot_inductor.output_path": f"{tmpdirname}/model.so"}, + ) + print("Running quantized model in AOTI") + fn = torch._export.aot_load(f"{tmpdirname}/model.so", "cpu") + fn(activations) if __name__ == "__main__": unittest.main() From efd9bb94fdbee8dba55e818f3d51949ea5f81f11 Mon Sep 17 00:00:00 2001 From: "Wei (Will) Feng" <134637289+weifengpy@users.noreply.github.com> Date: Wed, 25 Sep 2024 19:59:19 -0700 Subject: [PATCH 30/41] [float8] all-reduce amax on dp mesh instead of global pg (#933) * [float8] all-reduce amax on dp mesh instead of global pg Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * liner Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * improve comments Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * move hp tensor inside if Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * linter Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * linter Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * linter Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * linter Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * linter Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/float8/test_fsdp2/test_fsdp2.py | 32 +++++++++++++++++++++++++- torchao/float8/float8_scaling_utils.py | 3 ++- torchao/float8/float8_utils.py | 14 +++++++---- torchao/float8/fsdp_utils.py | 1 + 4 files changed, 44 insertions(+), 6 deletions(-) diff --git a/test/float8/test_fsdp2/test_fsdp2.py b/test/float8/test_fsdp2/test_fsdp2.py index ecde051e36..1ad5586513 100644 --- a/test/float8/test_fsdp2/test_fsdp2.py +++ b/test/float8/test_fsdp2/test_fsdp2.py @@ -17,10 +17,12 @@ import torch.nn as nn from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType from torchao.float8.float8_linear_utils import convert_to_float8_training +from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor from torchao.testing.float8.fsdp2_utils import check_parity_bf16_mp, check_parity_no_mp from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy -from torch.distributed._tensor import DTensor +from torch.distributed._tensor import DTensor, init_device_mesh +from torchao.float8.float8_tensor import GemmInputRole from torch.testing._internal.common_cuda import TEST_CUDA from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import ( @@ -293,6 +295,34 @@ def _get_curr_active_memory_mb(self) -> int: return round(mem_stats["active_bytes.all.current"] / 1e6) +class Test2DParallelMultiThread(FSDPTestMultiThread, TestFloat8Common): + @property + def world_size(self) -> int: + return 4 + + def test_amax_allreduce_device_mesh(self): + dp_size = 2 + pp_size = self.world_size // dp_size + global_mesh = init_device_mesh("cuda", (pp_size, dp_size), mesh_dim_names=("pp", "dp")) + dp_mesh = global_mesh["dp"] + pp_mesh = global_mesh["pp"] + + if self.rank in [0, 1]: + # rank 0 and 1 are the 1st stage in the pipeline + # rank 2 and 4 are doing nothing but waiting for the 1st stage + torch.manual_seed(42 + self.rank) + hp_tensor = torch.randn(768, 32, device="cuda") + float8_tensor = hp_tensor_to_float8_dynamic( + hp_tensor, + torch.float8_e4m3fn, + Float8LinearConfig( + cast_config_weight=CastConfig(scaling_type=ScalingType.DYNAMIC), + ), + gemm_input_role=GemmInputRole.WEIGHT, + reduce_amax=True, + device_mesh=dp_mesh + ) + class TestFloat8MultiThread(FSDPTestMultiThread, TestFloat8Common): @property def world_size(self) -> int: diff --git a/torchao/float8/float8_scaling_utils.py b/torchao/float8/float8_scaling_utils.py index d2ae896320..e9e1951763 100644 --- a/torchao/float8/float8_scaling_utils.py +++ b/torchao/float8/float8_scaling_utils.py @@ -36,6 +36,7 @@ def hp_tensor_to_float8_dynamic( linear_mm_config: LinearMMConfig, reduce_amax: bool = False, gemm_input_role: GemmInputRole = GemmInputRole.INPUT, + device_mesh = None, ) -> Float8Tensor: """ Given a high precision tensor `hp_tensor`, @@ -52,7 +53,7 @@ def hp_tensor_to_float8_dynamic( """ if tensor_already_casted_to_fp8(hp_tensor): return hp_tensor - scale = tensor_to_scale(hp_tensor, float8_dtype, reduce_amax) + scale = tensor_to_scale(hp_tensor, float8_dtype, reduce_amax, device_mesh) return hp_tensor_and_scale_to_float8( hp_tensor, scale, diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index 535c870890..d8ad315f16 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -100,23 +100,29 @@ def amax_history_to_scale_stack( @torch.no_grad() -def tensor_to_amax(x: torch.Tensor, reduce_amax: bool = False) -> torch.Tensor: +def tensor_to_amax( + x: torch.Tensor, reduce_amax: bool = False, device_mesh=None +) -> torch.Tensor: amax = torch.max(torch.abs(x)) # If the user asked for distributed reduction, do it. # If the user did not ask for it, assume that it will # happen elsewhere. if reduce_amax and dist.is_initialized(): - dist.all_reduce(amax, op=dist.ReduceOp.MAX) + pg = device_mesh.get_group() if device_mesh is not None else None + dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=pg) return amax @torch.no_grad() def tensor_to_scale( - x: torch.Tensor, float8_dtype: torch.dtype, reduce_amax: bool = False + x: torch.Tensor, + float8_dtype: torch.dtype, + reduce_amax: bool = False, + device_mesh=None, ) -> torch.Tensor: - amax = tensor_to_amax(x, reduce_amax=reduce_amax) + amax = tensor_to_amax(x, reduce_amax=reduce_amax, device_mesh=device_mesh) return amax_to_scale(amax, float8_dtype, x.dtype) diff --git a/torchao/float8/fsdp_utils.py b/torchao/float8/fsdp_utils.py index 19386d932b..d3c0f73c6c 100644 --- a/torchao/float8/fsdp_utils.py +++ b/torchao/float8/fsdp_utils.py @@ -216,6 +216,7 @@ def fsdp_pre_all_gather(self, mesh): self._linear_mm_config, reduce_amax=True, gemm_input_role=GemmInputRole.WEIGHT, + device_mesh=mesh, ) return (float8_tensor._data,), (float8_tensor._scale,) From 85126cc8f65656705b83755626133603b7d961ce Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Wed, 25 Sep 2024 21:16:06 -0700 Subject: [PATCH 31/41] int8 dynamic quant + bsr support (#821) This PR, adds in int8 dynamicquant + bsr support. Changes: * Use i8i8 -> bf16 matmul to maintain accuracy * Added a block sparse layout type to AffineQuantizedTensor + check/impl. * Cleaned up benchmark.py script and add a single line `benchmark.sh` file for acceleration numbers * Updated eval.py and added a single line `evaluate.sh` file for accuracy numbers * Lots of lint formatting and README updates * torch.compile now working and is correct --- test/sparsity/test_sparse_api.py | 139 +++- torchao/dtypes/affine_quantized_tensor.py | 183 +++++- .../sparsity/prototype/superblock/README.md | 162 +---- .../prototype/superblock/benchmark.py | 167 +++-- .../prototype/superblock/benchmark.sh | 39 ++ .../superblock/benchmark_results.txt | 30 + .../prototype/superblock/blocksparse.py | 132 +++- .../sparsity/prototype/superblock/evaluate.py | 131 ++-- .../sparsity/prototype/superblock/evaluate.sh | 23 + .../superblock/evaluation_results.txt | 19 + .../sparsity/prototype/superblock/presets.py | 73 --- .../sparsity/prototype/superblock/sampler.py | 64 -- .../sparsity/prototype/superblock/train.py | 382 +++++------ .../prototype/superblock/transforms.py | 185 ------ .../sparsity/prototype/superblock/utils.py | 619 ++++++++++++++++-- torchao/sparsity/sparse_api.py | 16 +- 16 files changed, 1442 insertions(+), 922 deletions(-) create mode 100644 torchao/sparsity/prototype/superblock/benchmark.sh create mode 100644 torchao/sparsity/prototype/superblock/benchmark_results.txt create mode 100644 torchao/sparsity/prototype/superblock/evaluate.sh create mode 100644 torchao/sparsity/prototype/superblock/evaluation_results.txt delete mode 100644 torchao/sparsity/prototype/superblock/presets.py delete mode 100644 torchao/sparsity/prototype/superblock/sampler.py delete mode 100644 torchao/sparsity/prototype/superblock/transforms.py diff --git a/test/sparsity/test_sparse_api.py b/test/sparsity/test_sparse_api.py index a54d902f1f..9a23a49520 100644 --- a/test/sparsity/test_sparse_api.py +++ b/test/sparsity/test_sparse_api.py @@ -4,27 +4,24 @@ import torch from torch import nn - -from torchao.sparsity import ( - apply_fake_sparsity, - sparsify_, - semi_sparse_weight, -) +from torch.testing._internal import common_utils from torchao.dtypes import MarlinSparseLayoutType, SemiSparseLayoutType from torchao.quantization.quant_api import ( + int4_weight_only, int8_dynamic_activation_int8_weight, quantize_, - int4_weight_only, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 -from torch.testing._internal.common_utils import TestCase + +from torchao.sparsity import apply_fake_sparsity, semi_sparse_weight, sparsify_ +from torchao.utils import TORCH_VERSION_AFTER_2_5, TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_4 logging.basicConfig( format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO ) -class TestSemiStructuredSparse(TestCase): + +class TestSemiStructuredSparse(common_utils.TestCase): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "pytorch 2.3+ feature") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @@ -37,6 +34,7 @@ def test_sparse(self): ) .half() .cuda() + .eval() ) apply_fake_sparsity(model) @@ -45,13 +43,17 @@ def test_sparse(self): sparsify_(model, semi_sparse_weight()) sparse_result = model(input) - assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3) -class TestQuantSemiSparse(TestCase): - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "pytorch 2.3+ feature") +class TestQuantSemiSparse(common_utils.TestCase): + + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "pytorch 2.5+ feature") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - def test_quant_semi_sparse(self): + @common_utils.parametrize("compile", [True, False]) + def test_quant_semi_sparse(self, compile): + torch.sparse.SparseSemiStructuredTensor._FORCE_CUTLASS = False + input = torch.rand((128, 128)).half().cuda() model = ( nn.Sequential( @@ -60,19 +62,27 @@ def test_quant_semi_sparse(self): ) .half() .cuda() + .eval() ) apply_fake_sparsity(model) model_copy = copy.deepcopy(model) quantize_(model_copy, int8_dynamic_activation_int8_weight()) dense_result = model_copy(input) - quantize_(model, int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType())) + quantize_( + model, + int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType()), + ) + if compile: + model = torch.compile(model) sparse_result = model(input) - assert torch.allclose(dense_result, sparse_result, rtol=1e-2, atol=1e-2) + torch.testing.assert_close(dense_result, sparse_result, rtol=1e-2, atol=1e-2) + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "pytorch 2.5+ feature") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - def test_sparse_marlin(self): + @common_utils.parametrize("compile", [True, False]) + def test_sparse_marlin(self, compile): input = torch.rand((256, 256)).half().cuda() model = ( nn.Sequential( @@ -81,6 +91,7 @@ def test_sparse_marlin(self): ) .half() .cuda() + .eval() ) apply_fake_sparsity(model) @@ -92,9 +103,101 @@ def test_sparse_marlin(self): # Sparse + quantized quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType())) + if compile: + model = torch.compile(model) sparse_result = model(input) - assert torch.allclose(dense_result, sparse_result, atol=3e-1), "Results are not close" + torch.testing.assert_close(dense_result, sparse_result, atol=3e-1, rtol=3e-1) + + +class TestBlockSparseWeight(common_utils.TestCase): + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "pytorch 2.4+ feature due to need for custom op support") + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @common_utils.parametrize("compile", [True, False]) + def test_sparse(self, compile): + input = torch.rand((1024, 1024)).half().cuda() + model = ( + nn.Sequential( + nn.Linear(1024, 2048), + nn.Linear(2048, 1024), + ) + .half() + .cuda() + .eval() + ) + + from torchao.sparsity.utils import create_block_sparse_tensor + + M, N = model[0].weight.shape + model[0].weight.data = create_block_sparse_tensor(M, N, 64, 0.5, torch.float16) + M, N = model[1].weight.shape + model[1].weight.data = create_block_sparse_tensor(M, N, 64, 0.5, torch.float16) + dense_result = model(input) + + from torchao.sparsity.prototype.superblock.blocksparse import ( + block_sparse_weight, + ) + + sparsify_(model, block_sparse_weight(blocksize=64)) + # if compile: + # model = torch.compile(model) + sparse_result = model(input) + + torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3) + + +class TestQuantBlockSparseWeight(common_utils.TestCase): + @unittest.skipIf(not TORCH_VERSION_AFTER_2_5, "pytorch 2.6+ feature") + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @common_utils.parametrize("compile", [True, False]) + def test_sparse(self, compile): + input = torch.rand((256, 128)).to(torch.bfloat16).cuda() + model = ( + nn.Sequential( + nn.Linear(128, 256), + nn.Linear(256, 128), + ) + .to(torch.bfloat16) + .cuda() + .eval() + ) + from torchao.sparsity.prototype.superblock.blocksparse import ( + blocksparse_int_addmm, + ) + from torchao.sparsity.utils import create_block_sparse_tensor + + M, N = model[0].weight.shape + model[0].weight.data = ( + create_block_sparse_tensor(M, N, 64, 0.5, torch.bfloat16) + * torch.rand(M, N, dtype=torch.bfloat16).cuda() + ) + M, N = model[1].weight.shape + model[1].weight.data = create_block_sparse_tensor(M, N, 64, 0.5, torch.bfloat16) + + model_copy = copy.deepcopy(model) + + quantize_(model_copy, int8_dynamic_activation_int8_weight()) + reference = model_copy(input) + + from torchao.dtypes.affine_quantized_tensor import BlockSparseLayoutType + + quantize_( + model, + int8_dynamic_activation_int8_weight( + layout_type=BlockSparseLayoutType(blocksize=64) + ), + ) + if compile: + model = torch.compile(model) + sparse_result = model(input) + + torch.testing.assert_close(reference, sparse_result, rtol=1e-1, atol=1e-1) + + +common_utils.instantiate_parametrized_tests(TestSemiStructuredSparse) +common_utils.instantiate_parametrized_tests(TestQuantSemiSparse) +common_utils.instantiate_parametrized_tests(TestBlockSparseWeight) +common_utils.instantiate_parametrized_tests(TestQuantBlockSparseWeight) if __name__ == "__main__": unittest.main() diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index e00576263f..43ee82ffaa 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -47,7 +47,6 @@ from torchao.float8.inference import Float8MMConfig aten = torch.ops.aten - ############################### # Base Layout Tensor Subclass # ############################### @@ -473,6 +472,11 @@ def pre_process(self, input: torch.Tensor) -> torch.Tensor: return temp +@dataclass(frozen=True) +class BlockSparseLayoutType(LayoutType): + blocksize: int = 64 + + @dataclass(frozen=True) class TensorCoreTiledLayoutType(LayoutType): inner_k_tiles: int = 8 @@ -669,6 +673,145 @@ def from_plain( int_data_compressed = torch._cslt_compress(int_data) return cls(int_data_compressed, scale, zero_point, layout_type) +@register_layout_cls(BlockSparseLayoutType) +class BlockSparseAQTLayout(PlainAQTLayout): + bsr_crow_indices: Optional[torch.Tensor] + bsr_col_indices: Optional[torch.Tensor] + bsr_values: Optional[torch.Tensor] + scale: Optional[torch.Tensor] + zero_point: Optional[torch.Tensor] + + __slots__ = ["bsr_crow_indices", "bsr_col_indices", "bsr_values", "scale", "zero_point"] + + @staticmethod + def __new__( # noqa: PYI034 + cls, + shape: torch.Size, + bsr_crow_indices: Optional[torch.Tensor], + bsr_col_indices: Optional[torch.Tensor], + bsr_values: Optional[torch.Tensor], + scale: Optional[torch.Tensor], + zero_point: Optional[torch.Tensor], + layout_type: LayoutType, + requires_grad: bool = False, + ): + if bsr_values is None: + raise ValueError("bsr values must be provided!") + else: + previous_tensor = bsr_values + + kwargs = { + "device": previous_tensor.device, + "dtype": previous_tensor.dtype, + "layout": previous_tensor.layout, + "requires_grad": requires_grad, + } + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( # noqa: PYI034 + self, + shape: torch.Size, + bsr_crow_indices: Optional[torch.Tensor], + bsr_col_indices: Optional[torch.Tensor], + bsr_values: Optional[torch.Tensor], + scale: Optional[torch.Tensor], + zero_point: Optional[torch.Tensor], + layout_type: LayoutType, + requires_grad: bool = False, + ): + self.bsr_crow_indices = bsr_crow_indices + self.bsr_col_indices = bsr_col_indices + self.bsr_values = bsr_values + self.scale = scale + self.zero_point = zero_point + self.layout_type = layout_type + + def __tensor_flatten__(self): + inner_tensors = list( + filter(lambda x: getattr(self, x) is not None, self.__slots__) + ) + tensor_meta = (self.shape, self.layout_type, self.requires_grad) + return inner_tensors, tensor_meta + + @classmethod + def __tensor_unflatten__( + cls, + inner_tensors, + tensor_meta: Tuple[torch.Size, bool], + outer_size, + outer_stride, + ) -> torch.Tensor: + shape, layout_type, requires_grad = tensor_meta + return cls( + shape=shape, + bsr_crow_indices=inner_tensors.get("bsr_crow_indices", None), + bsr_col_indices=inner_tensors.get("bsr_col_indices", None), + bsr_values=inner_tensors.get("bsr_values", None), + scale=inner_tensors.get("scale", None), + zero_point=inner_tensors.get("zero_point", None), + layout_type=layout_type, + requires_grad=requires_grad, + ) + + @classmethod + def from_plain(cls, int_data, scale, zero_point, layout_type): + bsr_tensor = int_data.to_sparse_bsr(layout_type.blocksize) + return cls( + shape=int_data.shape, + bsr_crow_indices=bsr_tensor.crow_indices(), + bsr_col_indices=bsr_tensor.col_indices(), + bsr_values=bsr_tensor.values(), + scale=scale, + zero_point=zero_point, + layout_type = layout_type, + requires_grad=False, + ) + + def get_plain(self): + int_data_expanded = torch.ops.blocksparse.bsr_to_dense(self.crow_indices(), self.col_indices(), self.values(), self.shape[0], self.shape[1]) + return int_data_expanded, self.scale, self.zero_point + + def _apply_fn_to_data(self, func): + return self.__class__( + shape = self.shape, + bsr_crow_indices=func(self.bsr_crow_indices), + bsr_col_indices=func(self.bsr_col_indices), + bsr_values=func(self.bsr_values), + scale=self.scale, + zero_point=self.zero_point, + layout_type=self.layout_type, + requires_grad=self.requires_grad, + ) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + + if func is aten.detach.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + if func is aten.clone.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + + # Need the following for bsr specific functions + if func is aten.crow_indices.default: + return args[0].bsr_crow_indices.detach() + + if func is aten.col_indices.default: + return args[0].bsr_col_indices.detach() + + if func is aten.values.default: + return args[0].bsr_values.detach() + + if func is aten._nnz.default: + return args[0].bsr_values.shape[0] + + raise NotImplementedError( + f"BlockSparseAQTLayout dispatch: attempting to run {func}, this is not supported" + ) @register_layout_cls(MarlinSparseLayoutType) class MarlinSparseAQTLayout(AQTLayout): @@ -1221,6 +1364,43 @@ def _linear_int8_act_int8_weight_semi_structured_sparse_impl(input_tensor, weigh y += bias return y +def _linear_int8_act_int8_weight_block_sparse_check(input_tensor, weight_tensor, bias): + return ( + isinstance(input_tensor, AffineQuantizedTensor) and + _aqt_is_int8_reduced_range(input_tensor) and + isinstance(weight_tensor, AffineQuantizedTensor) and + weight_tensor.is_cuda and + input_tensor.dtype == weight_tensor.dtype and + isinstance(input_tensor.layout_type, PlainLayoutType) and + isinstance(weight_tensor.layout_type, BlockSparseLayoutType) + ) + + +def _linear_int8_act_int8_weight_block_sparse_impl(input_tensor, weight_tensor, bias): + x_vals_int8 = input_tensor.layout_tensor.int_data + x_scales = input_tensor.layout_tensor.scale + w_vals = weight_tensor.layout_tensor + w_scales = weight_tensor.layout_tensor.scale + tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) + tmp_t = tmp.t() + + y = torch.ops.blocksparse.int_addmm(w_vals.crow_indices(), + w_vals.col_indices(), + w_vals.values(), + tmp_t, + w_scales, + x_scales.reshape(-1)) + y_shape = (*x_vals_int8.shape[:-1], w_scales.shape[-1]) + y = y.reshape(*y_shape) + + # can downcast only at the very end + output_dtype = input_tensor.dtype + y = y.to(output_dtype) + if bias is not None: + y += bias + return y + + def _linear_bf16_act_uint4_weight_check(input_tensor, weight_tensor, bias): return ( # input is native bfloat16 tensor @@ -1473,6 +1653,7 @@ def _register_aqt_quantized_linear_dispatches(): for dispatch_condition, impl in [ (_linear_int8_act_int8_weight_check, _linear_int8_act_int8_weight_impl), (_linear_int8_act_int8_weight_semi_structured_sparse_check, _linear_int8_act_int8_weight_semi_structured_sparse_impl), + (_linear_int8_act_int8_weight_block_sparse_check, _linear_int8_act_int8_weight_block_sparse_impl), (_linear_fp8_act_fp8_weight_check, _linear_fp8_act_fp8_weight_impl), (_linear_bf16_act_uint4_weight_check, _linear_bf16_act_uint4_weight_impl), (_linear_fp_act_int8_weight_check, _linear_fp_act_int8_weight_impl), diff --git a/torchao/sparsity/prototype/superblock/README.md b/torchao/sparsity/prototype/superblock/README.md index 54a6964b17..6fea1a0e3a 100644 --- a/torchao/sparsity/prototype/superblock/README.md +++ b/torchao/sparsity/prototype/superblock/README.md @@ -36,76 +36,33 @@ At least one GPU: conda create -n superblock conda activate superblock ``` -* Install PyTorch. For best performance, we recommend `2.3.0.dev20240305+cu121` nightly +* Install PyTorch. For best performance, we recommend the pytorch nightlies ``` - pip install --pre torch==2.3.0.dev20240305+cu121 --index-url https://download.pytorch.org/whl/nightly/cu121 - pip install --pre torchvision==0.18.0 --no-deps + pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121 ``` + We ran our experiments with torch==2.6.0.dev20240924+cu121 -## Benchmarking -Baseline: -``` -python benchmark.py \ - --model vit_h_14 \ - --batch-size 256 \ -``` -Result: -``` -532.1160546875 ms -``` +# Results +### Benchmarking +For all our benchmarking results, you can run `benchmark.sh`. +These benchmarks were run on a NVIDIA-A100-80GB, with cuSPARSELt v0.5.2. -80% sparsity, block size 64 (random weights): -``` -python benchmark.py \ - --model vit_h_14 \ - --batch-size 256 \ - --sparsity-linear 0.8 \ - --sp-linear-tile-size 64 \ - --bsr 64 \ - --sparsity bsr -``` -Result: -``` -393.864453125 ms -``` -Semi-structured sparsity +### Evaluation + +To reproduce our accuracy results, you can run `evaluate.sh` +You will need to set the following environment variables first to run the script: + ``` -python benchmark.py \ - --model vit_h_14 \ - --batch-size 256 \ - --sparsity semi_structured +IMAGENET_PATH= +NGPUS=1 # put number of available GPUS here ``` - ## Training Please refer to [TRAINING.md](TRAINING.md) for training from scratch. We use [Torchvision](https://github.com/pytorch/vision/tree/main/references/classification) as our framework for training. Supermask can be applied during training. -To apply supermask, we have the following arguments at our disposal, - -* Apply Supermask to linear layers: - ``` - --sparsity-linear - --sp-linear-tile-size - ``` -* Apply Supermask to conv1x1 layers: - ``` - --sparsity-conv1x1 - --sp-conv1x1-tile-size - ``` -* Apply Supermask to all other convolutional layers: - ``` - --sparsity-conv - --sp-conv-tile-size - ``` -* Skip the first transformer layer and/or last linear layer (ViT only): - ``` - --skip-last-layer-sparsity - --skip-first-transformer-sparsity - ``` - For example, if you would like to train a `vit_b_16` from scratch using Supermask, you can use the respective torchvision command found in [TRAINING.md](TRAINING.md) and append the supermask arguments: ``` torchrun --nproc_per_node=8 train.py\ @@ -119,59 +76,6 @@ Through this command, we are training a `vit_b_16` with 90% sparsity to linear l Please run `python train.py --help` for a full list of available arguments. -## Evaluation - -To run an evaluation of a Supermask-trained model, you can use [evaluate.py](evaluate.py). Our current version has signficant speedup with float32 only and not float16, hence, to illustrate speedup, we don't pass `--amp` in the example commands below. - -``` -MODEL_PATH= -IMAGENET_PATH= -NGPUS=1 # put number of available GPUS here -``` - -* Offline sparsification with BSR: - ``` - python evaluate.py --model vit_b_16 --batch-size 256 --sparsity-linear 0.9 --sp-linear-tile-size 32 --weights-path ${MODEL_PATH} --data-path ${IMAGENET_PATH} --sparsity bsr --bsr 64 - ``` - This command applies 90% sparsity to linear layers using 32x32 tiles, loads the model weights from ${MODEL_PATH}, loads the ImageNet validation set located at the specified path, applies offline sparsification to the weights, and converts the sparse weights to BSR format with a block size of 32. It is recommended to set `--bsr` the same as tile size. - -* Online sparsification without BSR: - ``` - torchrun --nproc_per_node=${NGPUS} evaluate.py --model vit_b_16 --batch-size 256 --sparsity-linear 0.9 --sp-linear-tile-size 32 --weights-path ${MODEL_PATH} --data-path ${IMAGENET_PATH} - ``` - This is similar to the previous command, but it does not apply offline sparsification or BSR conversion. Instead, the sparsity is applied on-the-fly during evaluation. - -* Semi-structured sparsity - ``` - python evaluate.py --model vit_b_16 --batch-size 256 --data-path $IMAGENET_PATH --weights-path checkpoints/2x4_sparse_ft_1_epoch.pth --sparsity semi_structured --skip-last-layer-sparsity - ``` - -Please run `python evaluate.py --help` for a full list of available arguments. - -Results (1x A100): -* Baseline - ``` - Test: Total time: 0:02:11 - Test: Acc@1 78.392 Acc@5 93.592 - ``` - -* Sparsity= 0.9, Tile Size = 32, Online Sparsification, BSR = None - ``` - Test: Total time: 0:01:52 - Test: Acc@1 76.092 Acc@5 92.656 - ``` - -* Sparsity= 0.9, Tile Size = 32, Offline Sparsification, BSR = None - ``` - Test: Total time: 0:01:54 - Test: Acc@1 76.092 Acc@5 92.656 - ``` - -* Sparsity= 0.9, Tile Size = 32, Offline Sparsification, BSR = 32 - ``` - Test: Total time: 0:01:25 - Test: Acc@1 76.092 Acc@5 92.656 - ``` ## Pretrained Weights @@ -189,43 +93,5 @@ wget https://huggingface.co/facebook/superblock-vit-b-16/resolve/main/checkpoint # For sparsified checkpoints, wget https://huggingface.co/facebook/superblock-vit-b-16/resolve/main/checkpoints/sp${SPARSITY}-ts${BLOCK_SIZE}.pth -P checkpoints/ ``` - -### Benchmark: -``` -python benchmark.py --model vit_b_16 \ - --batch-size 256 \ - --sparsity-linear ${SPARSITY} \ - --sp-linear-tile-size ${BLOCK_SIZE} \ - --sparsity bsr\ - --bsr ${BLOCK_SIZE} \ - --weights-path ./checkpoints/sp${SPARSITY}-ts${BLOCK_SIZE}.pth \ - > /dev/null -``` -Result: -``` -530.342578125 ms -``` - -### Evaluate: -8 x A100 GPUs: -``` -torchrun --nproc_per_node=8 evaluate.py --model vit_b_16 --batch-size 256 --sparsity-linear ${SPARSITY} --sp-linear-tile-size ${BLOCK_SIZE} --bsr ${BLOCK_SIZE} --sparsity bsr --weights-path checkpoints/sp${SPARSITY}-ts${BLOCK_SIZE}.pth --data-path ${IMAGENET_PATH} -``` -Result: -``` -Test: Total time: 0:01:01 -Test: Acc@1 77.644 Acc@5 93.554 -``` - -1 x A100 GPUs: -``` -torchrun --nproc_per_node=1 evaluate.py --model vit_b_16 --batch-size 256 --sparsity-linear ${SPARSITY} --sp-linear-tile-size ${BLOCK_SIZE} --bsr ${BLOCK_SIZE} --sparsity bsr--weights-path checkpoints/sp${SPARSITY}-ts${BLOCK_SIZE}.pth --data-path ${IMAGENET_PATH} -``` -Result: -``` -Test: Total time: 0:01:51 -Test: Acc@1 77.644 Acc@5 93.554 -``` - ## License SuperBlock is released under the [MIT license](https://github.com/pytorch-labs/superblock?tab=MIT-1-ov-file#readme). diff --git a/torchao/sparsity/prototype/superblock/benchmark.py b/torchao/sparsity/prototype/superblock/benchmark.py index 18de0bc2d5..a0fb27022c 100644 --- a/torchao/sparsity/prototype/superblock/benchmark.py +++ b/torchao/sparsity/prototype/superblock/benchmark.py @@ -1,26 +1,25 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. -import os -import time -import sys -import warnings -import hashlib +import torch import torchvision -import presets -import torch -import torch.utils.data -import utils -from torch import nn -from torch.sparse._triton_ops_meta import optimize_bsr_dense_addmm -from torchao.sparsity.prototype.superblock.utils import accelerate_with_sparsity, simulate_sparsity +from torch.sparse._triton_ops_meta import ( + dump as store_tuned_kernel_params, + optimize_bsr_dense_addmm, +) +from torchao.sparsity.prototype.superblock.utils import ( + accelerate_with_sparsity, + get_args_parser, + simulate_sparsity, +) from torchao.utils import benchmark_model, profiler_runner torch.sparse.SparseSemiStructuredTensor._FORCE_CUTLASS = False +torch.backends.mha.set_fastpath_enabled(False) + @torch.inference_mode def main(args): - print(args) device = torch.device(args.device) # We disable the cudnn benchmarking because it can noticeably affect the accuracy @@ -29,90 +28,114 @@ def main(args): num_classes = 1000 dtype = getattr(torch, args.dtype) - print(f"Using dtype: {dtype}") # BSR kernel tuning if args.bsr and args.tune_kernel_params: - print("Tuning kernel params") + kwargs = dict( + dtype=torch.int8 if args.quantization else dtype, + sparsity=args.sparsity_linear, + verbose=True, + # per blocksparse_int_addmm: + alpha=1, + beta=0, + use_left_alpha=True, + use_right_alpha=True, + # force tuning because existing tuning parameters are + # computed for use_left/right_alpha=False, however, it + # turns out that re-tuning for use_left/right_alpha=False + # leads to the same set of tuning parametes: + # force=True + ) if args.model == "vit_b_16": - optimize_bsr_dense_addmm(3072, 768, 50432, args.bsr, args.bsr, dtype=dtype, sparsity=args.sparsity_linear, verbose=True) - optimize_bsr_dense_addmm(768, 3072, 50432, args.bsr, args.bsr, dtype=dtype, sparsity=args.sparsity_linear, verbose=True) + optimize_bsr_dense_addmm(3072, 768, 50432, args.bsr, args.bsr, **kwargs) + optimize_bsr_dense_addmm(768, 3072, 50432, args.bsr, args.bsr, **kwargs) elif args.model == "vit_h_14": - optimize_bsr_dense_addmm(5120, 1280, 65792, args.bsr, args.bsr, dtype=dtype, sparsity=args.sparsity_linear, verbose=True) - optimize_bsr_dense_addmm(1280, 5120, 65792, args.bsr, args.bsr, dtype=dtype, sparsity=args.sparsity_linear, verbose=True) + optimize_bsr_dense_addmm(5120, 1280, 65792, args.bsr, args.bsr, **kwargs) + optimize_bsr_dense_addmm(1280, 5120, 65792, args.bsr, args.bsr, **kwargs) else: - raise NotImplementedError("Tuning kernel params for this model is not supported yet.") - - print("Creating model") - model = torchvision.models.get_model(args.model, weights=args.weights, num_classes=num_classes) - - # Fake sparsity necessary for BSR - simulate_sparsity(model, args) + raise NotImplementedError( + "Tuning kernel params for this model is not supported yet." + ) + # Warning: the following call will overwrite the source code + # of torch.sparse._triton_ops_meta (hence it is commented out + # by default) but when used, it'll enables reusing the tuned + # parameters in subsequent runs of this script: + # store_tuned_kernel_params() + model = torchvision.models.get_model( + args.model, weights=args.weights, num_classes=num_classes + ).eval() + + # Fake sparsity necessary for BSR, since we find based on SuperBlock + sparsifier_or_none = simulate_sparsity(model, args) + if sparsifier_or_none is not None: + sparsifier_or_none.squash_mask() if args.weights_path: try: checkpoint = torch.load(args.weights_path, map_location="cpu") model.load_state_dict(checkpoint["model"]) - print(f"Loaded checkpoint successfully from: {args.weights_path}") except FileNotFoundError: raise FileNotFoundError(f"No checkpoint found at {args.weights_path}.") model.to(device).to(dtype) - # Fake sparsity necessary for BSR + # With quantization, we must use cuSPARSELt to fuse one of the scalar matmuls. + # Otherwise, we observe the CUTLASS kernels to be faster, so we use those instead. accelerate_with_sparsity(model, args) - # compile - model = torch.compile(model, mode='max-autotune', fullgraph=True) + # compile + model = torch.compile(model, mode="max-autotune", fullgraph=True) # define image - image = torch.randn(args.batch_size, 3, args.val_crop_size, args.val_crop_size, dtype=dtype, device=device) + image = torch.randn( + args.batch_size, + 3, + args.val_crop_size, + args.val_crop_size, + dtype=dtype, + device=device, + ) # warmup - benchmark_model(model, 10, args=(image,)) + benchmark_model(model, 10, args=(image,)) if args.profile: - return profiler_runner("test.json.gz", benchmark_model, model, 10, (image,)) + return profiler_runner("test.json.gz", benchmark_model, model, 10, (image,)) else: - return benchmark_model(model, 100, args=(image,)) - - - -def get_args_parser(add_help=True): - import argparse - - parser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help) - parser.add_argument("--model", default="resnet18", type=str, help="model name") - parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)") - parser.add_argument( - "-b", "--batch-size", default=32, type=int, help="images per gpu, the total batch size is $NGPU x batch_size" - ) - parser.add_argument( - "--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)" - ) - parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") - parser.add_argument("--weights-path", type=str, help="path of pretrained weights to load") - # NOTE: sparsity args - parser.add_argument("--sparsity", choices=["bsr", "semi_structured"], default=None, help='weight sparsification to apply') - parser.add_argument("--sparsity-linear", type=float, default=0.0) - parser.add_argument("--sp-linear-tile-size", type=int, default=1) - parser.add_argument("--sparsity-conv1x1", type=float, default=0.0) - parser.add_argument("--sp-conv1x1-tile-size", type=int, default=1) - parser.add_argument("--sparsity-conv", type=float, default=0.0) - parser.add_argument("--sp-conv-tile-size", type=int, default=1) - parser.add_argument("--skip-last-layer-sparsity", action="store_true", help="Skip applying sparsity to the last linear layer (for vit only)") - parser.add_argument("--skip-first-transformer-sparsity", action="store_true", help="Skip applying sparsity to the first transformer layer (for vit only)") - parser.add_argument('--bsr', type=int, nargs='?', const=256, default=None, help='Convert sparsified weights to BSR format with optional block size (default: 256)') - parser.add_argument("--dtype", choices=["float32", "bfloat16", "float16"], help="data type", default="bfloat16") - parser.add_argument("--float16", action="store_true", help="Use float16") - parser.add_argument("--tune-kernel-params", action="store_true", help="Tune kernel params") - parser.add_argument("--profile", action="store_true", help="Profile the run and dump Prefetto trace") - parser.add_argument("--quantization", action="store_true", help="Profile the run and dump Prefetto trace") - - return parser + return benchmark_model(model, 100, args=(image,)) if __name__ == "__main__": - args = get_args_parser().parse_args() + args = get_args_parser(benchmark=True).parse_args() result = main(args) - print(f"{result:.3f} ms", file=sys.stderr) - print(f"{1000/result:.3f} img/s") + header = [ + "model", + "batch_size", + "dtype", + "sparsity", + "bsr", + "sparsity_level", + "quantization", + "tune_kernel_params", + "latency", + "img/s", + ] + result_string = ",".join( + str(_) + for _ in [ + args.model, + args.batch_size, + args.dtype, + args.sparsity, + args.bsr, + args.sparsity_linear, + args.quantization, + args.tune_kernel_params, + result, + 1000 / result, + ] + ) + with open("benchmark_results.txt", "a") as f: + if args.header: + f.write(",".join(header) + "\n") + f.write(result_string + "\n") + print(result_string) diff --git a/torchao/sparsity/prototype/superblock/benchmark.sh b/torchao/sparsity/prototype/superblock/benchmark.sh new file mode 100644 index 0000000000..3fc2a9869b --- /dev/null +++ b/torchao/sparsity/prototype/superblock/benchmark.sh @@ -0,0 +1,39 @@ +MODEL=vit_h_14 +BATCH_SIZE=256 + +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --header +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --quantization + +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity semi_structured +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.80 --bsr 64 --sparsity bsr +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.84 --bsr 64 --sparsity bsr +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.90 --bsr 64 --sparsity bsr + +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity semi_structured --quantization +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.80 --bsr 64 --sparsity bsr --quantization +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.84 --bsr 64 --sparsity bsr --quantization +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.90 --bsr 64 --sparsity bsr --quantization + +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.80 --bsr 64 --sparsity bsr --quantization --tune-kernel-params +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.84 --bsr 64 --sparsity bsr --quantization --tune-kernel-params +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.90 --bsr 64 --sparsity bsr --quantization --tune-kernel-params + +MODEL=vit_b_16 +BATCH_SIZE=256 + +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --header +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --quantization + +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity semi_structured +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.80 --bsr 64 --sparsity bsr +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.84 --bsr 64 --sparsity bsr +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.90 --bsr 64 --sparsity bsr + +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity semi_structured --quantization +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.80 --bsr 64 --sparsity bsr --quantization +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.84 --bsr 64 --sparsity bsr --quantization +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.90 --bsr 64 --sparsity bsr --quantization + +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.80 --bsr 64 --sparsity bsr --quantization --tune-kernel-params +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.84 --bsr 64 --sparsity bsr --quantization --tune-kernel-params +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.90 --bsr 64 --sparsity bsr --quantization --tune-kernel-params diff --git a/torchao/sparsity/prototype/superblock/benchmark_results.txt b/torchao/sparsity/prototype/superblock/benchmark_results.txt new file mode 100644 index 0000000000..3e18d9faec --- /dev/null +++ b/torchao/sparsity/prototype/superblock/benchmark_results.txt @@ -0,0 +1,30 @@ +model,batch_size,dtype,sparsity,bsr,sparsity_level,quantization,tune_kernel_params,latency,img/s +vit_h_14,256,bfloat16,None,None,0.0,False,False,489.645859375,2.0422923646825746 +vit_h_14,256,bfloat16,None,None,0.0,True,False,454.5648828125,2.1999059712064963 +vit_h_14,256,bfloat16,semi_structured,None,0.0,False,False,458.638046875,2.180368608347371 +vit_h_14,256,bfloat16,bsr,64,0.8,False,False,361.5827734375,2.765618479257699 +vit_h_14,256,bfloat16,bsr,64,0.84,False,False,343.1771484375,2.9139469354327407 +vit_h_14,256,bfloat16,bsr,64,0.9,False,False,315.37119140625,3.170866671559215 +vit_h_14,256,bfloat16,semi_structured,None,0.0,True,False,438.1652734375,2.2822438486619143 +vit_h_14,256,bfloat16,bsr,64,0.8,True,False,439.5409765625,2.2751007376392045 +vit_h_14,256,bfloat16,bsr,64,0.84,True,False,416.799375,2.3992358433838823 +vit_h_14,256,bfloat16,bsr,64,0.9,True,False,381.9370703125,2.6182323679181034 +vit_h_14,256,bfloat16,bsr,64,0.8,True,True,439.1569921875,2.277090010610706 +vit_h_14,256,bfloat16,bsr,64,0.84,True,True,416.18,2.4028064779662643 +vit_h_14,256,bfloat16,bsr,64,0.9,True,True,384.2584765625,2.6024149394069362 + +model,batch_size,dtype,sparsity,bsr,sparsity_level,quantization,tune_kernel_params,latency,img/s +vit_b_16,256,bfloat16,None,None,0.0,False,False,61.407705078125,16.284601398599175 +vit_b_16,256,bfloat16,None,None,0.0,True,False,60.934091796875,16.41117427881784 +vit_b_16,256,bfloat16,semi_structured,None,0.0,False,False,59.9600732421875,16.677764817945665 +vit_b_16,256,bfloat16,bsr,64,0.8,False,False,47.6238916015625,20.997864020990484 +vit_b_16,256,bfloat16,bsr,64,0.84,False,False,45.7176416015625,21.873394273378768 +vit_b_16,256,bfloat16,bsr,64,0.9,False,False,42.708759765625,23.414400359264707 +vit_b_16,256,bfloat16,semi_structured,None,0.0,True,False,58.783828125,17.011481420937148 +vit_b_16,256,bfloat16,bsr,64,0.8,True,False,58.1029541015625,17.210828872005806 +vit_b_16,256,bfloat16,bsr,64,0.84,True,False,55.8751025390625,17.89705887878946 +vit_b_16,256,bfloat16,bsr,64,0.9,True,False,52.3257763671875,19.111039900921202 +vit_b_16,256,bfloat16,bsr,64,0.8,True,True,58.649736328125,17.050375033322325 +vit_b_16,256,bfloat16,bsr,64,0.84,True,True,56.46744140625,17.709320186930174 +vit_b_16,256,bfloat16,bsr,64,0.9,True,True,52.528623046875,19.037239927413086 +vit_b_16,256,bfloat16,bsr,64,0.8,True,False,57.6839794921875,17.335835856044508 diff --git a/torchao/sparsity/prototype/superblock/blocksparse.py b/torchao/sparsity/prototype/superblock/blocksparse.py index 06b3548c55..69c98f6afc 100644 --- a/torchao/sparsity/prototype/superblock/blocksparse.py +++ b/torchao/sparsity/prototype/superblock/blocksparse.py @@ -1,24 +1,114 @@ from functools import partial +from typing import Any, Callable, Dict, List, Optional, Tuple import torch -from typing import Optional, Tuple, List, Dict, Any, Callable +from torch.sparse._triton_ops import broadcast_batch_dims, bsr_dense_addmm, bsr_dense_mm from torch.utils._python_dispatch import return_and_correct_aliasing -from torchao.utils import TorchAOBaseTensor from torchao.quantization.quant_api import _get_linear_subclass_inserter +from torchao.utils import TorchAOBaseTensor aten = torch.ops.aten + +# quantization support +@torch.library.custom_op("blocksparse::bsr_to_dense", mutates_args=()) +def bsr_to_dense( + crow_indices: torch.Tensor, + col_indices: torch.Tensor, + values: torch.Tensor, + M: int, + K: int, +) -> torch.Tensor: + return torch.sparse_bsr_tensor( + crow_indices=crow_indices, col_indices=col_indices, values=values, size=(M, K) + ).to_dense() + + +@torch.library.register_fake("blocksparse::bsr_to_dense") +def bsr_to_dense_abstract( + crow_indices: torch.Tensor, + col_indices: torch.Tensor, + values: torch.Tensor, + M: int, + K: int, +) -> torch.Tensor: + return torch.empty((M, K), dtype=values.dtype, device=values.device) + + +@torch.library.custom_op("blocksparse::int_addmm", mutates_args=()) +def blocksparse_int_addmm( + crow_indices: torch.Tensor, + col_indices: torch.Tensor, + values: torch.Tensor, + A: torch.Tensor, + left_alpha: torch.Tensor, + right_alpha: torch.Tensor, +) -> torch.Tensor: + assert values.dtype == torch.int8 + M = left_alpha.shape[-1] + K = A.shape[-2] + N = A.shape[-1] + weight_bsr = torch.sparse_bsr_tensor(crow_indices, col_indices, values, size=(M, K)) + original_batch_dims_broadcasted = broadcast_batch_dims( + blocksparse_int_addmm, weight_bsr, A + ) + out = A.new_empty(original_batch_dims_broadcasted + (M, N), dtype=torch.bfloat16) + return bsr_dense_addmm( + out, + weight_bsr, + A, + alpha=1, + beta=0, + out=out, + left_alpha=left_alpha, + right_alpha=right_alpha, + ).t() + + +@torch.library.register_fake("blocksparse::int_addmm") +def blocksparse_int_addmm_abstract( + crow_indices: torch.Tensor, + col_indices: torch.Tensor, + values: torch.Tensor, + A: torch.Tensor, + left_alpha: torch.Tensor, + right_alpha: torch.Tensor, +) -> torch.Tensor: + N = A.shape[-1] + M = left_alpha.shape[-1] + # to have the same strides as the transposed result + return torch.empty((M, N), dtype=torch.bfloat16, device=A.device).t() + + # bsr wrapper custom op @torch.library.custom_op("blocksparse::linear", mutates_args=()) -def blocksparse_linear(A: torch.Tensor, crow_indices: torch.Tensor, col_indices: torch.Tensor, values: torch.Tensor, M: int, K: int, bias: torch.Tensor) -> torch.Tensor: +def blocksparse_linear( + A: torch.Tensor, + crow_indices: torch.Tensor, + col_indices: torch.Tensor, + values: torch.Tensor, + M: int, + K: int, + bias: torch.Tensor, +) -> torch.Tensor: weight_bsr = torch.sparse_bsr_tensor(crow_indices, col_indices, values, size=(M, K)) return torch.nn.functional.linear(A, weight_bsr, bias) + @torch.library.register_fake("blocksparse::linear") -def blocksparse_linear_abstract(A: torch.Tensor, crow_indices: torch.Tensor, col_indices: torch.Tensor, values: torch.Tensor, M: int, K:int , bias: torch.Tensor) -> torch.Tensor: - new_shape = A.shape[:-1] + (bias.shape[0],) +def blocksparse_linear_abstract( + A: torch.Tensor, + crow_indices: torch.Tensor, + col_indices: torch.Tensor, + values: torch.Tensor, + M: int, + K: int, + bias: torch.Tensor, +) -> torch.Tensor: + new_shape = A.shape[:-1] + (M,) return torch.empty(new_shape, dtype=A.dtype, device=A.device) + # Subclass definition class BlockSparseTensor(TorchAOBaseTensor): bsr_crow_indices: Optional[torch.Tensor] @@ -37,7 +127,9 @@ def __new__( # noqa: PYI034 requires_grad: bool = False, ): if bsr_values is None: - raise ValueError("bsr values must be provided!") + raise ValueError( + "No values passed to BlockSparseTensor: bsr_values must be provided!" + ) else: previous_tensor = bsr_values @@ -72,7 +164,7 @@ def __tensor_unflatten__( outer_size, outer_stride, ) -> torch.Tensor: - shape, requires_grad = tensor_meta + shape, requires_grad = tensor_meta return cls( shape=shape, bsr_crow_indices=inner_tensors.get("bsr_crow_indices", None), @@ -94,44 +186,54 @@ def from_dense(cls, dense_tensor, blocksize): def apply_fn_to_shard(self, func): return BlockSparseTensor( - shape = self.shape, + shape=self.shape, bsr_crow_indices=func(self.bsr_crow_indices), bsr_col_indices=func(self.bsr_col_indices), bsr_values=func(self.bsr_values), requires_grad=self.requires_grad, ) + # Subclass op dispatch registration implements = BlockSparseTensor.implements + @implements(aten.detach.default) def block_sparse_detach(func, types, args, kwargs): - return return_and_correct_aliasing(func, args, kwargs, args[0].apply_fn_to_shard(torch.detach)) + return return_and_correct_aliasing( + func, args, kwargs, args[0].apply_fn_to_shard(torch.detach) + ) + @implements(aten.values.default) def block_sparse_values(func, types, args, kwargs): return args[0].bsr_values.detach() + @implements(aten.crow_indices.default) def block_sparse_crow_indices(func, types, args, kwargs): return args[0].bsr_crow_indices.detach() + @implements(aten.col_indices.default) def block_sparse_col_indices(func, types, args, kwargs): return args[0].bsr_col_indices.detach() + @implements(aten._nnz.default) def block_sparse__nnz(func, types, args, kwargs): return args[0].bsr_values.shape[0] + @implements(torch.nn.functional.linear) def block_sparse_linear(func, types, args, kwargs): x, w, bias = args - return torch.ops.blocksparse.linear(x, - w.crow_indices(), - w.col_indices(), - w.values(), - w.shape[0], w.shape[1], bias) + return torch.ops.blocksparse.linear( + x, w.crow_indices(), w.col_indices(), w.values(), w.shape[0], w.shape[1], bias + ) + def block_sparse_weight(blocksize=64): - return _get_linear_subclass_inserter(partial(BlockSparseTensor.from_dense, blocksize=blocksize)) + return _get_linear_subclass_inserter( + partial(BlockSparseTensor.from_dense, blocksize=blocksize) + ) diff --git a/torchao/sparsity/prototype/superblock/evaluate.py b/torchao/sparsity/prototype/superblock/evaluate.py index 09f34ebb64..5db9fc9e38 100644 --- a/torchao/sparsity/prototype/superblock/evaluate.py +++ b/torchao/sparsity/prototype/superblock/evaluate.py @@ -1,29 +1,23 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. - import os -import sys -import warnings -import hashlib -from functools import partial - -import presets import torch -import torch.utils.data import torchvision -import utils -from torch import nn -from torchvision.transforms.functional import InterpolationMode -from torchao.sparsity import sparsify_, semi_sparse_weight -from torchao.sparsity.prototype.superblock.supermask import apply_supermask -from torchao.sparsity.prototype.superblock.utils import apply_sparsity, verify_sparsity, mlp_only_with_args, simulate_sparsity, accelerate_with_sparsity -from torchao.sparsity.prototype.superblock.train import evaluate, _get_cache_path, load_data -from torchao.sparsity.prototype.sparsifier.weight_norm_sparsifier import WeightNormSparsifier +from torchao.sparsity.prototype.superblock.train import evaluate, load_data +from torchao.sparsity.prototype.superblock.utils import ( + accelerate_with_sparsity, + apply_sparsity, + get_args_parser, + init_distributed_mode, + simulate_sparsity, +) torch.sparse.SparseSemiStructuredTensor._FORCE_CUTLASS = False +torch.backends.mha.set_fastpath_enabled(False) + def main(args): - utils.init_distributed_mode(args) + init_distributed_mode(args) print(args) device = torch.device(args.device) @@ -35,13 +29,20 @@ def main(args): val_dir = os.path.join(args.data_path, "val") dataset_test, test_sampler = load_data(None, val_dir, args) data_loader_test = torch.utils.data.DataLoader( - dataset_test, batch_size=args.batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True + dataset_test, + batch_size=args.batch_size, + sampler=test_sampler, + num_workers=args.workers, + pin_memory=True, + drop_last=True, ) num_classes = len(dataset_test.classes) # Create Model print("Creating model") - model = torchvision.models.get_model(args.model, weights=args.weights, num_classes=num_classes) + model = torchvision.models.get_model( + args.model, weights=args.weights, num_classes=num_classes + ) sparsifier_or_none = simulate_sparsity(model, args) @@ -58,62 +59,44 @@ def main(args): if sparsifier_or_none is not None: sparsifier_or_none.squash_mask() accelerate_with_sparsity(model, args) - - criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) - evaluate(model, criterion, data_loader_test, device=device, dtype=torch.bfloat16) - - -def get_args_parser(add_help=True): - import argparse - - parser = argparse.ArgumentParser(description="Superblock evaluation", add_help=add_help) - parser.add_argument("--data-path", default="/datasets01/imagenet_full_size/061417", type=str, help="dataset path") - parser.add_argument("--model", default="vit-", type=str, help="model name") - parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)") - parser.add_argument( - "-b", "--batch-size", default=256, type=int, help="images per gpu, the total batch size is $NGPU x batch_size" - ) - parser.add_argument( - "-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)" - ) - parser.add_argument( - "--label-smoothing", default=0.0, type=float, help="label smoothing (default: 0.0)", dest="label_smoothing" - ) - parser.add_argument("--print-freq", default=10, type=int, help="print frequency") - parser.add_argument( - "--cache-dataset", - dest="cache_dataset", - help="Cache the datasets for quicker initialization. It also serializes the transforms", - action="store_true", - ) - parser.add_argument( - "--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)" - ) - parser.add_argument( - "--val-resize-size", default=256, type=int, help="the resize size used for validation (default: 256)" - ) - parser.add_argument( - "--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)" - ) - parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") - parser.add_argument("--weights-path", type=str, help="path of pretrained weights to load") + model = torch.compile(model, mode="max-autotune", fullgraph=True) - # NOTE: sparsity args - parser.add_argument("--sparsity", choices=["bsr", "semi_structured"], default=None, help='weight sparsification to apply') - parser.add_argument("--sparsity-linear", type=float, default=0.0) - parser.add_argument("--sp-linear-tile-size", type=int, default=1) - parser.add_argument("--sparsity-conv1x1", type=float, default=0.0) - parser.add_argument("--sp-conv1x1-tile-size", type=int, default=1) - parser.add_argument("--sparsity-conv", type=float, default=0.0) - parser.add_argument("--sp-conv-tile-size", type=int, default=1) - parser.add_argument("--skip-last-layer-sparsity", action="store_true", help="Skip applying sparsity to the last linear layer (for vit only)") - parser.add_argument("--skip-first-transformer-sparsity", action="store_true", help="Skip applying sparsity to the first transformer layer (for vit only)") - parser.add_argument('--bsr', type=int, nargs='?', default=64, help='Convert sparsified weights to BSR format with optional block size (default: 64)') - parser.add_argument('--meta', action='store_true', help='Use Meta internal imagenet structure') - - return parser + criterion = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) + return evaluate(model, criterion, data_loader_test, device=device, dtype=torch.bfloat16) if __name__ == "__main__": - args = get_args_parser().parse_args() - main(args) + args = get_args_parser(evaluate=True).parse_args() + accuracy, throughput, max_mem = main(args) + header = [ + "model", + "batch_size", + "dtype", + "sparsity", + "bsr", + "sparsity_level", + "quantization", + "top-1_acc", + "encoder img/s", + "max_mem (MB)", + ] + result_string = ",".join( + str(_) + for _ in [ + args.model, + args.batch_size, + "bfloat16", + args.sparsity, + args.bsr, + args.sparsity_linear, + args.quantization, + accuracy, + throughput, + max_mem + ] + ) + with open("evaluation_results.txt", "a") as f: + if args.header: + f.write(",".join(header) + "\n") + f.write(result_string + "\n") + print(result_string) diff --git a/torchao/sparsity/prototype/superblock/evaluate.sh b/torchao/sparsity/prototype/superblock/evaluate.sh new file mode 100644 index 0000000000..68be5175fd --- /dev/null +++ b/torchao/sparsity/prototype/superblock/evaluate.sh @@ -0,0 +1,23 @@ +MODEL=vit_b_16 +BATCH_SIZE=256 + +python evaluate.py --model $MODEL --batch-size $BATCH_SIZE --data-path $IMAGENET_PATH --weights ViT_B_16_Weights.IMAGENET1K_V1 --header +python evaluate.py --model $MODEL --batch-size $BATCH_SIZE --data-path $IMAGENET_PATH --weights ViT_B_16_Weights.IMAGENET1K_V1 --quantization +python evaluate.py --model $MODEL --batch-size $BATCH_SIZE --data-path $IMAGENET_PATH --weights ViT_B_16_Weights.IMAGENET1K_V1 --sparsity semi_structured +python evaluate.py --model $MODEL --batch-size $BATCH_SIZE --data-path $IMAGENET_PATH --weights ViT_B_16_Weights.IMAGENET1K_V1 --sparsity semi_structured --quantization +python evaluate.py --model $MODEL --batch-size $BATCH_SIZE --data-path $IMAGENET_PATH --sparsity bsr --sparsity-linear 0.80 --bsr 64 --weights-path checkpoints/$MODEL/sp0.80-ts64.pth +python evaluate.py --model $MODEL --batch-size $BATCH_SIZE --data-path $IMAGENET_PATH --sparsity bsr --sparsity-linear 0.80 --bsr 64 --weights-path checkpoints/$MODEL/sp0.80-ts64.pth --quantization +python evaluate.py --model $MODEL --batch-size $BATCH_SIZE --data-path $IMAGENET_PATH --sparsity bsr --sparsity-linear 0.84 --bsr 64 --weights-path checkpoints/$MODEL/sp0.84-ts64.pth +python evaluate.py --model $MODEL --batch-size $BATCH_SIZE --data-path $IMAGENET_PATH --sparsity bsr --sparsity-linear 0.84 --bsr 64 --weights-path checkpoints/$MODEL/sp0.84-ts64.pth --quantization +python evaluate.py --model $MODEL --batch-size $BATCH_SIZE --data-path $IMAGENET_PATH --sparsity bsr --sparsity-linear 0.90 --bsr 64 --weights-path checkpoints/$MODEL/sp0.90-ts64.pth +python evaluate.py --model $MODEL --batch-size $BATCH_SIZE --data-path $IMAGENET_PATH --sparsity bsr --sparsity-linear 0.90 --bsr 64 --weights-path checkpoints/$MODEL/sp0.90-ts64.pth --quantization + +MODEL=vit_h_14 +BATCH_SIZE=128 + +python evaluate.py --model $MODEL --batch-size $BATCH_SIZE --data-path $IMAGENET_PATH --weights ViT_H_14_Weights.IMAGENET1K_SWAG_E2E_V1 --header +python evaluate.py --model $MODEL --batch-size $BATCH_SIZE --data-path $IMAGENET_PATH --weights ViT_H_14_Weights.IMAGENET1K_SWAG_E2E_V1 --quantization +python evaluate.py --model $MODEL --batch-size $BATCH_SIZE --data-path $IMAGENET_PATH --weights ViT_H_14_Weights.IMAGENET1K_SWAG_E2E_V1 --sparsity semi_structured +python evaluate.py --model $MODEL --batch-size $BATCH_SIZE --data-path $IMAGENET_PATH --weights ViT_H_14_Weights.IMAGENET1K_SWAG_E2E_V1 --sparsity semi_structured --quantization +python evaluate.py --model $MODEL --batch-size $BATCH_SIZE --data-path $IMAGENET_PATH --sparsity bsr --sparsity-linear 0.90 --bsr 64 --weights-path checkpoints/$MODEL/sp0.90-ts64.pth +python evaluate.py --model $MODEL --batch-size $BATCH_SIZE --data-path $IMAGENET_PATH --sparsity bsr --sparsity-linear 0.90 --bsr 64 --weights-path checkpoints/$MODEL/sp0.90-ts64.pth --quantization diff --git a/torchao/sparsity/prototype/superblock/evaluation_results.txt b/torchao/sparsity/prototype/superblock/evaluation_results.txt new file mode 100644 index 0000000000..58dcade663 --- /dev/null +++ b/torchao/sparsity/prototype/superblock/evaluation_results.txt @@ -0,0 +1,19 @@ +model,batch_size,dtype,sparsity,bsr,sparsity_level,quantization,top-1_acc,encoder img/s,max_mem (MB) +vit_b_16,256,bfloat16,None,None,0.0,False,81.97716346153847,734.904399886552,247.97265625 +vit_b_16,256,bfloat16,None,None,0.0,True,81.89503205128206,230.83627917226997,196.841796875 +vit_b_16,256,bfloat16,semi_structured,None,0.0,False,77.05729166666667,1386.7278781133518,316.40234375 +vit_b_16,256,bfloat16,semi_structured,None,0.0,True,76.74078525641026,150.53603093207843,249.25390625 +vit_b_16,256,bfloat16,bsr,64,0.8,False,77.13541666666667,1469.2705176409308,179.55322265625 +vit_b_16,256,bfloat16,bsr,64,0.8,True,77.13341346153847,87.8480561274922,158.70361328125 +vit_b_16,256,bfloat16,bsr,64,0.84,False,76.14983974358974,1752.835540513905,174.01953125 +vit_b_16,256,bfloat16,bsr,64,0.84,True,76.0556891025641,1013.7495284783578,156.630859375 +vit_b_16,256,bfloat16,bsr,64,0.9,False,62.99879807692308,1702.289195236525,164.2822265625 +vit_b_16,256,bfloat16,bsr,64,0.9,True,62.946714743589745,987.5488468441617,152.5732421875 + +model,batch_size,dtype,sparsity,bsr,sparsity_level,quantization,top-1_acc,encoder img/s,max_mem (MB) +vit_h_14,128,bfloat16,None,None,0.0,False,89.29286858974359,81.02922135697278,1430.05615234375 +vit_h_14,128,bfloat16,None,None,0.0,True,89.3349358974359,56.076129157634355,1025.00927734375 +vit_h_14,128,bfloat16,semi_structured,None,0.0,False,82.03725961538461,75.83586253901329,1900.36279296875 +vit_h_14,128,bfloat16,semi_structured,None,0.0,True,82.06330128205128,36.36097831133589,1390.98779296875 +vit_h_14,128,bfloat16,bsr,64,0.9,False,78.21113782051282,350.91330496491446,599.6201171875 +vit_h_14,128,bfloat16,bsr,64,0.9,True,78.2051282051282,108.84048044884008,531.5810546875 diff --git a/torchao/sparsity/prototype/superblock/presets.py b/torchao/sparsity/prototype/superblock/presets.py deleted file mode 100644 index c5a242c549..0000000000 --- a/torchao/sparsity/prototype/superblock/presets.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. - -import torch -from torchvision.transforms import autoaugment, transforms -from torchvision.transforms.functional import InterpolationMode - - -class ClassificationPresetTrain: - def __init__( - self, - *, - crop_size, - mean=(0.485, 0.456, 0.406), - std=(0.229, 0.224, 0.225), - interpolation=InterpolationMode.BILINEAR, - hflip_prob=0.5, - auto_augment_policy=None, - ra_magnitude=9, - augmix_severity=3, - random_erase_prob=0.0, - ): - trans = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)] - if hflip_prob > 0: - trans.append(transforms.RandomHorizontalFlip(hflip_prob)) - if auto_augment_policy is not None: - if auto_augment_policy == "ra": - trans.append(autoaugment.RandAugment(interpolation=interpolation, magnitude=ra_magnitude)) - elif auto_augment_policy == "ta_wide": - trans.append(autoaugment.TrivialAugmentWide(interpolation=interpolation)) - elif auto_augment_policy == "augmix": - trans.append(autoaugment.AugMix(interpolation=interpolation, severity=augmix_severity)) - else: - aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy) - trans.append(autoaugment.AutoAugment(policy=aa_policy, interpolation=interpolation)) - trans.extend( - [ - transforms.PILToTensor(), - transforms.ConvertImageDtype(torch.float), - transforms.Normalize(mean=mean, std=std), - ] - ) - if random_erase_prob > 0: - trans.append(transforms.RandomErasing(p=random_erase_prob)) - - self.transforms = transforms.Compose(trans) - - def __call__(self, img): - return self.transforms(img) - - -class ClassificationPresetEval: - def __init__( - self, - *, - crop_size, - resize_size=256, - mean=(0.485, 0.456, 0.406), - std=(0.229, 0.224, 0.225), - interpolation=InterpolationMode.BILINEAR, - ): - - self.transforms = transforms.Compose( - [ - transforms.Resize(resize_size, interpolation=interpolation), - transforms.CenterCrop(crop_size), - transforms.PILToTensor(), - transforms.ConvertImageDtype(torch.float), - transforms.Normalize(mean=mean, std=std), - ] - ) - - def __call__(self, img): - return self.transforms(img) diff --git a/torchao/sparsity/prototype/superblock/sampler.py b/torchao/sparsity/prototype/superblock/sampler.py deleted file mode 100644 index bf36a17954..0000000000 --- a/torchao/sparsity/prototype/superblock/sampler.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. - -import math - -import torch -import torch.distributed as dist - - -class RASampler(torch.utils.data.Sampler): - """Sampler that restricts data loading to a subset of the dataset for distributed, - with repeated augmentation. - It ensures that different each augmented version of a sample will be visible to a - different process (GPU). - Heavily based on 'torch.utils.data.DistributedSampler'. - - This is borrowed from the DeiT Repo: - https://github.com/facebookresearch/deit/blob/main/samplers.py - """ - - def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0, repetitions=3): - if num_replicas is None: - if not dist.is_available(): - raise RuntimeError("Requires distributed package to be available!") - num_replicas = dist.get_world_size() - if rank is None: - if not dist.is_available(): - raise RuntimeError("Requires distributed package to be available!") - rank = dist.get_rank() - self.dataset = dataset - self.num_replicas = num_replicas - self.rank = rank - self.epoch = 0 - self.num_samples = int(math.ceil(len(self.dataset) * float(repetitions) / self.num_replicas)) - self.total_size = self.num_samples * self.num_replicas - self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas)) - self.shuffle = shuffle - self.seed = seed - self.repetitions = repetitions - - def __iter__(self): - if self.shuffle: - # Deterministically shuffle based on epoch - g = torch.Generator() - g.manual_seed(self.seed + self.epoch) - indices = torch.randperm(len(self.dataset), generator=g).tolist() - else: - indices = list(range(len(self.dataset))) - - # Add extra samples to make it evenly divisible - indices = [ele for ele in indices for i in range(self.repetitions)] - indices += indices[: (self.total_size - len(indices))] - assert len(indices) == self.total_size - - # Subsample - indices = indices[self.rank : self.total_size : self.num_replicas] - assert len(indices) == self.num_samples - - return iter(indices[: self.num_selected_samples]) - - def __len__(self): - return self.num_selected_samples - - def set_epoch(self, epoch): - self.epoch = epoch diff --git a/torchao/sparsity/prototype/superblock/train.py b/torchao/sparsity/prototype/superblock/train.py index 7fd1ce4d20..acfed09bc6 100644 --- a/torchao/sparsity/prototype/superblock/train.py +++ b/torchao/sparsity/prototype/superblock/train.py @@ -1,26 +1,35 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. import datetime -import os import glob +import os import sys import time import warnings -import presets import torch import torch.utils.data import torchvision -import transforms import utils -from sampler import RASampler from torch import nn from torch.utils.data.dataloader import default_collate -from torchvision.transforms.functional import InterpolationMode from torchao.sparsity.prototype.superblock.utils import simulate_sparsity - -def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None): +from torchvision.transforms.functional import InterpolationMode +from utils import RASampler + + +def train_one_epoch( + model, + criterion, + optimizer, + data_loader, + device, + epoch, + args, + model_ema=None, + scaler=None, +): model.train() metric_logger = utils.MetricLogger(delimiter=" ") metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}")) @@ -29,7 +38,9 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, arg header = f"Epoch: [{epoch}]" accumulation_counter = 0 # Counter for tracking accumulated gradients - for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)): + for i, (image, target) in enumerate( + metric_logger.log_every(data_loader, args.print_freq, header) + ): start_time = time.time() image, target = image.to(device), target.to(device) @@ -65,23 +76,48 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, arg acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) batch_size = image.shape[0] - metric_logger.update(loss=loss.item() * args.accumulation_steps, lr=optimizer.param_groups[0]["lr"]) # Scale back up for logging + metric_logger.update( + loss=loss.item() * args.accumulation_steps, + lr=optimizer.param_groups[0]["lr"], + ) # Scale back up for logging metric_logger.meters["acc1"].update(acc1.item(), n=batch_size) metric_logger.meters["acc5"].update(acc5.item(), n=batch_size) metric_logger.meters["img/s"].update(batch_size / (time.time() - start_time)) - -def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix="", dtype=torch.float32): + + +def evaluate( + model, + criterion, + data_loader, + device, + print_freq=100, + log_suffix="", + dtype=torch.float32, +): model.eval() metric_logger = utils.MetricLogger(delimiter=" ") header = f"Test: {log_suffix}" - + encoder_time = 0 num_processed_samples = 0 with torch.inference_mode(): for image, target in metric_logger.log_every(data_loader, print_freq, header): image = image.to(device, non_blocking=True).to(dtype) target = target.to(device, non_blocking=True).to(dtype) + # intialize encoder measurements + torch.cuda.reset_max_memory_allocated() + torch.cuda.synchronize() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + + # run encoder output = model(image) - # loss = criterion(output, target) + + # measure time in encoder + end_event.record() + torch.cuda.synchronize() + encoder_time += start_event.elapsed_time(end_event) + max_mem = torch.cuda.max_memory_allocated() / (1024**2) acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) # FIXME need to take into account that the datasets @@ -90,6 +126,7 @@ def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=" # metric_logger.update(loss=loss.item()) metric_logger.meters["acc1"].update(acc1.item(), n=batch_size) metric_logger.meters["acc5"].update(acc5.item(), n=batch_size) + metric_logger.meters["batch_time"].update(encoder_time, n=batch_size) num_processed_samples += batch_size # gather the stats from all processes @@ -97,7 +134,6 @@ def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=" if ( hasattr(data_loader.dataset, "__len__") and len(data_loader.dataset) != num_processed_samples - and torch.distributed.get_rank() == 0 ): # See FIXME above warnings.warn( @@ -109,21 +145,31 @@ def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=" metric_logger.synchronize_between_processes() - print(f"{header} Acc@1 {metric_logger.acc1.global_avg:.3f} Acc@5 {metric_logger.acc5.global_avg:.3f}") - return metric_logger.acc1.global_avg + print( + f"{header} Acc@1 {metric_logger.acc1.global_avg:.3f} Acc@5 {metric_logger.acc5.global_avg:.3f}" + ) + total_time = encoder_time / 1000.0 + return metric_logger.acc1.global_avg, num_processed_samples.item() / total_time, max_mem + def _get_cache_path(filepath): import hashlib h = hashlib.sha1(filepath.encode()).hexdigest() - cache_path = os.path.join("~", ".torch", "vision", "datasets", "imagefolder", h[:10] + ".pt") + cache_path = os.path.join( + "~", ".torch", "vision", "datasets", "imagefolder", h[:10] + ".pt" + ) cache_path = os.path.expanduser(cache_path) return cache_path + def load_data(traindir, valdir, args): # Data loading code print("Loading data") - val_resize_size, val_crop_size, = ( + ( + val_resize_size, + val_crop_size, + ) = ( args.val_resize_size, args.val_crop_size, ) @@ -142,7 +188,7 @@ def load_data(traindir, valdir, args): random_erase_prob = getattr(args, "random_erase", 0.0) ra_magnitude = args.ra_magnitude augmix_severity = args.augmix_severity - preprocessing = presets.ClassificationPresetTrain( + preprocessing = utils.ClassificationPresetTrain( crop_size=train_crop_size, interpolation=interpolation, auto_augment_policy=auto_augment_policy, @@ -150,9 +196,7 @@ def load_data(traindir, valdir, args): ra_magnitude=ra_magnitude, augmix_severity=augmix_severity, ) - dataset = torchvision.datasets.ImageFolder( - traindir, - preprocessing) + dataset = torchvision.datasets.ImageFolder(traindir, preprocessing) # ) if args.meta else torchvision.datasets.ImageNet( # traindir, # split="train", @@ -166,7 +210,9 @@ def load_data(traindir, valdir, args): print(f"Number of training images: {len(dataset)}") if args.distributed: if hasattr(args, "ra_sampler") and args.ra_sampler: - train_sampler = RASampler(dataset, shuffle=True, repetitions=args.ra_reps) + train_sampler = RASampler( + dataset, shuffle=True, repetitions=args.ra_reps + ) else: train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) else: @@ -177,30 +223,38 @@ def load_data(traindir, valdir, args): if args.cache_dataset and os.path.exists(cache_path): # Attention, as the transforms are also cached! print(f"Loading dataset_test from {cache_path}") - dataset_test, _ = torch.load(cache_path) + dataset_test, test_sampler = torch.load(cache_path) else: if args.weights: weights = torchvision.models.get_weight(args.weights) preprocessing = weights.transforms() else: - preprocessing = presets.ClassificationPresetEval( - crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation + preprocessing = utils.ClassificationPresetEval( + crop_size=val_crop_size, + resize_size=val_resize_size, + interpolation=interpolation, + ) + dataset_test = ( + torchvision.datasets.ImageFolder( + valdir, + preprocessing, + ) + if args.meta + else torchvision.datasets.ImageNet( + valdir, split="val", transform=preprocessing ) - dataset_test = torchvision.datasets.ImageFolder( - valdir, - preprocessing, - ) if args.meta else torchvision.datasets.ImageNet( - valdir, - split='val', - transform=preprocessing ) if args.cache_dataset: print(f"Saving dataset_test to {cache_path}") utils.mkdir(os.path.dirname(cache_path)) utils.save_on_master((dataset_test, valdir), cache_path) - + print(f"Number of validation images: {len(dataset_test)}") - test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False) if args.distributed else torch.utils.data.SequentialSampler(dataset_test) + test_sampler = ( + torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False) + if args.distributed + else torch.utils.data.SequentialSampler(dataset_test) + ) # for evaluation if traindir is None: @@ -208,6 +262,7 @@ def load_data(traindir, valdir, args): return dataset, dataset_test, train_sampler, test_sampler + def main(args): if args.output_dir: utils.mkdir(args.output_dir) @@ -225,15 +280,21 @@ def main(args): train_dir = os.path.join(args.data_path, "train_blurred") val_dir = os.path.join(args.data_path, "val") - dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args) + dataset, dataset_test, train_sampler, test_sampler = load_data( + train_dir, val_dir, args + ) collate_fn = None num_classes = len(dataset.classes) mixup_transforms = [] if args.mixup_alpha > 0.0: - mixup_transforms.append(transforms.RandomMixup(num_classes, p=1.0, alpha=args.mixup_alpha)) + mixup_transforms.append( + utils.RandomMixup(num_classes, p=1.0, alpha=args.mixup_alpha) + ) if args.cutmix_alpha > 0.0: - mixup_transforms.append(transforms.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha)) + mixup_transforms.append( + utils.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha) + ) if mixup_transforms: mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms) @@ -249,11 +310,17 @@ def collate_fn(batch): collate_fn=collate_fn, ) data_loader_test = torch.utils.data.DataLoader( - dataset_test, batch_size=args.batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True + dataset_test, + batch_size=args.batch_size, + sampler=test_sampler, + num_workers=args.workers, + pin_memory=True, ) print("Creating model") - model = torchvision.models.get_model(args.model, weights=args.weights, num_classes=num_classes) + model = torchvision.models.get_model( + args.model, weights=args.weights, num_classes=num_classes + ) if args.weights_path is not None: sd = torch.load(args.weights_path, map_location="cpu") @@ -262,7 +329,7 @@ def collate_fn(batch): model.to(device) if args.distributed and args.sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) - + sparsifier = simulate_sparsity(model, args) criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) @@ -270,13 +337,19 @@ def collate_fn(batch): if args.bias_weight_decay is not None: custom_keys_weight_decay.append(("bias", args.bias_weight_decay)) if args.transformer_embedding_decay is not None: - for key in ["class_token", "position_embedding", "relative_position_bias_table"]: + for key in [ + "class_token", + "position_embedding", + "relative_position_bias_table", + ]: custom_keys_weight_decay.append((key, args.transformer_embedding_decay)) parameters = utils.set_weight_decay( model, args.weight_decay, norm_weight_decay=args.norm_weight_decay, - custom_keys_weight_decay=custom_keys_weight_decay if len(custom_keys_weight_decay) > 0 else None, + custom_keys_weight_decay=( + custom_keys_weight_decay if len(custom_keys_weight_decay) > 0 else None + ), ) opt_name = args.opt.lower() @@ -290,24 +363,37 @@ def collate_fn(batch): ) elif opt_name == "rmsprop": optimizer = torch.optim.RMSprop( - parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, eps=0.0316, alpha=0.9 + parameters, + lr=args.lr, + momentum=args.momentum, + weight_decay=args.weight_decay, + eps=0.0316, + alpha=0.9, ) elif opt_name == "adamw": - optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay) + optimizer = torch.optim.AdamW( + parameters, lr=args.lr, weight_decay=args.weight_decay + ) else: - raise RuntimeError(f"Invalid optimizer {args.opt}. Only SGD, RMSprop and AdamW are supported.") + raise RuntimeError( + f"Invalid optimizer {args.opt}. Only SGD, RMSprop and AdamW are supported." + ) scaler = torch.cuda.amp.GradScaler() if args.amp else None args.lr_scheduler = args.lr_scheduler.lower() if args.lr_scheduler == "steplr": - main_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma) + main_lr_scheduler = torch.optim.lr_scheduler.StepLR( + optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma + ) elif args.lr_scheduler == "cosineannealinglr": main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=args.epochs - args.lr_warmup_epochs, eta_min=args.lr_min ) elif args.lr_scheduler == "exponentiallr": - main_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_gamma) + main_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR( + optimizer, gamma=args.lr_gamma + ) else: raise RuntimeError( f"Invalid lr scheduler '{args.lr_scheduler}'. Only StepLR, CosineAnnealingLR and ExponentialLR " @@ -317,25 +403,33 @@ def collate_fn(batch): if args.lr_warmup_epochs > 0: if args.lr_warmup_method == "linear": warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR( - optimizer, start_factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs + optimizer, + start_factor=args.lr_warmup_decay, + total_iters=args.lr_warmup_epochs, ) elif args.lr_warmup_method == "constant": warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR( - optimizer, factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs + optimizer, + factor=args.lr_warmup_decay, + total_iters=args.lr_warmup_epochs, ) else: raise RuntimeError( f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant are supported." ) lr_scheduler = torch.optim.lr_scheduler.SequentialLR( - optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[args.lr_warmup_epochs] + optimizer, + schedulers=[warmup_lr_scheduler, main_lr_scheduler], + milestones=[args.lr_warmup_epochs], ) else: lr_scheduler = main_lr_scheduler model_without_ddp = model if args.distributed: - model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True) + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[args.gpu], find_unused_parameters=True + ) model_without_ddp = model.module model_ema = None @@ -349,16 +443,20 @@ def collate_fn(batch): adjust = args.world_size * args.batch_size * args.model_ema_steps / args.epochs alpha = 1.0 - args.model_ema_decay alpha = min(1.0, alpha * adjust) - model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=1.0 - alpha) + model_ema = utils.ExponentialMovingAverage( + model_without_ddp, device=device, decay=1.0 - alpha + ) - #TODO: need to test resume functionality + # TODO: need to test resume functionality if args.resume: checkpoint_pattern = os.path.join(args.output_dir, "model_*.pth") checkpoint_files = glob.glob(checkpoint_pattern) - epochs = [int(f.split('_')[-1].split('.')[0]) for f in checkpoint_files] + epochs = [int(f.split("_")[-1].split(".")[0]) for f in checkpoint_files] if epochs: latest_epoch = max(epochs) - latest_checkpoint = os.path.join(args.output_dir, f"model_{latest_epoch}.pth") + latest_checkpoint = os.path.join( + args.output_dir, f"model_{latest_epoch}.pth" + ) try: checkpoint = torch.load(latest_checkpoint, map_location="cpu") model_without_ddp.load_state_dict(checkpoint["model"]) @@ -371,7 +469,9 @@ def collate_fn(batch): scaler.load_state_dict(checkpoint["scaler"]) print(f"Resumed training from epoch {args.start_epoch}.") except FileNotFoundError: - print(f"No checkpoint found at {latest_checkpoint}. Starting training from scratch.") + print( + f"No checkpoint found at {latest_checkpoint}. Starting training from scratch." + ) args.start_epoch = 0 else: print("No checkpoint found. Starting training from scratch.") @@ -380,7 +480,9 @@ def collate_fn(batch): args.start_epoch = 0 print("Zero-shot evaluation") if model_ema: - evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA") + evaluate( + model_ema, criterion, data_loader_test, device=device, log_suffix="EMA" + ) else: evaluate(model, criterion, data_loader_test, device=device) @@ -389,11 +491,23 @@ def collate_fn(batch): for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) - train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema, scaler) + train_one_epoch( + model, + criterion, + optimizer, + data_loader, + device, + epoch, + args, + model_ema, + scaler, + ) lr_scheduler.step() evaluate(model, criterion, data_loader_test, device=device) if model_ema: - evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA") + evaluate( + model_ema, criterion, data_loader_test, device=device, log_suffix="EMA" + ) if args.output_dir: checkpoint = { "model": model_without_ddp.state_dict(), @@ -408,152 +522,18 @@ def collate_fn(batch): checkpoint["model_ema"] = model_ema.state_dict() if scaler: checkpoint["scaler"] = scaler.state_dict() - utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth")) - utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth")) + utils.save_on_master( + checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth") + ) + utils.save_on_master( + checkpoint, os.path.join(args.output_dir, "checkpoint.pth") + ) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print(f"Training time {total_time_str}") -def get_args_parser(add_help=True): - import argparse - - parser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help) - parser.add_argument("--data-path", type=str, help="dataset path") - parser.add_argument("--model", default="resnet18", type=str, help="model name") - parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)") - parser.add_argument( - "-b", "--batch-size", default=32, type=int, help="images per gpu, the total batch size is $NGPU x batch_size" - ) - parser.add_argument("--accumulation-steps", default=1, type=int, help="Number of steps to accumulate gradients over") - parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run") - parser.add_argument( - "-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)" - ) - parser.add_argument("--opt", default="sgd", type=str, help="optimizer") - parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate") - parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum") - parser.add_argument( - "--wd", - "--weight-decay", - default=1e-4, - type=float, - metavar="W", - help="weight decay (default: 1e-4)", - dest="weight_decay", - ) - parser.add_argument( - "--norm-weight-decay", - default=None, - type=float, - help="weight decay for Normalization layers (default: None, same value as --wd)", - ) - parser.add_argument( - "--bias-weight-decay", - default=None, - type=float, - help="weight decay for bias parameters of all layers (default: None, same value as --wd)", - ) - parser.add_argument( - "--transformer-embedding-decay", - default=None, - type=float, - help="weight decay for embedding parameters for vision transformer models (default: None, same value as --wd)", - ) - parser.add_argument( - "--label-smoothing", default=0.0, type=float, help="label smoothing (default: 0.0)", dest="label_smoothing" - ) - parser.add_argument("--mixup-alpha", default=0.0, type=float, help="mixup alpha (default: 0.0)") - parser.add_argument("--cutmix-alpha", default=0.0, type=float, help="cutmix alpha (default: 0.0)") - parser.add_argument("--lr-scheduler", default="steplr", type=str, help="the lr scheduler (default: steplr)") - parser.add_argument("--lr-warmup-epochs", default=0, type=int, help="the number of epochs to warmup (default: 0)") - parser.add_argument( - "--lr-warmup-method", default="constant", type=str, help="the warmup method (default: constant)" - ) - parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr") - parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs") - parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma") - parser.add_argument("--lr-min", default=0.0, type=float, help="minimum lr of lr schedule (default: 0.0)") - parser.add_argument("--print-freq", default=10, type=int, help="print frequency") - parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs") - parser.add_argument('--resume', action='store_true', help='Resumes training from latest available checkpoint ("model_.pth")') - parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch") - parser.add_argument( - "--cache-dataset", - dest="cache_dataset", - help="Cache the datasets for quicker initialization. It also serializes the transforms", - action="store_true", - ) - parser.add_argument( - "--sync-bn", - dest="sync_bn", - help="Use sync batch norm", - action="store_true", - ) - parser.add_argument("--auto-augment", default=None, type=str, help="auto augment policy (default: None)") - parser.add_argument("--ra-magnitude", default=9, type=int, help="magnitude of auto augment policy") - parser.add_argument("--augmix-severity", default=3, type=int, help="severity of augmix policy") - parser.add_argument("--random-erase", default=0.0, type=float, help="random erasing probability (default: 0.0)") - - # Mixed precision training parameters - parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training") - - # distributed training parameters - parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") - parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training") - parser.add_argument( - "--model-ema", action="store_true", help="enable tracking Exponential Moving Average of model parameters" - ) - parser.add_argument( - "--model-ema-steps", - type=int, - default=32, - help="the number of iterations that controls how often to update the EMA model (default: 32)", - ) - parser.add_argument( - "--model-ema-decay", - type=float, - default=0.99998, - help="decay factor for Exponential Moving Average of model parameters (default: 0.99998)", - ) - parser.add_argument( - "--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only." - ) - parser.add_argument( - "--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)" - ) - parser.add_argument( - "--val-resize-size", default=256, type=int, help="the resize size used for validation (default: 256)" - ) - parser.add_argument( - "--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)" - ) - parser.add_argument( - "--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)" - ) - parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)") - parser.add_argument( - "--ra-reps", default=3, type=int, help="number of repetitions for Repeated Augmentation (default: 3)" - ) - parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") - parser.add_argument("--weights-path", type=str) - - # NOTE: sparsity args - parser.add_argument("--sparsity", choices=["bsr", "semi_structured"], default=None, help='weight sparsification to apply') - parser.add_argument("--sparsity-linear", type=float, default=0.0) - parser.add_argument("--sp-linear-tile-size", type=int, default=1) - parser.add_argument("--sparsity-conv1x1", type=float, default=0.0) - parser.add_argument("--sp-conv1x1-tile-size", type=int, default=1) - parser.add_argument("--sparsity-conv", type=float, default=0.0) - parser.add_argument("--sp-conv-tile-size", type=int, default=1) - parser.add_argument("--skip-last-layer-sparsity", action="store_true", help="Skip applying sparsity to the last linear layer (for vit only)") - parser.add_argument("--skip-first-transformer-sparsity", action="store_true", help="Skip applying sparsity to the first transformer layer (for vit only)") - parser.add_argument('--bsr', type=int, nargs='?', const=256, default=None, help='Convert sparsified weights to BSR format with optional block size (default: 256)') - parser.add_argument('--meta', action='store_true', help='Use Meta internal imagenet structure') - return parser - - if __name__ == "__main__": - args = get_args_parser().parse_args() + args = utils.get_args_parser(train=True).parse_args() main(args) diff --git a/torchao/sparsity/prototype/superblock/transforms.py b/torchao/sparsity/prototype/superblock/transforms.py deleted file mode 100644 index 2375e3fc41..0000000000 --- a/torchao/sparsity/prototype/superblock/transforms.py +++ /dev/null @@ -1,185 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. - -import math -from typing import Tuple - -import torch -from torch import Tensor -from torchvision.transforms import functional as F - - -class RandomMixup(torch.nn.Module): - """Randomly apply Mixup to the provided batch and targets. - The class implements the data augmentations as described in the paper - `"mixup: Beyond Empirical Risk Minimization" `_. - - Args: - num_classes (int): number of classes used for one-hot encoding. - p (float): probability of the batch being transformed. Default value is 0.5. - alpha (float): hyperparameter of the Beta distribution used for mixup. - Default value is 1.0. - inplace (bool): boolean to make this transform inplace. Default set to False. - """ - - def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None: - super().__init__() - - if num_classes < 1: - raise ValueError( - f"Please provide a valid positive value for the num_classes. Got num_classes={num_classes}" - ) - - if alpha <= 0: - raise ValueError("Alpha param can't be zero.") - - self.num_classes = num_classes - self.p = p - self.alpha = alpha - self.inplace = inplace - - def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: - """ - Args: - batch (Tensor): Float tensor of size (B, C, H, W) - target (Tensor): Integer tensor of size (B, ) - - Returns: - Tensor: Randomly transformed batch. - """ - if batch.ndim != 4: - raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}") - if target.ndim != 1: - raise ValueError(f"Target ndim should be 1. Got {target.ndim}") - if not batch.is_floating_point(): - raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.") - if target.dtype != torch.int64: - raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}") - - if not self.inplace: - batch = batch.clone() - target = target.clone() - - if target.ndim == 1: - target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype) - - if torch.rand(1).item() >= self.p: - return batch, target - - # It's faster to roll the batch by one instead of shuffling it to create image pairs - batch_rolled = batch.roll(1, 0) - target_rolled = target.roll(1, 0) - - # Implemented as on mixup paper, page 3. - lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0]) - batch_rolled.mul_(1.0 - lambda_param) - batch.mul_(lambda_param).add_(batch_rolled) - - target_rolled.mul_(1.0 - lambda_param) - target.mul_(lambda_param).add_(target_rolled) - - return batch, target - - def __repr__(self) -> str: - s = ( - f"{self.__class__.__name__}(" - f"num_classes={self.num_classes}" - f", p={self.p}" - f", alpha={self.alpha}" - f", inplace={self.inplace}" - f")" - ) - return s - - -class RandomCutmix(torch.nn.Module): - """Randomly apply Cutmix to the provided batch and targets. - The class implements the data augmentations as described in the paper - `"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features" - `_. - - Args: - num_classes (int): number of classes used for one-hot encoding. - p (float): probability of the batch being transformed. Default value is 0.5. - alpha (float): hyperparameter of the Beta distribution used for cutmix. - Default value is 1.0. - inplace (bool): boolean to make this transform inplace. Default set to False. - """ - - def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None: - super().__init__() - if num_classes < 1: - raise ValueError("Please provide a valid positive value for the num_classes.") - if alpha <= 0: - raise ValueError("Alpha param can't be zero.") - - self.num_classes = num_classes - self.p = p - self.alpha = alpha - self.inplace = inplace - - def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: - """ - Args: - batch (Tensor): Float tensor of size (B, C, H, W) - target (Tensor): Integer tensor of size (B, ) - - Returns: - Tensor: Randomly transformed batch. - """ - if batch.ndim != 4: - raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}") - if target.ndim != 1: - raise ValueError(f"Target ndim should be 1. Got {target.ndim}") - if not batch.is_floating_point(): - raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.") - if target.dtype != torch.int64: - raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}") - - if not self.inplace: - batch = batch.clone() - target = target.clone() - - if target.ndim == 1: - target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype) - - if torch.rand(1).item() >= self.p: - return batch, target - - # It's faster to roll the batch by one instead of shuffling it to create image pairs - batch_rolled = batch.roll(1, 0) - target_rolled = target.roll(1, 0) - - # Implemented as on cutmix paper, page 12 (with minor corrections on typos). - lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0]) - _, H, W = F.get_dimensions(batch) - - r_x = torch.randint(W, (1,)) - r_y = torch.randint(H, (1,)) - - r = 0.5 * math.sqrt(1.0 - lambda_param) - r_w_half = int(r * W) - r_h_half = int(r * H) - - x1 = int(torch.clamp(r_x - r_w_half, min=0)) - y1 = int(torch.clamp(r_y - r_h_half, min=0)) - x2 = int(torch.clamp(r_x + r_w_half, max=W)) - y2 = int(torch.clamp(r_y + r_h_half, max=H)) - - batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2] - lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H)) - - target_rolled.mul_(1.0 - lambda_param) - target.mul_(lambda_param).add_(target_rolled) - - return batch, target - - def __repr__(self) -> str: - s = ( - f"{self.__class__.__name__}(" - f"num_classes={self.num_classes}" - f", p={self.p}" - f", alpha={self.alpha}" - f", inplace={self.inplace}" - f")" - ) - return s diff --git a/torchao/sparsity/prototype/superblock/utils.py b/torchao/sparsity/prototype/superblock/utils.py index e779613f5c..cf865fd369 100644 --- a/torchao/sparsity/prototype/superblock/utils.py +++ b/torchao/sparsity/prototype/superblock/utils.py @@ -1,91 +1,190 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import argparse import copy import datetime import errno import hashlib +import math import os import time from collections import defaultdict, deque, OrderedDict from typing import List, Optional, Tuple import torch -import torch.distributed as dist -from torchao.quantization import quantize_, int8_dynamic_activation_int8_semi_sparse_weight -from torchao.sparsity import sparsify_, semi_sparse_weight -from torchao.sparsity.prototype.superblock.supermask import SupermaskLinear, apply_supermask +from torchao.quantization import int8_dynamic_activation_int8_weight, quantize_ +from torchao.sparsity import semi_sparse_weight, sparsify_ +from torchao.sparsity.prototype.sparsifier.weight_norm_sparsifier import ( + WeightNormSparsifier, +) from torchao.sparsity.prototype.superblock.blocksparse import block_sparse_weight -from torchao.sparsity.prototype.sparsifier.weight_norm_sparsifier import WeightNormSparsifier +from torchao.sparsity.prototype.superblock.supermask import ( + apply_supermask, + SupermaskLinear, +) +from torchvision.transforms import autoaugment, functional as F, transforms +from torchvision.transforms.functional import InterpolationMode + +def get_args_parser(train=False, evaluate=False, benchmark=False): + assert sum([train, evaluate, benchmark]) == 1, "One and only one of training, evaluation, or benchmark can be true" + + # Shared common args + parser = argparse.ArgumentParser(description="SuperBlock Imagenet Training/Evaluation/Benchmarking Script", add_help=True) + parser.add_argument("--data-path", type=str, help="IMAGENET dataset path") + parser.add_argument("--model", default="vit_b_16", choices=["vit_b_16", "vit_h_14"], type=str, help="ViT base model") + parser.add_argument("--device", default="cuda", type=str, help="device (Default: cuda)") + parser.add_argument("-b", "--batch-size", default=32, type=int, help="per device batch size") + parser.add_argument("--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)") + parser.add_argument("--sparsity", choices=["bsr", "semi_structured"], default=None, help='weight sparsification to apply') + parser.add_argument('--bsr', type=int, nargs='?', const=256, default=None, help='Convert sparsified weights to BSR format with optional block size (default: 256)') + parser.add_argument("--sparsity-linear", type=float, default=0.0) + parser.add_argument("--sparsity-conv1x1", type=float, default=0.0) + parser.add_argument("--sparsity-conv", type=float, default=0.0) + parser.add_argument("--skip-last-layer-sparsity", action="store_true", help="Skip applying sparsity to the last linear layer (for vit only)") + parser.add_argument("--skip-first-transformer-sparsity", action="store_true", help="Skip applying sparsity to the first transformer layer (for vit only)") + parser.add_argument("--quantization", action="store_true", help="Run with int8 dynamic quantization") + parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") + parser.add_argument("--weights-path", type=str, help="optional checkpoint to load weights after intialization") + parser.add_argument("--header", action="store_true", help="Print header for first run") + + # Eval a subset of training args + # lots of training args + if train or evaluate: + parser.add_argument("-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers") + parser.add_argument("--accumulation-steps", default=1, type=int, help="Number of steps to accumulate gradients over") + parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run") + parser.add_argument("--opt", default="sgd", type=str, help="optimizer") + parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate") + parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum") + parser.add_argument("--wd", "--weight-decay", default=1e-4, type=float, metavar="W", help="weight decay", dest="weight_decay") + parser.add_argument("--norm-weight-decay", default=None, type=float, help="weight decay for Normalization layers (default: None, same value as --wd)") + parser.add_argument("--bias-weight-decay", default=None, type=float, help="weight decay for bias parameters of all layers (default: None, same value as --wd)") + parser.add_argument("--transformer-embedding-decay", default=None, type=float, help="weight decay for embedding parameters for vision transformer models (default: None, same value as --wd)") + parser.add_argument("--label-smoothing", default=0.0, type=float, help="label smoothing (default: 0.0)", dest="label_smoothing") + parser.add_argument("--mixup-alpha", default=0.0, type=float, help="mixup alpha (default: 0.0)") + parser.add_argument("--cutmix-alpha", default=0.0, type=float, help="cutmix alpha (default: 0.0)") + parser.add_argument("--lr-scheduler", default="steplr", type=str, help="the lr scheduler (default: steplr)") + parser.add_argument("--lr-warmup-epochs", default=0, type=int, help="the number of epochs to warmup (default: 0)") + parser.add_argument("--lr-warmup-method", default="constant", type=str, help="the warmup method (default: constant)") + parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr") + parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs") + parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma") + parser.add_argument("--lr-min", default=0.0, type=float, help="minimum lr of lr schedule (default: 0.0)") + parser.add_argument("--print-freq", default=10, type=int, help="print frequency") + parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs") + parser.add_argument('--resume', action='store_true', help='Resumes training from latest available checkpoint ("model_.pth")') + parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch") + parser.add_argument("--cache-dataset", dest="cache_dataset", help="Cache the datasets for quicker initialization. It also serializes the transforms", action="store_true") + parser.add_argument("--sync-bn", dest="sync_bn", help="Use sync batch norm", action="store_true") + parser.add_argument("--auto-augment", default=None, type=str, help="auto augment policy (default: None)") + parser.add_argument("--ra-magnitude", default=9, type=int, help="magnitude of auto augment policy") + parser.add_argument("--augmix-severity", default=3, type=int, help="severity of augmix policy") + parser.add_argument("--random-erase", default=0.0, type=float, help="random erasing probability (default: 0.0)") + # Mixed precision training parameters + parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training") + # distributed training parameters + parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") + parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training") + parser.add_argument("--model-ema", action="store_true", help="enable tracking Exponential Moving Average of model parameters") + parser.add_argument("--model-ema-steps", type=int, default=32, help="the number of iterations that controls how often to update the EMA model (default: 32)") + parser.add_argument("--model-ema-decay", type=float, default=0.99998, help="decay factor for Exponential Moving Average of model parameters (default: 0.99998)") + parser.add_argument("--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only.") + parser.add_argument("--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)") + parser.add_argument("--val-resize-size", default=256, type=int, help="the resize size used for validation (default: 256)") + parser.add_argument("--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)") + parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)") + parser.add_argument("--ra-reps", default=3, type=int, help="number of repetitions for Repeated Augmentation (default: 3)") + parser.add_argument('--meta', action='store_true', help='Use Meta internal imagenet structure') + + if benchmark: + parser.add_argument("--dtype", choices=["float32", "bfloat16", "float16"], help="Data type", default="bfloat16") + parser.add_argument("--tune-kernel-params", action="store_true", help="Tune kernel params for BSR") + parser.add_argument("--profile", action="store_true", help="Dump Prefetto trace") + + return parser + -### Custom sparsification utils -def apply_sparsity(model): - for name, module in model.named_modules(): - if isinstance(module, SupermaskLinear) and "mlp" in name: - module.sparsify_offline() - -def verify_sparsity(model): - for name, module in model.named_modules(): - if isinstance(module, torch.nn.Linear): - total_weights = module.weight.numel() - sparse_weights = (module.weight == 0).sum().item() - sparsity_percentage = (sparse_weights / total_weights) * 100 - print(f"Sparsity verified in layer {name}: {sparsity_percentage:.2f}%") # filter functions def mlp_0_only(mod, name): - return isinstance(mod, torch.nn.Linear) and 'mlp.0' in name + return isinstance(mod, torch.nn.Linear) and "mlp.0" in name + def mlp_3_only(mod, name): - return isinstance(mod, torch.nn.Linear) and 'mlp.3' in name + return isinstance(mod, torch.nn.Linear) and "mlp.3" in name + def mlp_only(mod, name): - return isinstance(mod, torch.nn.Linear) and 'mlp' in name - + return isinstance(mod, torch.nn.Linear) and "mlp" in name + + def superblock_only(mod, name): return isinstance(mod, SupermaskLinear) and "mlp" in name -def mlp_only_with_args(mod, name, skip_last_layer_sparsity=False, skip_first_transformer_sparsity=False): + +def mlp_only_with_args( + mod, name, skip_last_layer_sparsity=False, skip_first_transformer_sparsity=False +): if skip_last_layer_sparsity and "heads.head" in name: return False if skip_first_transformer_sparsity and "encoder.layers.encoder_layer_0" in name: return False - if isinstance(mod, torch.nn.Linear) and "mlp" in name: + if isinstance(mod, torch.nn.Linear) and "mlp" in name: return True return False -### other + +### Custom sparsification utils +def apply_sparsity(model): + for name, module in model.named_modules(): + if isinstance(module, SupermaskLinear) and "mlp" in name: + module.sparsify_offline() + def accelerate_with_sparsity(model, args): if args.sparsity == "bsr": apply_sparsity(model) - verify_sparsity(model) - assert args.bsr is not None, "BSR requires a block size" - sparsify_(model, block_sparse_weight(blocksize=args.bsr), superblock_only) - + if args.quantization: + from torchao.dtypes.affine_quantized_tensor import BlockSparseLayoutType + + quantize_( + model, + int8_dynamic_activation_int8_weight( + layout_type=BlockSparseLayoutType(blocksize=args.bsr) + ), + superblock_only, + ) + else: + assert args.bsr is not None, "BSR requires a block size" + sparsify_(model, block_sparse_weight(blocksize=args.bsr), superblock_only) elif args.sparsity == "semi_structured": if args.quantization: - quantize_(model, - int8_dynamic_activation_int8_semi_sparse_weight(), - mlp_0_only) - sparsify_(model, - semi_sparse_weight(), - mlp_3_only) + from torchao.dtypes.affine_quantized_tensor import SemiSparseLayoutType + + quantize_( + model, + int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType()), + mlp_0_only, + ) + sparsify_(model, semi_sparse_weight(), mlp_3_only) else: - sparsify_(model, - semi_sparse_weight(), - mlp_only) + sparsify_(model, semi_sparse_weight(), mlp_only) + else: + if args.quantization: + quantize_(model, int8_dynamic_activation_int8_weight(), mlp_only) + def simulate_sparsity(model, args): if args.sparsity == "bsr": apply_supermask( model, linear_sparsity=args.sparsity_linear, - linear_sp_tilesize=args.sp_linear_tile_size, + linear_sp_tilesize=args.bsr, conv1x1_sparsity=args.sparsity_conv1x1, - conv1x1_sp_tilesize=args.sp_conv1x1_tile_size, + conv1x1_sp_tilesize=args.bsr, conv_sparsity=args.sparsity_conv, - conv_sp_tilesize=args.sp_conv_tile_size, + conv_sp_tilesize=args.bsr, skip_last_layer_sparsity=args.skip_last_layer_sparsity, skip_first_transformer_sparsity=args.skip_first_transformer_sparsity, device=args.device, @@ -94,24 +193,27 @@ def simulate_sparsity(model, args): elif args.sparsity == "semi_structured": sparse_config = [] for name, mod in model.named_modules(): - if mlp_only_with_args(mod, name, - skip_first_transformer_sparsity=args.skip_first_transformer_sparsity, - skip_last_layer_sparsity=args.skip_last_layer_sparsity): + if mlp_only_with_args( + mod, + name, + skip_first_transformer_sparsity=args.skip_first_transformer_sparsity, + skip_last_layer_sparsity=args.skip_last_layer_sparsity, + ): sparse_config.append({"tensor_fqn": f"{name}.weight"}) sparsifier = WeightNormSparsifier( sparsity_level=1.0, sparse_block_shape=(1, 4), zeros_per_block=2 ) sparsifier.prepare(model, sparse_config) - for line in sparse_config: - print(line) sparsifier.step() return sparsifier - else: - print("No sparsity applied!") -### Existing torchvision utils +# ------------------------------------------------------------ +# The following code contains torchvision reference code, +# largely copied from: https://github.com/pytorch/vision/tree/main/references/classification +# Please open issues in the original repository if you have questions. + class SmoothedValue: """Track a series of values and provide access to smoothed values over a @@ -164,7 +266,11 @@ def value(self): def __str__(self): return self.fmt.format( - median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value, ) @@ -185,7 +291,9 @@ def __getattr__(self, attr): return self.meters[attr] if attr in self.__dict__: return self.__dict__[attr] - raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'") + raise AttributeError( + f"'{type(self).__name__}' object has no attribute '{attr}'" + ) def __str__(self): loss_str = [] @@ -223,7 +331,14 @@ def log_every(self, iterable, print_freq, header=None): ) else: log_msg = self.delimiter.join( - [header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", "{meters}", "time: {time}", "data: {data}"] + [ + header, + "[{0" + space_fmt + "}/{1}]", + "eta: {eta}", + "{meters}", + "time: {time}", + "data: {data}", + ] ) MB = 1024.0 * 1024.0 for obj in iterable: @@ -248,7 +363,12 @@ def log_every(self, iterable, print_freq, header=None): else: print( log_msg.format( - i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time) + i, + len(iterable), + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), ) ) i += 1 @@ -316,9 +436,9 @@ def print(*args, **kwargs): def is_dist_avail_and_initialized(): - if not dist.is_available(): + if not torch.distributed.is_available(): return False - if not dist.is_initialized(): + if not torch.distributed.is_initialized(): return False return True @@ -326,13 +446,13 @@ def is_dist_avail_and_initialized(): def get_world_size(): if not is_dist_avail_and_initialized(): return 1 - return dist.get_world_size() + return torch.distributed.get_world_size() def get_rank(): if not is_dist_avail_and_initialized(): return 0 - return dist.get_rank() + return torch.distributed.get_rank() def is_main_process(): @@ -363,9 +483,12 @@ def init_distributed_mode(args): torch.cuda.set_device(args.gpu) args.dist_backend = "nccl" - print(f"| distributed init (rank {args.rank}): {args.dist_url}", flush=True) + print(f"| distributed init (rank {args.rank})", flush=True) torch.distributed.init_process_group( - backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank + backend=args.dist_backend, + init_method=args.dist_url, + world_size=args.world_size, + rank=args.rank, ) torch.distributed.barrier() setup_for_distributed(args.rank == 0) @@ -390,7 +513,9 @@ def average_checkpoints(inputs): with open(fpath, "rb") as f: state = torch.load( f, - map_location=(lambda s, _: torch.serialization.default_restore_location(s, "cpu")), + map_location=( + lambda s, _: torch.serialization.default_restore_location(s, "cpu") + ), ) # Copies over the settings from the first checkpoint if new_state is None: @@ -475,7 +600,9 @@ def store_model_weights(model, checkpoint_path, checkpoint_key="model", strict=T # and remove unnecessary weights (such as auxiliaries, etc) if checkpoint_key == "model_ema": del checkpoint[checkpoint_key]["n_averaged"] - torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(checkpoint[checkpoint_key], "module.") + torch.nn.modules.utils.consume_prefix_in_state_dict_if_present( + checkpoint[checkpoint_key], "module." + ) model.load_state_dict(checkpoint[checkpoint_key], strict=strict) tmp_path = os.path.join(output_dir, str(model.__hash__())) @@ -500,8 +627,8 @@ def reduce_across_processes(val): return torch.tensor(val) t = torch.tensor(val, device="cuda") - dist.barrier() - dist.all_reduce(t) + torch.distributed.barrier() + torch.distributed.all_reduce(t) return t @@ -543,7 +670,9 @@ def _add_params(module, prefix=""): continue is_custom_key = False for key in custom_keys: - target_name = f"{prefix}.{name}" if prefix != "" and "." in key else name + target_name = ( + f"{prefix}.{name}" if prefix != "" and "." in key else name + ) if key == target_name: params[key].append(p) is_custom_key = True @@ -563,5 +692,365 @@ def _add_params(module, prefix=""): param_groups = [] for key in params: if len(params[key]) > 0: - param_groups.append({"params": params[key], "weight_decay": params_weight_decay[key]}) + param_groups.append( + {"params": params[key], "weight_decay": params_weight_decay[key]} + ) return param_groups + + +# Presets for ImageNet training/eval taken from: https://github.com/pytorch/vision/blob/main/references/classification/presets.py + + +class ClassificationPresetTrain: + def __init__( + self, + *, + crop_size, + mean=(0.485, 0.456, 0.406), + std=(0.229, 0.224, 0.225), + interpolation=InterpolationMode.BILINEAR, + hflip_prob=0.5, + auto_augment_policy=None, + ra_magnitude=9, + augmix_severity=3, + random_erase_prob=0.0, + ): + trans = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)] + if hflip_prob > 0: + trans.append(transforms.RandomHorizontalFlip(hflip_prob)) + if auto_augment_policy is not None: + if auto_augment_policy == "ra": + trans.append( + autoaugment.RandAugment( + interpolation=interpolation, magnitude=ra_magnitude + ) + ) + elif auto_augment_policy == "ta_wide": + trans.append( + autoaugment.TrivialAugmentWide(interpolation=interpolation) + ) + elif auto_augment_policy == "augmix": + trans.append( + autoaugment.AugMix( + interpolation=interpolation, severity=augmix_severity + ) + ) + else: + aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy) + trans.append( + autoaugment.AutoAugment( + policy=aa_policy, interpolation=interpolation + ) + ) + trans.extend( + [ + transforms.PILToTensor(), + transforms.ConvertImageDtype(torch.float), + transforms.Normalize(mean=mean, std=std), + ] + ) + if random_erase_prob > 0: + trans.append(transforms.RandomErasing(p=random_erase_prob)) + + self.transforms = transforms.Compose(trans) + + def __call__(self, img): + return self.transforms(img) + + +class ClassificationPresetEval: + def __init__( + self, + *, + crop_size, + resize_size=256, + mean=(0.485, 0.456, 0.406), + std=(0.229, 0.224, 0.225), + interpolation=InterpolationMode.BILINEAR, + ): + + self.transforms = transforms.Compose( + [ + transforms.Resize(resize_size, interpolation=interpolation), + transforms.CenterCrop(crop_size), + transforms.PILToTensor(), + transforms.ConvertImageDtype(torch.float), + transforms.Normalize(mean=mean, std=std), + ] + ) + + def __call__(self, img): + return self.transforms(img) + + +# transforms taken from: https://github.com/pytorch/vision/blob/main/references/classification/transforms.py + + +class RandomMixup(torch.nn.Module): + """Randomly apply Mixup to the provided batch and targets. + The class implements the data augmentations as described in the paper + `"mixup: Beyond Empirical Risk Minimization" `_. + + Args: + num_classes (int): number of classes used for one-hot encoding. + p (float): probability of the batch being transformed. Default value is 0.5. + alpha (float): hyperparameter of the Beta distribution used for mixup. + Default value is 1.0. + inplace (bool): boolean to make this transform inplace. Default set to False. + """ + + def __init__( + self, + num_classes: int, + p: float = 0.5, + alpha: float = 1.0, + inplace: bool = False, + ) -> None: + super().__init__() + + if num_classes < 1: + raise ValueError( + f"Please provide a valid positive value for the num_classes. Got num_classes={num_classes}" + ) + + if alpha <= 0: + raise ValueError("Alpha param can't be zero.") + + self.num_classes = num_classes + self.p = p + self.alpha = alpha + self.inplace = inplace + + def forward( + self, batch: torch.Tensor, target: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + batch (Tensor): Float tensor of size (B, C, H, W) + target (Tensor): Integer tensor of size (B, ) + + Returns: + Tensor: Randomly transformed batch. + """ + if batch.ndim != 4: + raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}") + if target.ndim != 1: + raise ValueError(f"Target ndim should be 1. Got {target.ndim}") + if not batch.is_floating_point(): + raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.") + if target.dtype != torch.int64: + raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}") + + if not self.inplace: + batch = batch.clone() + target = target.clone() + + if target.ndim == 1: + target = torch.nn.functional.one_hot( + target, num_classes=self.num_classes + ).to(dtype=batch.dtype) + + if torch.rand(1).item() >= self.p: + return batch, target + + # It's faster to roll the batch by one instead of shuffling it to create image pairs + batch_rolled = batch.roll(1, 0) + target_rolled = target.roll(1, 0) + + # Implemented as on mixup paper, page 3. + lambda_param = float( + torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0] + ) + batch_rolled.mul_(1.0 - lambda_param) + batch.mul_(lambda_param).add_(batch_rolled) + + target_rolled.mul_(1.0 - lambda_param) + target.mul_(lambda_param).add_(target_rolled) + + return batch, target + + def __repr__(self) -> str: + s = ( + f"{self.__class__.__name__}(" + f"num_classes={self.num_classes}" + f", p={self.p}" + f", alpha={self.alpha}" + f", inplace={self.inplace}" + f")" + ) + return s + + +class RandomCutmix(torch.nn.Module): + """Randomly apply Cutmix to the provided batch and targets. + The class implements the data augmentations as described in the paper + `"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features" + `_. + + Args: + num_classes (int): number of classes used for one-hot encoding. + p (float): probability of the batch being transformed. Default value is 0.5. + alpha (float): hyperparameter of the Beta distribution used for cutmix. + Default value is 1.0. + inplace (bool): boolean to make this transform inplace. Default set to False. + """ + + def __init__( + self, + num_classes: int, + p: float = 0.5, + alpha: float = 1.0, + inplace: bool = False, + ) -> None: + super().__init__() + if num_classes < 1: + raise ValueError( + "Please provide a valid positive value for the num_classes." + ) + if alpha <= 0: + raise ValueError("Alpha param can't be zero.") + + self.num_classes = num_classes + self.p = p + self.alpha = alpha + self.inplace = inplace + + def forward( + self, batch: torch.Tensor, target: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + batch (Tensor): Float tensor of size (B, C, H, W) + target (Tensor): Integer tensor of size (B, ) + + Returns: + Tensor: Randomly transformed batch. + """ + if batch.ndim != 4: + raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}") + if target.ndim != 1: + raise ValueError(f"Target ndim should be 1. Got {target.ndim}") + if not batch.is_floating_point(): + raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.") + if target.dtype != torch.int64: + raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}") + + if not self.inplace: + batch = batch.clone() + target = target.clone() + + if target.ndim == 1: + target = torch.nn.functional.one_hot( + target, num_classes=self.num_classes + ).to(dtype=batch.dtype) + + if torch.rand(1).item() >= self.p: + return batch, target + + # It's faster to roll the batch by one instead of shuffling it to create image pairs + batch_rolled = batch.roll(1, 0) + target_rolled = target.roll(1, 0) + + # Implemented as on cutmix paper, page 12 (with minor corrections on typos). + lambda_param = float( + torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0] + ) + _, H, W = F.get_dimensions(batch) + + r_x = torch.randint(W, (1,)) + r_y = torch.randint(H, (1,)) + + r = 0.5 * math.sqrt(1.0 - lambda_param) + r_w_half = int(r * W) + r_h_half = int(r * H) + + x1 = int(torch.clamp(r_x - r_w_half, min=0)) + y1 = int(torch.clamp(r_y - r_h_half, min=0)) + x2 = int(torch.clamp(r_x + r_w_half, max=W)) + y2 = int(torch.clamp(r_y + r_h_half, max=H)) + + batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2] + lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H)) + + target_rolled.mul_(1.0 - lambda_param) + target.mul_(lambda_param).add_(target_rolled) + + return batch, target + + def __repr__(self) -> str: + s = ( + f"{self.__class__.__name__}(" + f"num_classes={self.num_classes}" + f", p={self.p}" + f", alpha={self.alpha}" + f", inplace={self.inplace}" + f")" + ) + return s + + +# RA Sampler implementaion taken from: https://github.com/pytorch/vision/blob/main/references/classification/sampler.py + + +class RASampler(torch.utils.data.Sampler): + """Sampler that restricts data loading to a subset of the dataset for distributed, + with repeated augmentation. + It ensures that different each augmented version of a sample will be visible to a + different process (GPU). + Heavily based on 'torch.utils.data.DistributedSampler'. + + This is borrowed from the DeiT Repo: + https://github.com/facebookresearch/deit/blob/main/samplers.py + """ + + def __init__( + self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0, repetitions=3 + ): + if num_replicas is None: + if not torch.distributed.is_available(): + raise RuntimeError("Requires distributed package to be available!") + num_replicas = torch.distributed.get_world_size() + if rank is None: + if not torch.distributed.is_available(): + raise RuntimeError("Requires distributed package to be available!") + rank = torch.distributed.get_rank() + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.num_samples = int( + math.ceil(len(self.dataset) * float(repetitions) / self.num_replicas) + ) + self.total_size = self.num_samples * self.num_replicas + self.num_selected_samples = int( + math.floor(len(self.dataset) // 256 * 256 / self.num_replicas) + ) + self.shuffle = shuffle + self.seed = seed + self.repetitions = repetitions + + def __iter__(self): + if self.shuffle: + # Deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + indices = torch.randperm(len(self.dataset), generator=g).tolist() + else: + indices = list(range(len(self.dataset))) + + # Add extra samples to make it evenly divisible + indices = [ele for ele in indices for i in range(self.repetitions)] + indices += indices[: (self.total_size - len(indices))] + assert len(indices) == self.total_size + + # Subsample + indices = indices[self.rank : self.total_size : self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices[: self.num_selected_samples]) + + def __len__(self): + return self.num_selected_samples + + def set_epoch(self, epoch): + self.epoch = epoch diff --git a/torchao/sparsity/sparse_api.py b/torchao/sparsity/sparse_api.py index e1d1a99627..ae343add9e 100644 --- a/torchao/sparsity/sparse_api.py +++ b/torchao/sparsity/sparse_api.py @@ -4,17 +4,17 @@ from torch.ao.pruning import WeightNormSparsifier from torch.sparse import to_sparse_semi_structured from torchao.quantization.quant_api import ( + _get_linear_subclass_inserter, _is_linear, _replace_with_custom_fn_if_matches_filter, - _get_linear_subclass_inserter, int8_dynamic_activation_int8_semi_sparse_weight, ) + # Sparsity helper functions def apply_fake_sparsity(model, **kwargs): """ This function simulates 2:4 sparsity on all linear layers in a model. - It uses the torch.ao.pruning flow. """ filter_fn = kwargs.pop("filter_fn", _is_linear) # torch.ao.pruning flow @@ -30,15 +30,19 @@ def apply_fake_sparsity(model, **kwargs): sparsifier.step() sparsifier.squash_mask() + def semi_sparse_weight(): """ Convert the weight of linear moduels to semi-structured (2:4) sparsity """ return _get_linear_subclass_inserter(to_sparse_semi_structured) -def sparsify_(model: torch.nn.Module, - apply_tensor_subclass: Callable[[torch.Tensor], torch.Tensor], - filter_fn: Optional[Callable[[torch.nn.Module, str], bool]]=None) -> torch.nn.Module: + +def sparsify_( + model: torch.nn.Module, + apply_tensor_subclass: Callable[[torch.Tensor], torch.Tensor], + filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None, +) -> torch.nn.Module: """Convert the weight of linear modules in the model with `apply_tensor_subclass` This function is essentially the same as quantize, put for sparsity subclasses. @@ -73,6 +77,6 @@ def filter_fn(module: nn.Module, fqn: str) -> bool: """ _replace_with_custom_fn_if_matches_filter( model, - apply_tensor_subclass, + apply_tensor_subclass, _is_linear if filter_fn is None else filter_fn, ) From a5a426e88cc183fa1369f39bb0c748747a79f4b2 Mon Sep 17 00:00:00 2001 From: HDCharles <39544797+HDCharles@users.noreply.github.com> Date: Thu, 26 Sep 2024 03:33:30 -0400 Subject: [PATCH 32/41] fixing some issues with our support for 70/405B models (#941) Summary: download and convert scripts needed to be updated alongside model.py config files Test Plan: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-70B/model.pth Reviewers: Subscribers: Tasks: Tags: --- scripts/convert_hf_checkpoint.py | 161 +++++++++++++++---------------- scripts/download.py | 2 +- torchao/_models/llama/model.py | 8 +- 3 files changed, 84 insertions(+), 87 deletions(-) diff --git a/scripts/convert_hf_checkpoint.py b/scripts/convert_hf_checkpoint.py index 7b0f76903c..3098c818bb 100644 --- a/scripts/convert_hf_checkpoint.py +++ b/scripts/convert_hf_checkpoint.py @@ -8,9 +8,10 @@ import json import re import shutil +import sys from pathlib import Path from typing import Optional - +from safetensors.torch import load_file as load_safetensors_file import torch from torchao._models.llama.model import ModelArgs @@ -24,63 +25,49 @@ def convert_hf_checkpoint( ) -> None: if model_name is None: model_name = checkpoint_dir.name - - # Llama 3 8B doesn't need conversion; instead, the original/consolidated.NN.pth files - # need to be copied into model.pth. - # Llama 3 70B can't be easily merged into one model.pth file, though, since names of the - # weights is state dict are the same in each consolidated.NN.pth file. Thus, it is not - # currently supported. - # Along this, we need to copy the original/tokenizer.model file to tokenizer.model.tiktoken - is_llama3 = "Llama-3" in model_name - if is_llama3: - # Check if we have multiple original/consolidated.NN.pth files and report error - # if we do for Llama 3. - original_dir = checkpoint_dir / "original" - pattern = re.compile(r"^consolidated\.\d{2}\.pth$") - bin_files = [bin for bin in original_dir.iterdir() if pattern.match(bin.name)] - if len(bin_files) > 1: - raise ValueError( - f"Multiple consolidated.NN.pth files found in {original_dir}. " - "Merging them into one model.pth file is not supported for Llama 3.") - - config = ModelArgs.from_name(model_name) print(f"Model config {config.__dict__}") # Load the json file containing weight mapping - if not is_llama3: - model_map_json = checkpoint_dir / "pytorch_model.bin.index.json" - - assert model_map_json.is_file() - - with open(model_map_json) as json_map: - bin_index = json.load(json_map) - - weight_map = { - "model.embed_tokens.weight": "tok_embeddings.weight", - "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight", - "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight", - "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight", - "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight", - 'model.layers.{}.self_attn.rotary_emb.inv_freq': None, - 'model.layers.{}.mlp.gate_proj.weight': 'layers.{}.feed_forward.w1.weight', - "model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight", - "model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight", - "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight", - "model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight", - "model.norm.weight": "norm.weight", - "lm_head.weight": "output.weight", - } - bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()} - else: - # There is no separate pytorch_model.bin.index.json file for llama3. - # Instead, we will just use all original/consolidated.NN.pth files. - # so, we use model.safetensors.index.json - weight_map = None - original_dir = checkpoint_dir / "original" - pattern = re.compile(r"^consolidated\.\d{2}\.pth$") - bin_files = {bin for bin in original_dir.iterdir() if pattern.match(bin.name)} - + model_map_json_safetensors = checkpoint_dir / 'model.safetensors.index.json' + model_map_json_pytorch = checkpoint_dir / "pytorch_model.bin.index.json" + model_map_json = None + + try: + assert model_map_json_safetensors.is_file() + model_map_json = model_map_json_safetensors + print(f"Found safetensors index at {model_map_json_safetensors}") + except AssertionError: + print(f"{model_map_json_safetensors} not found") + if model_map_json is None: + try: + assert model_map_json_pytorch.is_file() + model_map_json = model_map_json_pytorch + print(f"Found pytorch index at {model_map_json_pytorch}") + except AssertionError: + print(f"{model_map_json_pytorch} not found") + + if model_map_json is None: raise Exception("No model map found!") + + with open(model_map_json) as json_map: + bin_index = json.load(json_map) + + weight_map = { + "model.embed_tokens.weight": "tok_embeddings.weight", + "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight", + "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight", + "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight", + "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight", + 'model.layers.{}.self_attn.rotary_emb.inv_freq': None, + 'model.layers.{}.mlp.gate_proj.weight': 'layers.{}.feed_forward.w1.weight', + "model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight", + "model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight", + "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight", + "model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight", + "model.norm.weight": "norm.weight", + "lm_head.weight": "output.weight", + } + bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()} def permute(w, n_head): dim = config.dim @@ -92,40 +79,44 @@ def permute(w, n_head): merged_result = {} for file in sorted(bin_files): - state_dict = torch.load(str(file), map_location="cpu", mmap=True, weights_only=True) - merged_result.update(state_dict) + if "safetensors" in str(file): + state_dict = load_safetensors_file(str(file), device="cpu") + merged_result.update(state_dict) + else: + state_dict = torch.load(str(file), map_location="cpu", mmap=True, weights_only=True) + merged_result.update(state_dict) final_result = {} - if weight_map is not None: - for key, value in merged_result.items(): - if "layers" in key: - abstract_key = re.sub(r'(\d+)', '{}', key) - layer_num = re.search(r'\d+', key).group(0) - new_key = weight_map[abstract_key] - if new_key is None: - continue - new_key = new_key.format(layer_num) - else: - new_key = weight_map[key] - - final_result[new_key] = value - - for key in tuple(final_result.keys()): - if "wq" in key: - q = final_result[key] - k = final_result[key.replace("wq", "wk")] - v = final_result[key.replace("wq", "wv")] - q = permute(q, config.n_head) - k = permute(k, config.n_local_heads) - final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v]) - del final_result[key] - del final_result[key.replace("wq", "wk")] - del final_result[key.replace("wq", "wv")] - else: - final_result = merged_result + for key, value in merged_result.items(): + if "layers" in key: + abstract_key = re.sub(r'(\d+)', '{}', key) + layer_num = re.search(r'\d+', key).group(0) + new_key = weight_map[abstract_key] + if new_key is None: + continue + new_key = new_key.format(layer_num) + else: + new_key = weight_map[key] + + final_result[new_key] = value + + for key in tuple(final_result.keys()): + if "wq" in key: + q = final_result[key] + k = final_result[key.replace("wq", "wk")] + v = final_result[key.replace("wq", "wv")] + q = permute(q, config.n_head) + k = permute(k, config.n_local_heads) + final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v]) + del final_result[key] + del final_result[key.replace("wq", "wk")] + del final_result[key.replace("wq", "wv")] print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}") torch.save(final_result, checkpoint_dir / "model.pth") - if is_llama3: - original_dir = checkpoint_dir / "original" + if 'llama-3-' in model_name.lower() or 'llama-3.1-' in model_name.lower(): + if 'llama-3.1-405b' in model_name.lower(): + original_dir = checkpoint_dir / "original" / "mp16" + else: + original_dir = checkpoint_dir / "original" tokenizer_model = original_dir / "tokenizer.model" tokenizer_model_tiktoken = checkpoint_dir / "tokenizer.model" print(f"Copying {tokenizer_model} to {tokenizer_model_tiktoken}") diff --git a/scripts/download.py b/scripts/download.py index 3fc89e7126..571e03adb5 100644 --- a/scripts/download.py +++ b/scripts/download.py @@ -15,7 +15,7 @@ def hf_download(repo_id: Optional[str] = None, hf_token: Optional[str] = None) - from huggingface_hub import snapshot_download os.makedirs(f"checkpoints/{repo_id}", exist_ok=True) try: - snapshot_download(repo_id, local_dir=f"checkpoints/{repo_id}", local_dir_use_symlinks=False, token=hf_token, ignore_patterns="*.safetensors") + snapshot_download(repo_id, local_dir=f"checkpoints/{repo_id}", local_dir_use_symlinks=False, token=hf_token) except HTTPError as e: if e.response.status_code == 401: print("You need to pass a valid `--hf_token=...` to download private checkpoints.") diff --git a/torchao/_models/llama/model.py b/torchao/_models/llama/model.py index 92448b5990..de1f311979 100644 --- a/torchao/_models/llama/model.py +++ b/torchao/_models/llama/model.py @@ -72,7 +72,13 @@ def from_name(cls, name: str): "stories15M": dict(n_layer=6, n_head=6, dim=288), "stories110M": dict(n_layer=12, n_head=12, dim=768), "Llama-3-8B": dict(block_size=8192, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256, rope_base=500000), - "Llama-3.1-8B": dict(block_size=131072, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256, rope_base=500000, use_scaled_rope=True) + "Llama-3.1-8B": dict(block_size=131072, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256, rope_base=500000, use_scaled_rope=True), + "Llama-3.1-70B": dict(block_size=131072, n_layer=80, n_head=64, n_local_heads=8, dim=8192, intermediate_size=28672, vocab_size=128256, rope_base=500000, + use_scaled_rope=True + ), + "Llama-3.1-405B": dict(block_size=131072, n_layer=126, n_head=128, n_local_heads=8, dim=16384, intermediate_size=53248, vocab_size=128256, rope_base=500000, + use_scaled_rope=True + ), } # this is a model specific variable that controls whether index_put is used for the kv_cache update, From e7270f17f62ef5ea2770c5ede04b7a73db81def7 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Thu, 26 Sep 2024 23:42:41 +0800 Subject: [PATCH 33/41] Update INT8 mixed-precision training test to be less flaky (#950) --- test/prototype/test_quantized_training.py | 51 +++++++++-------------- 1 file changed, 19 insertions(+), 32 deletions(-) diff --git a/test/prototype/test_quantized_training.py b/test/prototype/test_quantized_training.py index bffff16fc1..b07ade0b54 100644 --- a/test/prototype/test_quantized_training.py +++ b/test/prototype/test_quantized_training.py @@ -161,44 +161,31 @@ def test_int8_weight_only_training(self, compile, device): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_int8_mixed_precision_training(self, compile, config): _reset() - bsize = 4 - embed_dim = 32 + bsize = 64 + embed_dim = 64 device = "cuda" - # only use 1 matmul shape to reduce triton autotune time - model_ref = nn.Sequential( - nn.Linear(embed_dim, embed_dim, bias=False), - nn.GELU(), - nn.Linear(embed_dim, embed_dim), - ).to(device) - model_int8mp = copy.deepcopy(model_ref) - quantize_(model_int8mp, int8_mixed_precision_training(config), set_inductor_config=False) + linear = nn.Linear(embed_dim, embed_dim).cuda() + linear_int8mp = copy.deepcopy(linear) + quantize_(linear_int8mp, int8_mixed_precision_training(config), set_inductor_config=False) if compile: - model_ref.compile() - model_int8mp.compile() + linear.compile() + linear_int8mp.compile() - optim_ref = torch.optim.AdamW(model_ref.parameters()) - optim_int8mp = torch.optim.AdamW(model_int8mp.parameters()) + inputs = torch.randn(bsize, embed_dim, device=device) + grad_outputs = torch.randn(bsize, embed_dim, device=device) - for i in range(5): - inputs = torch.randn(bsize, embed_dim, device=device) - labels = torch.randint(embed_dim, size=(bsize,), device=device) - loss_ref = F.cross_entropy(model_ref(inputs), labels) - loss_int8mp = F.cross_entropy(model_int8mp(inputs), labels) - - rel_error = abs(loss_int8mp.item() - loss_ref.item()) / abs(loss_ref.item()) - assert rel_error < 3e-3, (i, rel_error) - - loss_ref.backward() - optim_ref.step() - optim_ref.zero_grad() - - loss_int8mp.backward() - for p in model_int8mp.parameters(): - assert p.grad is not None - optim_int8mp.step() - optim_int8mp.zero_grad() + inputs_ref, outputs_ref = self._forward_and_backward(linear, inputs, grad_outputs) + inputs_int8mp, outputs_int8mp = self._forward_and_backward(linear_int8mp, inputs, grad_outputs) + + def snr(ref, actual): + error = actual - ref + return 20 * torch.log10(ref.norm() / error.norm()) + + assert snr(outputs_ref, outputs_int8mp) > 20 + assert snr(inputs_ref.grad, inputs_int8mp.grad) > 20 + assert snr(linear.weight.grad, linear_int8mp.weight.grad) > 20 _FSDP_WORLD_SIZE = 2 From 352685cd19869b0e6d3571220033dca6211333c9 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Thu, 26 Sep 2024 10:53:38 -0700 Subject: [PATCH 34/41] Add executorch parallel Differential Revision: D62711909 Pull Request resolved: https://github.com/pytorch/ao/pull/953 --- torchao/experimental/CMakeLists.txt | 6 +++- torchao/experimental/Utils.cmake | 8 ++++++ torchao/experimental/build_torchao_ops.sh | 4 +-- .../kernels/cpu/aarch64/CMakeLists.txt | 5 ++++ .../experimental/ops/linear/CMakeLists.txt | 5 ++++ ...bit_activation_groupwise_lowbit_weight.cpp | 2 +- .../linear/linear_a8wxdq_op/CMakeLists.txt | 6 ++-- .../ops/parallel-executorch-impl.h | 28 +++++++++++++++++++ torchao/experimental/ops/parallel.h | 3 +- 9 files changed, 58 insertions(+), 9 deletions(-) create mode 100644 torchao/experimental/ops/parallel-executorch-impl.h diff --git a/torchao/experimental/CMakeLists.txt b/torchao/experimental/CMakeLists.txt index 198e9ebd44..db2054c3a8 100644 --- a/torchao/experimental/CMakeLists.txt +++ b/torchao/experimental/CMakeLists.txt @@ -24,11 +24,15 @@ if(NOT TORCHAO_INCLUDE_DIRS) set(TORCHAO_INCLUDE_DIRS ${TORCHAO_ROOT}/../..) endif() +if (NOT TORCHAO_OP_TARGET) + message(FATAL_ERROR "TORCHAO_OP_TARGET is not set. Set it to ATEN or EXECUTORCH.") +endif() + if (NOT TORCHAO_PARALLEL_BACKEND) if (TORCHAO_OP_TARGET STREQUAL "ATEN") set(TORCHAO_PARALLEL_BACKEND "ATEN_OPENMP") elseif(TORCHAO_OP_TARGET STREQUAL "EXECUTORCH") - set(TORCHAO_PARALLEL_BACKEND "PTHREADPOOL") + set(TORCHAO_PARALLEL_BACKEND "EXECUTORCH") else() message(TORCHAO_PARALLEL_BACKEND "TORCHAO_PARALLEL_BACKEND is not set. Please set it directly or set TORCHAO_OP_TARGET to get a default.") endif() diff --git a/torchao/experimental/Utils.cmake b/torchao/experimental/Utils.cmake index 592f9366fc..d6e6254de7 100644 --- a/torchao/experimental/Utils.cmake +++ b/torchao/experimental/Utils.cmake @@ -23,6 +23,14 @@ function(target_link_torchao_parallel_backend target_name torchao_parallel_backe target_compile_definitions(${target_name} PRIVATE TORCHAO_PARALLEL_ATEN=1 AT_PARALLEL_OPENMP=1 INTRA_OP_PARALLEL=1) target_link_libraries(${target_name} PRIVATE ${TORCH_INSTALL_PREFIX}/lib/libomp${CMAKE_SHARED_LIBRARY_SUFFIX}) + elseif(TORCHAO_PARALLEL_BACKEND_TOUPPER STREQUAL "EXECUTORCH") + message(STATUS "Building with TORCHAO_PARALLEL_BACKEND=TORCHAO_PARALLEL_EXECUTORCH") + message(STATUS "EXECUTORCH_INCLUDE_DIRS: ${EXECUTORCH_INCLUDE_DIRS}") + message(STATUS "EXECUTORCH_LIBRARIES: ${EXECUTORCH_LIBRARIES}") + target_include_directories(${target_name} PRIVATE "${EXECUTORCH_INCLUDE_DIRS}") + target_link_libraries(${target_name} PRIVATE "${EXECUTORCH_LIBRARIES}") + target_compile_definitions(${target_name} PRIVATE TORCHAO_PARALLEL_EXECUTORCH=1) + elseif(TORCHAO_PARALLEL_BACKEND_TOUPPER STREQUAL "OPENMP") message(STATUS "Building with TORCHAO_PARALLEL_BACKEND=OPENMP. You must set the CMake variable OpenMP_ROOT to the OMP library location before compiling. Do not use this option if Torch was built with OPENMP; use ATEN_OPENMP instead.") find_package(OpenMP REQUIRED) diff --git a/torchao/experimental/build_torchao_ops.sh b/torchao/experimental/build_torchao_ops.sh index de6d8e17d8..2cb7201588 100644 --- a/torchao/experimental/build_torchao_ops.sh +++ b/torchao/experimental/build_torchao_ops.sh @@ -11,8 +11,8 @@ export CMAKE_OUT=/tmp/cmake-out/torchao cmake -DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH} \ -DCMAKE_INSTALL_PREFIX=${CMAKE_OUT} \ -DTORCHAO_OP_TARGET="$1" \ - -DEXECUTORCH_LIBRARIES=${EXECUTORCH_LIBRARIES} \ - -DEXECUTORCH_INCLUDE_DIRS=${EXECUTORCH_INCLUDE_DIRS} \ + -DEXECUTORCH_LIBRARIES="${EXECUTORCH_LIBRARIES}" \ + -DEXECUTORCH_INCLUDE_DIRS="${EXECUTORCH_INCLUDE_DIRS}" \ -S . \ -B ${CMAKE_OUT} cmake --build ${CMAKE_OUT} --target install --config Release diff --git a/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt b/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt index ec497a1871..4f36945f8a 100644 --- a/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt +++ b/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt @@ -13,3 +13,8 @@ if (CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") ${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp ) endif() + +install( + TARGETS torchao_kernels_aarch64 + DESTINATION lib +) diff --git a/torchao/experimental/ops/linear/CMakeLists.txt b/torchao/experimental/ops/linear/CMakeLists.txt index 2f7b91bbf9..087dfeb383 100644 --- a/torchao/experimental/ops/linear/CMakeLists.txt +++ b/torchao/experimental/ops/linear/CMakeLists.txt @@ -10,3 +10,8 @@ include(${TORCHAO_ROOT}/Utils.cmake) add_library(torchao_ops_linear_${TORCHAO_PARALLEL_BACKEND} STATIC channelwise_8bit_activation_groupwise_lowbit_weight.cpp) target_link_torchao_parallel_backend(torchao_ops_linear_${TORCHAO_PARALLEL_BACKEND} "${TORCHAO_PARALLEL_BACKEND}") + +install( + TARGETS torchao_ops_linear_${TORCHAO_PARALLEL_BACKEND} + DESTINATION lib +) diff --git a/torchao/experimental/ops/linear/channelwise_8bit_activation_groupwise_lowbit_weight.cpp b/torchao/experimental/ops/linear/channelwise_8bit_activation_groupwise_lowbit_weight.cpp index ae611d3ccc..02557b61fa 100644 --- a/torchao/experimental/ops/linear/channelwise_8bit_activation_groupwise_lowbit_weight.cpp +++ b/torchao/experimental/ops/linear/channelwise_8bit_activation_groupwise_lowbit_weight.cpp @@ -93,7 +93,7 @@ LinearTilingParams get_default_linear_tiling_params( LinearTilingParams tiling_params; auto num_threads = torchao::get_num_threads(); - assert(num_threads >= 1); + TORCHAO_CHECK(num_threads >= 1, "num_threads must be >= 1"); tiling_params.mc_by_mr = 1; int mc = tiling_params.mc_by_mr * ukernel_config.mr; diff --git a/torchao/experimental/ops/linear/linear_a8wxdq_op/CMakeLists.txt b/torchao/experimental/ops/linear/linear_a8wxdq_op/CMakeLists.txt index f69d884cd8..31a8320108 100644 --- a/torchao/experimental/ops/linear/linear_a8wxdq_op/CMakeLists.txt +++ b/torchao/experimental/ops/linear/linear_a8wxdq_op/CMakeLists.txt @@ -19,7 +19,7 @@ if(TORCHAO_OP_TARGET STREQUAL "ATEN") target_compile_definitions(linear_a8wxdq_${TORCHAO_OP_TARGET} PRIVATE USE_ATEN=1) elseif(TORCHAO_OP_TARGET STREQUAL "EXECUTORCH") message(STATUS "Building with TORCHAO_OP_TARGET=EXECUTORCH") - add_library(linear_a8wxdq_${TORCHAO_OP_TARGET} SHARED + add_library(linear_a8wxdq_${TORCHAO_OP_TARGET} STATIC linear_a8wxdq_executorch/w2s.cpp linear_a8wxdq_executorch/w2sz.cpp linear_a8wxdq_executorch/w3s.cpp @@ -29,9 +29,9 @@ elseif(TORCHAO_OP_TARGET STREQUAL "EXECUTORCH") linear_a8wxdq_executorch/w5s.cpp linear_a8wxdq_executorch/w5sz.cpp ) - target_include_directories(linear_a8wxdq_${TORCHAO_OP_TARGET} PRIVATE ${EXECUTORCH_INCLUDE_DIRS}) + target_include_directories(linear_a8wxdq_${TORCHAO_OP_TARGET} PRIVATE "${EXECUTORCH_INCLUDE_DIRS}") target_compile_definitions(linear_a8wxdq_${TORCHAO_OP_TARGET} PRIVATE USE_EXECUTORCH=1) - target_link_libraries(linear_a8wxdq_${TORCHAO_OP_TARGET} PRIVATE ${EXECUTORCH_LIBRARIES}) + target_link_libraries(linear_a8wxdq_${TORCHAO_OP_TARGET} PRIVATE "${EXECUTORCH_LIBRARIES}") target_link_libraries(linear_a8wxdq_${TORCHAO_OP_TARGET} PRIVATE torchao_kernels_aarch64) target_link_libraries(linear_a8wxdq_${TORCHAO_OP_TARGET} PRIVATE torchao_ops_linear_${TORCHAO_PARALLEL_BACKEND}) else() diff --git a/torchao/experimental/ops/parallel-executorch-impl.h b/torchao/experimental/ops/parallel-executorch-impl.h new file mode 100644 index 0000000000..233f7250d4 --- /dev/null +++ b/torchao/experimental/ops/parallel-executorch-impl.h @@ -0,0 +1,28 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include + +template +void torchao::parallel_1d(const int64_t begin, const int64_t end, const F& f) { + torch::executorch::threadpool::get_threadpool()->run( + [&](size_t i) { + int64_t idx = begin + i; + f(idx); + }, + end - begin); +} + +inline void torchao::set_num_threads(int num_threads) { + torch::executorch::threadpool::get_threadpool()->_unsafe_reset_threadpool( + num_threads); +} + +inline int torchao::get_num_threads() { + return torch::executorch::threadpool::get_threadpool()->get_thread_count(); +} diff --git a/torchao/experimental/ops/parallel.h b/torchao/experimental/ops/parallel.h index e3949b8551..5372c5a2dd 100644 --- a/torchao/experimental/ops/parallel.h +++ b/torchao/experimental/ops/parallel.h @@ -34,8 +34,7 @@ int get_num_threads(); #ifdef TORCHAO_PARALLEL_EXECUTORCH #pragma message( \ "TORCHAO_PARALLEL_EXECUTORCH is set. Using ExecuTorch parallel backend.") - -#error "TORCHAO_PARALLEL_EXECUTORCH is not implemented yet" +#include #else #ifdef TORCHAO_PARALLEL_PTHREADPOOL From 37e1479a70f63a397ec6ee6e94e22398263057bd Mon Sep 17 00:00:00 2001 From: Wei Feng Date: Thu, 26 Sep 2024 11:48:46 -0700 Subject: [PATCH 35/41] test CI Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/float8/test_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/float8/test_base.py b/test/float8/test_base.py index 60d8ffa57d..ba2d490335 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -600,7 +600,7 @@ def test_small_amax_float16(self, float8_dtype): def test_dynamic_scale_parity(self, dtype: torch.dtype): scaling_type_weight = ScalingType.DYNAMIC torch.manual_seed(0) - hp_tensor = torch.randn(32, 32, device="cuda", dtype=dtype) + hp_tensor = torch.randn(16, 16, device="cuda", dtype=dtype) float8_config = Float8LinearConfig( cast_config_weight=CastConfig(scaling_type=scaling_type_weight), ) From 2efde4925c3c92a2db035903f2a315dd32966441 Mon Sep 17 00:00:00 2001 From: Wei Feng Date: Thu, 26 Sep 2024 12:13:37 -0700 Subject: [PATCH 36/41] better comment on why upcasting Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/float8/test_base.py | 14 ++++++++++---- torchao/float8/float8_tensor.py | 4 +++- torchao/float8/float8_utils.py | 3 ++- torchao/float8/fsdp_utils.py | 3 +++ 4 files changed, 18 insertions(+), 6 deletions(-) diff --git a/test/float8/test_base.py b/test/float8/test_base.py index ba2d490335..bbe93b6969 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -600,25 +600,31 @@ def test_small_amax_float16(self, float8_dtype): def test_dynamic_scale_parity(self, dtype: torch.dtype): scaling_type_weight = ScalingType.DYNAMIC torch.manual_seed(0) - hp_tensor = torch.randn(16, 16, device="cuda", dtype=dtype) + hp_tensor1 = torch.randn(16, 16, device="cuda", dtype=dtype) + hp_tensor2 = hp_tensor1.detach().clone() float8_config = Float8LinearConfig( cast_config_weight=CastConfig(scaling_type=scaling_type_weight), ) float8_eager = hp_tensor_to_float8_dynamic( - hp_tensor, + hp_tensor1, torch.float8_e4m3fn, float8_config, gemm_input_role=GemmInputRole.WEIGHT, ) + torch._dynamo.reset() float8_compile = torch.compile(hp_tensor_to_float8_dynamic)( - hp_tensor, + hp_tensor2, torch.float8_e4m3fn, float8_config, gemm_input_role=GemmInputRole.WEIGHT, ) torch.set_printoptions(precision=10, threshold=2000) assert torch.equal(float8_eager._scale, float8_compile._scale) - assert torch.equal(float8_eager._data, float8_compile._data), f"{float8_eager._data=} vs {float8_compile._data=}" + torch.testing.assert_close( + float8_eager.to_original_precision(), + float8_compile.to_original_precision(), + msg=f"{float8_eager.to_original_precision()=} vs {float8_compile.to_original_precision()=}", + ) class TestFloat8LinearUtils(unittest.TestCase): diff --git a/torchao/float8/float8_tensor.py b/torchao/float8/float8_tensor.py index a584166107..6f9595bc24 100644 --- a/torchao/float8/float8_tensor.py +++ b/torchao/float8/float8_tensor.py @@ -163,7 +163,9 @@ def forward( DTensor Invariant: DTensor must always be the outer most tensor subclass """ - # scale is float32 thus upcasting tensor to match + # Note: when the line below is compiled with `torch.compile`, `tensor` is automatically + # upcasted to `float32` to multiply with the scale + # In order to match numerics between eager and compile, we upcast manually here. tensor_scaled = tensor.to(torch.float32) * scale bits_fp8 = to_fp8_saturated(tensor_scaled, float8_dtype) diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index d8ad315f16..00a53a4dd4 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -42,7 +42,8 @@ def amax_to_scale( float8_dtype: The float8 dtype. orig_dtype: The original dtype of the tensor. """ - # _scaled_mm requires float32 scale + # torch.compile and eager show different numerics for 1.0 / float32, + # upcast to float64 to ensure same numeric between compile and eager amax = amax.to(torch.float64) if float8_dtype in FP8_TYPES: res = torch.finfo(float8_dtype).max / torch.clamp(amax, min=EPS) diff --git a/torchao/float8/fsdp_utils.py b/torchao/float8/fsdp_utils.py index d3c0f73c6c..eaaf2063b7 100644 --- a/torchao/float8/fsdp_utils.py +++ b/torchao/float8/fsdp_utils.py @@ -59,6 +59,9 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None: return # inf-norm is equivalent to max(abs(w)) + # keep consistent with float8_utils.amax_to_scale + # torch.compile and eager show different numerics for 1.0 / float32, + # upcast to float64 to ensure same numeric between compile and eager max_weights = torch._foreach_norm(weights, ord=math.inf, dtype=torch.float64) # Partial amax_tensor = torch.stack(max_weights) # Partial # clamp is dispatched through DTensor From 8c04f4fa0e2b5ce5bc188083abd631468a643758 Mon Sep 17 00:00:00 2001 From: Wei Feng Date: Thu, 26 Sep 2024 12:25:11 -0700 Subject: [PATCH 37/41] control seed Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/float8/test_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/float8/test_base.py b/test/float8/test_base.py index bbe93b6969..369e0206b7 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -599,7 +599,7 @@ def test_small_amax_float16(self, float8_dtype): ) def test_dynamic_scale_parity(self, dtype: torch.dtype): scaling_type_weight = ScalingType.DYNAMIC - torch.manual_seed(0) + torch.manual_seed(42) hp_tensor1 = torch.randn(16, 16, device="cuda", dtype=dtype) hp_tensor2 = hp_tensor1.detach().clone() float8_config = Float8LinearConfig( From 04b229b3710964d190b57d0edc8c55ef8369b4cf Mon Sep 17 00:00:00 2001 From: Wei Feng Date: Thu, 26 Sep 2024 13:37:00 -0700 Subject: [PATCH 38/41] move unit test to test_compile Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/float8/test_base.py | 44 ---------------------------------- test/float8/test_compile.py | 47 +++++++++++++++++++++++++++++++++++-- 2 files changed, 45 insertions(+), 46 deletions(-) diff --git a/test/float8/test_base.py b/test/float8/test_base.py index 369e0206b7..8fb3921f67 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -15,9 +15,6 @@ import torch import torch.nn as nn -from torchao.float8.float8_scaling_utils import ( - hp_tensor_to_float8_dynamic, -) from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 @@ -584,47 +581,6 @@ def test_small_amax_float16(self, float8_dtype): x = torch.tensor([target_amax], dtype=torch.float16, device="cuda") scale = tensor_to_scale(x, float8_dtype) assert not torch.any(torch.isinf(scale)) - - @unittest.skipIf( - not is_cuda_8_9, - "CUDA not available", - ) - @pytest.mark.parametrize( - "dtype", - [ - torch.float32, - torch.bfloat16, - torch.float16, - ], - ) - def test_dynamic_scale_parity(self, dtype: torch.dtype): - scaling_type_weight = ScalingType.DYNAMIC - torch.manual_seed(42) - hp_tensor1 = torch.randn(16, 16, device="cuda", dtype=dtype) - hp_tensor2 = hp_tensor1.detach().clone() - float8_config = Float8LinearConfig( - cast_config_weight=CastConfig(scaling_type=scaling_type_weight), - ) - float8_eager = hp_tensor_to_float8_dynamic( - hp_tensor1, - torch.float8_e4m3fn, - float8_config, - gemm_input_role=GemmInputRole.WEIGHT, - ) - torch._dynamo.reset() - float8_compile = torch.compile(hp_tensor_to_float8_dynamic)( - hp_tensor2, - torch.float8_e4m3fn, - float8_config, - gemm_input_role=GemmInputRole.WEIGHT, - ) - torch.set_printoptions(precision=10, threshold=2000) - assert torch.equal(float8_eager._scale, float8_compile._scale) - torch.testing.assert_close( - float8_eager.to_original_precision(), - float8_compile.to_original_precision(), - msg=f"{float8_eager.to_original_precision()=} vs {float8_compile.to_original_precision()=}", - ) class TestFloat8LinearUtils(unittest.TestCase): diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index bae62bf77d..4c7e9dccb1 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -25,8 +25,14 @@ get_float8_layers, sync_float8_amax_and_scale_history, ) -from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_delayed -from torchao.float8.float8_tensor import LinearMMConfig +from torchao.float8.float8_scaling_utils import ( + hp_tensor_to_float8_delayed, + hp_tensor_to_float8_dynamic, +) +from torchao.float8.float8_tensor import ( + LinearMMConfig, + GemmInputRole, +) from torchao.float8.float8_utils import e4m3_dtype from torch._dynamo.test_case import TestCase as DynamoTestCase @@ -353,5 +359,42 @@ def test_sync_amax_func_cuda_graph_success(): assert "skipping cudagraphs due to mutaton on input" not in stderr[0] +@unittest.skipIf( + not is_cuda_8_9, + "CUDA not available", + ) +@pytest.mark.parametrize( + "dtype", + [ + torch.float32, + torch.bfloat16, + torch.float16, + ], +) +def test_dynamic_scale_numeric_parity(dtype: torch.dtype): + scaling_type_weight = ScalingType.DYNAMIC + torch.manual_seed(42) + hp_tensor1 = torch.randn(16, 16, device="cuda", dtype=dtype) + hp_tensor2 = hp_tensor1.detach().clone() + float8_config = Float8LinearConfig( + cast_config_weight=CastConfig(scaling_type=scaling_type_weight), + ) + float8_eager = hp_tensor_to_float8_dynamic( + hp_tensor1, + torch.float8_e4m3fn, + float8_config, + gemm_input_role=GemmInputRole.WEIGHT, + ) + torch._dynamo.reset() + float8_compile = torch.compile(hp_tensor_to_float8_dynamic)( + hp_tensor2, + torch.float8_e4m3fn, + float8_config, + gemm_input_role=GemmInputRole.WEIGHT, + ) + assert torch.equal(float8_eager._scale, float8_compile._scale) + assert torch.equal(float8_eager._data, float_compile._data) + + if __name__ == "__main__": pytest.main([__file__]) From 8b7c2ef517d7e5f1e001d543682e85745b160687 Mon Sep 17 00:00:00 2001 From: Wei Feng Date: Thu, 26 Sep 2024 13:44:44 -0700 Subject: [PATCH 39/41] fix typo Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/float8/test_compile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index 4c7e9dccb1..2af4875d9d 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -393,7 +393,7 @@ def test_dynamic_scale_numeric_parity(dtype: torch.dtype): gemm_input_role=GemmInputRole.WEIGHT, ) assert torch.equal(float8_eager._scale, float8_compile._scale) - assert torch.equal(float8_eager._data, float_compile._data) + assert torch.equal(float8_eager._data, float8_compile._data) if __name__ == "__main__": From 9346afd341bab0fdf56cf023ebc98c1fd3fe9b12 Mon Sep 17 00:00:00 2001 From: Wei Feng Date: Fri, 27 Sep 2024 14:29:09 -0700 Subject: [PATCH 40/41] float64 upcasting after allreduce Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchao/float8/fsdp_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchao/float8/fsdp_utils.py b/torchao/float8/fsdp_utils.py index eaaf2063b7..4bb22e1111 100644 --- a/torchao/float8/fsdp_utils.py +++ b/torchao/float8/fsdp_utils.py @@ -62,12 +62,12 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None: # keep consistent with float8_utils.amax_to_scale # torch.compile and eager show different numerics for 1.0 / float32, # upcast to float64 to ensure same numeric between compile and eager - max_weights = torch._foreach_norm(weights, ord=math.inf, dtype=torch.float64) # Partial + max_weights = torch._foreach_norm(weights, ord=math.inf) # Partial amax_tensor = torch.stack(max_weights) # Partial # clamp is dispatched through DTensor # it will issue a single all-reduce amax_tensor = torch.clamp(amax_tensor, EPS) # Replicate - scale_tensor = torch.finfo(torch.float8_e4m3fn).max / amax_tensor # Replicate + scale_tensor = torch.finfo(torch.float8_e4m3fn).max / amax_tensor.to(torch.float64) # Replicate if amax_tensor.dtype is torch.float16: scale_tensor = torch.clamp(scale_tensor, max=torch.finfo(torch.float16).max) local_scale_tensor = scale_tensor.to_local().to(torch.float32) From 3d0da208d16a536c5e4b72031ebdb23f64840fde Mon Sep 17 00:00:00 2001 From: Wei Feng Date: Mon, 30 Sep 2024 15:25:24 -0700 Subject: [PATCH 41/41] use LinearMMConfig Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/float8/test_compile.py | 28 ++++++++++++++++++++++++++-- torchao/float8/fsdp_utils.py | 12 +++++++----- 2 files changed, 33 insertions(+), 7 deletions(-) diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index 2af4875d9d..5106bd7780 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -32,6 +32,7 @@ from torchao.float8.float8_tensor import ( LinearMMConfig, GemmInputRole, + ScaledMMConfig, ) from torchao.float8.float8_utils import e4m3_dtype @@ -379,17 +380,40 @@ def test_dynamic_scale_numeric_parity(dtype: torch.dtype): float8_config = Float8LinearConfig( cast_config_weight=CastConfig(scaling_type=scaling_type_weight), ) + linear_mm_config = LinearMMConfig( + # output + ScaledMMConfig( + False, + float8_config.gemm_config_output.use_fast_accum, + False, + float8_config.pad_inner_dim, + ), + # grad_input + ScaledMMConfig( + False, + float8_config.gemm_config_grad_input.use_fast_accum, + False, + float8_config.pad_inner_dim, + ), + # grad_weight + ScaledMMConfig( + False, + float8_config.gemm_config_grad_weight.use_fast_accum, + False, + float8_config.pad_inner_dim, + ), + ) float8_eager = hp_tensor_to_float8_dynamic( hp_tensor1, torch.float8_e4m3fn, - float8_config, + linear_mm_config, gemm_input_role=GemmInputRole.WEIGHT, ) torch._dynamo.reset() float8_compile = torch.compile(hp_tensor_to_float8_dynamic)( hp_tensor2, torch.float8_e4m3fn, - float8_config, + linear_mm_config, gemm_input_role=GemmInputRole.WEIGHT, ) assert torch.equal(float8_eager._scale, float8_compile._scale) diff --git a/torchao/float8/fsdp_utils.py b/torchao/float8/fsdp_utils.py index 4bb22e1111..201e9fdfed 100644 --- a/torchao/float8/fsdp_utils.py +++ b/torchao/float8/fsdp_utils.py @@ -59,16 +59,18 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None: return # inf-norm is equivalent to max(abs(w)) - # keep consistent with float8_utils.amax_to_scale - # torch.compile and eager show different numerics for 1.0 / float32, - # upcast to float64 to ensure same numeric between compile and eager max_weights = torch._foreach_norm(weights, ord=math.inf) # Partial amax_tensor = torch.stack(max_weights) # Partial # clamp is dispatched through DTensor # it will issue a single all-reduce amax_tensor = torch.clamp(amax_tensor, EPS) # Replicate - scale_tensor = torch.finfo(torch.float8_e4m3fn).max / amax_tensor.to(torch.float64) # Replicate - if amax_tensor.dtype is torch.float16: + # keep consistent with float8_utils.amax_to_scale + # torch.compile and eager show different numerics for 1.0 / float32, + # upcast to float64 to ensure same numeric between compile and eager + origin_dtype = amax_tensor.dtype + amax_tensor = amax_tensor.to(torch.float64) + scale_tensor = torch.finfo(torch.float8_e4m3fn).max / amax_tensor # Replicate + if origin_dtype is torch.float16: scale_tensor = torch.clamp(scale_tensor, max=torch.finfo(torch.float16).max) local_scale_tensor = scale_tensor.to_local().to(torch.float32) for i, float8_linear in enumerate(float8_linears):