diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index db8d98893f96fe..237fba6f645fa5 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -610,8 +610,6 @@ def init_weight(module: nn.Module, std: float): init_weight(module.Wqkv, stds["in"]) init_weight(module.Wo, stds["out"]) elif isinstance(module, ModernBertPredictionHead): - init_weight(module.dense, stds["in"]) - elif isinstance(module, ModernBertPoolingHead): init_weight(module.dense, stds["out"]) elif isinstance(module, ModernBertForMaskedLM): init_weight(module.decoder, stds["out"]) @@ -1109,26 +1107,6 @@ def forward( ) -class ModernBertPoolingHead(nn.Module): - def __init__(self, config: ModernBertConfig): - super().__init__() - self.config = config - self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.classifier_bias) - self.act = ACT2FN[config.classifier_activation] - self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) - self.drop = torch.nn.Dropout(config.classifier_dropout) - - def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: - if self.config.classifier_pooling == "cls": - hidden_states = hidden_states[:, 0] - elif self.config.classifier_pooling == "mean": - hidden_states = (hidden_states * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum( - dim=1, keepdim=True - ) - - return self.drop(self.norm(self.act(self.dense(hidden_states)))) - - @add_start_docstrings( "The ModernBert Model with a sequence classification head on top that performs pooling.", MODERNBERT_START_DOCSTRING, @@ -1140,7 +1118,8 @@ def __init__(self, config: ModernBertConfig): self.config = config self.model = ModernBertModel(config) - self.head = ModernBertPoolingHead(config) + self.head = ModernBertPredictionHead(config) + self.drop = torch.nn.Dropout(config.classifier_dropout) self.classifier = nn.Linear(config.hidden_size, config.num_labels) # Initialize weights and apply final processing @@ -1194,7 +1173,15 @@ def forward( ) last_hidden_state = outputs[0] - pooled_output = self.head(last_hidden_state, attention_mask) + if self.config.classifier_pooling == "cls": + last_hidden_state = last_hidden_state[:, 0] + elif self.config.classifier_pooling == "mean": + last_hidden_state = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum( + dim=1, keepdim=True + ) + + pooled_output = self.head(last_hidden_state) + pooled_output = self.drop(pooled_output) logits = self.classifier(pooled_output) loss = None @@ -1242,7 +1229,8 @@ def __init__(self, config: ModernBertConfig): self.num_labels = config.num_labels self.model = ModernBertModel(config) - self.drop = nn.Dropout(config.classifier_dropout) + self.head = ModernBertPredictionHead(config) + self.drop = torch.nn.Dropout(config.classifier_dropout) self.classifier = nn.Linear(config.hidden_size, config.num_labels) # Initialize weights and apply final processing @@ -1293,6 +1281,7 @@ def forward( ) last_hidden_state = outputs[0] + last_hidden_state = self.head(last_hidden_state) last_hidden_state = self.drop(last_hidden_state) logits = self.classifier(last_hidden_state) diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index 3c23f9178b1b51..dac356146f3015 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -40,7 +40,7 @@ logging, ) from ...utils.import_utils import is_triton_available -from ..gemma.modeling_gemma import GemmaRotaryEmbedding, apply_rotary_pos_emb +from ..gemma.modeling_gemma import apply_rotary_pos_emb if is_flash_attn_2_available(): @@ -493,8 +493,32 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return self.Wo(self.drop(self.act(input) * gate)) -class ModernBertRotaryEmbedding(GemmaRotaryEmbedding): - pass +class ModernBertRotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) + self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) + + @torch.no_grad() + def forward(self, x, position_ids, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + self.inv_freq.to(x.device) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) def eager_attention_forward( @@ -811,8 +835,6 @@ def init_weight(module: nn.Module, std: float): init_weight(module.Wqkv, stds["in"]) init_weight(module.Wo, stds["out"]) elif isinstance(module, ModernBertPredictionHead): - init_weight(module.dense, stds["in"]) - elif isinstance(module, ModernBertPoolingHead): init_weight(module.dense, stds["out"]) elif isinstance(module, ModernBertForMaskedLM): init_weight(module.decoder, stds["out"]) @@ -1238,26 +1260,6 @@ def forward( ) -class ModernBertPoolingHead(nn.Module): - def __init__(self, config: ModernBertConfig): - super().__init__() - self.config = config - self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.classifier_bias) - self.act = ACT2FN[config.classifier_activation] - self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) - self.drop = torch.nn.Dropout(config.classifier_dropout) - - def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: - if self.config.classifier_pooling == "cls": - hidden_states = hidden_states[:, 0] - elif self.config.classifier_pooling == "mean": - hidden_states = (hidden_states * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum( - dim=1, keepdim=True - ) - - return self.drop(self.norm(self.act(self.dense(hidden_states)))) - - @add_start_docstrings( "The ModernBert Model with a sequence classification head on top that performs pooling.", MODERNBERT_START_DOCSTRING, @@ -1269,7 +1271,8 @@ def __init__(self, config: ModernBertConfig): self.config = config self.model = ModernBertModel(config) - self.head = ModernBertPoolingHead(config) + self.head = ModernBertPredictionHead(config) + self.drop = torch.nn.Dropout(config.classifier_dropout) self.classifier = nn.Linear(config.hidden_size, config.num_labels) # Initialize weights and apply final processing @@ -1323,7 +1326,15 @@ def forward( ) last_hidden_state = outputs[0] - pooled_output = self.head(last_hidden_state, attention_mask) + if self.config.classifier_pooling == "cls": + last_hidden_state = last_hidden_state[:, 0] + elif self.config.classifier_pooling == "mean": + last_hidden_state = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum( + dim=1, keepdim=True + ) + + pooled_output = self.head(last_hidden_state) + pooled_output = self.drop(pooled_output) logits = self.classifier(pooled_output) loss = None @@ -1371,7 +1382,8 @@ def __init__(self, config: ModernBertConfig): self.num_labels = config.num_labels self.model = ModernBertModel(config) - self.drop = nn.Dropout(config.classifier_dropout) + self.head = ModernBertPredictionHead(config) + self.drop = torch.nn.Dropout(config.classifier_dropout) self.classifier = nn.Linear(config.hidden_size, config.num_labels) # Initialize weights and apply final processing @@ -1422,6 +1434,7 @@ def forward( ) last_hidden_state = outputs[0] + last_hidden_state = self.head(last_hidden_state) last_hidden_state = self.drop(last_hidden_state) logits = self.classifier(last_hidden_state)