From 367aea3b1e3883318b3039da026122f3b1b23aab Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Wed, 8 Jan 2025 06:40:24 -0800 Subject: [PATCH] add handling for batch dim in float8nocompile ghstack-source-id: 47996349f565c237f81124afe8b500d221d1448b ghstack-comment-id: 2574237839 Pull Request resolved: https://github.com/pytorch/ao/pull/1512 --- .../float8nocompile/float8nocompile_linear.py | 25 ++++++- .../float8nocompile_linear_test.py | 73 +++++++++++++++++++ .../float8nocompile_linear_utils.py | 9 ++- .../float8nocompile/test/train_test.py | 7 +- 4 files changed, 109 insertions(+), 5 deletions(-) create mode 100644 torchao/prototype/float8nocompile/float8nocompile_linear_test.py diff --git a/torchao/prototype/float8nocompile/float8nocompile_linear.py b/torchao/prototype/float8nocompile/float8nocompile_linear.py index 37de7b852..59a44ef77 100644 --- a/torchao/prototype/float8nocompile/float8nocompile_linear.py +++ b/torchao/prototype/float8nocompile/float8nocompile_linear.py @@ -80,7 +80,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: return output @classmethod - def from_float(cls, mod, kernel_algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_MAX): + def from_float(cls, mod, config: Float8LinearConfig, kernel_algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_MAX): """ Create an nn.Linear with fp8 compute from a regular nn.Linear @@ -88,7 +88,6 @@ def from_float(cls, mod, kernel_algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_M mod (torch.nn.Linear): nn.Linear to convert config (Optional[Float8LinearConfig]): configuration for conversion to float8 """ - config = Float8LinearConfig() with torch.device("meta"): new_mod = cls( mod.in_features, @@ -107,6 +106,10 @@ def from_float(cls, mod, kernel_algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_M class matmul_with_args_in_hp(torch.autograd.Function): @staticmethod def forward(ctx, input_hp, weight_hp, config, linear_mm_config, kernel_algo): + # reshape to be 2D for triton kernels + orig_input_shape = input_hp.shape + input_hp = input_hp.reshape(-1, input_hp.shape[-1]) + # output = input @ weight_t input_fp8_row_major, input_fp8_col_major = ToFP8RowAndColumnMajor.apply( input_hp, @@ -130,12 +133,24 @@ def forward(ctx, input_hp, weight_hp, config, linear_mm_config, kernel_algo): ctx.linear_mm_config = linear_mm_config ctx.kernel_algo = kernel_algo + # reshape back to expected dims + output = output.reshape(*orig_input_shape[:-1], output.shape[-1]) return output @staticmethod def backward(ctx, grad_output): + # grad_output may not be contiguous in cases like: + # output.sum().backward() where grad is all 1s, so the (M,N) view of the scalar "1" + # results in a non-contiguous tensor with stride (0,0). + if not grad_output.is_contiguous(): + grad_output = grad_output.contiguous() + input_fp8_col_major, weight_hp = ctx.saved_tensors + # reshsape to be 2D for triton kernels + orig_grad_output_shape = grad_output.shape + grad_output = grad_output.reshape(-1, grad_output.shape[-1]) + # cast grad output to float8_e5m2 for backward grad_output_fp8_row_major, grad_output_t_row_major = ( ToFP8RowMajorTAndNonT.apply( @@ -162,4 +177,10 @@ def backward(ctx, grad_output): # source: https://github.com/pytorch/ao/blob/fe5f11b2c58b452e01ba9ec7359629928b143619/torchao/float8/float8_linear.py#L84-L85 grad_weight = torch.mm(grad_output_t_row_major, input_fp8_col_major) + # reshape grad input to match original shape + grad_input = grad_input.reshape( + *orig_grad_output_shape[:-1], grad_input.shape[-1] + ) + + # grad input shape return grad_input, grad_weight, None, None, None diff --git a/torchao/prototype/float8nocompile/float8nocompile_linear_test.py b/torchao/prototype/float8nocompile/float8nocompile_linear_test.py new file mode 100644 index 000000000..b269131bb --- /dev/null +++ b/torchao/prototype/float8nocompile/float8nocompile_linear_test.py @@ -0,0 +1,73 @@ +import pytest + +import torch +from torch.autograd.function import FunctionCtx +from torchao.float8.float8_linear import manual_float8_matmul_with_args_in_hp +from torchao.float8.config import Float8LinearConfig +from torchao.float8.float8_tensor import LinearMMConfig, ScaledMMConfig +from torch.autograd import gradcheck + +from torchao.prototype.float8nocompile.float8nocompile_linear import matmul_with_args_in_hp +from torchao.prototype.float8nocompile.float8nocompile_scaling_utils import KernelAlgorithm + +# unit test comparing the two implementations +@pytest.mark.parametrize( + "input_shape", + [(32, 16), (1,32,16), (2,32,16)], +) +def test_matmul_with_args_in_hp(input_shape: tuple[int, int]): + assert torch.cuda.is_available() + device = "cuda" + + # high precision inputs + input_bf16 = torch.randn(input_shape, dtype=torch.bfloat16, device=device, requires_grad=True) + x_input_bf16 = input_bf16.clone().detach().to(device).requires_grad_(True) + y_input_bf16 = input_bf16.clone().detach().to(device).requires_grad_(True) + + # high precision weights + # nn.Linear stores weights in transposed form + weight_bf16 = torch.randn((32, input_bf16.shape[-1]), dtype=torch.bfloat16, device=device, requires_grad=True) + x_weight_bf16 = weight_bf16.clone().detach().to(device).requires_grad_(True) + y_weight_bf16 = weight_bf16.clone().detach().to(device).requires_grad_(True) + + # default configs + config = Float8LinearConfig() + emulate = False + linear_mm_config = linear_mm_config = LinearMMConfig( + # output + ScaledMMConfig( + emulate, + config.gemm_config_output.use_fast_accum, + False, + config.pad_inner_dim, + ), + # grad_input + ScaledMMConfig( + emulate, + config.gemm_config_grad_input.use_fast_accum, + False, + config.pad_inner_dim, + ), + # grad_weight + ScaledMMConfig( + emulate, + config.gemm_config_grad_weight.use_fast_accum, + False, + config.pad_inner_dim, + ), + ) + + # prod forward. expects transposed weight. + out_prod = manual_float8_matmul_with_args_in_hp.apply(x_input_bf16, x_weight_bf16.t(), linear_mm_config, config) + + # prototype forward. expects non-transposed weight + out_prototype = matmul_with_args_in_hp.apply(y_input_bf16, y_weight_bf16, config, linear_mm_config, KernelAlgorithm.ATOMIC_MAX) + + # compare + assert torch.allclose(out_prod, out_prototype, atol=1e-3, rtol=1e-3) + + out_prod.sum().backward() + out_prototype.sum().backward() + + assert torch.allclose(x_input_bf16.grad, y_input_bf16.grad, atol=1e-3, rtol=1e-3) + assert torch.allclose(x_weight_bf16.grad, y_weight_bf16.grad, atol=1e-3, rtol=1e-3) diff --git a/torchao/prototype/float8nocompile/float8nocompile_linear_utils.py b/torchao/prototype/float8nocompile/float8nocompile_linear_utils.py index 2ab707a4e..6739242f0 100644 --- a/torchao/prototype/float8nocompile/float8nocompile_linear_utils.py +++ b/torchao/prototype/float8nocompile/float8nocompile_linear_utils.py @@ -8,6 +8,7 @@ import torch.nn as nn +from torchao.float8.config import Float8LinearConfig from torchao.float8.float8_linear_utils import swap_linear_layers from torchao.prototype.float8nocompile.float8nocompile_linear import ( Float8LinearNoCompile, @@ -23,6 +24,7 @@ def convert_to_float8_nocompile_training( module: nn.Module, *, + config: Float8LinearConfig = None, module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None, kernel_algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_MAX, ) -> nn.Module: @@ -39,7 +41,12 @@ def convert_to_float8_nocompile_training( Returns: nn.Module: The modified module with swapped linear layers. """ - from_float = lambda m: Float8LinearNoCompile.from_float(m, kernel_algo=kernel_algo) + if config is None: + config = Float8LinearConfig() + + from_float = lambda m: Float8LinearNoCompile.from_float( + m, config=config, kernel_algo=kernel_algo + ) return swap_linear_layers( module, from_float, diff --git a/torchao/prototype/float8nocompile/test/train_test.py b/torchao/prototype/float8nocompile/test/train_test.py index 7e28ff52c..23165ce05 100644 --- a/torchao/prototype/float8nocompile/test/train_test.py +++ b/torchao/prototype/float8nocompile/test/train_test.py @@ -36,7 +36,10 @@ def model2(): return TestModel() -def test_model_weights_and_gradients(model1, model2): +@pytest.mark.parametrize( + "input_shape", [(16,32), (1,16,32), (2,16,32)] +) +def test_model_weights_and_gradients(model1, model2, input_shape: tuple[int, int]): assert torch.cuda.is_available() device = torch.device("cuda") @@ -48,7 +51,7 @@ def test_model_weights_and_gradients(model1, model2): convert_to_float8_nocompile_training(model1) input_tensor = torch.randn( - 16, 32, requires_grad=True, dtype=torch.bfloat16, device=device + *input_shape, requires_grad=True, dtype=torch.bfloat16, device=device ) input_copy1 = input_tensor.clone().detach().requires_grad_(True) input_copy2 = input_tensor.clone().detach().requires_grad_(True)