From e76db70ec14ff1ff6fc9f1944c904d4247c05de9 Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Mon, 7 Oct 2024 13:31:48 -0700 Subject: [PATCH] add axiswise scaling to Float8Linear (#920) Summary: This PR: support scaling of all arguments of all gemms to be axiswise, and ensure that training with axiswise scaling works e2e. Future PR: support more granular configurability and optimize performance, add docs Feel free to ignore the UX introduced in this PR, it's just an intermediate step. See next PR for the real UX. Test Plan: ``` // tests pass ./test/float8/test_everything.sh // sanity check on torchtitan with LLaMa 3 8B on 4 H100s with float8: // 1. verify performance does not regress with tensorwise scaling // 2. smoke test that axiswise scaling works and numerics are sane, performance isn't there though // logs: https://gist.github.com/vkuzo/70fa5eb3c23375f307d11e7bae48682f ``` Reviewers: Subscribers: Tasks: Tags: --- benchmarks/float8/bench_linear_float8.py | 32 +++- benchmarks/float8/bench_matmul.py | 13 +- benchmarks/float8/profile_linear_float8.py | 27 ++- test/float8/test_base.py | 33 +++- test/float8/test_compile.py | 105 +++++++++- test/float8/test_numerics_integration.py | 39 +++- torchao/float8/config.py | 33 +++- torchao/float8/float8_linear.py | 213 ++++++++++++++++++--- torchao/float8/float8_ops.py | 22 ++- 9 files changed, 462 insertions(+), 55 deletions(-) diff --git a/benchmarks/float8/bench_linear_float8.py b/benchmarks/float8/bench_linear_float8.py index e18006f0e..f92303c62 100644 --- a/benchmarks/float8/bench_linear_float8.py +++ b/benchmarks/float8/bench_linear_float8.py @@ -14,7 +14,12 @@ import torch import torch.utils.benchmark as benchmark -from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType +from torchao.float8.config import ( + CastConfig, + Float8LinearConfig, + ScalingType, + ScalingGranularity, +) from torchao.float8.float8_linear import Float8Linear from torchao.float8.float8_linear_utils import ( linear_requires_sync, @@ -107,6 +112,7 @@ def main( scaling_type_input: str = "dynamic", scaling_type_weight: str = "dynamic", scaling_type_grad_output: str = "dynamic", + scaling_granularity: str = "tensorwise", ): device = "cuda" print(f"Compile is set to | {compile}") @@ -114,28 +120,41 @@ def main( scaling_type_input = ScalingType(scaling_type_input) scaling_type_weight = ScalingType(scaling_type_weight) scaling_type_grad_output = ScalingType(scaling_type_grad_output) + scaling_granularity = ScalingGranularity(scaling_granularity) if scaling_type_input is ScalingType.STATIC: cast_config_input=CastConfig( scaling_type=scaling_type_input, static_scale=torch.tensor([1.0], device="cuda"), + scaling_granularity=scaling_granularity, ) else: - cast_config_input=CastConfig(scaling_type=scaling_type_input) + cast_config_input=CastConfig( + scaling_type=scaling_type_input, + scaling_granularity=scaling_granularity, + ) if scaling_type_weight is ScalingType.STATIC: cast_config_weight=CastConfig( scaling_type=scaling_type_weight, static_scale=torch.tensor([1.0], device="cuda"), + scaling_granularity=scaling_granularity, ) else: - cast_config_weight=CastConfig(scaling_type=scaling_type_weight) + cast_config_weight=CastConfig( + scaling_type=scaling_type_weight, + scaling_granularity=scaling_granularity, + ) if scaling_type_grad_output is ScalingType.STATIC: cast_config_grad_output=CastConfig( scaling_type=scaling_type_grad_output, static_scale=torch.tensor([1.0], device="cuda"), + scaling_granularity=scaling_granularity, ) else: - cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output) + cast_config_grad_output=CastConfig( + scaling_type=scaling_type_grad_output, + scaling_granularity=scaling_granularity, + ) config = Float8LinearConfig( cast_config_input=cast_config_input, @@ -167,7 +186,7 @@ def main( copy.deepcopy(linear_ref), config=config, ) - scaling_repr = linear_float8.scaling_repr() + scaling_repr = f"{linear_float8.scaling_type_repr()},{linear_float8.scaling_granularity_repr()}" if fast_accum: linear_float8.forward_config = ScaledMMConfig(False, True, False) @@ -310,6 +329,7 @@ def invoke_main() -> None: parser.add_argument("--scaling_type_input", type=str, required=False) parser.add_argument("--scaling_type_weight", type=str, required=False) parser.add_argument("--scaling_type_grad_output", type=str, required=False) + parser.add_argument("--scaling_granularity", type=str, required=False) args = parser.parse_args() output_path = Path(args.output_path) if args.output_path is not None else None kwargs = {} @@ -327,6 +347,8 @@ def invoke_main() -> None: kwargs["scaling_type_weight"] = args.scaling_type_weight if args.scaling_type_grad_output is not None: kwargs["scaling_type_grad_output"] = args.scaling_type_grad_output + if args.scaling_granularity is not None: + kwargs["scaling_granularity"] = args.scaling_granularity main( output_path, not args.disable_compile, diff --git a/benchmarks/float8/bench_matmul.py b/benchmarks/float8/bench_matmul.py index 6b816300c..e969846b2 100644 --- a/benchmarks/float8/bench_matmul.py +++ b/benchmarks/float8/bench_matmul.py @@ -13,6 +13,8 @@ import torch.nn as nn import torch.utils.benchmark as benchmark +from torchao.float8.config import ScalingGranularity + from utils import ( get_name_to_shapes_iter, profiler_output_to_filtered_time_by_kernel_name, @@ -75,6 +77,7 @@ def run( K: Optional[int] = None, N: Optional[int] = None, use_gpu_kernel_time: bool = False, + scaling_granularity: str = "tensorwise", ): device = "cuda" @@ -84,6 +87,7 @@ def run( dtype = torch.bfloat16 name_to_shapes = get_name_to_shapes_iter(shape_gen_name, M, K, N) fast_accum_vals = [True, False] + scaling_granularity = ScalingGranularity(scaling_granularity) for idx, (fast_accum, (name, (M, K, N))) in enumerate(itertools.product(fast_accum_vals, name_to_shapes)): if n_limit is not None and idx >= n_limit: @@ -109,8 +113,13 @@ def run( d1, d2, d3 = torch.float8_e4m3fn, torch.float8_e4m3fn, dtype A = torch.zeros(M, K, device=device, dtype=d1) B = torch.zeros(K, N, device=device, dtype=d2).t().contiguous().t() - scale_a = torch.tensor([1.0], device=device) - scale_b = torch.tensor([1.0], device=device) + if scaling_granularity == ScalingGranularity.TENSORWISE: + scale_a = torch.tensor([1.0], device=device) + scale_b = torch.tensor([1.0], device=device) + else: + assert scaling_granularity == ScalingGranularity.AXISWISE, "unsupported" + scale_a = torch.ones(M, 1, device=device) + scale_b = torch.ones(1, N, device=device) def do_matmul(A, B): nonlocal scale_a diff --git a/benchmarks/float8/profile_linear_float8.py b/benchmarks/float8/profile_linear_float8.py index c204d49b0..6afefa009 100644 --- a/benchmarks/float8/profile_linear_float8.py +++ b/benchmarks/float8/profile_linear_float8.py @@ -22,7 +22,12 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType +from torchao.float8.config import ( + CastConfig, + Float8LinearConfig, + ScalingType, + ScalingGranularity, +) from torchao.float8.float8_linear_utils import ( convert_to_float8_training, linear_requires_sync, @@ -252,6 +257,7 @@ def main( scaling_type_input: str = "dynamic", scaling_type_weight: str = "dynamic", scaling_type_grad_output: str = "dynamic", + scaling_granularity: str = "tensorwise", model_type: str = "linear", dtype_filter: str = "both", add_inductor_metadata_to_trace: bool = True, @@ -263,28 +269,41 @@ def main( scaling_type_input = ScalingType(scaling_type_input) scaling_type_weight = ScalingType(scaling_type_weight) scaling_type_grad_output = ScalingType(scaling_type_grad_output) + scaling_granularity = ScalingGranularity(scaling_granularity) if scaling_type_input is ScalingType.STATIC: cast_config_input=CastConfig( scaling_type=scaling_type_input, static_scale=torch.tensor([1.0], device="cuda"), + scaling_granularity=scaling_granularity, ) else: - cast_config_input=CastConfig(scaling_type=scaling_type_input) + cast_config_input=CastConfig( + scaling_type=scaling_type_input, + scaling_granularity=scaling_granularity, + ) if scaling_type_weight is ScalingType.STATIC: cast_config_weight=CastConfig( scaling_type=scaling_type_weight, static_scale=torch.tensor([1.0], device="cuda"), + scaling_granularity=scaling_granularity, ) else: - cast_config_weight=CastConfig(scaling_type=scaling_type_weight) + cast_config_weight=CastConfig( + scaling_type=scaling_type_weight, + scaling_granularity=scaling_granularity, + ) if scaling_type_grad_output is ScalingType.STATIC: cast_config_grad_output=CastConfig( scaling_type=scaling_type_grad_output, static_scale=torch.tensor([1.0], device="cuda"), + scaling_granularity=scaling_granularity, ) else: - cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output) + cast_config_grad_output=CastConfig( + scaling_type=scaling_type_grad_output, + scaling_granularity=scaling_granularity, + ) config = Float8LinearConfig( cast_config_input=cast_config_input, diff --git a/test/float8/test_base.py b/test/float8/test_base.py index a14981732..0aab91f55 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -324,6 +324,10 @@ def _test_linear_impl( "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], ) + @pytest.mark.parametrize( + "scaling_granularity", + [ScalingGranularity.TENSORWISE, ScalingGranularity.AXISWISE], + ) @pytest.mark.parametrize("linear_dtype", [torch.bfloat16, torch.float32]) @pytest.mark.parametrize("linear_bias", [False, True]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") @@ -334,33 +338,56 @@ def test_linear( scaling_type_input: ScalingType, scaling_type_weight: ScalingType, scaling_type_grad_output: ScalingType, + scaling_granularity: ScalingGranularity, linear_dtype: torch.dtype, linear_bias: bool, ): + if scaling_granularity is ScalingGranularity.AXISWISE: + if ( + scaling_type_input != ScalingType.DYNAMIC or + scaling_type_weight != ScalingType.DYNAMIC or + scaling_type_grad_output != ScalingType.DYNAMIC or + linear_dtype != torch.bfloat16 or + (not is_cuda_9_0) + ): + pytest.skip() + x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype) m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype) if scaling_type_input is ScalingType.STATIC: cast_config_input = CastConfig( scaling_type=scaling_type_input, + scaling_granularity=scaling_granularity, static_scale=torch.tensor([1.0], device="cuda"), ) else: - cast_config_input = CastConfig(scaling_type=scaling_type_input) + cast_config_input = CastConfig( + scaling_type=scaling_type_input, + scaling_granularity=scaling_granularity, + ) if scaling_type_weight is ScalingType.STATIC: cast_config_weight = CastConfig( scaling_type=scaling_type_weight, + scaling_granularity=scaling_granularity, static_scale=torch.tensor([1.0], device="cuda"), ) else: - cast_config_weight = CastConfig(scaling_type=scaling_type_weight) + cast_config_weight = CastConfig( + scaling_type=scaling_type_weight, + scaling_granularity=scaling_granularity, + ) if scaling_type_grad_output is ScalingType.STATIC: cast_config_grad_output = CastConfig( scaling_type=scaling_type_grad_output, + scaling_granularity=scaling_granularity, static_scale=torch.tensor([1.0], device="cuda"), ) else: - cast_config_grad_output = CastConfig(scaling_type=scaling_type_grad_output) + cast_config_grad_output = CastConfig( + scaling_type=scaling_type_grad_output, + scaling_granularity=scaling_granularity, + ) config = Float8LinearConfig( cast_config_input=cast_config_input, diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index 5106bd778..317743288 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -18,7 +18,12 @@ import torch import torch.nn as nn -from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType +from torchao.float8.config import ( + CastConfig, + Float8LinearConfig, + ScalingType, + ScalingGranularity, +) from torchao.float8.float8_linear import Float8Linear from torchao.float8.float8_linear_utils import ( convert_to_float8_training, @@ -39,6 +44,7 @@ from torch._dynamo.test_case import TestCase as DynamoTestCase from torch._dynamo.testing import CompileCounterWithBackend +# TODO(future PR): standardize IS_H100 with the rest of the codebase is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) @@ -67,6 +73,8 @@ def _test_compile_base( y_fp8.sum().backward() y_ref = m_ref(x) y_ref.sum().backward() + # TODO(future PR): can also test fp8 eager vs compile here with a tigher + # tolerance torch.testing.assert_close(y_fp8, y_ref, atol=9.5e-2, rtol=9.5e-2) torch.testing.assert_close( m_fp8.weight.grad, m_ref.weight.grad, atol=2e-1, rtol=2e-1 @@ -77,29 +85,42 @@ def _get_config( scaling_type_input, scaling_type_weight, scaling_type_grad_output, + scaling_granularity, emulate, ): if scaling_type_input is ScalingType.STATIC: cast_config_input = CastConfig( scaling_type=scaling_type_input, + scaling_granularity=scaling_granularity, static_scale=torch.tensor([1.0], device="cuda"), ) else: - cast_config_input = CastConfig(scaling_type=scaling_type_input) + cast_config_input = CastConfig( + scaling_type=scaling_type_input, + scaling_granularity=scaling_granularity, + ) if scaling_type_weight is ScalingType.STATIC: cast_config_weight = CastConfig( scaling_type=scaling_type_weight, + scaling_granularity=scaling_granularity, static_scale=torch.tensor([1.0], device="cuda"), ) else: - cast_config_weight = CastConfig(scaling_type=scaling_type_weight) + cast_config_weight = CastConfig( + scaling_type=scaling_type_weight, + scaling_granularity=scaling_granularity, + ) if scaling_type_grad_output is ScalingType.STATIC: cast_config_grad_output = CastConfig( scaling_type=scaling_type_grad_output, + scaling_granularity=scaling_granularity, static_scale=torch.tensor([1.0], device="cuda"), ) else: - cast_config_grad_output = CastConfig(scaling_type=scaling_type_grad_output) + cast_config_grad_output = CastConfig( + scaling_type=scaling_type_grad_output, + scaling_granularity=scaling_granularity, + ) config = Float8LinearConfig( cast_config_input=cast_config_input, @@ -110,6 +131,25 @@ def _get_config( return config +def is_supported( + scaling_granularity, + scaling_type_input, + scaling_type_weight, + scaling_type_grad_output, + dtype, +) -> bool: + if scaling_granularity is ScalingGranularity.AXISWISE: + if ( + scaling_type_input != ScalingType.DYNAMIC or + scaling_type_weight != ScalingType.DYNAMIC or + scaling_type_grad_output != ScalingType.DYNAMIC or + dtype != torch.bfloat16 or + (not is_H100) + ): + return False + return True + + @pytest.mark.parametrize("fullgraph", [True]) @pytest.mark.parametrize( "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] @@ -120,6 +160,9 @@ def _get_config( @pytest.mark.parametrize( "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] ) +@pytest.mark.parametrize( + "scaling_granularity", [ScalingGranularity.TENSORWISE, ScalingGranularity.AXISWISE] +) @pytest.mark.parametrize("emulate", [False, True] if is_cuda_8_9 else [True]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") @@ -129,11 +172,25 @@ def test_eager_only( scaling_type_input: ScalingType, scaling_type_weight: ScalingType, scaling_type_grad_output: ScalingType, + scaling_granularity: ScalingGranularity, dtype: torch.dtype, ): + if not is_supported( + scaling_granularity, + scaling_type_input, + scaling_type_weight, + scaling_type_grad_output, + dtype, + ): + pytest.skip() + torch._dynamo.reset() config = _get_config( - scaling_type_input, scaling_type_weight, scaling_type_grad_output, emulate, + scaling_type_input, + scaling_type_weight, + scaling_type_grad_output, + scaling_granularity, + emulate, ) _test_compile_base( "eager", @@ -154,6 +211,9 @@ def test_eager_only( @pytest.mark.parametrize( "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] ) +@pytest.mark.parametrize( + "scaling_granularity", [ScalingGranularity.TENSORWISE, ScalingGranularity.AXISWISE] +) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_aot_eager( @@ -162,11 +222,25 @@ def test_aot_eager( scaling_type_input: ScalingType, scaling_type_weight: ScalingType, scaling_type_grad_output: ScalingType, + scaling_granularity: ScalingGranularity, dtype: torch.dtype, ): + if not is_supported( + scaling_granularity, + scaling_type_input, + scaling_type_weight, + scaling_type_grad_output, + dtype, + ): + pytest.skip() + torch._dynamo.reset() config = _get_config( - scaling_type_input, scaling_type_weight, scaling_type_grad_output, emulate, + scaling_type_input, + scaling_type_weight, + scaling_type_grad_output, + scaling_granularity, + emulate, ) _test_compile_base( "aot_eager", @@ -187,6 +261,9 @@ def test_aot_eager( @pytest.mark.parametrize( "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] ) +@pytest.mark.parametrize( + "scaling_granularity", [ScalingGranularity.TENSORWISE, ScalingGranularity.AXISWISE] +) @unittest.skipIf(not torch.cuda.is_available() or not is_cuda_8_9, "CUDA with float8 support not available") @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) def test_inductor( @@ -195,11 +272,25 @@ def test_inductor( scaling_type_input: ScalingType, scaling_type_weight: ScalingType, scaling_type_grad_output: ScalingType, + scaling_granularity: ScalingGranularity, dtype: torch.dtype, ): + if not is_supported( + scaling_granularity, + scaling_type_input, + scaling_type_weight, + scaling_type_grad_output, + dtype, + ): + pytest.skip() + torch._dynamo.reset() config = _get_config( - scaling_type_input, scaling_type_weight, scaling_type_grad_output, emulate, + scaling_type_input, + scaling_type_weight, + scaling_type_grad_output, + scaling_granularity, + emulate, ) _test_compile_base( "inductor", diff --git a/test/float8/test_numerics_integration.py b/test/float8/test_numerics_integration.py index 6db05dc56..07fcddaad 100644 --- a/test/float8/test_numerics_integration.py +++ b/test/float8/test_numerics_integration.py @@ -19,7 +19,12 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType +from torchao.float8.config import ( + CastConfig, + Float8LinearConfig, + ScalingType, + ScalingGranularity, +) from torchao.float8.float8_linear_utils import ( convert_to_float8_training, linear_requires_sync, @@ -28,6 +33,7 @@ from torchao.float8.float8_utils import compute_error, IS_ROCM is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) +is_cuda_9_0 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) torch.manual_seed(0) @@ -90,6 +96,10 @@ class TestFloat8NumericsIntegrationTest: "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], ) + @pytest.mark.parametrize( + "scaling_granularity", + [ScalingGranularity.TENSORWISE, ScalingGranularity.AXISWISE], + ) @pytest.mark.skipif(not is_cuda_8_9, reason="requires SM89 compatible machine") @pytest.mark.skipif(IS_ROCM, reason="test doesn't currently work on the ROCm stack") def test_encoder_fw_bw( @@ -97,10 +107,21 @@ def test_encoder_fw_bw( scaling_type_input: ScalingType, scaling_type_weight: ScalingType, scaling_type_grad_output: ScalingType, + scaling_granularity: ScalingGranularity, ): # TODO(later): maybe add float16 back if it becomes important data_dtype = torch.bfloat16 + if scaling_granularity is ScalingGranularity.AXISWISE: + if ( + scaling_type_input != ScalingType.DYNAMIC or + scaling_type_weight != ScalingType.DYNAMIC or + scaling_type_grad_output != ScalingType.DYNAMIC or + data_dtype != torch.bfloat16 or + (not is_cuda_9_0) + ): + pytest.skip() + # LLaMa 3 70B shapes model_ref = ( FeedForward( @@ -119,24 +140,34 @@ def test_encoder_fw_bw( if scaling_type_input is ScalingType.STATIC: cast_config_input = CastConfig( scaling_type=scaling_type_input, + scaling_granularity=scaling_granularity, static_scale=torch.tensor([1.0], device="cuda"), ) else: - cast_config_input = CastConfig(scaling_type=scaling_type_input) + cast_config_input = CastConfig( + scaling_type=scaling_type_input, + scaling_granularity=scaling_granularity, + ) if scaling_type_weight is ScalingType.STATIC: cast_config_weight = CastConfig( scaling_type=scaling_type_weight, static_scale=torch.tensor([1.0], device="cuda"), ) else: - cast_config_weight = CastConfig(scaling_type=scaling_type_weight) + cast_config_weight = CastConfig( + scaling_type=scaling_type_weight, + scaling_granularity=scaling_granularity, + ) if scaling_type_grad_output is ScalingType.STATIC: cast_config_grad_output = CastConfig( scaling_type=scaling_type_grad_output, static_scale=torch.tensor([1.0], device="cuda"), ) else: - cast_config_grad_output = CastConfig(scaling_type=scaling_type_grad_output) + cast_config_grad_output = CastConfig( + scaling_type=scaling_type_grad_output, + scaling_granularity=scaling_granularity, + ) config = Float8LinearConfig( cast_config_input=cast_config_input, diff --git a/torchao/float8/config.py b/torchao/float8/config.py index 16e638738..0ed6d2622 100644 --- a/torchao/float8/config.py +++ b/torchao/float8/config.py @@ -37,6 +37,13 @@ class ScalingGranularity(enum.Enum): # size 1. AXISWISE = "axiswise" + def short_str(self): + if self is ScalingGranularity.TENSORWISE: + return "ten" + else: + assert self is ScalingGranularity.AXISWISE + return "axs" + @dataclass(frozen=True) class CastConfig: @@ -45,6 +52,7 @@ class CastConfig: """ scaling_type: ScalingType = ScalingType.DYNAMIC + scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE static_scale: Optional[torch.Tensor] = None def __post_init__(self): @@ -52,7 +60,9 @@ def __post_init__(self): assert ( self.static_scale is not None ), "static_scale must be specified for static scaling" - + if self.scaling_granularity is ScalingGranularity.AXISWISE: + assert self.scaling_type is ScalingType.DYNAMIC, \ + "only dynamic scaling type is supported for axiswise scaling granularity" @dataclass(frozen=True) class DelayedScalingConfig: @@ -163,6 +173,27 @@ class Float8LinearConfig: force_recompute_fp8_weight_in_bwd: bool = False + def __post_init__(self): + # float8 all-gather only supports tensorwise, in the future may support blockwise + if self.cast_config_weight.scaling_granularity != ScalingGranularity.TENSORWISE: + assert not self.enable_fsdp_float8_all_gather, \ + f"enable_fsdp_float8_all_gather only supports tensorwise scaling granularity, got {self.cast_config_weight.scaling_granularity}" + + # for now, axiswise granularity is all-or-nothing + # TODO(future PR): enable more granular setting per-gemm-input + has_any_axiswise_scaling = ( + self.cast_config_input.scaling_granularity is ScalingGranularity.AXISWISE or + self.cast_config_weight.scaling_granularity is ScalingGranularity.AXISWISE or + self.cast_config_grad_output.scaling_granularity is ScalingGranularity.AXISWISE + ) + has_all_axiswise_scaling = ( + self.cast_config_input.scaling_granularity is ScalingGranularity.AXISWISE and + self.cast_config_weight.scaling_granularity is ScalingGranularity.AXISWISE and + self.cast_config_grad_output.scaling_granularity is ScalingGranularity.AXISWISE + ) + if has_any_axiswise_scaling: + assert has_all_axiswise_scaling, \ + "for now, axiswise scaling must be enabled for either all casts or none of the casts" # If True, use 'fnuz' float8 types for calculations. # Currently, ROCm only supports fnuz variants. diff --git a/torchao/float8/float8_linear.py b/torchao/float8/float8_linear.py index 4558695e3..9aaffa99c 100644 --- a/torchao/float8/float8_linear.py +++ b/torchao/float8/float8_linear.py @@ -16,7 +16,7 @@ import torch.utils.checkpoint as checkpoint -from torchao.float8.config import Float8LinearConfig, ScalingType +from torchao.float8.config import Float8LinearConfig, ScalingType, ScalingGranularity from torchao.float8.float8_scaling_utils import ( _maybe_initialize_amaxes_scales_for_float8_cast, @@ -50,11 +50,17 @@ ) -# this code was resurrected from https://github.com/pytorch-labs/torchao.float8/pull/128/files @torch._dynamo.allow_in_graph -class manual_float8_matmul(torch.autograd.Function): +class manual_float8_matmul_with_args_in_float8(torch.autograd.Function): """ Like torch.matmul, but with the arguments in float8 + + Note: this function requires all arguments to already be Float8Tensor objects, + which only supports tensorwise scaling granularity. The reason we didn't just make this + function support axiswise scaling granularity is because that would need very + careful testing of delayed scaling, as delayed scaling modifies buffers inplace. + + In the future we'll probably have to unify, just postponing that until a future PR. """ @staticmethod @@ -105,6 +111,133 @@ def backward(ctx, grad_output_fp8): return grad_input, grad_weight.t() +@torch._dynamo.allow_in_graph +class manual_float8_matmul_with_args_in_hp(torch.autograd.Function): + """ + Like torch.matmul, but with the arguments in high precision and the cast to float8 + defined inside of this function. + + Note: this function currently only supports dynamic scaling type and + axiswise granularity. We will have to unify this with other scaling types + and other granularities in a separate PR. + """ + + # TODO(this PR): types of inputs + @staticmethod + def forward( + ctx, + input_hp: torch.Tensor, + weight_hp_t: torch.Tensor, + linear_mm_config: LinearMMConfig, + input_scaling_granularity: ScalingGranularity, + weight_scaling_granularity: ScalingGranularity, + grad_output_scaling_granularity: ScalingGranularity, + ): + ctx.save_for_backward(input_hp, weight_hp_t) + ctx.linear_mm_config = linear_mm_config + ctx.input_scaling_granularity = input_scaling_granularity + ctx.weight_scaling_granularity = weight_scaling_granularity + ctx.grad_output_scaling_granularity = grad_output_scaling_granularity + + input_fp8 = hp_tensor_to_float8_dynamic( + input_hp, + e4m3_dtype, + linear_mm_config, + gemm_input_role=GemmInputRole.INPUT, + scaling_granularity=input_scaling_granularity, + axiswise_dim=-1, + ) + + weight_fp8_t = hp_tensor_to_float8_dynamic( + weight_hp_t, + e4m3_dtype, + linear_mm_config, + gemm_input_role=GemmInputRole.WEIGHT, + scaling_granularity=weight_scaling_granularity, + axiswise_dim=0, + ) + + # the reshapes are needed in order to make the shapes compatible with + # torch.mm + orig_shape = input_fp8.shape + input_fp8_reshaped = input_fp8.reshape(-1, orig_shape[-1]) + res_bits = torch.mm(input_fp8_reshaped, weight_fp8_t) + res_bits = res_bits.reshape(*orig_shape[:-1], res_bits.shape[-1]) + return res_bits + + @staticmethod + def backward(ctx, grad_output): + input_hp, weight_hp_t = ctx.saved_tensors + + # TODO scaling + + # the reshapes are needed in order to make the shapes compatible with + # torch.mm + grad_output_orig_shape = grad_output.shape + grad_output_reshaped = grad_output.reshape( + -1, grad_output_orig_shape[-1] + ) + + # + # calculate grad_input + # + + grad_output_reshaped_fp8_dim0 = hp_tensor_to_float8_dynamic( + grad_output_reshaped, + e5m2_dtype, + ctx.linear_mm_config, + gemm_input_role=GemmInputRole.GRAD_OUTPUT, + scaling_granularity=ctx.grad_output_scaling_granularity, + axiswise_dim=-1, + ) + weight_t_fp8_dim0 = hp_tensor_to_float8_dynamic( + weight_hp_t, + e4m3_dtype, + ctx.linear_mm_config, + gemm_input_role=GemmInputRole.WEIGHT, + scaling_granularity=ctx.weight_scaling_granularity, + axiswise_dim=-1, # will be transposed + ) + + grad_input = torch.mm( + grad_output_reshaped_fp8_dim0, + weight_t_fp8_dim0.t(), + ) + grad_input = grad_input.reshape( + *grad_output_orig_shape[:-1], grad_input.shape[-1] + ) + + input_hp_orig_shape = input_hp.shape + input_hp_reshaped = input_hp.reshape(-1, input_hp_orig_shape[-1]) + + # + # calculate grad_weight + # + + grad_output_reshaped_fp8_dim1 = hp_tensor_to_float8_dynamic( + grad_output_reshaped, + e5m2_dtype, + ctx.linear_mm_config, + gemm_input_role=GemmInputRole.GRAD_OUTPUT, + scaling_granularity=ctx.grad_output_scaling_granularity, + axiswise_dim=0, # will be transposed + ) + input_reshaped_fp8_dim1 = hp_tensor_to_float8_dynamic( + input_hp_reshaped, + e4m3_dtype, + ctx.linear_mm_config, + gemm_input_role=GemmInputRole.INPUT, + scaling_granularity=ctx.input_scaling_granularity, + axiswise_dim=0, + ) + + grad_weight = torch.mm( + grad_output_reshaped_fp8_dim1.t(), + input_reshaped_fp8_dim1, + ) + + return grad_input, grad_weight.t(), None, None, None, None + class Float8Linear(torch.nn.Linear): """ @@ -297,7 +430,10 @@ def cast_input_to_float8( ) elif self.scaling_type_input is ScalingType.DYNAMIC: input_fp8 = hp_tensor_to_float8_dynamic( - input, e4m3_dtype, self.linear_mm_config + input, + e4m3_dtype, + self.linear_mm_config, + gemm_input_role=GemmInputRole.INPUT, ) else: assert self.scaling_type_input is ScalingType.STATIC @@ -395,29 +531,48 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: if self.has_any_delayed_scaling: self.float8_pre_forward(input) - input_fp8 = self.cast_input_to_float8(input, self.is_amax_initialized) + # TODO(this PR): reuse with config, make a property + has_all_axiswise_scaling = ( + self.config.cast_config_input.scaling_granularity is ScalingGranularity.AXISWISE and + self.config.cast_config_weight.scaling_granularity is ScalingGranularity.AXISWISE and + self.config.cast_config_grad_output.scaling_granularity is ScalingGranularity.AXISWISE + ) + + if not has_all_axiswise_scaling: + input_fp8 = self.cast_input_to_float8(input, self.is_amax_initialized) + # If force_recompute_fp8_weight_in_bwd, we only recompute the fp8 weight, + # weight_scale should be saved. + weight_scale = self.get_weight_scale(self.weight) + + if self.config.force_recompute_fp8_weight_in_bwd: + weight_fp8_t = checkpoint.checkpoint( + self.cast_weight_to_float8_t, + self.weight, + self.is_amax_initialized, + weight_scale, + ) + else: + weight_fp8_t = self.cast_weight_to_float8_t( + self.weight, self.is_amax_initialized, weight_scale + ) + + output = manual_float8_matmul_with_args_in_float8.apply(input_fp8, weight_fp8_t) - # If force_recompute_fp8_weight_in_bwd, we only recompute the fp8 weight, - # weight_scale should be saved. - weight_scale = self.get_weight_scale(self.weight) + # Cast grad_output to float8_e5m2 during backward + output = self.cast_output_to_float8_in_bw(output) - if self.config.force_recompute_fp8_weight_in_bwd: - weight_fp8_t = checkpoint.checkpoint( - self.cast_weight_to_float8_t, - self.weight, - self.is_amax_initialized, - weight_scale, - ) else: - weight_fp8_t = self.cast_weight_to_float8_t( - self.weight, self.is_amax_initialized, weight_scale + # for now, axiswise path is separate + # TODO(future PR): unify to support mix and match + output = manual_float8_matmul_with_args_in_hp.apply( + input, + self.weight.t(), + self.linear_mm_config, + self.config.cast_config_input.scaling_granularity, + self.config.cast_config_weight.scaling_granularity, + self.config.cast_config_grad_output.scaling_granularity, ) - output = manual_float8_matmul.apply(input_fp8, weight_fp8_t) - - # Cast grad_output to float8_e5m2 during backward - output = self.cast_output_to_float8_in_bw(output) - if self.bias is not None: output = output + self.bias.to(output.dtype) @@ -425,13 +580,21 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: self.float8_post_forward() return output - def scaling_repr(self): - # add scaling settings without using too many characters + def scaling_type_repr(self): + # add scaling type settings without using too many characters # example: "i:del,w:del,go:dyn" return f"i:{self.scaling_type_input.short_str()},w:{self.scaling_type_weight.short_str()},go:{self.scaling_type_grad_output.short_str()}" + def scaling_granularity_repr(self): + # add scaling granularity settings without using too many characters + # example: "i:ten,w:ten,g:ten" or "i:axs,w:axs,g:axs" + gi = self.config.cast_config_input.scaling_granularity.short_str() + gw = self.config.cast_config_weight.scaling_granularity.short_str() + ggo = self.config.cast_config_grad_output.scaling_granularity.short_str() + return f"i:{gi},w:{gw},go:{ggo}" + def extra_repr(self): - s = f'{super().extra_repr()}, scaling="{self.scaling_repr()}"' + s = f'{super().extra_repr()}, scaling_type="{self.scaling_type_repr()}", scaling_granularity="{self.scaling_granularity_repr()}"' return s @classmethod diff --git a/torchao/float8/float8_ops.py b/torchao/float8/float8_ops.py index 1bf9faaa4..b97d03211 100644 --- a/torchao/float8/float8_ops.py +++ b/torchao/float8/float8_ops.py @@ -43,12 +43,9 @@ def decorator(func): [ aten.view.default, aten._unsafe_view.default, - aten.t.default, aten.as_strided.default, aten.clone.default, - aten.detach.default, aten.slice.Tensor, - aten.transpose.int, aten.fill_.Scalar, aten.reshape.default, ] @@ -65,13 +62,30 @@ def float8_desugar_op(aten_op, args, kwargs=None): ) +@implements( + [ + aten.detach.default, + ] +) +def float8_desugar_data_and_scale_op(aten_op, args, kwargs=None): + new_data = aten_op(args[0]._data, *args[1:], **kwargs) + new_scale = aten_op(args[0]._scale, *args[1:], **kwargs) + return Float8Tensor( + new_data, + new_scale, + args[0]._orig_dtype, + args[0]._linear_mm_config, + args[0]._gemm_input_role, + ) + + @implements( [ aten.t.default, aten.transpose.int, ] ) -def float8_desugar_data_and_scale(aten_op, args, kwargs=None): +def float8_transpose(aten_op, args, kwargs=None): new_data = aten_op(args[0]._data, *args[1:], **kwargs) new_scale = aten_op(args[0]._scale, *args[1:], **kwargs)