Skip to content

Commit

Permalink
ModernBert: reuse GemmaRotaryEmbedding via modular + Integration tests (
Browse files Browse the repository at this point in the history
#35459)

* Introduce 5 integration tests for the 4 model classes + torch export

* ModernBert: reuse GemmaRotaryEmbedding via modular

* Revert #35589, keep rope_kwargs; rely on them in modular_modernbert

* Revert "Revert #35589, keep rope_kwargs; rely on them in modular_modernbert"

This reverts commit 11b44b9.

* Don't set rope_kwargs; override 'self.rope_init_fn' call instead
  • Loading branch information
tomaarsen authored Jan 10, 2025
1 parent 8de7b1b commit 6b73ee8
Show file tree
Hide file tree
Showing 3 changed files with 178 additions and 48 deletions.
56 changes: 42 additions & 14 deletions src/transformers/models/modernbert/modeling_modernbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_outputs import BaseModelOutput, MaskedLMOutput, SequenceClassifierOutput, TokenClassifierOutput
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_code_sample_docstrings,
Expand Down Expand Up @@ -241,30 +242,59 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:


class ModernBertRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
def __init__(self, config: ModernBertConfig, dim: int, base: float, device: Optional[torch.device] = None):
super().__init__()
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings

self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(None, device, dim=dim, base=base)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq

def _dynamic_frequency_update(self, position_ids, device):
"""
dynamic RoPE layers should recompute `inv_freq` in the following situations:
1 - growing beyond the cached sequence length (allow scaling)
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
"""
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len

self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len

@torch.no_grad()
def forward(self, x, position_ids, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
self.inv_freq.to(x.device)
def forward(self, x, position_ids):
if "dynamic" in self.rope_type:
self._dynamic_frequency_update(position_ids, device=x.device)

# Core RoPE block
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()

# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
cos = cos * self.attention_scaling
sin = sin * self.attention_scaling

return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


Expand Down Expand Up @@ -468,9 +498,7 @@ def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None):
dim=self.head_dim, max_seqlen=max_position_embeddings, base=rope_theta
)
else:
self.rotary_emb = ModernBertRotaryEmbedding(
dim=self.head_dim, max_position_embeddings=max_position_embeddings, base=rope_theta
)
self.rotary_emb = ModernBertRotaryEmbedding(config=config, dim=self.head_dim, base=rope_theta)

self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity()
Expand Down
36 changes: 6 additions & 30 deletions src/transformers/models/modernbert/modular_modernbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
logging,
)
from ...utils.import_utils import is_triton_available
from ..gemma.modeling_gemma import apply_rotary_pos_emb
from ..gemma.modeling_gemma import GemmaRotaryEmbedding, apply_rotary_pos_emb


if is_flash_attn_2_available():
Expand Down Expand Up @@ -504,32 +504,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.Wo(self.drop(self.act(input) * gate))


class ModernBertRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()

self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)

@torch.no_grad()
def forward(self, x, position_ids, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
self.inv_freq.to(x.device)
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
class ModernBertRotaryEmbedding(GemmaRotaryEmbedding):
def __init__(self, config: ModernBertConfig, dim: int, base: float, device: Optional[torch.device] = None):
super().__init__(self, config=config, device=device)
inv_freq, self.attention_scaling = self.rope_init_fn(None, device, dim=dim, base=base)


def eager_attention_forward(
Expand Down Expand Up @@ -698,9 +676,7 @@ def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None):
dim=self.head_dim, max_seqlen=max_position_embeddings, base=rope_theta
)
else:
self.rotary_emb = ModernBertRotaryEmbedding(
dim=self.head_dim, max_position_embeddings=max_position_embeddings, base=rope_theta
)
self.rotary_emb = ModernBertRotaryEmbedding(config=config, dim=self.head_dim, base=rope_theta)

self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity()
Expand Down
134 changes: 130 additions & 4 deletions tests/models/modernbert/test_modeling_modernbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
import unittest

import pytest
from packaging import version

