Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modernbert Release Fixes #35344

Merged
merged 4 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions src/transformers/models/modernbert/modeling_modernbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1118,13 +1118,16 @@ def __init__(self, config: ModernBertConfig):
self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
self.drop = torch.nn.Dropout(config.classifier_dropout)

def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
if self.config.classifier_pooling == "cls":
hidden_states = hidden_states[:, 0]
elif self.config.classifier_pooling == "mean":
hidden_states = (hidden_states * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(
dim=1, keepdim=True
)
def forward(
self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, apply_pooling: bool = True
) -> torch.Tensor:
if apply_pooling:
if self.config.classifier_pooling == "cls":
hidden_states = hidden_states[:, 0]
elif self.config.classifier_pooling == "mean":
hidden_states = (hidden_states * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(
dim=1, keepdim=True
)

return self.drop(self.norm(self.act(self.dense(hidden_states))))

Expand Down Expand Up @@ -1242,7 +1245,7 @@ def __init__(self, config: ModernBertConfig):
self.num_labels = config.num_labels

self.model = ModernBertModel(config)
self.drop = nn.Dropout(config.classifier_dropout)
self.head = ModernBertPoolingHead(config)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)

# Initialize weights and apply final processing
Expand Down Expand Up @@ -1293,7 +1296,7 @@ def forward(
)
last_hidden_state = outputs[0]

last_hidden_state = self.drop(last_hidden_state)
last_hidden_state = self.head(last_hidden_state, attention_mask, apply_pooling=False)
logits = self.classifier(last_hidden_state)

loss = None
Expand Down
51 changes: 39 additions & 12 deletions src/transformers/models/modernbert/modular_modernbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
logging,
)
from ...utils.import_utils import is_triton_available
from ..gemma.modeling_gemma import GemmaRotaryEmbedding, apply_rotary_pos_emb
from ..gemma.modeling_gemma import apply_rotary_pos_emb


if is_flash_attn_2_available():
Expand Down Expand Up @@ -493,8 +493,32 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.Wo(self.drop(self.act(input) * gate))


class ModernBertRotaryEmbedding(GemmaRotaryEmbedding):
pass
class ModernBertRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()

self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)

@torch.no_grad()
def forward(self, x, position_ids, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
self.inv_freq.to(x.device)
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


def eager_attention_forward(
Expand Down Expand Up @@ -1247,13 +1271,16 @@ def __init__(self, config: ModernBertConfig):
self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
self.drop = torch.nn.Dropout(config.classifier_dropout)

def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
if self.config.classifier_pooling == "cls":
hidden_states = hidden_states[:, 0]
elif self.config.classifier_pooling == "mean":
hidden_states = (hidden_states * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(
dim=1, keepdim=True
)
def forward(
self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, apply_pooling: bool = True
) -> torch.Tensor:
if apply_pooling:
if self.config.classifier_pooling == "cls":
hidden_states = hidden_states[:, 0]
elif self.config.classifier_pooling == "mean":
hidden_states = (hidden_states * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(
dim=1, keepdim=True
)
tomaarsen marked this conversation as resolved.
Show resolved Hide resolved

return self.drop(self.norm(self.act(self.dense(hidden_states))))

Expand Down Expand Up @@ -1371,7 +1398,7 @@ def __init__(self, config: ModernBertConfig):
self.num_labels = config.num_labels

self.model = ModernBertModel(config)
self.drop = nn.Dropout(config.classifier_dropout)
self.head = ModernBertPoolingHead(config)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)

# Initialize weights and apply final processing
Expand Down Expand Up @@ -1422,7 +1449,7 @@ def forward(
)
last_hidden_state = outputs[0]

last_hidden_state = self.drop(last_hidden_state)
last_hidden_state = self.head(last_hidden_state, attention_mask, apply_pooling=False)
tomaarsen marked this conversation as resolved.
Show resolved Hide resolved
logits = self.classifier(last_hidden_state)

loss = None
Expand Down