From 753653c7081aa8c043688afaf44433094bc1388a Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 27 Sep 2024 11:08:34 -0700 Subject: [PATCH] [wip] make scaling configurable by gemm-argument Summary: My brain hurts from so many long identifiers... Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 4b97519d234404f72ed56f2608417ff9ca43edf9 ghstack-comment-id: 2372563439 Pull Request resolved: https://github.com/pytorch/ao/pull/940 --- benchmarks/float8/profile_linear_float8.py | 92 ++++++----- test/float8/test_base.py | 129 ++++++++------- test/float8/test_compile.py | 176 ++++++++++---------- torchao/float8/config.py | 71 ++++++-- torchao/float8/float8_linear.py | 182 ++++++++++++--------- torchao/float8/float8_ops.py | 14 ++ torchao/float8/float8_scaling_utils.py | 15 ++ torchao/float8/float8_utils.py | 60 ++++++- torchao/testing/float8/test_utils.py | 131 +++++++++++++++ 9 files changed, 595 insertions(+), 275 deletions(-) create mode 100644 torchao/testing/float8/test_utils.py diff --git a/benchmarks/float8/profile_linear_float8.py b/benchmarks/float8/profile_linear_float8.py index 6afefa009..912f4a1c1 100644 --- a/benchmarks/float8/profile_linear_float8.py +++ b/benchmarks/float8/profile_linear_float8.py @@ -33,6 +33,10 @@ linear_requires_sync, sync_float8_amax_and_scale_history, ) +from torchao.testing.float8.test_utils import ( + scaling_granularities_by_gemm_lcw_recipe, + get_test_float8_linear_config, +) from torch.profiler import profile, ProfilerActivity, record_function from utils import ( kernel_name_to_category, @@ -258,6 +262,8 @@ def main( scaling_type_weight: str = "dynamic", scaling_type_grad_output: str = "dynamic", scaling_granularity: str = "tensorwise", + # TODO(future PR): clean up the override, it's confusing + recipe_override: Optional[str] = None, model_type: str = "linear", dtype_filter: str = "both", add_inductor_metadata_to_trace: bool = True, @@ -271,45 +277,57 @@ def main( scaling_type_grad_output = ScalingType(scaling_type_grad_output) scaling_granularity = ScalingGranularity(scaling_granularity) - if scaling_type_input is ScalingType.STATIC: - cast_config_input=CastConfig( - scaling_type=scaling_type_input, - static_scale=torch.tensor([1.0], device="cuda"), - scaling_granularity=scaling_granularity, - ) - else: - cast_config_input=CastConfig( - scaling_type=scaling_type_input, - scaling_granularity=scaling_granularity, - ) - if scaling_type_weight is ScalingType.STATIC: - cast_config_weight=CastConfig( - scaling_type=scaling_type_weight, - static_scale=torch.tensor([1.0], device="cuda"), - scaling_granularity=scaling_granularity, - ) - else: - cast_config_weight=CastConfig( - scaling_type=scaling_type_weight, - scaling_granularity=scaling_granularity, - ) - if scaling_type_grad_output is ScalingType.STATIC: - cast_config_grad_output=CastConfig( - scaling_type=scaling_type_grad_output, - static_scale=torch.tensor([1.0], device="cuda"), - scaling_granularity=scaling_granularity, - ) - else: - cast_config_grad_output=CastConfig( - scaling_type=scaling_type_grad_output, - scaling_granularity=scaling_granularity, + if recipe_override is None: + + if scaling_type_input is ScalingType.STATIC: + cast_config_input=CastConfig( + scaling_type=scaling_type_input, + static_scale=torch.tensor([1.0], device="cuda"), + scaling_granularity=scaling_granularity, + ) + else: + cast_config_input=CastConfig( + scaling_type=scaling_type_input, + scaling_granularity=scaling_granularity, + ) + if scaling_type_weight is ScalingType.STATIC: + cast_config_weight=CastConfig( + scaling_type=scaling_type_weight, + static_scale=torch.tensor([1.0], device="cuda"), + scaling_granularity=scaling_granularity, + ) + else: + cast_config_weight=CastConfig( + scaling_type=scaling_type_weight, + scaling_granularity=scaling_granularity, + ) + if scaling_type_grad_output is ScalingType.STATIC: + cast_config_grad_output=CastConfig( + scaling_type=scaling_type_grad_output, + static_scale=torch.tensor([1.0], device="cuda"), + scaling_granularity=scaling_granularity, + ) + else: + cast_config_grad_output=CastConfig( + scaling_type=scaling_type_grad_output, + scaling_granularity=scaling_granularity, + ) + + config = Float8LinearConfig( + cast_config_input=cast_config_input, + cast_config_weight=cast_config_weight, + cast_config_grad_output=cast_config_grad_output, ) - config = Float8LinearConfig( - cast_config_input=cast_config_input, - cast_config_weight=cast_config_weight, - cast_config_grad_output=cast_config_grad_output, - ) + elif recipe_override == "lcw": + scaling_granularities_by_gemm = scaling_granularities_by_gemm_lcw_recipe + config = get_test_float8_linear_config( + scaling_type_input, + scaling_type_weight, + scaling_type_grad_output, + scaling_granularities_by_gemm, + False, # emulate + ) scaling_repr = "_".join( [ diff --git a/test/float8/test_base.py b/test/float8/test_base.py index f0e0ac0a9..9967de69b 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -8,6 +8,7 @@ import itertools import random import re +from typing import List, Tuple import unittest import warnings @@ -51,6 +52,10 @@ FP8_TYPES, tensor_to_scale, ) +from torchao.testing.float8.test_utils import ( + scaling_granularities_by_gemm, + get_test_float8_linear_config, +) random.seed(0) torch.manual_seed(0) @@ -59,6 +64,8 @@ is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) is_cuda_9_0 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) + + def bitwise_identical(a: Float8Tensor, b: Float8Tensor) -> bool: assert torch.all(a._data == b._data).item(), "scales are not identical" assert torch.all(a._data == b._data).item(), "data is not identical" @@ -211,31 +218,52 @@ def test_axiswise_reshape(self): a_fp8_d2_r2 = a_fp8_d2.reshape(3, -1) @pytest.mark.parametrize("a_shape", [(16, 32), (2, 16, 32), (1, 2, 16, 32)]) + @pytest.mark.parametrize( + "a_granularity,b_granularity", + [ + (ScalingGranularity.AXISWISE, ScalingGranularity.AXISWISE), + (ScalingGranularity.AXISWISE, ScalingGranularity.TENSORWISE), + (ScalingGranularity.TENSORWISE, ScalingGranularity.AXISWISE), + ] + ) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") @unittest.skipIf(not is_cuda_9_0, "Requires CUDA capability >= 9.0") - def test_axiswise_gemm(self, a_shape): + def test_axiswise_gemm(self, a_shape, a_granularity, b_granularity): a = torch.randn(*a_shape, dtype=torch.bfloat16, device="cuda") b = torch.randn(64, 32, dtype=torch.bfloat16, device="cuda") linear_mm_config = LinearMMConfig() + if a_granularity is ScalingGranularity.AXISWISE: + a_axiswise_dim = -1 + else: + assert a_granularity is ScalingGranularity.TENSORWISE + a_axiswise_dim = None a_fp8 = hp_tensor_to_float8_dynamic( a, e4m3_dtype, linear_mm_config, gemm_input_role=GemmInputRole.INPUT, - scaling_granularity=ScalingGranularity.AXISWISE, - axiswise_dim=-1, + scaling_granularity=a_granularity, + axiswise_dim=a_axiswise_dim, ) a_fp8 = a_fp8.reshape(-1, a_shape[-1]) + + b_axiswise_dim = 1 if b_granularity is ScalingGranularity.AXISWISE else None + if b_granularity is ScalingGranularity.AXISWISE: + b_axiswise_dim = 1 # will be transposed + else: + assert b_granularity is ScalingGranularity.TENSORWISE + b_axiswise_dim = None b_fp8 = hp_tensor_to_float8_dynamic( b, e4m3_dtype, linear_mm_config, gemm_input_role=GemmInputRole.WEIGHT, - scaling_granularity=ScalingGranularity.AXISWISE, - axiswise_dim=1, # will be transposed + scaling_granularity=b_granularity, + axiswise_dim=b_axiswise_dim, ) + c_fp8_compute = torch.mm(a_fp8, b_fp8.t()) a = a.reshape(-1, a_shape[-1]) c_ref = torch.mm(a, b.t()) @@ -316,26 +344,33 @@ def _test_linear_impl( # verify initialization flags got updated assert m_fp8.is_amax_initialized, "Amax was not properly initialized" - @pytest.mark.parametrize("emulate", [True, False] if is_cuda_8_9 else [True]) - @pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]) + # @pytest.mark.parametrize("emulate", [True, False] if is_cuda_8_9 else [True]) + @pytest.mark.parametrize("emulate", [False] if is_cuda_8_9 else [True]) + # @pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]) + @pytest.mark.parametrize("x_shape", [(16, 16),]) @pytest.mark.parametrize( "scaling_type_input", - [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] + # [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] + [ScalingType.DYNAMIC] ) @pytest.mark.parametrize( "scaling_type_weight", - [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] + # [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] + [ScalingType.DYNAMIC] ) @pytest.mark.parametrize( "scaling_type_grad_output", - [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], + # [ScalingType.DELAYED, ScalingType.DYNAMIC], + [ScalingType.DYNAMIC] ) @pytest.mark.parametrize( - "scaling_granularity", - [ScalingGranularity.TENSORWISE, ScalingGranularity.AXISWISE], + "scaling_granularities_by_gemm", + scaling_granularities_by_gemm ) - @pytest.mark.parametrize("linear_dtype", [torch.bfloat16, torch.float32]) - @pytest.mark.parametrize("linear_bias", [False, True]) + # @pytest.mark.parametrize("linear_dtype", [torch.bfloat16, torch.float32]) + @pytest.mark.parametrize("linear_dtype", [torch.bfloat16, ]) + # @pytest.mark.parametrize("linear_bias", [False, True]) + @pytest.mark.parametrize("linear_bias", [False, ]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_linear( self, @@ -344,7 +379,7 @@ def test_linear( scaling_type_input: ScalingType, scaling_type_weight: ScalingType, scaling_type_grad_output: ScalingType, - scaling_granularity: ScalingGranularity, + scaling_granularities_by_gemm: List[List[Tuple[ScalingGranularity, ScalingGranularity]]], linear_dtype: torch.dtype, linear_bias: bool, ): @@ -357,7 +392,23 @@ def test_linear( f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)" ) pytest.skip() - if scaling_granularity is ScalingGranularity.AXISWISE: + + ( + (scaling_granularity_input, scaling_granularity_weight, original_prec_input, original_prec_weight), + (scaling_granularity_grad_output, scaling_granularity_weight_for_grad_input, original_prec_grad_output, original_prec_weight_for_grad_input), + (scaling_granularity_input_for_grad_weight, scaling_granularity_grad_output_for_grad_weight, original_prec_input_for_grad_weight, original_prec_grad_output_for_grad_weight), + ) = scaling_granularities_by_gemm + + has_any_axiswise_scaling = ( + scaling_granularity_input is ScalingGranularity.AXISWISE or + scaling_granularity_weight is ScalingGranularity.AXISWISE or + scaling_granularity_grad_output is ScalingGranularity.AXISWISE or + scaling_granularity_input_for_grad_weight is ScalingGranularity.AXISWISE or + scaling_granularity_weight_for_grad_input is ScalingGranularity.AXISWISE or + scaling_granularity_grad_output_for_grad_weight is ScalingGranularity.AXISWISE + ) + + if has_any_axiswise_scaling: if ( scaling_type_input != ScalingType.DYNAMIC or scaling_type_weight != ScalingType.DYNAMIC or @@ -370,46 +421,14 @@ def test_linear( x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype) m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype) - if scaling_type_input is ScalingType.STATIC: - cast_config_input = CastConfig( - scaling_type=scaling_type_input, - scaling_granularity=scaling_granularity, - static_scale=torch.tensor([1.0], device="cuda"), - ) - else: - cast_config_input = CastConfig( - scaling_type=scaling_type_input, - scaling_granularity=scaling_granularity, - ) - if scaling_type_weight is ScalingType.STATIC: - cast_config_weight = CastConfig( - scaling_type=scaling_type_weight, - scaling_granularity=scaling_granularity, - static_scale=torch.tensor([1.0], device="cuda"), - ) - else: - cast_config_weight = CastConfig( - scaling_type=scaling_type_weight, - scaling_granularity=scaling_granularity, - ) - if scaling_type_grad_output is ScalingType.STATIC: - cast_config_grad_output = CastConfig( - scaling_type=scaling_type_grad_output, - scaling_granularity=scaling_granularity, - static_scale=torch.tensor([1.0], device="cuda"), - ) - else: - cast_config_grad_output = CastConfig( - scaling_type=scaling_type_grad_output, - scaling_granularity=scaling_granularity, - ) - - config = Float8LinearConfig( - cast_config_input=cast_config_input, - cast_config_weight=cast_config_weight, - cast_config_grad_output=cast_config_grad_output, - emulate=emulate, + config = get_test_float8_linear_config( + scaling_type_input, + scaling_type_weight, + scaling_type_grad_output, + scaling_granularities_by_gemm, + emulate, ) + self._test_linear_impl( x, m_ref, diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index eacd317b1..5a41e2972 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import copy import random +from typing import List, Tuple import sys import unittest from io import StringIO @@ -33,6 +34,10 @@ from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_delayed from torchao.float8.float8_tensor import LinearMMConfig from torchao.float8.float8_utils import e4m3_dtype +from torchao.testing.float8.test_utils import ( + scaling_granularities_by_gemm, + get_test_float8_linear_config, +) from torch._dynamo.test_case import TestCase as DynamoTestCase from torch._dynamo.testing import CompileCounterWithBackend @@ -52,7 +57,8 @@ def _test_compile_base( x_shape = (16, 16) linear_dtype = torch.bfloat16 - x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype) + x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype).requires_grad_() + x_ref = copy.deepcopy(x) m_ref = nn.Linear(16, 32, bias=True, device="cuda", dtype=linear_dtype) m_fp8 = Float8Linear.from_float( @@ -64,7 +70,7 @@ def _test_compile_base( m_ref = torch.compile(m_ref, backend=backend, fullgraph=fullgraph) y_fp8 = m_fp8(x) y_fp8.sum().backward() - y_ref = m_ref(x) + y_ref = m_ref(x_ref) y_ref.sum().backward() # TODO(future PR): can also test fp8 eager vs compile here with a tigher # tolerance @@ -73,65 +79,33 @@ def _test_compile_base( m_fp8.weight.grad, m_ref.weight.grad, atol=2e-1, rtol=2e-1 ) torch.testing.assert_close(m_fp8.bias.grad, m_ref.bias.grad, atol=8e-2, rtol=8e-2) - -def _get_config( - scaling_type_input, - scaling_type_weight, - scaling_type_grad_output, - scaling_granularity, - emulate, -): - if scaling_type_input is ScalingType.STATIC: - cast_config_input = CastConfig( - scaling_type=scaling_type_input, - scaling_granularity=scaling_granularity, - static_scale=torch.tensor([1.0], device="cuda"), - ) - else: - cast_config_input = CastConfig( - scaling_type=scaling_type_input, - scaling_granularity=scaling_granularity, - ) - if scaling_type_weight is ScalingType.STATIC: - cast_config_weight = CastConfig( - scaling_type=scaling_type_weight, - scaling_granularity=scaling_granularity, - static_scale=torch.tensor([1.0], device="cuda"), - ) - else: - cast_config_weight = CastConfig( - scaling_type=scaling_type_weight, - scaling_granularity=scaling_granularity, - ) - if scaling_type_grad_output is ScalingType.STATIC: - cast_config_grad_output = CastConfig( - scaling_type=scaling_type_grad_output, - scaling_granularity=scaling_granularity, - static_scale=torch.tensor([1.0], device="cuda"), - ) - else: - cast_config_grad_output = CastConfig( - scaling_type=scaling_type_grad_output, - scaling_granularity=scaling_granularity, - ) - - config = Float8LinearConfig( - cast_config_input=cast_config_input, - cast_config_weight=cast_config_weight, - cast_config_grad_output=cast_config_grad_output, - emulate=emulate, - ) - return config + torch.testing.assert_close(x.grad, x_ref.grad, atol=8e-2, rtol=8e-2) def is_supported( - scaling_granularity, + scaling_granularities_by_gemm, scaling_type_input, scaling_type_weight, scaling_type_grad_output, dtype, ) -> bool: - if scaling_granularity is ScalingGranularity.AXISWISE: + + ( + (scaling_granularity_input, scaling_granularity_weight, original_prec_input, original_prec_weight), + (scaling_granularity_grad_output, scaling_granularity_weight_for_grad_input, original_prec_grad_output, original_prec_weight_for_grad_input), + (scaling_granularity_input_for_grad_weight, scaling_granularity_grad_output_for_grad_weight, original_prec_input_for_grad_weight, original_prec_grad_output_for_grad_weight), + ) = scaling_granularities_by_gemm + + has_any_axiswise_scaling = ( + scaling_granularity_input is ScalingGranularity.AXISWISE or + scaling_granularity_weight is ScalingGranularity.AXISWISE or + scaling_granularity_grad_output is ScalingGranularity.AXISWISE or + scaling_granularity_input_for_grad_weight is ScalingGranularity.AXISWISE or + scaling_granularity_weight_for_grad_input is ScalingGranularity.AXISWISE or + scaling_granularity_grad_output_for_grad_weight is ScalingGranularity.AXISWISE + ) + + if has_any_axiswise_scaling: if ( scaling_type_input != ScalingType.DYNAMIC or scaling_type_weight != ScalingType.DYNAMIC or @@ -145,19 +119,28 @@ def is_supported( @pytest.mark.parametrize("fullgraph", [True]) @pytest.mark.parametrize( - "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] + # "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] + "scaling_type_input", [ScalingType.DYNAMIC,] ) @pytest.mark.parametrize( - "scaling_type_weight", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] + # "scaling_type_weight", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] + "scaling_type_weight", [ScalingType.DYNAMIC,] ) @pytest.mark.parametrize( - "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] + # "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] + "scaling_type_grad_output", [ScalingType.DYNAMIC,] ) +# @pytest.mark.parametrize( +# "scaling_granularity", [ScalingGranularity.TENSORWISE, ScalingGranularity.AXISWISE] +# ) @pytest.mark.parametrize( - "scaling_granularity", [ScalingGranularity.TENSORWISE, ScalingGranularity.AXISWISE] + "scaling_granularities_by_gemm", + scaling_granularities_by_gemm ) -@pytest.mark.parametrize("emulate", [False, True] if is_cuda_8_9 else [True]) -@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) +# @pytest.mark.parametrize("emulate", [False, True] if is_cuda_8_9 else [True]) +@pytest.mark.parametrize("emulate", [False, ]) +# @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) +@pytest.mark.parametrize("dtype", [torch.bfloat16, ]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_eager_only( fullgraph, @@ -165,11 +148,11 @@ def test_eager_only( scaling_type_input: ScalingType, scaling_type_weight: ScalingType, scaling_type_grad_output: ScalingType, - scaling_granularity: ScalingGranularity, + scaling_granularities_by_gemm: List[List[Tuple[ScalingGranularity, ScalingGranularity]]], dtype: torch.dtype, ): if not is_supported( - scaling_granularity, + scaling_granularities_by_gemm, scaling_type_input, scaling_type_weight, scaling_type_grad_output, @@ -178,11 +161,11 @@ def test_eager_only( pytest.skip() torch._dynamo.reset() - config = _get_config( - scaling_type_input, - scaling_type_weight, - scaling_type_grad_output, - scaling_granularity, + config = get_test_float8_linear_config( + scaling_type_input, + scaling_type_weight, + scaling_type_grad_output, + scaling_granularities_by_gemm, emulate, ) _test_compile_base( @@ -194,20 +177,26 @@ def test_eager_only( @pytest.mark.parametrize("fullgraph", [True]) -@pytest.mark.parametrize("emulate", [False, True] if is_cuda_8_9 else [True]) +# @pytest.mark.parametrize("emulate", [False, True] if is_cuda_8_9 else [True]) +@pytest.mark.parametrize("emulate", [False,]) @pytest.mark.parametrize( - "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] + # "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] + "scaling_type_input", [ScalingType.DYNAMIC,] ) @pytest.mark.parametrize( - "scaling_type_weight", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] + # "scaling_type_weight", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] + "scaling_type_weight", [ScalingType.DYNAMIC,] ) @pytest.mark.parametrize( - "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] + # "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] + "scaling_type_grad_output", [ScalingType.DYNAMIC,] ) @pytest.mark.parametrize( - "scaling_granularity", [ScalingGranularity.TENSORWISE, ScalingGranularity.AXISWISE] + "scaling_granularities_by_gemm", + scaling_granularities_by_gemm ) -@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) +# @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) +@pytest.mark.parametrize("dtype", [torch.bfloat16,]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_aot_eager( fullgraph, @@ -215,11 +204,11 @@ def test_aot_eager( scaling_type_input: ScalingType, scaling_type_weight: ScalingType, scaling_type_grad_output: ScalingType, - scaling_granularity: ScalingGranularity, + scaling_granularities_by_gemm: List[List[Tuple[ScalingGranularity, ScalingGranularity]]], dtype: torch.dtype, ): if not is_supported( - scaling_granularity, + scaling_granularities_by_gemm, scaling_type_input, scaling_type_weight, scaling_type_grad_output, @@ -228,11 +217,11 @@ def test_aot_eager( pytest.skip() torch._dynamo.reset() - config = _get_config( - scaling_type_input, - scaling_type_weight, - scaling_type_grad_output, - scaling_granularity, + config = get_test_float8_linear_config( + scaling_type_input, + scaling_type_weight, + scaling_type_grad_output, + scaling_granularities_by_gemm, emulate, ) _test_compile_base( @@ -246,30 +235,35 @@ def test_aot_eager( @pytest.mark.parametrize("fullgraph", [True]) @pytest.mark.parametrize("emulate", [False]) @pytest.mark.parametrize( - "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] + # "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] + "scaling_type_input", [ScalingType.DYNAMIC, ] ) @pytest.mark.parametrize( - "scaling_type_weight", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] + # "scaling_type_weight", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] + "scaling_type_weight", [ScalingType.DYNAMIC, ] ) @pytest.mark.parametrize( - "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] + # "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] + "scaling_type_grad_output", [ScalingType.DYNAMIC, ] ) @pytest.mark.parametrize( - "scaling_granularity", [ScalingGranularity.TENSORWISE, ScalingGranularity.AXISWISE] + "scaling_granularities_by_gemm", + scaling_granularities_by_gemm ) @unittest.skipIf(not torch.cuda.is_available() or not is_cuda_8_9, "CUDA with float8 support not available") -@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) +# @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) +@pytest.mark.parametrize("dtype", [torch.bfloat16,]) def test_inductor( fullgraph, emulate: bool, scaling_type_input: ScalingType, scaling_type_weight: ScalingType, scaling_type_grad_output: ScalingType, - scaling_granularity: ScalingGranularity, + scaling_granularities_by_gemm: List[List[Tuple[ScalingGranularity, ScalingGranularity]]], dtype: torch.dtype, ): if not is_supported( - scaling_granularity, + scaling_granularities_by_gemm, scaling_type_input, scaling_type_weight, scaling_type_grad_output, @@ -278,11 +272,11 @@ def test_inductor( pytest.skip() torch._dynamo.reset() - config = _get_config( - scaling_type_input, - scaling_type_weight, - scaling_type_grad_output, - scaling_granularity, + config = get_test_float8_linear_config( + scaling_type_input, + scaling_type_weight, + scaling_type_grad_output, + scaling_granularities_by_gemm, emulate, ) _test_compile_base( diff --git a/torchao/float8/config.py b/torchao/float8/config.py index 4d82bd111..1470d0793 100644 --- a/torchao/float8/config.py +++ b/torchao/float8/config.py @@ -48,12 +48,16 @@ def short_str(self): @dataclass(frozen=True) class CastConfig: """ - Configuration for casting a single tensor to float8 + Configuration for maybe casting a single tensor to float8 """ scaling_type: ScalingType = ScalingType.DYNAMIC scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE static_scale: Optional[torch.Tensor] = None + # If True, this tensor is not scaled to float8 and left in its original + # precision. + # TODO(ideally before this PR lands): a better name for this + keep_in_original_precision: bool = False def __post_init__(self): if self.scaling_type is ScalingType.STATIC: @@ -98,7 +102,7 @@ class Float8GemmConfig: use_fast_accum: bool = False -@dataclass(frozen=True) +@dataclass(frozen=False) class Float8LinearConfig: """ Configuration for converting a `torch.nn.Linear` module to float8 @@ -112,9 +116,22 @@ class Float8LinearConfig: cast_config_weight: CastConfig = CastConfig() cast_config_grad_output: CastConfig = CastConfig() + # + # Optional per-tensor configuration for `input`, `weight`, `grad_output` to + # calculate `grad_weight`, `grad_input`, and `grad_weight` respectively. + # If not specified, then the configuration from the is reused. + # TODO(future PR): maybe rename `cast_config_input` to + # `cast_config_input_for_output`, etc, to make the names consistent, + # will be BC-breaking. + # + cast_config_input_for_grad_weight: Optional[CastConfig] = None + cast_config_weight_for_grad_input: Optional[CastConfig] = None + cast_config_grad_output_for_grad_weight: Optional[CastConfig] = None + # # Per-gemm configuration for gemms calculating `output`, `grad_input` and # `grad_weight` + # TODO(this PR): throw warning if fast_accum False is used with axiswise scaling # gemm_config_output: Float8GemmConfig = Float8GemmConfig(use_fast_accum=True) gemm_config_grad_input: Float8GemmConfig = Float8GemmConfig() @@ -156,26 +173,46 @@ class Float8LinearConfig: delayed_scaling_config: DelayedScalingConfig = DelayedScalingConfig() def __post_init__(self): + # populate the additional cast overrides, if the user did not specify them + if self.cast_config_input_for_grad_weight is None: + self.cast_config_input_for_grad_weight = self.cast_config_input + if self.cast_config_weight_for_grad_input is None: + self.cast_config_weight_for_grad_input = self.cast_config_weight + if self.cast_config_grad_output_for_grad_weight is None: + self.cast_config_grad_output_for_grad_weight = self.cast_config_grad_output + # float8 all-gather only supports tensorwise, in the future may support blockwise if self.cast_config_weight.scaling_granularity != ScalingGranularity.TENSORWISE: assert not self.enable_fsdp_float8_all_gather, \ f"enable_fsdp_float8_all_gather only supports tensorwise scaling granularity, got {self.cast_config_weight.scaling_granularity}" - # for now, axiswise granularity is all-or-nothing - # TODO(future PR): enable more granular setting per-gemm-input - has_any_axiswise_scaling = ( - self.cast_config_input.scaling_granularity is ScalingGranularity.AXISWISE or - self.cast_config_weight.scaling_granularity is ScalingGranularity.AXISWISE or - self.cast_config_grad_output.scaling_granularity is ScalingGranularity.AXISWISE - ) - has_all_axiswise_scaling = ( - self.cast_config_input.scaling_granularity is ScalingGranularity.AXISWISE and - self.cast_config_weight.scaling_granularity is ScalingGranularity.AXISWISE and - self.cast_config_grad_output.scaling_granularity is ScalingGranularity.AXISWISE - ) - if has_any_axiswise_scaling: - assert has_all_axiswise_scaling, \ - "for now, axiswise scaling must be enabled for either all casts or none of the casts" + # save some characters in the compatibility checks below + cc_i = self.cast_config_input + cc_w = self.cast_config_weight + cc_go = self.cast_config_grad_output + cc_i_gw = self.cast_config_input_for_grad_weight + cc_w_gi = self.cast_config_weight_for_grad_input + cc_go_gw = self.cast_config_grad_output_for_grad_weight + + # for now, we only have gemm kernels where both operands are scaled with the same + # granularity. In the future this may be relaxed. + assert cc_i.scaling_granularity == cc_w.scaling_granularity, \ + "incompatible scaling granularity for output" + # assert cc_go.scaling_granularity == cc_w_gi.scaling_granularity, \ + # "incompatible scaling granularity for grad_input" + assert cc_i_gw.scaling_granularity == cc_go_gw.scaling_granularity, \ + "incompatible scaling granularity for grad_weight" + + # for now, we only have gemm kernels where both operands are either both + # in high precision, or both in float8. In the future, this may be relaxed. + # TODO(future): make the float8 check more precise with the specific dtypes. + assert cc_i.keep_in_original_precision == cc_w.keep_in_original_precision, \ + "incompatible operand precision for output" + assert cc_go.keep_in_original_precision == cc_w_gi.keep_in_original_precision, \ + "incompatible operand precision for grad_input" + assert cc_i_gw.keep_in_original_precision == cc_go_gw.keep_in_original_precision, \ + "incompatible operand precision for grad_weight" + # If True, use 'fnuz' float8 types for calculations. # Currently, ROCm only supports fnuz variants. diff --git a/torchao/float8/float8_linear.py b/torchao/float8/float8_linear.py index 5f87e82fe..2767b7d1e 100644 --- a/torchao/float8/float8_linear.py +++ b/torchao/float8/float8_linear.py @@ -21,6 +21,7 @@ hp_tensor_to_float8_delayed, hp_tensor_to_float8_dynamic, hp_tensor_to_float8_static, + get_maybe_axiswise_dim, NoopFwToFloat8E5M2BwDelayed, NoopFwToFloat8E5M2BwDynamic, NoopFwToFloat8E5M2BwStatic, @@ -33,7 +34,13 @@ ScaledMMConfig, ) -from torchao.float8.float8_utils import e4m3_dtype, e5m2_dtype, tensor_to_amax +from torchao.float8.float8_utils import ( + e4m3_dtype, + e5m2_dtype, + tensor_to_amax, + float8_linear_config_to_concise_casts_config, + Float8LinearConciseCastsConfig, +) from torchao.float8.fsdp_utils import ( WeightWithDelayedFloat8CastTensor, @@ -114,54 +121,56 @@ class manual_float8_matmul_with_args_in_hp(torch.autograd.Function): and other granularities in a separate PR. """ - # TODO(this PR): types of inputs @staticmethod def forward( ctx, input_hp: torch.Tensor, weight_hp_t: torch.Tensor, linear_mm_config: LinearMMConfig, - input_scaling_granularity: ScalingGranularity, - weight_scaling_granularity: ScalingGranularity, - grad_output_scaling_granularity: ScalingGranularity, + concise_casts_config: Float8LinearConciseCastsConfig, ): ctx.save_for_backward(input_hp, weight_hp_t) ctx.linear_mm_config = linear_mm_config - ctx.input_scaling_granularity = input_scaling_granularity - ctx.weight_scaling_granularity = weight_scaling_granularity - ctx.grad_output_scaling_granularity = grad_output_scaling_granularity - - input_fp8 = hp_tensor_to_float8_dynamic( - input_hp, - e4m3_dtype, - linear_mm_config, - gemm_input_role=GemmInputRole.INPUT, - scaling_granularity=input_scaling_granularity, - axiswise_dim=-1, - ) + ctx.concise_casts_config = concise_casts_config - weight_fp8_t = hp_tensor_to_float8_dynamic( - weight_hp_t, - e4m3_dtype, - linear_mm_config, - gemm_input_role=GemmInputRole.WEIGHT, - scaling_granularity=weight_scaling_granularity, - axiswise_dim=0, - ) + c = concise_casts_config + + if c.cc_i.orig_prec: + input_maybe_fp8 = input_hp + else: + input_maybe_fp8 = hp_tensor_to_float8_dynamic( + input_hp, + e4m3_dtype, + linear_mm_config, + gemm_input_role=GemmInputRole.INPUT, + scaling_granularity=c.cc_i.sc_gr, + axiswise_dim=get_maybe_axiswise_dim(-1, c.cc_i.sc_gr), + ) + + if c.cc_w.orig_prec: + weight_maybe_fp8_t = weight_hp_t + else: + weight_maybe_fp8_t = hp_tensor_to_float8_dynamic( + weight_hp_t, + e4m3_dtype, + linear_mm_config, + gemm_input_role=GemmInputRole.WEIGHT, + scaling_granularity=c.cc_w.sc_gr, + axiswise_dim=get_maybe_axiswise_dim(0, c.cc_w.sc_gr), + ) # the reshapes are needed in order to make the shapes compatible with # torch.mm - orig_shape = input_fp8.shape - input_fp8_reshaped = input_fp8.reshape(-1, orig_shape[-1]) - res_bits = torch.mm(input_fp8_reshaped, weight_fp8_t) + orig_shape = input_maybe_fp8.shape + input_maybe_fp8_reshaped = input_maybe_fp8.reshape(-1, orig_shape[-1]) + res_bits = torch.mm(input_maybe_fp8_reshaped, weight_maybe_fp8_t) res_bits = res_bits.reshape(*orig_shape[:-1], res_bits.shape[-1]) return res_bits @staticmethod def backward(ctx, grad_output): input_hp, weight_hp_t = ctx.saved_tensors - - # TODO scaling + c = ctx.concise_casts_config # the reshapes are needed in order to make the shapes compatible with # torch.mm @@ -174,26 +183,37 @@ def backward(ctx, grad_output): # calculate grad_input # - grad_output_reshaped_fp8_dim0 = hp_tensor_to_float8_dynamic( - grad_output_reshaped, - e5m2_dtype, - ctx.linear_mm_config, - gemm_input_role=GemmInputRole.GRAD_OUTPUT, - scaling_granularity=ctx.grad_output_scaling_granularity, - axiswise_dim=-1, - ) - weight_t_fp8_dim0 = hp_tensor_to_float8_dynamic( - weight_hp_t, - e4m3_dtype, - ctx.linear_mm_config, - gemm_input_role=GemmInputRole.WEIGHT, - scaling_granularity=ctx.weight_scaling_granularity, - axiswise_dim=1, # will be transposed - ) + if c.cc_go.orig_prec: + grad_output_reshaped_maybe_fp8_dim0 = grad_output_reshaped + else: + grad_output_reshaped_maybe_fp8_dim0 = hp_tensor_to_float8_dynamic( + grad_output_reshaped, + e5m2_dtype, + ctx.linear_mm_config, + gemm_input_role=GemmInputRole.GRAD_OUTPUT, + scaling_granularity=c.cc_go.sc_gr, + axiswise_dim=get_maybe_axiswise_dim(-1, c.cc_go.sc_gr), + ) + + if c.cc_w_gi.orig_prec: + weight_t_maybe_fp8_dim0 = weight_hp_t + else: + # Note: we need https://github.com/pytorch/pytorch/issues/136267 + # to be solved to have a chance to reuse max(abs(weight, dim=...)) + # from the forward to get max(abs(weight)) here without reading + # the entire tensor. + weight_t_maybe_fp8_dim0 = hp_tensor_to_float8_dynamic( + weight_hp_t, + e4m3_dtype, + ctx.linear_mm_config, + gemm_input_role=GemmInputRole.WEIGHT, + scaling_granularity=c.cc_w_gi.sc_gr, + axiswise_dim=get_maybe_axiswise_dim(1, c.cc_w_gi.sc_gr), + ) grad_input = torch.mm( - grad_output_reshaped_fp8_dim0, - weight_t_fp8_dim0.t(), + grad_output_reshaped_maybe_fp8_dim0, + weight_t_maybe_fp8_dim0.t(), ) grad_input = grad_input.reshape( *grad_output_orig_shape[:-1], grad_input.shape[-1] @@ -206,29 +226,38 @@ def backward(ctx, grad_output): # calculate grad_weight # - grad_output_reshaped_fp8_dim1 = hp_tensor_to_float8_dynamic( - grad_output_reshaped, - e5m2_dtype, - ctx.linear_mm_config, - gemm_input_role=GemmInputRole.GRAD_OUTPUT, - scaling_granularity=ctx.grad_output_scaling_granularity, - axiswise_dim=0, # will be transposed - ) - input_reshaped_fp8_dim1 = hp_tensor_to_float8_dynamic( - input_hp_reshaped, - e4m3_dtype, - ctx.linear_mm_config, - gemm_input_role=GemmInputRole.INPUT, - scaling_granularity=ctx.input_scaling_granularity, - axiswise_dim=0, - ) + if c.cc_go_gw.orig_prec: + grad_output_reshaped_maybe_fp8_dim1 = grad_output_reshaped + else: + grad_output_reshaped_maybe_fp8_dim1 = hp_tensor_to_float8_dynamic( + grad_output_reshaped, + e5m2_dtype, + ctx.linear_mm_config, + gemm_input_role=GemmInputRole.GRAD_OUTPUT, + scaling_granularity=c.cc_go_gw.sc_gr, + axiswise_dim=get_maybe_axiswise_dim(0, c.cc_go_gw.sc_gr), + ) + + if c.cc_i_gw.orig_prec: + input_reshaped_maybe_fp8_dim1 = input_hp_reshaped + else: + input_reshaped_maybe_fp8_dim1 = hp_tensor_to_float8_dynamic( + input_hp_reshaped, + e4m3_dtype, + ctx.linear_mm_config, + gemm_input_role=GemmInputRole.INPUT, + scaling_granularity=c.cc_i_gw.sc_gr, + axiswise_dim=get_maybe_axiswise_dim(0, c.cc_i_gw.sc_gr), + ) grad_weight = torch.mm( - grad_output_reshaped_fp8_dim1.t(), - input_reshaped_fp8_dim1, + grad_output_reshaped_maybe_fp8_dim1.t(), + input_reshaped_maybe_fp8_dim1, ) - return grad_input, grad_weight.t(), None, None, None, None + empty_grads = None, None + + return grad_input, grad_weight.t(), *empty_grads class Float8Linear(torch.nn.Linear): @@ -313,6 +342,9 @@ def __init__(self, *args, **kwargs): # would be initialized in every iteration. self.enable_pre_and_post_forward = self.config.enable_pre_and_post_forward + self.concise_casts_config: Float8LinearConciseCastsConfig = \ + float8_linear_config_to_concise_casts_config(self.config) + def create_buffers(self): # Default values for history buffers, see above TODO history_len = self.config.delayed_scaling_config.history_len @@ -554,9 +586,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: input, self.weight.t(), self.linear_mm_config, - self.config.cast_config_input.scaling_granularity, - self.config.cast_config_weight.scaling_granularity, - self.config.cast_config_grad_output.scaling_granularity, + self.concise_casts_config, ) if self.bias is not None: @@ -574,10 +604,14 @@ def scaling_type_repr(self): def scaling_granularity_repr(self): # add scaling granularity settings without using too many characters # example: "i:ten,w:ten,g:ten" or "i:axs,w:axs,g:axs" - gi = self.config.cast_config_input.scaling_granularity.short_str() - gw = self.config.cast_config_weight.scaling_granularity.short_str() - ggo = self.config.cast_config_grad_output.scaling_granularity.short_str() - return f"i:{gi},w:{gw},go:{ggo}" + c = self.config + gi = c.cast_config_input.scaling_granularity.short_str() + gw = c.cast_config_weight.scaling_granularity.short_str() + ggo = c.cast_config_grad_output.scaling_granularity.short_str() + gi2 = c.cast_config_input_for_grad_weight.scaling_granularity.short_str() + gw2 = c.cast_config_weight_for_grad_input.scaling_granularity.short_str() + ggo2 = c.cast_config_grad_output_for_grad_weight.scaling_granularity.short_str() + return f"i:{gi},w:{gw},go:{ggo},i2:{gi2},w2:{gw2},go2:{ggo2}" def extra_repr(self): s = f'{super().extra_repr()}, scaling_type="{self.scaling_type_repr()}", scaling_granularity="{self.scaling_granularity_repr()}"' diff --git a/torchao/float8/float8_ops.py b/torchao/float8/float8_ops.py index b97d03211..8f5bc768e 100644 --- a/torchao/float8/float8_ops.py +++ b/torchao/float8/float8_ops.py @@ -251,6 +251,20 @@ def preprocess_addmm(a: Float8Tensor, b: Float8Tensor): if is_row_major(b_data.stride()): b_data = b_data.t().contiguous().t() b_scale = b._scale + + # Today, torch._scaled_mm only supports both operands using the + # same granularity. The code below checks for cases where one + # operand is scaled axiswise and one tensorwise. If this case is found, + # we reshape the tensorwise scale to be repeat along the needed axis, + # so that torch._scaled_mm can call the axiswise-axiswise kernel. + # Note: using shape/size info does not work with compile here, which is + # why we are using inferring scaling type from the presence of + # axiswise_dim. + if a._axiswise_dim is None and b._axiswise_dim is not None: + a_scale = a_scale.repeat(a_data.shape[0]).reshape(-1, 1) + elif a._axiswise_dim is not None and b._axiswise_dim is None: + b_scale = b_scale.repeat(b_data.shape[1]).reshape(1, -1) + return a_data, a_scale, b_data, b_scale diff --git a/torchao/float8/float8_scaling_utils.py b/torchao/float8/float8_scaling_utils.py index f46293d61..a8ee92f28 100644 --- a/torchao/float8/float8_scaling_utils.py +++ b/torchao/float8/float8_scaling_utils.py @@ -141,6 +141,21 @@ def hp_tensor_to_float8_static( ) +def get_maybe_axiswise_dim( + axiswise_dim: int, + scaling_granularity: ScalingGranularity, +) -> Optional[int]: + """ + Convenience function which takes in an axiswise dim which is only relevant + for axiswise scaing, and a scaling type. The output is pass-through + if scaling type is axiswise, and None otherwise. This is done to keep the + logic from choosing the axiswise dim out of the scaling function. + """ + if scaling_granularity is ScalingGranularity.AXISWISE: + return axiswise_dim + return None + + def _maybe_initialize_amaxes_scales_for_float8_cast( x, cur_amax, diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index e79cf27d8..665373c91 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -from typing import Iterable, Literal, Optional, Tuple, Union +from typing import Iterable, Literal, NamedTuple, Optional, Tuple, Union import torchao.float8.config as config @@ -258,3 +258,61 @@ def pad_tensor_for_matmul( pad_dim2 = dim2_aligned - dim2 return torch.nn.functional.pad(tensor, (0, pad_dim2, 0, pad_dim1)) + + +# The code below introduces a bit of duplication with Float8LinearConfig in +# order to improve readability of the implementation of how Float8Linear +# uses the config. Specifically, we do two things: +# 1. wrap the relevant parts of configs in namedtuple, so we can pass +# them around in compile-friendly code. +# 2. make the tuple key names more brief, to make the implementation +# code less verbose (the code was so verbose that I felt the need +# to add this workaround). +# As I was writing this, it became less and less clear on why not just have +# a namedtuple as a top level config. Punting that to a future PR as +# that might be BC-breaking, but probably worth exploring. +# Note: I also think below is pretty hacky, it's good enough to unblock +# further prototyping, but IMO pretty important to clean up sooner rather +# than later. + + +class ConciseCastConfig(NamedTuple): + sc_tp: config.ScalingType + sc_gr: config.ScalingGranularity + st_sc: Optional[torch.Tensor] + orig_prec: bool + + @classmethod + def from_cast_config(cls, c: config.CastConfig): + return cls( + sc_tp=c.scaling_type, + sc_gr=c.scaling_granularity, + st_sc=c.static_scale, + orig_prec=c.keep_in_original_precision, + ) + + +class Float8LinearConciseCastsConfig(NamedTuple): + cc_i: ConciseCastConfig + cc_w: ConciseCastConfig + cc_go: ConciseCastConfig + cc_i_gw: ConciseCastConfig + cc_w_gi: ConciseCastConfig + cc_go_gw: ConciseCastConfig + + +def float8_linear_config_to_concise_casts_config( + c: config.Float8LinearConfig, +) -> Float8LinearConciseCastsConfig: + concise_config = Float8LinearConciseCastsConfig( + cc_i=ConciseCastConfig.from_cast_config(c.cast_config_input), + cc_w=ConciseCastConfig.from_cast_config(c.cast_config_weight), + cc_go=ConciseCastConfig.from_cast_config(c.cast_config_grad_output), + cc_i_gw=ConciseCastConfig.from_cast_config(c.cast_config_input_for_grad_weight), + cc_w_gi=ConciseCastConfig.from_cast_config(c.cast_config_weight_for_grad_input), + cc_go_gw=ConciseCastConfig.from_cast_config( + c.cast_config_grad_output_for_grad_weight + ), + ) + + return concise_config diff --git a/torchao/testing/float8/test_utils.py b/torchao/testing/float8/test_utils.py new file mode 100644 index 000000000..6aa5e78e9 --- /dev/null +++ b/torchao/testing/float8/test_utils.py @@ -0,0 +1,131 @@ +import torch +from torchao.float8.config import ( + ScalingGranularity, + ScalingType, + CastConfig, + Float8LinearConfig, + Float8GemmConfig, +) + +scaling_granularities_by_gemm_lcw_recipe = [ + # @lcw's recipe + # output = input @ weight_t + # input: axiswise + # weight_t: axiswise + (ScalingGranularity.AXISWISE, ScalingGranularity.AXISWISE, False, False), + # grad_input = grad_output @ weight + # grad_output: axiswise + # weight: tensorwise (but that can be computed from axiswise done in the forward) + (ScalingGranularity.AXISWISE, ScalingGranularity.TENSORWISE, False, False), + # grad_weight = input_t @ grad_output, in high precision (bfloat16) + # input_t: high precision + # grad_output: high precision + (ScalingGranularity.TENSORWISE, ScalingGranularity.TENSORWISE, True, True), +] + +scaling_granularities_by_gemm_all_tensorwise = [ + (ScalingGranularity.TENSORWISE, ScalingGranularity.TENSORWISE, False, False), + (ScalingGranularity.TENSORWISE, ScalingGranularity.TENSORWISE, False, False), + (ScalingGranularity.TENSORWISE, ScalingGranularity.TENSORWISE, False, False), +] + +scaling_granularities_by_gemm_all_axiswise = [ + (ScalingGranularity.AXISWISE, ScalingGranularity.AXISWISE, False, False), + (ScalingGranularity.AXISWISE, ScalingGranularity.AXISWISE, False, False), + (ScalingGranularity.AXISWISE, ScalingGranularity.AXISWISE, False, False), +] + +# scaling granularity and keep_in_original_precision to test by gemm arguments in this +# order: output, grad_input, grad_weight +scaling_granularities_by_gemm = [ + # TODO(before land): move this last + scaling_granularities_by_gemm_lcw_recipe, + # scaling_granularities_by_gemm_all_tensorwise, + # scaling_granularities_by_gemm_all_axiswise, +] + +def get_test_float8_linear_config( + scaling_type_input, + scaling_type_weight, + scaling_type_grad_output, + scaling_granularities_by_gemm, + emulate: bool, +): + ( + (scaling_granularity_input, scaling_granularity_weight, original_prec_input, original_prec_weight), + (scaling_granularity_grad_output, scaling_granularity_weight_for_grad_input, original_prec_grad_output, original_prec_weight_for_grad_input), + (scaling_granularity_input_for_grad_weight, scaling_granularity_grad_output_for_grad_weight, original_prec_input_for_grad_weight, original_prec_grad_output_for_grad_weight), + ) = scaling_granularities_by_gemm + + static_scale_one = torch.tensor([1.0], device="cuda") + + if scaling_type_input is ScalingType.STATIC: + static_scale_input = static_scale_one + else: + static_scale_input = None + if scaling_type_weight is ScalingType.STATIC: + static_scale_weight = static_scale_one + else: + static_scale_weight = None + if scaling_type_grad_output is ScalingType.STATIC: + static_scale_grad_output = static_scale_one + else: + static_scale_grad_output = None + + cast_config_input = CastConfig( + scaling_type=scaling_type_input, + scaling_granularity=scaling_granularity_input, + static_scale=static_scale_input, + keep_in_original_precision=original_prec_input, + ) + cast_config_input_for_grad_weight = CastConfig( + scaling_type=scaling_type_input, + scaling_granularity=scaling_granularity_input_for_grad_weight, + static_scale=static_scale_input, + keep_in_original_precision=original_prec_input_for_grad_weight, + ) + + cast_config_weight = CastConfig( + scaling_type=scaling_type_weight, + scaling_granularity=scaling_granularity_weight, + static_scale=static_scale_weight, + keep_in_original_precision=original_prec_weight, + ) + cast_config_weight_for_grad_input = CastConfig( + scaling_type=scaling_type_weight, + scaling_granularity=scaling_granularity_weight_for_grad_input, + static_scale=static_scale_weight, + keep_in_original_precision=original_prec_weight_for_grad_input, + ) + + cast_config_grad_output = CastConfig( + scaling_type=scaling_type_grad_output, + scaling_granularity=scaling_granularity_grad_output, + static_scale=static_scale_grad_output, + keep_in_original_precision=original_prec_grad_output, + ) + cast_config_grad_output_for_grad_weight = CastConfig( + scaling_type=scaling_type_grad_output, + scaling_granularity=scaling_granularity_grad_output_for_grad_weight, + static_scale=static_scale_grad_output, + keep_in_original_precision=original_prec_grad_output_for_grad_weight, + ) + + gemm_config_output = Float8GemmConfig(use_fast_accum=True) + # TODO(this PR): toggle fast accum by axiswise scaling presence + gemm_config_grad_input = Float8GemmConfig(use_fast_accum=True) + gemm_config_grad_weight = Float8GemmConfig(use_fast_accum=True) + + config = Float8LinearConfig( + cast_config_input=cast_config_input, + cast_config_weight=cast_config_weight, + cast_config_grad_output=cast_config_grad_output, + cast_config_input_for_grad_weight=cast_config_input_for_grad_weight, + cast_config_weight_for_grad_input=cast_config_weight_for_grad_input, + cast_config_grad_output_for_grad_weight=cast_config_grad_output_for_grad_weight, + gemm_config_output=gemm_config_output, + gemm_config_grad_input=gemm_config_grad_input, + gemm_config_grad_weight=gemm_config_grad_weight, + emulate=emulate, + ) + return config