Skip to content

Commit

Permalink
Add test for Llama3.1 ScaledRoPE (#1216)
Browse files Browse the repository at this point in the history
  • Loading branch information
joecummings committed Jul 25, 2024
1 parent 403c7f3 commit 83557aa
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 12 deletions.
140 changes: 140 additions & 0 deletions tests/torchtune/models/llama3_1/test_position_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pytest
import torch

from tests.test_utils import assert_expected
from torch import tensor

from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE

from torchtune.utils.seed import set_seed


@pytest.fixture(autouse=True)
def random():
set_seed(0)


class TestLlama3ScaledRoPE:
"""
Class for testing our Scaled RoPE for LLama3.1 (RoPE)
implementation. The expected tensors are computed from the
reference implementation here:
https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/api/model.py#L272
The expected values are computed using the following code:
https://gist.github.com/joecummings/4f1331a9c1e5aa15bad1641acb74fe0e
"""

EXPECTED_FREQS_CIS_MEAN = tensor(0.1738)
EXPECTED_FREQS_CIS_SUM = tensor(91141.7656)
EXPECTED_FREQS_CIS_MAX = tensor(1.0)

EXPECTED_X_OUT_MEAN = tensor(-2.4781e-06)
EXPECTED_X_OUT_SUM = tensor(-83.1523)
EXPECTED_X_OUT_MAX = tensor(5.4625)

@pytest.fixture
def input_params(self):
bsz = 4
num_heads = 32
embed_dim = 4096
head_dim = embed_dim // num_heads
seq_len = 2048
max_seq_len = 4096
return bsz, num_heads, head_dim, seq_len, max_seq_len

@pytest.fixture
def input(self, input_params) -> tensor:
bsz, num_heads, head_dim, seq_len, _ = input_params
return torch.randn(bsz, seq_len, num_heads, head_dim)

@pytest.fixture
def rope(self, input_params) -> Llama3ScaledRoPE:
_, _, head_dim, _, max_seq_len = input_params
return Llama3ScaledRoPE(dim=head_dim, max_seq_len=max_seq_len)

def test_cache_equality(self, input, rope) -> None:
# Have to explicitly call _rope_init() to initialize theta matrix
rope._rope_init()
cache = rope.cache

assert_expected(cache.mean(), self.EXPECTED_FREQS_CIS_MEAN, atol=1e-4)
assert_expected(cache.sum(), self.EXPECTED_FREQS_CIS_SUM, atol=1e-4)
assert_expected(cache.max(), self.EXPECTED_FREQS_CIS_MAX)

def test_forward(self, input, rope) -> None:
x_out = rope(input)

# check the numerics of the computed tensor
assert_expected(x_out.mean(), self.EXPECTED_X_OUT_MEAN)
assert_expected(x_out.sum(), self.EXPECTED_X_OUT_SUM)
assert_expected(x_out.max(), self.EXPECTED_X_OUT_MAX)

# check shapes
assert_expected(x_out.shape, input.shape)

def test_forward_with_curr_pos(self, input, rope) -> None:
(
_,
seq_len,
_,
_,
) = input.shape
x_out = rope(input, input_pos=torch.arange(seq_len))

# these values should be exactly the same as test_forward
# since in this case input_pos covers the entire input
# sequence. This tests that input_pos works as expected i.e.
# extracts the embeddings for the relevant positions
assert_expected(x_out.mean(), self.EXPECTED_X_OUT_MEAN, atol=1e-4)
assert_expected(x_out.sum(), self.EXPECTED_X_OUT_SUM)
assert_expected(x_out.max(), self.EXPECTED_X_OUT_MAX)

# check shapes
assert_expected(x_out.shape, input.shape)

def test_forward_with_2d_pos_ids(self, input, rope) -> None:
"""
Use input_pos to indicate positions of each token relative to its sequence
when sample is packed.
"""
(
bsz,
seq_len,
_,
_,
) = input.shape
x_out = rope(
input, input_pos=torch.arange(seq_len).unsqueeze(0).expand(bsz, seq_len)
)

# these values should be exactly the same as test_forward
# AND test_forward_with_current_pos. In this case input_pos
# covers the entire batch dim and is defined for each sample separately.
# This tests that input_pos works as expected i.e.
# extracts the embeddings for the relevant positions for each sample
assert_expected(x_out.mean(), self.EXPECTED_X_OUT_MEAN, atol=1e-4)
assert_expected(x_out.sum(), self.EXPECTED_X_OUT_SUM)
assert_expected(x_out.max(), self.EXPECTED_X_OUT_MAX)

# check shapes
assert_expected(x_out.shape, input.shape)

def test_rope_init_meta_device(self, input_params):
_, _, head_dim, _, max_seq_len = input_params
rope_on_device = Llama3ScaledRoPE(dim=head_dim, max_seq_len=max_seq_len)
with torch.device("meta"):
meta_rope = Llama3ScaledRoPE(dim=head_dim, max_seq_len=max_seq_len)

meta_rope._rope_init()
for p1, p2 in zip(rope_on_device.buffers(), meta_rope.buffers()):
torch.testing.assert_close(p1, p2)

# Assert meta_rope cache is no longer on meta device
assert meta_rope.cache.device != torch.device("meta")
6 changes: 3 additions & 3 deletions torchtune/models/llama3_1/_component_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torch import nn

from torchtune.models.llama3._model_utils import scale_hidden_dim_for_mlp
from torchtune.models.llama3_1._position_embeddings import Llama31RotaryPositionalEmbeddings
from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE

from torchtune.modules import (
CausalSelfAttention,
Expand Down Expand Up @@ -81,7 +81,7 @@ def llama3_1(
"""
head_dim = embed_dim // num_heads
num_kv_heads = num_kv_heads if num_kv_heads else num_heads
rope = Llama31RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base)
rope = Llama3ScaledRoPE(dim=head_dim, max_seq_len=max_seq_len, base=rope_base)
self_attn = CausalSelfAttention(
embed_dim=embed_dim,
num_heads=num_heads,
Expand Down Expand Up @@ -358,7 +358,7 @@ def lora_llama3_1_self_attention(
if "output_proj" in lora_modules
else nn.Linear(embed_dim, embed_dim, bias=False)
)
rope = Llama31RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base)
rope = Llama3ScaledRoPE(dim=head_dim, max_seq_len=max_seq_len, base=rope_base)
self_attn = CausalSelfAttention(
embed_dim=embed_dim,
num_heads=num_heads,
Expand Down
12 changes: 3 additions & 9 deletions torchtune/models/llama3_1/_position_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,11 @@
from torch import nn, Tensor


class Llama31RotaryPositionalEmbeddings(nn.Module):
class Llama3ScaledRoPE(nn.Module):
"""
This class implements Rotary Positional Embeddings (RoPE)
proposed in https://arxiv.org/abs/2104.09864.
Reference implementation (used for correctness verfication)
can be found here:
https://github.com/meta-llama/llama/blob/main/llama/model.py#L80
proposed in https://arxiv.org/abs/2104.09864 with additional
scaling from https://github.com/meta-llama/llama-models/blob/dc42f22a3b05502e7296402b019a51f57fa045c9/models/llama3_1.
In this implementation we cache the embeddings for each position upto
``max_seq_len`` by computing this during init.
Expand Down Expand Up @@ -120,9 +117,6 @@ def forward(self, x: Tensor, *, input_pos: Optional[Tensor] = None) -> Tensor:
- s: sequence length
- n_h: num heads
- h_d: head dim
TODO: The implementation below can be made more efficient
for inference.
"""
# TODO: Remove this hack for handling scaling for Meta device
if not self.is_cache_built:
Expand Down

0 comments on commit 83557aa

Please sign in to comment.