from transformers import ModernBertConfig, is_torch_available
from transformers import AutoTokenizer, ModernBertConfig, is_torch_available
from transformers.models.auto import get_values
from transformers.testing_utils import (
CaptureLogger,
Expand Down Expand Up @@ -362,6 +363,131 @@ def test_flash_attn_2_conversion(self):

@require_torch
class ModernBertModelIntegrationTest(unittest.TestCase):
"""
These still need to be written, once public models are available.
"""
@slow
def test_inference_masked_lm(self):
if version.parse(torch.__version__) < version.parse("2.4.0"):
self.skipTest(reason="This test requires torch >= 2.4 to run.")

model = ModernBertForMaskedLM.from_pretrained(
"answerdotai/ModernBERT-base", reference_compile=False, attn_implementation="sdpa"
)
tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base")

inputs = tokenizer("Hello World!", return_tensors="pt")
with torch.no_grad():
output = model(**inputs)[0]
expected_shape = torch.Size((1, 5, 50368))
self.assertEqual(output.shape, expected_shape)

# compare the actual values for a slice.
expected_slice = torch.tensor(
[[[3.8387, -0.2017, 12.2839], [3.6300, 0.6869, 14.7123], [-5.1137, -3.8122, 11.9874]]]
)
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))

@slow
def test_inference_no_head(self):
if version.parse(torch.__version__) < version.parse("2.4.0"):
self.skipTest(reason="This test requires torch >= 2.4 to run.")

model = ModernBertModel.from_pretrained(
"answerdotai/ModernBERT-base", reference_compile=False, attn_implementation="sdpa"
)
tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base")

inputs = tokenizer("Hello World!", return_tensors="pt")
with torch.no_grad():
output = model(**inputs)[0]
expected_shape = torch.Size((1, 5, 768))
self.assertEqual(output.shape, expected_shape)

# compare the actual values for a slice.
expected_slice = torch.tensor(
[[[0.3151, -0.6417, -0.7027], [-0.7834, -1.5810, 0.4576], [1.0614, -0.7268, -0.0871]]]
)
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))

@slow
def test_inference_token_classification(self):
if version.parse(torch.__version__) < version.parse("2.4.0"):
self.skipTest(reason="This test requires torch >= 2.4 to run.")

model = ModernBertForTokenClassification.from_pretrained(
"hf-internal-testing/tiny-random-ModernBertForTokenClassification",
reference_compile=False,
attn_implementation="sdpa",
)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-ModernBertForTokenClassification")

inputs = tokenizer("Hello World!", return_tensors="pt")
with torch.no_grad():
output = model(**inputs)[0]
expected_shape = torch.Size((1, 5, 2))
self.assertEqual(output.shape, expected_shape)

expected = torch.tensor(
[[[2.0159, 4.6569], [-0.9430, 3.1595], [-3.8770, 3.2653], [1.5752, 4.5167], [-1.6939, 1.2524]]]
)
self.assertTrue(torch.allclose(output, expected, atol=1e-4))

@slow
def test_inference_sequence_classification(self):
if version.parse(torch.__version__) < version.parse("2.4.0"):
self.skipTest(reason="This test requires torch >= 2.4 to run.")

model = ModernBertForSequenceClassification.from_pretrained(
"hf-internal-testing/tiny-random-ModernBertForSequenceClassification",
reference_compile=False,
attn_implementation="sdpa",
)
tokenizer = AutoTokenizer.from_pretrained(
"hf-internal-testing/tiny-random-ModernBertForSequenceClassification"
)

inputs = tokenizer("Hello World!", return_tensors="pt")
with torch.no_grad():
output = model(**inputs)[0]
expected_shape = torch.Size((1, 2))
self.assertEqual(output.shape, expected_shape)

expected = torch.tensor([[1.6466, 4.5662]])
self.assertTrue(torch.allclose(output, expected, atol=1e-4))

@slow
def test_export(self):
if version.parse(torch.__version__) < version.parse("2.4.0"):
self.skipTest(reason="This test requires torch >= 2.4 to run.")

bert_model = "answerdotai/ModernBERT-base"
device = "cpu"
attn_implementation = "sdpa"
max_length = 512

tokenizer = AutoTokenizer.from_pretrained(bert_model)
inputs = tokenizer(
"the man worked as a [MASK].",
return_tensors="pt",
padding="max_length",
max_length=max_length,
)

model = ModernBertForMaskedLM.from_pretrained(
bert_model,
device_map=device,
attn_implementation=attn_implementation,
)

logits = model(**inputs).logits
eg_predicted_mask = tokenizer.decode(logits[0, 6].topk(5).indices)
self.assertEqual(eg_predicted_mask.split(), ["lawyer", "mechanic", "teacher", "doctor", "waiter"])

exported_program = torch.export.export(
model,
args=(inputs["input_ids"],),
kwargs={"attention_mask": inputs["attention_mask"]},
strict=True,
)

result = exported_program.module().forward(inputs["input_ids"], inputs["attention_mask"])
ep_predicted_mask = tokenizer.decode(result.logits[0, 6].topk(5).indices)
self.assertEqual(eg_predicted_mask, ep_predicted_mask)

0 comments on commit 6b73ee8

Please sign in to comment.