From ef56618e85333b19ee4824d9570f02e5e8cd84d8 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 23 Sep 2024 12:25:42 -0700 Subject: [PATCH] add axiswise scaling to Float8Linear Summary: This PR: support scaling of all arguments of all gemms to be axiswise, and ensure that training with axiswise scaling works e2e. Future PR: support more granular configurability and optimize performance, add docs Test Plan: ``` // tests pass ./test/float8/test_everything.sh // sanity check on torchtitan with LLaMa 3 8B on 4 H100s with float8: // 1. verify performance does not regress with tensorwise scaling // 2. smoke test that axiswise scaling works and numerics are sane, performance isn't there though // logs: https://gist.github.com/vkuzo/70fa5eb3c23375f307d11e7bae48682f ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: af334fd3f9f0b10e2f0a7cf1e38513741d1b45f7 ghstack-comment-id: 2368837904 Pull Request resolved: https://github.com/pytorch/ao/pull/920 --- benchmarks/float8/bench_linear_float8.py | 32 +++- benchmarks/float8/bench_matmul.py | 13 +- benchmarks/float8/profile_linear_float8.py | 27 ++- test/float8/test_base.py | 33 +++- test/float8/test_compile.py | 105 +++++++++++- test/float8/test_numerics_integration.py | 39 ++++- torchao/float8/config.py | 32 ++++ torchao/float8/float8_linear.py | 188 +++++++++++++++++++-- torchao/float8/float8_ops.py | 22 ++- 9 files changed, 450 insertions(+), 41 deletions(-) diff --git a/benchmarks/float8/bench_linear_float8.py b/benchmarks/float8/bench_linear_float8.py index e18006f0e..f92303c62 100644 --- a/benchmarks/float8/bench_linear_float8.py +++ b/benchmarks/float8/bench_linear_float8.py @@ -14,7 +14,12 @@ import torch import torch.utils.benchmark as benchmark -from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType +from torchao.float8.config import ( + CastConfig, + Float8LinearConfig, + ScalingType, + ScalingGranularity, +) from torchao.float8.float8_linear import Float8Linear from torchao.float8.float8_linear_utils import ( linear_requires_sync, @@ -107,6 +112,7 @@ def main( scaling_type_input: str = "dynamic", scaling_type_weight: str = "dynamic", scaling_type_grad_output: str = "dynamic", + scaling_granularity: str = "tensorwise", ): device = "cuda" print(f"Compile is set to | {compile}") @@ -114,28 +120,41 @@ def main( scaling_type_input = ScalingType(scaling_type_input) scaling_type_weight = ScalingType(scaling_type_weight) scaling_type_grad_output = ScalingType(scaling_type_grad_output) + scaling_granularity = ScalingGranularity(scaling_granularity) if scaling_type_input is ScalingType.STATIC: cast_config_input=CastConfig( scaling_type=scaling_type_input, static_scale=torch.tensor([1.0], device="cuda"), + scaling_granularity=scaling_granularity, ) else: - cast_config_input=CastConfig(scaling_type=scaling_type_input) + cast_config_input=CastConfig( + scaling_type=scaling_type_input, + scaling_granularity=scaling_granularity, + ) if scaling_type_weight is ScalingType.STATIC: cast_config_weight=CastConfig( scaling_type=scaling_type_weight, static_scale=torch.tensor([1.0], device="cuda"), + scaling_granularity=scaling_granularity, ) else: - cast_config_weight=CastConfig(scaling_type=scaling_type_weight) + cast_config_weight=CastConfig( + scaling_type=scaling_type_weight, + scaling_granularity=scaling_granularity, + ) if scaling_type_grad_output is ScalingType.STATIC: cast_config_grad_output=CastConfig( scaling_type=scaling_type_grad_output, static_scale=torch.tensor([1.0], device="cuda"), + scaling_granularity=scaling_granularity, ) else: - cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output) + cast_config_grad_output=CastConfig( + scaling_type=scaling_type_grad_output, + scaling_granularity=scaling_granularity, + ) config = Float8LinearConfig( cast_config_input=cast_config_input, @@ -167,7 +186,7 @@ def main( copy.deepcopy(linear_ref), config=config, ) - scaling_repr = linear_float8.scaling_repr() + scaling_repr = f"{linear_float8.scaling_type_repr()},{linear_float8.scaling_granularity_repr()}" if fast_accum: linear_float8.forward_config = ScaledMMConfig(False, True, False) @@ -310,6 +329,7 @@ def invoke_main() -> None: parser.add_argument("--scaling_type_input", type=str, required=False) parser.add_argument("--scaling_type_weight", type=str, required=False) parser.add_argument("--scaling_type_grad_output", type=str, required=False) + parser.add_argument("--scaling_granularity", type=str, required=False) args = parser.parse_args() output_path = Path(args.output_path) if args.output_path is not None else None kwargs = {} @@ -327,6 +347,8 @@ def invoke_main() -> None: kwargs["scaling_type_weight"] = args.scaling_type_weight if args.scaling_type_grad_output is not None: kwargs["scaling_type_grad_output"] = args.scaling_type_grad_output + if args.scaling_granularity is not None: + kwargs["scaling_granularity"] = args.scaling_granularity main( output_path, not args.disable_compile, diff --git a/benchmarks/float8/bench_matmul.py b/benchmarks/float8/bench_matmul.py index 6b816300c..e969846b2 100644 --- a/benchmarks/float8/bench_matmul.py +++ b/benchmarks/float8/bench_matmul.py @@ -13,6 +13,8 @@ import torch.nn as nn import torch.utils.benchmark as benchmark +from torchao.float8.config import ScalingGranularity + from utils import ( get_name_to_shapes_iter, profiler_output_to_filtered_time_by_kernel_name, @@ -75,6 +77,7 @@ def run( K: Optional[int] = None, N: Optional[int] = None, use_gpu_kernel_time: bool = False, + scaling_granularity: str = "tensorwise", ): device = "cuda" @@ -84,6 +87,7 @@ def run( dtype = torch.bfloat16 name_to_shapes = get_name_to_shapes_iter(shape_gen_name, M, K, N) fast_accum_vals = [True, False] + scaling_granularity = ScalingGranularity(scaling_granularity) for idx, (fast_accum, (name, (M, K, N))) in enumerate(itertools.product(fast_accum_vals, name_to_shapes)): if n_limit is not None and idx >= n_limit: @@ -109,8 +113,13 @@ def run( d1, d2, d3 = torch.float8_e4m3fn, torch.float8_e4m3fn, dtype A = torch.zeros(M, K, device=device, dtype=d1) B = torch.zeros(K, N, device=device, dtype=d2).t().contiguous().t() - scale_a = torch.tensor([1.0], device=device) - scale_b = torch.tensor([1.0], device=device) + if scaling_granularity == ScalingGranularity.TENSORWISE: + scale_a = torch.tensor([1.0], device=device) + scale_b = torch.tensor([1.0], device=device) + else: + assert scaling_granularity == ScalingGranularity.AXISWISE, "unsupported" + scale_a = torch.ones(M, 1, device=device) + scale_b = torch.ones(1, N, device=device) def do_matmul(A, B): nonlocal scale_a diff --git a/benchmarks/float8/profile_linear_float8.py b/benchmarks/float8/profile_linear_float8.py index c204d49b0..6afefa009 100644 --- a/benchmarks/float8/profile_linear_float8.py +++ b/benchmarks/float8/profile_linear_float8.py @@ -22,7 +22,12 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType +from torchao.float8.config import ( + CastConfig, + Float8LinearConfig, + ScalingType, + ScalingGranularity, +) from torchao.float8.float8_linear_utils import ( convert_to_float8_training, linear_requires_sync, @@ -252,6 +257,7 @@ def main( scaling_type_input: str = "dynamic", scaling_type_weight: str = "dynamic", scaling_type_grad_output: str = "dynamic", + scaling_granularity: str = "tensorwise", model_type: str = "linear", dtype_filter: str = "both", add_inductor_metadata_to_trace: bool = True, @@ -263,28 +269,41 @@ def main( scaling_type_input = ScalingType(scaling_type_input) scaling_type_weight = ScalingType(scaling_type_weight) scaling_type_grad_output = ScalingType(scaling_type_grad_output) + scaling_granularity = ScalingGranularity(scaling_granularity) if scaling_type_input is ScalingType.STATIC: cast_config_input=CastConfig( scaling_type=scaling_type_input, static_scale=torch.tensor([1.0], device="cuda"), + scaling_granularity=scaling_granularity, ) else: - cast_config_input=CastConfig(scaling_type=scaling_type_input) + cast_config_input=CastConfig( + scaling_type=scaling_type_input, + scaling_granularity=scaling_granularity, + ) if scaling_type_weight is ScalingType.STATIC: cast_config_weight=CastConfig( scaling_type=scaling_type_weight, static_scale=torch.tensor([1.0], device="cuda"), + scaling_granularity=scaling_granularity, ) else: - cast_config_weight=CastConfig(scaling_type=scaling_type_weight) + cast_config_weight=CastConfig( + scaling_type=scaling_type_weight, + scaling_granularity=scaling_granularity, + ) if scaling_type_grad_output is ScalingType.STATIC: cast_config_grad_output=CastConfig( scaling_type=scaling_type_grad_output, static_scale=torch.tensor([1.0], device="cuda"), + scaling_granularity=scaling_granularity, ) else: - cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output) + cast_config_grad_output=CastConfig( + scaling_type=scaling_type_grad_output, + scaling_granularity=scaling_granularity, + ) config = Float8LinearConfig( cast_config_input=cast_config_input, diff --git a/test/float8/test_base.py b/test/float8/test_base.py index ebc33f037..f0e0ac0a9 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -330,6 +330,10 @@ def _test_linear_impl( "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], ) + @pytest.mark.parametrize( + "scaling_granularity", + [ScalingGranularity.TENSORWISE, ScalingGranularity.AXISWISE], + ) @pytest.mark.parametrize("linear_dtype", [torch.bfloat16, torch.float32]) @pytest.mark.parametrize("linear_bias", [False, True]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") @@ -340,6 +344,7 @@ def test_linear( scaling_type_input: ScalingType, scaling_type_weight: ScalingType, scaling_type_grad_output: ScalingType, + scaling_granularity: ScalingGranularity, linear_dtype: torch.dtype, linear_bias: bool, ): @@ -352,30 +357,52 @@ def test_linear( f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)" ) pytest.skip() + if scaling_granularity is ScalingGranularity.AXISWISE: + if ( + scaling_type_input != ScalingType.DYNAMIC or + scaling_type_weight != ScalingType.DYNAMIC or + scaling_type_grad_output != ScalingType.DYNAMIC or + linear_dtype != torch.bfloat16 or + (not is_cuda_9_0) + ): + pytest.skip() + x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype) m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype) if scaling_type_input is ScalingType.STATIC: cast_config_input = CastConfig( scaling_type=scaling_type_input, + scaling_granularity=scaling_granularity, static_scale=torch.tensor([1.0], device="cuda"), ) else: - cast_config_input = CastConfig(scaling_type=scaling_type_input) + cast_config_input = CastConfig( + scaling_type=scaling_type_input, + scaling_granularity=scaling_granularity, + ) if scaling_type_weight is ScalingType.STATIC: cast_config_weight = CastConfig( scaling_type=scaling_type_weight, + scaling_granularity=scaling_granularity, static_scale=torch.tensor([1.0], device="cuda"), ) else: - cast_config_weight = CastConfig(scaling_type=scaling_type_weight) + cast_config_weight = CastConfig( + scaling_type=scaling_type_weight, + scaling_granularity=scaling_granularity, + ) if scaling_type_grad_output is ScalingType.STATIC: cast_config_grad_output = CastConfig( scaling_type=scaling_type_grad_output, + scaling_granularity=scaling_granularity, static_scale=torch.tensor([1.0], device="cuda"), ) else: - cast_config_grad_output = CastConfig(scaling_type=scaling_type_grad_output) + cast_config_grad_output = CastConfig( + scaling_type=scaling_type_grad_output, + scaling_granularity=scaling_granularity, + ) config = Float8LinearConfig( cast_config_input=cast_config_input, diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index 8a0458bec..74cc6faa5 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -18,7 +18,12 @@ import torch import torch.nn as nn -from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType +from torchao.float8.config import ( + CastConfig, + Float8LinearConfig, + ScalingType, + ScalingGranularity, +) from torchao.float8.float8_linear import Float8Linear from torchao.float8.float8_linear_utils import ( convert_to_float8_training, @@ -32,6 +37,7 @@ from torch._dynamo.test_case import TestCase as DynamoTestCase from torch._dynamo.testing import CompileCounterWithBackend +# TODO(future PR): standardize IS_H100 with the rest of the codebase is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) @@ -60,6 +66,8 @@ def _test_compile_base( y_fp8.sum().backward() y_ref = m_ref(x) y_ref.sum().backward() + # TODO(future PR): can also test fp8 eager vs compile here with a tigher + # tolerance torch.testing.assert_close(y_fp8, y_ref, atol=9.5e-2, rtol=9.5e-2) torch.testing.assert_close( m_fp8.weight.grad, m_ref.weight.grad, atol=2e-1, rtol=2e-1 @@ -70,29 +78,42 @@ def _get_config( scaling_type_input, scaling_type_weight, scaling_type_grad_output, + scaling_granularity, emulate, ): if scaling_type_input is ScalingType.STATIC: cast_config_input = CastConfig( scaling_type=scaling_type_input, + scaling_granularity=scaling_granularity, static_scale=torch.tensor([1.0], device="cuda"), ) else: - cast_config_input = CastConfig(scaling_type=scaling_type_input) + cast_config_input = CastConfig( + scaling_type=scaling_type_input, + scaling_granularity=scaling_granularity, + ) if scaling_type_weight is ScalingType.STATIC: cast_config_weight = CastConfig( scaling_type=scaling_type_weight, + scaling_granularity=scaling_granularity, static_scale=torch.tensor([1.0], device="cuda"), ) else: - cast_config_weight = CastConfig(scaling_type=scaling_type_weight) + cast_config_weight = CastConfig( + scaling_type=scaling_type_weight, + scaling_granularity=scaling_granularity, + ) if scaling_type_grad_output is ScalingType.STATIC: cast_config_grad_output = CastConfig( scaling_type=scaling_type_grad_output, + scaling_granularity=scaling_granularity, static_scale=torch.tensor([1.0], device="cuda"), ) else: - cast_config_grad_output = CastConfig(scaling_type=scaling_type_grad_output) + cast_config_grad_output = CastConfig( + scaling_type=scaling_type_grad_output, + scaling_granularity=scaling_granularity, + ) config = Float8LinearConfig( cast_config_input=cast_config_input, @@ -103,6 +124,25 @@ def _get_config( return config +def is_supported( + scaling_granularity, + scaling_type_input, + scaling_type_weight, + scaling_type_grad_output, + dtype, +) -> bool: + if scaling_granularity is ScalingGranularity.AXISWISE: + if ( + scaling_type_input != ScalingType.DYNAMIC or + scaling_type_weight != ScalingType.DYNAMIC or + scaling_type_grad_output != ScalingType.DYNAMIC or + dtype != torch.bfloat16 or + (not IS_H100) + ): + return False + return True + + @pytest.mark.parametrize("fullgraph", [True]) @pytest.mark.parametrize( "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] @@ -113,6 +153,9 @@ def _get_config( @pytest.mark.parametrize( "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] ) +@pytest.mark.parametrize( + "scaling_granularity", [ScalingGranularity.TENSORWISE, ScalingGranularity.AXISWISE] +) @pytest.mark.parametrize("emulate", [False, True] if is_cuda_8_9 else [True]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") @@ -122,11 +165,25 @@ def test_eager_only( scaling_type_input: ScalingType, scaling_type_weight: ScalingType, scaling_type_grad_output: ScalingType, + scaling_granularity: ScalingGranularity, dtype: torch.dtype, ): + if not is_supported( + scaling_granularity, + scaling_type_input, + scaling_type_weight, + scaling_type_grad_output, + dtype, + ): + pytest.skip() + torch._dynamo.reset() config = _get_config( - scaling_type_input, scaling_type_weight, scaling_type_grad_output, emulate, + scaling_type_input, + scaling_type_weight, + scaling_type_grad_output, + scaling_granularity, + emulate, ) _test_compile_base( "eager", @@ -147,6 +204,9 @@ def test_eager_only( @pytest.mark.parametrize( "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] ) +@pytest.mark.parametrize( + "scaling_granularity", [ScalingGranularity.TENSORWISE, ScalingGranularity.AXISWISE] +) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_aot_eager( @@ -155,11 +215,25 @@ def test_aot_eager( scaling_type_input: ScalingType, scaling_type_weight: ScalingType, scaling_type_grad_output: ScalingType, + scaling_granularity: ScalingGranularity, dtype: torch.dtype, ): + if not is_supported( + scaling_granularity, + scaling_type_input, + scaling_type_weight, + scaling_type_grad_output, + dtype, + ): + pytest.skip() + torch._dynamo.reset() config = _get_config( - scaling_type_input, scaling_type_weight, scaling_type_grad_output, emulate, + scaling_type_input, + scaling_type_weight, + scaling_type_grad_output, + scaling_granularity, + emulate, ) _test_compile_base( "aot_eager", @@ -180,6 +254,9 @@ def test_aot_eager( @pytest.mark.parametrize( "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] ) +@pytest.mark.parametrize( + "scaling_granularity", [ScalingGranularity.TENSORWISE, ScalingGranularity.AXISWISE] +) @unittest.skipIf(not torch.cuda.is_available() or not is_cuda_8_9, "CUDA with float8 support not available") @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) def test_inductor( @@ -188,11 +265,25 @@ def test_inductor( scaling_type_input: ScalingType, scaling_type_weight: ScalingType, scaling_type_grad_output: ScalingType, + scaling_granularity: ScalingGranularity, dtype: torch.dtype, ): + if not is_supported( + scaling_granularity, + scaling_type_input, + scaling_type_weight, + scaling_type_grad_output, + dtype, + ): + pytest.skip() + torch._dynamo.reset() config = _get_config( - scaling_type_input, scaling_type_weight, scaling_type_grad_output, emulate, + scaling_type_input, + scaling_type_weight, + scaling_type_grad_output, + scaling_granularity, + emulate, ) _test_compile_base( "inductor", diff --git a/test/float8/test_numerics_integration.py b/test/float8/test_numerics_integration.py index 6db05dc56..07fcddaad 100644 --- a/test/float8/test_numerics_integration.py +++ b/test/float8/test_numerics_integration.py @@ -19,7 +19,12 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType +from torchao.float8.config import ( + CastConfig, + Float8LinearConfig, + ScalingType, + ScalingGranularity, +) from torchao.float8.float8_linear_utils import ( convert_to_float8_training, linear_requires_sync, @@ -28,6 +33,7 @@ from torchao.float8.float8_utils import compute_error, IS_ROCM is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) +is_cuda_9_0 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) torch.manual_seed(0) @@ -90,6 +96,10 @@ class TestFloat8NumericsIntegrationTest: "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], ) + @pytest.mark.parametrize( + "scaling_granularity", + [ScalingGranularity.TENSORWISE, ScalingGranularity.AXISWISE], + ) @pytest.mark.skipif(not is_cuda_8_9, reason="requires SM89 compatible machine") @pytest.mark.skipif(IS_ROCM, reason="test doesn't currently work on the ROCm stack") def test_encoder_fw_bw( @@ -97,10 +107,21 @@ def test_encoder_fw_bw( scaling_type_input: ScalingType, scaling_type_weight: ScalingType, scaling_type_grad_output: ScalingType, + scaling_granularity: ScalingGranularity, ): # TODO(later): maybe add float16 back if it becomes important data_dtype = torch.bfloat16 + if scaling_granularity is ScalingGranularity.AXISWISE: + if ( + scaling_type_input != ScalingType.DYNAMIC or + scaling_type_weight != ScalingType.DYNAMIC or + scaling_type_grad_output != ScalingType.DYNAMIC or + data_dtype != torch.bfloat16 or + (not is_cuda_9_0) + ): + pytest.skip() + # LLaMa 3 70B shapes model_ref = ( FeedForward( @@ -119,24 +140,34 @@ def test_encoder_fw_bw( if scaling_type_input is ScalingType.STATIC: cast_config_input = CastConfig( scaling_type=scaling_type_input, + scaling_granularity=scaling_granularity, static_scale=torch.tensor([1.0], device="cuda"), ) else: - cast_config_input = CastConfig(scaling_type=scaling_type_input) + cast_config_input = CastConfig( + scaling_type=scaling_type_input, + scaling_granularity=scaling_granularity, + ) if scaling_type_weight is ScalingType.STATIC: cast_config_weight = CastConfig( scaling_type=scaling_type_weight, static_scale=torch.tensor([1.0], device="cuda"), ) else: - cast_config_weight = CastConfig(scaling_type=scaling_type_weight) + cast_config_weight = CastConfig( + scaling_type=scaling_type_weight, + scaling_granularity=scaling_granularity, + ) if scaling_type_grad_output is ScalingType.STATIC: cast_config_grad_output = CastConfig( scaling_type=scaling_type_grad_output, static_scale=torch.tensor([1.0], device="cuda"), ) else: - cast_config_grad_output = CastConfig(scaling_type=scaling_type_grad_output) + cast_config_grad_output = CastConfig( + scaling_type=scaling_type_grad_output, + scaling_granularity=scaling_granularity, + ) config = Float8LinearConfig( cast_config_input=cast_config_input, diff --git a/torchao/float8/config.py b/torchao/float8/config.py index b24b5ba74..4d82bd111 100644 --- a/torchao/float8/config.py +++ b/torchao/float8/config.py @@ -37,6 +37,13 @@ class ScalingGranularity(enum.Enum): # size 1. AXISWISE = "axiswise" + def short_str(self): + if self is ScalingGranularity.TENSORWISE: + return "ten" + else: + assert self is ScalingGranularity.AXISWISE + return "axs" + @dataclass(frozen=True) class CastConfig: @@ -45,12 +52,16 @@ class CastConfig: """ scaling_type: ScalingType = ScalingType.DYNAMIC + scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE static_scale: Optional[torch.Tensor] = None def __post_init__(self): if self.scaling_type is ScalingType.STATIC: assert self.static_scale is not None, \ "static_scale must be specified for static scaling" + if self.scaling_granularity is ScalingGranularity.AXISWISE: + assert self.scaling_type is ScalingType.DYNAMIC, \ + "only dynamic scaling type is supported for axiswise scaling granularity" @dataclass(frozen=True) class DelayedScalingConfig: @@ -144,6 +155,27 @@ class Float8LinearConfig: # configuration, this field may move to per-tensor configs. delayed_scaling_config: DelayedScalingConfig = DelayedScalingConfig() + def __post_init__(self): + # float8 all-gather only supports tensorwise, in the future may support blockwise + if self.cast_config_weight.scaling_granularity != ScalingGranularity.TENSORWISE: + assert not self.enable_fsdp_float8_all_gather, \ + f"enable_fsdp_float8_all_gather only supports tensorwise scaling granularity, got {self.cast_config_weight.scaling_granularity}" + + # for now, axiswise granularity is all-or-nothing + # TODO(future PR): enable more granular setting per-gemm-input + has_any_axiswise_scaling = ( + self.cast_config_input.scaling_granularity is ScalingGranularity.AXISWISE or + self.cast_config_weight.scaling_granularity is ScalingGranularity.AXISWISE or + self.cast_config_grad_output.scaling_granularity is ScalingGranularity.AXISWISE + ) + has_all_axiswise_scaling = ( + self.cast_config_input.scaling_granularity is ScalingGranularity.AXISWISE and + self.cast_config_weight.scaling_granularity is ScalingGranularity.AXISWISE and + self.cast_config_grad_output.scaling_granularity is ScalingGranularity.AXISWISE + ) + if has_any_axiswise_scaling: + assert has_all_axiswise_scaling, \ + "for now, axiswise scaling must be enabled for either all casts or none of the casts" # If True, use 'fnuz' float8 types for calculations. # Currently, ROCm only supports fnuz variants. diff --git a/torchao/float8/float8_linear.py b/torchao/float8/float8_linear.py index cb0ff7afb..5f87e82fe 100644 --- a/torchao/float8/float8_linear.py +++ b/torchao/float8/float8_linear.py @@ -14,7 +14,7 @@ import torch -from torchao.float8.config import Float8LinearConfig, ScalingType +from torchao.float8.config import Float8LinearConfig, ScalingType, ScalingGranularity from torchao.float8.float8_scaling_utils import ( _maybe_initialize_amaxes_scales_for_float8_cast, @@ -42,11 +42,17 @@ ) -# this code was resurrected from https://github.com/pytorch-labs/torchao.float8/pull/128/files @torch._dynamo.allow_in_graph -class manual_float8_matmul(torch.autograd.Function): +class manual_float8_matmul_with_args_in_float8(torch.autograd.Function): """ Like torch.matmul, but with the arguments in float8 + + Note: this function requires all arguments to already be Float8Tensor objects, + which only supports tensorwise scaling granularity. The reason we didn't just make this + function support axiswise scaling granularity is because that would need very + careful testing of delayed scaling, as delayed scaling modifies buffers inplace. + + In the future we'll probably have to unify, just postponing that until a future PR. """ @staticmethod @@ -97,6 +103,133 @@ def backward(ctx, grad_output_fp8): return grad_input, grad_weight.t() +@torch._dynamo.allow_in_graph +class manual_float8_matmul_with_args_in_hp(torch.autograd.Function): + """ + Like torch.matmul, but with the arguments in high precision and the cast to float8 + defined inside of this function. + + Note: this function currently only supports dynamic scaling type and + axiswise granularity. We will have to unify this with other scaling types + and other granularities in a separate PR. + """ + + # TODO(this PR): types of inputs + @staticmethod + def forward( + ctx, + input_hp: torch.Tensor, + weight_hp_t: torch.Tensor, + linear_mm_config: LinearMMConfig, + input_scaling_granularity: ScalingGranularity, + weight_scaling_granularity: ScalingGranularity, + grad_output_scaling_granularity: ScalingGranularity, + ): + ctx.save_for_backward(input_hp, weight_hp_t) + ctx.linear_mm_config = linear_mm_config + ctx.input_scaling_granularity = input_scaling_granularity + ctx.weight_scaling_granularity = weight_scaling_granularity + ctx.grad_output_scaling_granularity = grad_output_scaling_granularity + + input_fp8 = hp_tensor_to_float8_dynamic( + input_hp, + e4m3_dtype, + linear_mm_config, + gemm_input_role=GemmInputRole.INPUT, + scaling_granularity=input_scaling_granularity, + axiswise_dim=-1, + ) + + weight_fp8_t = hp_tensor_to_float8_dynamic( + weight_hp_t, + e4m3_dtype, + linear_mm_config, + gemm_input_role=GemmInputRole.WEIGHT, + scaling_granularity=weight_scaling_granularity, + axiswise_dim=0, + ) + + # the reshapes are needed in order to make the shapes compatible with + # torch.mm + orig_shape = input_fp8.shape + input_fp8_reshaped = input_fp8.reshape(-1, orig_shape[-1]) + res_bits = torch.mm(input_fp8_reshaped, weight_fp8_t) + res_bits = res_bits.reshape(*orig_shape[:-1], res_bits.shape[-1]) + return res_bits + + @staticmethod + def backward(ctx, grad_output): + input_hp, weight_hp_t = ctx.saved_tensors + + # TODO scaling + + # the reshapes are needed in order to make the shapes compatible with + # torch.mm + grad_output_orig_shape = grad_output.shape + grad_output_reshaped = grad_output.reshape( + -1, grad_output_orig_shape[-1] + ) + + # + # calculate grad_input + # + + grad_output_reshaped_fp8_dim0 = hp_tensor_to_float8_dynamic( + grad_output_reshaped, + e5m2_dtype, + ctx.linear_mm_config, + gemm_input_role=GemmInputRole.GRAD_OUTPUT, + scaling_granularity=ctx.grad_output_scaling_granularity, + axiswise_dim=-1, + ) + weight_t_fp8_dim0 = hp_tensor_to_float8_dynamic( + weight_hp_t, + e4m3_dtype, + ctx.linear_mm_config, + gemm_input_role=GemmInputRole.WEIGHT, + scaling_granularity=ctx.weight_scaling_granularity, + axiswise_dim=1, # will be transposed + ) + + grad_input = torch.mm( + grad_output_reshaped_fp8_dim0, + weight_t_fp8_dim0.t(), + ) + grad_input = grad_input.reshape( + *grad_output_orig_shape[:-1], grad_input.shape[-1] + ) + + input_hp_orig_shape = input_hp.shape + input_hp_reshaped = input_hp.reshape(-1, input_hp_orig_shape[-1]) + + # + # calculate grad_weight + # + + grad_output_reshaped_fp8_dim1 = hp_tensor_to_float8_dynamic( + grad_output_reshaped, + e5m2_dtype, + ctx.linear_mm_config, + gemm_input_role=GemmInputRole.GRAD_OUTPUT, + scaling_granularity=ctx.grad_output_scaling_granularity, + axiswise_dim=0, # will be transposed + ) + input_reshaped_fp8_dim1 = hp_tensor_to_float8_dynamic( + input_hp_reshaped, + e4m3_dtype, + ctx.linear_mm_config, + gemm_input_role=GemmInputRole.INPUT, + scaling_granularity=ctx.input_scaling_granularity, + axiswise_dim=0, + ) + + grad_weight = torch.mm( + grad_output_reshaped_fp8_dim1.t(), + input_reshaped_fp8_dim1, + ) + + return grad_input, grad_weight.t(), None, None, None, None + class Float8Linear(torch.nn.Linear): """ @@ -289,7 +422,10 @@ def cast_input_to_float8( ) elif self.scaling_type_input is ScalingType.DYNAMIC: input_fp8 = hp_tensor_to_float8_dynamic( - input, e4m3_dtype, self.linear_mm_config + input, + e4m3_dtype, + self.linear_mm_config, + gemm_input_role=GemmInputRole.INPUT, ) else: assert self.scaling_type_input is ScalingType.STATIC @@ -395,13 +531,33 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: if self.has_any_delayed_scaling: self.float8_pre_forward(input) - input_fp8 = self.cast_input_to_float8(input, self.is_amax_initialized) - weight_fp8 = self.cast_weight_to_float8(self.weight, self.is_amax_initialized) + # TODO(this PR): reuse with config, make a property + has_all_axiswise_scaling = ( + self.config.cast_config_input.scaling_granularity is ScalingGranularity.AXISWISE and + self.config.cast_config_weight.scaling_granularity is ScalingGranularity.AXISWISE and + self.config.cast_config_grad_output.scaling_granularity is ScalingGranularity.AXISWISE + ) + + if not has_all_axiswise_scaling: + input_fp8 = self.cast_input_to_float8(input, self.is_amax_initialized) + weight_fp8 = self.cast_weight_to_float8(self.weight, self.is_amax_initialized) - output = manual_float8_matmul.apply(input_fp8, weight_fp8.t()) + output = manual_float8_matmul_with_args_in_float8.apply(input_fp8, weight_fp8.t()) - # Cast grad_output to float8_e5m2 during backward - output = self.cast_output_to_float8_in_bw(output) + # Cast grad_output to float8_e5m2 during backward + output = self.cast_output_to_float8_in_bw(output) + + else: + # for now, axiswise path is separate + # TODO(future PR): unify to support mix and match + output = manual_float8_matmul_with_args_in_hp.apply( + input, + self.weight.t(), + self.linear_mm_config, + self.config.cast_config_input.scaling_granularity, + self.config.cast_config_weight.scaling_granularity, + self.config.cast_config_grad_output.scaling_granularity, + ) if self.bias is not None: output = output + self.bias.to(output.dtype) @@ -410,13 +566,21 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: self.float8_post_forward() return output - def scaling_repr(self): - # add scaling settings without using too many characters + def scaling_type_repr(self): + # add scaling type settings without using too many characters # example: "i:del,w:del,go:dyn" return f"i:{self.scaling_type_input.short_str()},w:{self.scaling_type_weight.short_str()},go:{self.scaling_type_grad_output.short_str()}" + def scaling_granularity_repr(self): + # add scaling granularity settings without using too many characters + # example: "i:ten,w:ten,g:ten" or "i:axs,w:axs,g:axs" + gi = self.config.cast_config_input.scaling_granularity.short_str() + gw = self.config.cast_config_weight.scaling_granularity.short_str() + ggo = self.config.cast_config_grad_output.scaling_granularity.short_str() + return f"i:{gi},w:{gw},go:{ggo}" + def extra_repr(self): - s = f'{super().extra_repr()}, scaling="{self.scaling_repr()}"' + s = f'{super().extra_repr()}, scaling_type="{self.scaling_type_repr()}", scaling_granularity="{self.scaling_granularity_repr()}"' return s @classmethod diff --git a/torchao/float8/float8_ops.py b/torchao/float8/float8_ops.py index 1bf9faaa4..b97d03211 100644 --- a/torchao/float8/float8_ops.py +++ b/torchao/float8/float8_ops.py @@ -43,12 +43,9 @@ def decorator(func): [ aten.view.default, aten._unsafe_view.default, - aten.t.default, aten.as_strided.default, aten.clone.default, - aten.detach.default, aten.slice.Tensor, - aten.transpose.int, aten.fill_.Scalar, aten.reshape.default, ] @@ -65,13 +62,30 @@ def float8_desugar_op(aten_op, args, kwargs=None): ) +@implements( + [ + aten.detach.default, + ] +) +def float8_desugar_data_and_scale_op(aten_op, args, kwargs=None): + new_data = aten_op(args[0]._data, *args[1:], **kwargs) + new_scale = aten_op(args[0]._scale, *args[1:], **kwargs) + return Float8Tensor( + new_data, + new_scale, + args[0]._orig_dtype, + args[0]._linear_mm_config, + args[0]._gemm_input_role, + ) + + @implements( [ aten.t.default, aten.transpose.int, ] ) -def float8_desugar_data_and_scale(aten_op, args, kwargs=None): +def float8_transpose(aten_op, args, kwargs=None): new_data = aten_op(args[0]._data, *args[1:], **kwargs) new_scale = aten_op(args[0]._scale, *args[1:], **kwargs)