From ec6418283740c04f823229ed5eb79fceb2c0d30a Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Wed, 18 Dec 2024 08:41:31 -0800 Subject: [PATCH] [float8nocompile] Simplified Float8Linear implementation which only supports dynamic tensorwise scaling (#1429) * float8nocompile: add simplified implementation of float8linear which only supports dynamic tensorwise scaling * address comments --------- Co-authored-by: Daniel Vega-Myhre --- torchao/prototype/float8nocompile/.gitignore | 1 - .../float8nocompile/examples/example.py | 33 +++++ .../float8nocompile/float8nocompile_linear.py | 128 +++++++++++++++++- .../float8nocompile_linear_utils.py | 12 +- .../float8nocompile_scaling_utils.py | 61 +++++++++ 5 files changed, 224 insertions(+), 11 deletions(-) create mode 100644 torchao/prototype/float8nocompile/examples/example.py create mode 100644 torchao/prototype/float8nocompile/float8nocompile_scaling_utils.py diff --git a/torchao/prototype/float8nocompile/.gitignore b/torchao/prototype/float8nocompile/.gitignore index b7bd8d0f1..174f6de30 100644 --- a/torchao/prototype/float8nocompile/.gitignore +++ b/torchao/prototype/float8nocompile/.gitignore @@ -1,2 +1 @@ -examples/ kernels/ diff --git a/torchao/prototype/float8nocompile/examples/example.py b/torchao/prototype/float8nocompile/examples/example.py new file mode 100644 index 000000000..7da145c4d --- /dev/null +++ b/torchao/prototype/float8nocompile/examples/example.py @@ -0,0 +1,33 @@ +import torch +import torch.nn as nn + +from torchao.prototype.float8nocompile.float8nocompile_linear_utils import ( + convert_to_float8_nocompile_training, +) +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 + +if not TORCH_VERSION_AT_LEAST_2_5: + raise AssertionError("torchao.float8 requires PyTorch version 2.5 or greater") + +# create model and sample input +m = ( + nn.Sequential( + nn.Linear(32, 32), + ) + .bfloat16() + .cuda() +) +x = torch.randn(32, 32, device="cuda", dtype=torch.bfloat16) +optimizer = torch.optim.SGD(m.parameters(), lr=0.1) + +# convert specified `torch.nn.Linear` modules to `Float8Linear` +print("calling convert_to_float8_nocompile_training") +convert_to_float8_nocompile_training(m) +print("finished convert_to_float8_nocompile_training") + +for i in range(10): + print(f"step {i}") + optimizer.zero_grad() + y = m(x) + y.sum().backward() + optimizer.step() diff --git a/torchao/prototype/float8nocompile/float8nocompile_linear.py b/torchao/prototype/float8nocompile/float8nocompile_linear.py index 2290fa207..d172a0e53 100644 --- a/torchao/prototype/float8nocompile/float8nocompile_linear.py +++ b/torchao/prototype/float8nocompile/float8nocompile_linear.py @@ -5,11 +5,28 @@ # LICENSE file in the root directory of this source tree. """ A simple module swap UX for a float8 version of `torch.nn.Linear` which -does not require `torch.compile` to be performant.. +does not require `torch.compile` to be performant. """ +from typing import Optional import torch +from torchao.float8.config import Float8LinearConfig, ScalingGranularity, ScalingType +from torchao.float8.distributed_utils import tensor_already_casted_to_fp8 +from torchao.float8.float8_linear import manual_float8_matmul_with_args_in_float8 +from torchao.float8.float8_scaling_utils import NoopFwToFloat8BwDynamic +from torchao.float8.float8_tensor import ( + GemmInputRole, + hp_tensor_and_scale_to_float8, + LinearMMConfig, + ScaledMMConfig, +) +from torchao.float8.float8_utils import tensor_to_scale + +from torchao.prototype.float8nocompile.float8nocompile_scaling_utils import ( + hp_tensor_to_float8nocompile_dynamic, +) + class Float8LinearNoCompile(torch.nn.Linear): """ @@ -19,4 +36,111 @@ class Float8LinearNoCompile(torch.nn.Linear): Note: this is **prototype** and not suitable for production use. """ - pass + def __init__(self, *args, **kwargs): + """ + Additional arguments on top of `torch.nn.Linear`'s arguments: + * `config`: Float8LinearConfig + """ + config = kwargs.pop("config") + emulate = config.emulate + super().__init__(*args, **kwargs) + + self.config = config + + self.linear_mm_config = LinearMMConfig( + # output + ScaledMMConfig( + emulate, + self.config.gemm_config_output.use_fast_accum, + False, + self.config.pad_inner_dim, + ), + # grad_input + ScaledMMConfig( + emulate, + self.config.gemm_config_grad_input.use_fast_accum, + False, + self.config.pad_inner_dim, + ), + # grad_weight + ScaledMMConfig( + emulate, + self.config.gemm_config_grad_weight.use_fast_accum, + False, + self.config.pad_inner_dim, + ), + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + # TODO(danielvegamyhre): replace conversions with triton kernels + # TODO(danielvegamyhre): support for FSDP once dependencies are implemented + input_fp8 = self.cast_input_to_float8(input) + weight_fp8_t = self.cast_weight_to_float8_t(self.weight) + + # compute fp8 matmul + output = manual_float8_matmul_with_args_in_float8.apply(input_fp8, weight_fp8_t) + + # cast grad_output to float8_e5m2 during backward + return self.cast_output_to_float8_in_bw(output) + + def cast_input_to_float8(self, input: torch.Tensor) -> torch.Tensor: + # Duplicate the autocast logic for F.linear, so that the output + # of our module has the right original precision + if torch.is_autocast_enabled(): + # For now, hardcode to GPU's autocast dtype + # if we need CPU support in the future, we can add it + autocast_dtype = torch.get_autocast_gpu_dtype() + input = input.to(autocast_dtype) + + # TODO(danielvegamyhre): implement this fn in scaling_utils with call to triton kernel + return hp_tensor_to_float8nocompile_dynamic( + input, + self.config.cast_config_input.target_dtype, + self.linear_mm_config, + gemm_input_role=GemmInputRole.INPUT, + ) + + def cast_weight_to_float8_t( + self, + weight: torch.Tensor, + ) -> torch.Tensor: + # TODO(danielvegamyhre): replace conversion with triton kernel + weight_fp8 = hp_tensor_to_float8nocompile_dynamic( + weight, + self.config.cast_config_weight.target_dtype, + self.linear_mm_config, + gemm_input_role=GemmInputRole.WEIGHT, + ) + return weight_fp8.t() + + def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor: + # casts grad_output to float8_e5m2 for backward + # TODO(danielvegamyhre): replace conversion with triton kernel + return NoopFwToFloat8BwDynamic.apply( + output, + self.linear_mm_config, + self.config.cast_config_grad_output.target_dtype, + ) + + @classmethod + def from_float(cls, mod): + """ + Create an nn.Linear with fp8 compute from a regular nn.Linear + + Args: + mod (torch.nn.Linear): nn.Linear to convert + config (Optional[Float8LinearConfig]): configuration for conversion to float8 + """ + config = Float8LinearConfig() + with torch.device("meta"): + new_mod = cls( + mod.in_features, + mod.out_features, + bias=False, + config=config, + ) + new_mod.weight = mod.weight + new_mod.bias = mod.bias + + # TODO(danielvegamyhre): support for FSDP once dependencies are implemented + return new_mod diff --git a/torchao/prototype/float8nocompile/float8nocompile_linear_utils.py b/torchao/prototype/float8nocompile/float8nocompile_linear_utils.py index d706c02d5..257a7563b 100644 --- a/torchao/prototype/float8nocompile/float8nocompile_linear_utils.py +++ b/torchao/prototype/float8nocompile/float8nocompile_linear_utils.py @@ -12,7 +12,9 @@ from torchao.float8.config import Float8LinearConfig from torchao.float8.float8_linear_utils import swap_linear_layers -from torchao.prototype.float8nocompile.float8_linear import Float8LinearNoCompile +from torchao.prototype.float8nocompile.float8nocompile_linear import ( + Float8LinearNoCompile, +) log = logging.getLogger(__name__) log.addHandler(logging.NullHandler()) @@ -22,7 +24,6 @@ def convert_to_float8_nocompile_training( module: nn.Module, *, module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None, - config: Float8LinearConfig = None, ) -> nn.Module: """ Swaps `torch.nn.Linear` in `module` with `Float8LinearNoCompile`. @@ -37,12 +38,7 @@ def convert_to_float8_nocompile_training( Returns: nn.Module: The modified module with swapped linear layers. """ - if config is None: - config = Float8LinearConfig() - from_float = lambda m: Float8LinearNoCompile.from_float( - m, - config=config, - ) + from_float = lambda m: Float8LinearNoCompile.from_float(m) return swap_linear_layers( module, from_float, diff --git a/torchao/prototype/float8nocompile/float8nocompile_scaling_utils.py b/torchao/prototype/float8nocompile/float8nocompile_scaling_utils.py new file mode 100644 index 000000000..4087a1ec1 --- /dev/null +++ b/torchao/prototype/float8nocompile/float8nocompile_scaling_utils.py @@ -0,0 +1,61 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +""" +Utilities for scaling high precision tensors to float8. +""" + +from typing import Optional + +import torch + +from torchao.float8.config import ScalingGranularity +from torchao.float8.distributed_utils import tensor_already_casted_to_fp8 +from torchao.float8.float8_tensor import ( + _ToFloat8ConstrFunc, + Float8Tensor, + GemmInputRole, + LinearMMConfig, +) +from torchao.float8.float8_utils import tensor_to_scale + +# avoid division by zero when calculating scale +# TODO: align this value with NVIDIA's assumptions (current value is a guess) +EPS = 1e-12 + + +def hp_tensor_to_float8nocompile_dynamic( + hp_tensor: torch.Tensor, + float8_dtype: torch.dtype, + linear_mm_config: LinearMMConfig, + gemm_input_role: GemmInputRole = GemmInputRole.INPUT, +) -> Float8Tensor: + """ + Given a high precision tensor `hp_tensor`, + scales `hp_tensor` dynamically and returns a `Float8Tensor` of the result. + + Args: + hp_tensor: the tensor to convert + float8_dtype: the float8 dtype to use + linear_mm_config: Defines the configuration for the scaled_mm for + the 3 fwd/bwd gemms of linear + gemm_input_role: Defines the role of this tensor (input, weight or grad_output) in + the 3 fwd/bwd gemms of linear + """ + # TODO(danielvegamyhre): replace this torch implementation with custom triton kernel + # torch.compile and eager show different numerics for 1.0 / float32, + # upcast to float64 to ensure same numeric between compile and eager + amax = torch.max(torch.abs(hp_tensor)).to(torch.float64) + scale = torch.finfo(float8_dtype).max / torch.clamp(amax, min=EPS) + scale = scale.to(torch.float32) # scale must be fp32 + return _ToFloat8ConstrFunc.apply( + hp_tensor, + scale, + float8_dtype, + linear_mm_config, + gemm_input_role, + None, + )