Skip to content

Commit

Permalink
Triton RMSNorm (#1050)
Browse files Browse the repository at this point in the history
* Triton RMS Norm

* fix

* refactor

* Fix test

* no bias

* Update tests/models/test_model.py

Co-authored-by: Daniel King <[email protected]>

* fixes

* Unittest

* fix

* fix2

* fix3

* fix3

---------

Co-authored-by: Daniel King <[email protected]>
  • Loading branch information
josejg and dakinggg authored Apr 2, 2024
1 parent d8ea2c5 commit b765b47
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 0 deletions.
42 changes: 42 additions & 0 deletions llmfoundry/models/layers/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,51 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
self.eps).to(dtype=x.dtype)


class TritonRMSNorm(torch.nn.Module):

def __init__(
self,
normalized_shape: Union[int, List[int], torch.Size],
eps: float = 1e-5,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
super().__init__()
self.eps = eps

try:
from flash_attn.ops.triton.layer_norm import rms_norm_fn
except ImportError:
raise ImportError(
'triton_rms_norm requires Flash Attention to be installed. ' +
'Please pip install flash-attn.')

if not isinstance(normalized_shape, int):
raise ValueError('TritonRMSNorm only supports 1D tensors')

self.rms_norm_fn = rms_norm_fn

self.weight = torch.nn.Parameter(
torch.ones(normalized_shape, device=device, dtype=dtype))

def forward(self, x: torch.Tensor):
# Flash Attention expect a flat tensor
return self.rms_norm_fn(
x,
self.weight,
None, # no bias
residual=None,
eps=self.eps,
dropout_p=0.0, # no dropout by default
prenorm=False,
residual_in_fp32=False,
)


NORM_CLASS_REGISTRY: Dict[str, Type[torch.nn.Module]] = {
'layernorm': torch.nn.LayerNorm,
'low_precision_layernorm': LPLayerNorm,
'rmsnorm': RMSNorm,
'low_precision_rmsnorm': LPRMSNorm,
'triton_rmsnorm': TritonRMSNorm,
}
5 changes: 5 additions & 0 deletions tests/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,11 @@ def test_lora_id():
def test_mpt_creation(norm_type: str, no_bias: bool, tie_word_embeddings: bool,
expansion_ratio: Union[int, float], ffn_hidden_size: int,
ffn_act_fn: dict):
if norm_type == 'triton_rmsnorm' and not is_flash_v2_installed():
pytest.skip(
f'norm_type=triton_rmsnorm requires flash Attention to be installed'
)

# Test that the config constructs the model as expected.
hf_config = MPTConfig(
init_device='cpu',
Expand Down
75 changes: 75 additions & 0 deletions tests/models/test_rmsnorm_triton_vs_eager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from typing import List, Union

import pytest
import torch
from composer.core.precision import get_precision_context

from llmfoundry.models.layers.attention import is_flash_v2_installed


@pytest.mark.gpu
@pytest.mark.parametrize('normalized_shape', [32, 128, 4096])
def test_rmsnorm_triton_vs_eager(normalized_shape: Union[int, List[int]],
device: str = 'cuda'):
# Compare Triton and PyTorch Eager implementations of RMSNorm
if not is_flash_v2_installed():
pytest.skip(
'triton implementation of rmsnorm requires flash attention 2.')

from llmfoundry.models.layers import norm

batch_size = 2

cfg = {
'normalized_shape': normalized_shape,
'device': device,
}

eager_rmsnorm = norm.NORM_CLASS_REGISTRY['rmsnorm'](**cfg)
triton_rmsnorm = norm.NORM_CLASS_REGISTRY['triton_rmsnorm'](**cfg)

triton_rmsnorm.load_state_dict(eager_rmsnorm.state_dict())

if isinstance(normalized_shape, int):
input_shape = [batch_size, normalized_shape]
else:
input_shape = tuple([batch_size, *normalized_shape])

x0 = torch.randn(size=input_shape, device=device)
x1 = x0.clone().detach()
x0.requires_grad = True
x1.requires_grad = True

with get_precision_context('amp_bf16'):
y0 = eager_rmsnorm(x0)
y1 = triton_rmsnorm(x1)

loss0 = y0.sum()
loss1 = y1.sum()

loss0.backward()
loss1.backward()

rtol = 1e-6
atol = 1e-6

torch.testing.assert_close(y0, y1, rtol=rtol, atol=atol)

p0 = eager_rmsnorm.weight
p1 = triton_rmsnorm.weight

# weight check
torch.testing.assert_close(p0, p1, rtol=rtol, atol=atol)
# weight gradient check
assert p0.grad is not None
assert p1.grad is not None
assert torch.norm(p0.grad - p1.grad) <= atol + rtol * torch.norm(p0.grad)

# input gradient check
assert x0.grad is not None
assert x1.grad is not None
# Relaxed to a l2-norm based check.
assert torch.norm(x0.grad - x1.grad) <= atol + rtol * torch.norm(x0.grad)

0 comments on commit b765b47

Please sign in to comment.