Skip to content

Commit

Permalink
Modernbert Release Fixes (#35344)
Browse files Browse the repository at this point in the history
* fix ForSequenceClassification

* unmodularize rope layer

* fix linting warning

* Avoid complex PoolingHead, only one prediction head needed

---------

Co-authored-by: Tom Aarsen <[email protected]>
  • Loading branch information
warner-benjamin and tomaarsen authored Dec 19, 2024
1 parent 1fa807f commit 0ade1ca
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 53 deletions.
39 changes: 14 additions & 25 deletions src/transformers/models/modernbert/modeling_modernbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,8 +610,6 @@ def init_weight(module: nn.Module, std: float):
init_weight(module.Wqkv, stds["in"])
init_weight(module.Wo, stds["out"])
elif isinstance(module, ModernBertPredictionHead):
init_weight(module.dense, stds["in"])
elif isinstance(module, ModernBertPoolingHead):
init_weight(module.dense, stds["out"])
elif isinstance(module, ModernBertForMaskedLM):
init_weight(module.decoder, stds["out"])
Expand Down Expand Up @@ -1109,26 +1107,6 @@ def forward(
)


class ModernBertPoolingHead(nn.Module):
def __init__(self, config: ModernBertConfig):
super().__init__()
self.config = config
self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.classifier_bias)
self.act = ACT2FN[config.classifier_activation]
self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
self.drop = torch.nn.Dropout(config.classifier_dropout)

def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
if self.config.classifier_pooling == "cls":
hidden_states = hidden_states[:, 0]
elif self.config.classifier_pooling == "mean":
hidden_states = (hidden_states * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(
dim=1, keepdim=True
)

return self.drop(self.norm(self.act(self.dense(hidden_states))))


@add_start_docstrings(
"The ModernBert Model with a sequence classification head on top that performs pooling.",
MODERNBERT_START_DOCSTRING,
Expand All @@ -1140,7 +1118,8 @@ def __init__(self, config: ModernBertConfig):
self.config = config

self.model = ModernBertModel(config)
self.head = ModernBertPoolingHead(config)
self.head = ModernBertPredictionHead(config)
self.drop = torch.nn.Dropout(config.classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)

# Initialize weights and apply final processing
Expand Down Expand Up @@ -1194,7 +1173,15 @@ def forward(
)
last_hidden_state = outputs[0]

pooled_output = self.head(last_hidden_state, attention_mask)
if self.config.classifier_pooling == "cls":
last_hidden_state = last_hidden_state[:, 0]
elif self.config.classifier_pooling == "mean":
last_hidden_state = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(
dim=1, keepdim=True
)

pooled_output = self.head(last_hidden_state)
pooled_output = self.drop(pooled_output)
logits = self.classifier(pooled_output)

loss = None
Expand Down Expand Up @@ -1242,7 +1229,8 @@ def __init__(self, config: ModernBertConfig):
self.num_labels = config.num_labels

self.model = ModernBertModel(config)
self.drop = nn.Dropout(config.classifier_dropout)
self.head = ModernBertPredictionHead(config)
self.drop = torch.nn.Dropout(config.classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)

# Initialize weights and apply final processing
Expand Down Expand Up @@ -1293,6 +1281,7 @@ def forward(
)
last_hidden_state = outputs[0]

last_hidden_state = self.head(last_hidden_state)
last_hidden_state = self.drop(last_hidden_state)
logits = self.classifier(last_hidden_state)

Expand Down
69 changes: 41 additions & 28 deletions src/transformers/models/modernbert/modular_modernbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
logging,
)
from ...utils.import_utils import is_triton_available
from ..gemma.modeling_gemma import GemmaRotaryEmbedding, apply_rotary_pos_emb
from ..gemma.modeling_gemma import apply_rotary_pos_emb


if is_flash_attn_2_available():
Expand Down Expand Up @@ -493,8 +493,32 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.Wo(self.drop(self.act(input) * gate))


class ModernBertRotaryEmbedding(GemmaRotaryEmbedding):
pass
class ModernBertRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()

self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)

@torch.no_grad()
def forward(self, x, position_ids, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
self.inv_freq.to(x.device)
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


def eager_attention_forward(
Expand Down Expand Up @@ -811,8 +835,6 @@ def init_weight(module: nn.Module, std: float):
init_weight(module.Wqkv, stds["in"])
init_weight(module.Wo, stds["out"])
elif isinstance(module, ModernBertPredictionHead):
init_weight(module.dense, stds["in"])
elif isinstance(module, ModernBertPoolingHead):
init_weight(module.dense, stds["out"])
elif isinstance(module, ModernBertForMaskedLM):
init_weight(module.decoder, stds["out"])
Expand Down Expand Up @@ -1238,26 +1260,6 @@ def forward(
)


class ModernBertPoolingHead(nn.Module):
def __init__(self, config: ModernBertConfig):
super().__init__()
self.config = config
self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.classifier_bias)
self.act = ACT2FN[config.classifier_activation]
self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
self.drop = torch.nn.Dropout(config.classifier_dropout)

def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
if self.config.classifier_pooling == "cls":
hidden_states = hidden_states[:, 0]
elif self.config.classifier_pooling == "mean":
hidden_states = (hidden_states * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(
dim=1, keepdim=True
)

return self.drop(self.norm(self.act(self.dense(hidden_states))))


@add_start_docstrings(
"The ModernBert Model with a sequence classification head on top that performs pooling.",
MODERNBERT_START_DOCSTRING,
Expand All @@ -1269,7 +1271,8 @@ def __init__(self, config: ModernBertConfig):
self.config = config

self.model = ModernBertModel(config)
self.head = ModernBertPoolingHead(config)
self.head = ModernBertPredictionHead(config)
self.drop = torch.nn.Dropout(config.classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)

# Initialize weights and apply final processing
Expand Down Expand Up @@ -1323,7 +1326,15 @@ def forward(
)
last_hidden_state = outputs[0]

pooled_output = self.head(last_hidden_state, attention_mask)
if self.config.classifier_pooling == "cls":
last_hidden_state = last_hidden_state[:, 0]
elif self.config.classifier_pooling == "mean":
last_hidden_state = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(
dim=1, keepdim=True
)

pooled_output = self.head(last_hidden_state)
pooled_output = self.drop(pooled_output)
logits = self.classifier(pooled_output)

loss = None
Expand Down Expand Up @@ -1371,7 +1382,8 @@ def __init__(self, config: ModernBertConfig):
self.num_labels = config.num_labels

self.model = ModernBertModel(config)
self.drop = nn.Dropout(config.classifier_dropout)
self.head = ModernBertPredictionHead(config)
self.drop = torch.nn.Dropout(config.classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)

# Initialize weights and apply final processing
Expand Down Expand Up @@ -1422,6 +1434,7 @@ def forward(
)
last_hidden_state = outputs[0]

last_hidden_state = self.head(last_hidden_state)
last_hidden_state = self.drop(last_hidden_state)
logits = self.classifier(last_hidden_state)

Expand Down

0 comments on commit 0ade1ca

Please sign in to comment.