Skip to content

Commit

Permalink
T5 Encoder (pytorch#2069)
Browse files Browse the repository at this point in the history
  • Loading branch information
calvinpelletier authored Jan 7, 2025
1 parent 213f386 commit 27fd3a1
Show file tree
Hide file tree
Showing 13 changed files with 711 additions and 2 deletions.
Binary file added tests/assets/sentencepiece.model
Binary file not shown.
5 changes: 5 additions & 0 deletions tests/torchtune/models/t5/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
81 changes: 81 additions & 0 deletions tests/torchtune/models/t5/test_t5_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pytest
import torch

from torchtune.models.t5._component_builders import t5_encoder
from torchtune.training.seed import set_seed

VOCAB_SIZE = 512
MAX_SEQ_LEN = 8
BSZ = 2
EMBED_DIM = 2


@pytest.fixture(autouse=True)
def random():
set_seed(0)


class TestT5Encoder:
@pytest.fixture
def model(self):
model = t5_encoder(
embed_dim=EMBED_DIM,
mlp_dim=4,
num_heads=2,
head_dim=EMBED_DIM // 2,
num_layers=2,
rel_pos_num_buckets=4,
rel_pos_max_dist=4,
vocab_size=VOCAB_SIZE,
norm_eps=1e-6,
max_seq_len=MAX_SEQ_LEN,
)

for param in model.parameters():
param.data.uniform_(0, 1)

return model

@pytest.fixture
def inputs(self):
return torch.randint(0, VOCAB_SIZE, (BSZ, MAX_SEQ_LEN))

def test_forward(self, model, inputs):
actual = model(inputs)
expected = torch.tensor(
[
[
[0.3670, 0.2938],
[0.3692, 0.2921],
[0.3611, 0.2984],
[0.4207, 0.2437],
[0.3447, 0.3106],
[0.3383, 0.3150],
[0.3727, 0.2892],
[0.3996, 0.2653],
],
[
[0.3855, 0.2783],
[0.2627, 0.3581],
[0.3601, 0.2992],
[0.3473, 0.3087],
[0.3549, 0.3032],
[0.2871, 0.3459],
[0.2753, 0.3520],
[0.2285, 0.3728],
],
]
)
assert actual.shape == (BSZ, MAX_SEQ_LEN, EMBED_DIM)
torch.testing.assert_close(actual, expected, atol=1e-4, rtol=1e-4)

def test_backward(self, model, inputs):
y = model(inputs)
loss = y.mean()
loss.backward()
39 changes: 39 additions & 0 deletions tests/torchtune/models/t5/test_t5_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import pytest

from tests.common import ASSETS
from torchtune.models.t5._model_builders import t5_tokenizer


class TestT5Tokenizer:
@pytest.fixture
def tokenizer(self):
return t5_tokenizer(str(ASSETS / "sentencepiece.model"))

def test_encoding(self, tokenizer):
texts = [
"a cow jumping over the moon",
"a helpful AI assistant",
]
correct_tokens = [
[3, 9, 9321, 15539, 147, 8, 8114, 1],
[3, 9, 2690, 7833, 6165, 1],
]
for text, correct in zip(texts, correct_tokens):
tokens = tokenizer.encode(text)
print(tokens)
assert tokens == correct

def test_decoding(self, tokenizer):
text = "this is torchtune"
assert text == tokenizer.decode(tokenizer.encode(text))

def test_call(self, tokenizer):
sample = {"text": "hello world"}
sample = tokenizer(sample)
assert "text" not in sample
assert "tokens" in sample
14 changes: 14 additions & 0 deletions torchtune/models/t5/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from ._component_builders import t5_encoder
from ._model_builders import t5_tokenizer, t5_v1_1_xxl_encoder

__all__ = [
"t5_encoder",
"t5_tokenizer",
"t5_v1_1_xxl_encoder",
]
89 changes: 89 additions & 0 deletions torchtune/models/t5/_component_builders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from torch import nn

from torchtune.models.t5._encoder import (
T5Encoder,
T5EncoderLayer,
T5EncoderSelfAttention,
)
from torchtune.modules.feed_forward import FeedForward
from torchtune.modules.rms_norm import RMSNorm


def t5_encoder(
embed_dim: int,
mlp_dim: int,
num_heads: int,
head_dim: int,
num_layers: int,
rel_pos_num_buckets: int,
rel_pos_max_dist: int,
vocab_size: int,
norm_eps: float,
max_seq_len: int,
):
"""
Builder for the T5 encoder.
T5 paper: https://arxiv.org/abs/1910.10683
Args:
embed_dim (int): The model dimension.
mlp_dim (int): The inner dimension of the feed forward layers.
num_heads (int): The number of attention heads.
head_dim (int): The dimension of the attention heads (should equal `embed_dim // num_heads`)
num_layers (int): Number of encoder layers.
rel_pos_num_buckets (int): Number of discrete buckets to divide the relative positions into.
See: :class:`~torchtune.models.t5._encoder.T5EncoderRelativePositionBias`
rel_pos_max_dist (int): Maximum distance for relative positions.
Distances beyond this are grouped into the last bucket.
See: :class:`~torchtune.models.t5._encoder.T5EncoderRelativePositionBias`
vocab_size (int): Vocab size of the model's tokenizer.
norm_eps (float): Small value added to denominator for numerical stability.
max_seq_len (int): The maximum sequence length (context length) of the model.
Returns:
T5Encoder
"""
token_embedding = nn.Embedding(vocab_size, embed_dim)

attn = T5EncoderSelfAttention(
embed_dim=embed_dim,
num_heads=num_heads,
head_dim=head_dim,
q_proj=nn.Linear(embed_dim, embed_dim, bias=False),
k_proj=nn.Linear(embed_dim, embed_dim, bias=False),
v_proj=nn.Linear(embed_dim, embed_dim, bias=False),
output_proj=nn.Linear(embed_dim, embed_dim, bias=False),
)

mlp = FeedForward(
gate_proj=nn.Linear(embed_dim, mlp_dim, bias=False),
down_proj=nn.Linear(mlp_dim, embed_dim, bias=False),
up_proj=nn.Linear(embed_dim, mlp_dim, bias=False),
activation=nn.GELU(),
)

layer = T5EncoderLayer(
attn=attn,
mlp=mlp,
sa_norm=RMSNorm(embed_dim, eps=norm_eps),
mlp_norm=RMSNorm(embed_dim, eps=norm_eps),
)

final_norm = RMSNorm(embed_dim, eps=norm_eps)

return T5Encoder(
token_embedding=token_embedding,
layer=layer,
final_norm=final_norm,
num_layers=num_layers,
num_heads=num_heads,
rel_pos_num_buckets=rel_pos_num_buckets,
rel_pos_max_dist=rel_pos_max_dist,
max_seq_len=max_seq_len,
)
49 changes: 49 additions & 0 deletions torchtune/models/t5/_convert_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torchtune.models.convert_weights import get_mapped_key

# state dict key mappings from HF's format to torchtune's format
_FROM_HF = {
# emb
"encoder.embed_tokens.weight": "token_embedding.weight",
"encoder.block.{}.layer._0.SelfAttention.relative_attention_bias.weight": "relative_position_bias.embedding.weight",
# attn
"encoder.block.{}.layer._0.SelfAttention.q.weight": "layers.{}.attn.q_proj.weight",
"encoder.block.{}.layer._0.SelfAttention.k.weight": "layers.{}.attn.k_proj.weight",
"encoder.block.{}.layer._0.SelfAttention.v.weight": "layers.{}.attn.v_proj.weight",
"encoder.block.{}.layer._0.SelfAttention.o.weight": "layers.{}.attn.output_proj.weight",
# ff
"encoder.block.{}.layer._1.DenseReluDense.wi_0.weight": "layers.{}.mlp.w1.weight",
"encoder.block.{}.layer._1.DenseReluDense.wo.weight": "layers.{}.mlp.w2.weight",
"encoder.block.{}.layer._1.DenseReluDense.wi_1.weight": "layers.{}.mlp.w3.weight",
# norm
"encoder.block.{}.layer._0.layer_norm.weight": "layers.{}.sa_norm.scale",
"encoder.block.{}.layer._1.layer_norm.weight": "layers.{}.mlp_norm.scale",
"encoder.final_layer_norm.weight": "final_norm.scale",
}

_IGNORE = {
"shared.weight",
"lm_head.weight",
}


def t5_encoder_hf_to_tune(state_dict):
converted_state_dict = {}
for key, value in state_dict.items():
if key.startswith("decoder.") or key in _IGNORE:
continue

# NOTE: HF's T5 has ".<integer>." parts that we do NOT want to be dynamically mapped
# to corresponding ".<integer>." parts in our converted state dict.
# This breaks the `get_mapped_key` implementation, so as a temporary hack,
# we add leading underscores to these parts here and in the `_FROM_HF` map above.
key = key.replace("layer.0.", "layer._0.").replace("layer.1.", "layer._1.")

new_key = get_mapped_key(key, _FROM_HF)
converted_state_dict[new_key] = value
return converted_state_dict
Loading

0 comments on commit 27fd3a1

Please sign in to comment.