diff --git a/llmfoundry/models/layers/norm.py b/llmfoundry/models/layers/norm.py index 2ff4eaed0c..be4f50f521 100644 --- a/llmfoundry/models/layers/norm.py +++ b/llmfoundry/models/layers/norm.py @@ -111,9 +111,51 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: self.eps).to(dtype=x.dtype) +class TritonRMSNorm(torch.nn.Module): + + def __init__( + self, + normalized_shape: Union[int, List[int], torch.Size], + eps: float = 1e-5, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + super().__init__() + self.eps = eps + + try: + from flash_attn.ops.triton.layer_norm import rms_norm_fn + except ImportError: + raise ImportError( + 'triton_rms_norm requires Flash Attention to be installed. ' + + 'Please pip install flash-attn.') + + if not isinstance(normalized_shape, int): + raise ValueError('TritonRMSNorm only supports 1D tensors') + + self.rms_norm_fn = rms_norm_fn + + self.weight = torch.nn.Parameter( + torch.ones(normalized_shape, device=device, dtype=dtype)) + + def forward(self, x: torch.Tensor): + # Flash Attention expect a flat tensor + return self.rms_norm_fn( + x, + self.weight, + None, # no bias + residual=None, + eps=self.eps, + dropout_p=0.0, # no dropout by default + prenorm=False, + residual_in_fp32=False, + ) + + NORM_CLASS_REGISTRY: Dict[str, Type[torch.nn.Module]] = { 'layernorm': torch.nn.LayerNorm, 'low_precision_layernorm': LPLayerNorm, 'rmsnorm': RMSNorm, 'low_precision_rmsnorm': LPRMSNorm, + 'triton_rmsnorm': TritonRMSNorm, } diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 7244ddc8c2..464423f512 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -616,6 +616,11 @@ def test_lora_id(): def test_mpt_creation(norm_type: str, no_bias: bool, tie_word_embeddings: bool, expansion_ratio: Union[int, float], ffn_hidden_size: int, ffn_act_fn: dict): + if norm_type == 'triton_rmsnorm' and not is_flash_v2_installed(): + pytest.skip( + f'norm_type=triton_rmsnorm requires flash Attention to be installed' + ) + # Test that the config constructs the model as expected. hf_config = MPTConfig( init_device='cpu', diff --git a/tests/models/test_rmsnorm_triton_vs_eager.py b/tests/models/test_rmsnorm_triton_vs_eager.py new file mode 100644 index 0000000000..1902f46d78 --- /dev/null +++ b/tests/models/test_rmsnorm_triton_vs_eager.py @@ -0,0 +1,75 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from typing import List, Union + +import pytest +import torch +from composer.core.precision import get_precision_context + +from llmfoundry.models.layers.attention import is_flash_v2_installed + + +@pytest.mark.gpu +@pytest.mark.parametrize('normalized_shape', [32, 128, 4096]) +def test_rmsnorm_triton_vs_eager(normalized_shape: Union[int, List[int]], + device: str = 'cuda'): + # Compare Triton and PyTorch Eager implementations of RMSNorm + if not is_flash_v2_installed(): + pytest.skip( + 'triton implementation of rmsnorm requires flash attention 2.') + + from llmfoundry.models.layers import norm + + batch_size = 2 + + cfg = { + 'normalized_shape': normalized_shape, + 'device': device, + } + + eager_rmsnorm = norm.NORM_CLASS_REGISTRY['rmsnorm'](**cfg) + triton_rmsnorm = norm.NORM_CLASS_REGISTRY['triton_rmsnorm'](**cfg) + + triton_rmsnorm.load_state_dict(eager_rmsnorm.state_dict()) + + if isinstance(normalized_shape, int): + input_shape = [batch_size, normalized_shape] + else: + input_shape = tuple([batch_size, *normalized_shape]) + + x0 = torch.randn(size=input_shape, device=device) + x1 = x0.clone().detach() + x0.requires_grad = True + x1.requires_grad = True + + with get_precision_context('amp_bf16'): + y0 = eager_rmsnorm(x0) + y1 = triton_rmsnorm(x1) + + loss0 = y0.sum() + loss1 = y1.sum() + + loss0.backward() + loss1.backward() + + rtol = 1e-6 + atol = 1e-6 + + torch.testing.assert_close(y0, y1, rtol=rtol, atol=atol) + + p0 = eager_rmsnorm.weight + p1 = triton_rmsnorm.weight + + # weight check + torch.testing.assert_close(p0, p1, rtol=rtol, atol=atol) + # weight gradient check + assert p0.grad is not None + assert p1.grad is not None + assert torch.norm(p0.grad - p1.grad) <= atol + rtol * torch.norm(p0.grad) + + # input gradient check + assert x0.grad is not None + assert x1.grad is not None + # Relaxed to a l2-norm based check. + assert torch.norm(x0.grad - x1.grad) <= atol + rtol * torch.norm(x0.grad)