From 6b5a8233c0ca00c7f461e6879f59a2e8bcba3c47 Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Mon, 9 Dec 2024 01:12:08 -0600 Subject: [PATCH 01/88] initial cut of modernbert for transformers --- src/transformers/__init__.py | 18 + src/transformers/loss/loss_utils.py | 17 + src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 2 + src/transformers/models/auto/modeling_auto.py | 4 + .../models/auto/tokenization_auto.py | 1 + .../models/modernbert/__init__.py | 61 + .../modernbert/configuration_modernbert.py | 184 +++ .../models/modernbert/modeling_modernbert.py | 1099 +++++++++++++++ .../models/modernbert/modular_modernbert.py | 1249 +++++++++++++++++ tests/models/modernbert/__init__.py | 0 11 files changed, 2636 insertions(+) create mode 100644 src/transformers/models/modernbert/__init__.py create mode 100644 src/transformers/models/modernbert/configuration_modernbert.py create mode 100644 src/transformers/models/modernbert/modeling_modernbert.py create mode 100644 src/transformers/models/modernbert/modular_modernbert.py create mode 100644 tests/models/modernbert/__init__.py diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index e1ca1956807318..b70544f6788774 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -450,6 +450,7 @@ "models.fuyu": ["FuyuConfig"], "models.gemma": ["GemmaConfig"], "models.gemma2": ["Gemma2Config"], + "models.modernbert": ["ModernBertConfig"], "models.git": [ "GitConfig", "GitProcessor", @@ -2299,6 +2300,15 @@ "Gemma2PreTrainedModel", ] ) + _import_structure["models.modernbert"].extend( + [ + "ModernBertForCausalLM", + "ModernBertForSequenceClassification", + "ModernBertForTokenClassification", + "ModernBertModel", + "ModernBertPreTrainedModel", + ] + ) _import_structure["models.git"].extend( [ "GitForCausalLM", @@ -5328,6 +5338,7 @@ from .models.fuyu import FuyuConfig from .models.gemma import GemmaConfig from .models.gemma2 import Gemma2Config + from .models.modernbert import ModernBertConfig from .models.git import ( GitConfig, GitProcessor, @@ -7056,6 +7067,13 @@ Gemma2Model, Gemma2PreTrainedModel, ) + from .models.modernbert import ( + ModernBertForCausalLM, + ModernBertForSequenceClassification, + ModernBertForTokenClassification, + ModernBertModel, + ModernBertPreTrainedModel, + ) from .models.git import ( GitForCausalLM, GitModel, diff --git a/src/transformers/loss/loss_utils.py b/src/transformers/loss/loss_utils.py index efa23d24e360b4..7f6aaaa44264ca 100644 --- a/src/transformers/loss/loss_utils.py +++ b/src/transformers/loss/loss_utils.py @@ -47,6 +47,22 @@ def ForCausalLMLoss( return loss +def ForMaskedLMLoss( + logits, labels, vocab_size: int, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs +): + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + + # Flatten the tokens + logits = logits.view(-1, vocab_size) + labels = labels.view(-1) + # Enable model parallelism + + labels = labels.to(logits.device) + loss = fixed_cross_entropy(logits, labels, num_items_in_batch, ignore_index, **kwargs) + return loss + + def ForSequenceClassificationLoss(labels, pooled_logits, config, **kwargs): num_labels = config.num_labels if config.problem_type is None: @@ -101,6 +117,7 @@ def ForTokenClassification(logits, labels, config, **kwargs): LOSS_MAPPING = { "ForCausalLM": ForCausalLMLoss, + "ForMaskedLM": ForMaskedLMLoss, "ForQuestionAnswering": ForQuestionAnsweringLoss, "ForSequenceClassification": ForSequenceClassificationLoss, "ForTokenClassification": ForTokenClassification, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 2d2a3b41d4378b..2d295287b39de0 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -96,6 +96,7 @@ fuyu, gemma, gemma2, + modernbert, git, glm, glpn, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 4ab6d392282657..9c1374c5b9ccb6 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -113,6 +113,7 @@ ("fuyu", "FuyuConfig"), ("gemma", "GemmaConfig"), ("gemma2", "Gemma2Config"), + ("modernbert", "ModernBertConfig"), ("git", "GitConfig"), ("glm", "GlmConfig"), ("glpn", "GLPNConfig"), @@ -417,6 +418,7 @@ ("fuyu", "Fuyu"), ("gemma", "Gemma"), ("gemma2", "Gemma2"), + ("modernbert", "ModernBERT"), ("git", "GIT"), ("glm", "GLM"), ("glpn", "GLPN"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 2c519a7dc42ca5..fb08610df495bd 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -110,6 +110,7 @@ ("funnel", ("FunnelModel", "FunnelBaseModel")), ("gemma", "GemmaModel"), ("gemma2", "Gemma2Model"), + ("modernbert", "ModernBertModel"), ("git", "GitModel"), ("glm", "GlmModel"), ("glpn", "GLPNModel"), @@ -487,6 +488,7 @@ ("fuyu", "FuyuForCausalLM"), ("gemma", "GemmaForCausalLM"), ("gemma2", "Gemma2ForCausalLM"), + ("modernbert", "ModernBertForCausalLM"), ("git", "GitForCausalLM"), ("glm", "GlmForCausalLM"), ("gpt-sw3", "GPT2LMHeadModel"), @@ -945,6 +947,7 @@ ("funnel", "FunnelForSequenceClassification"), ("gemma", "GemmaForSequenceClassification"), ("gemma2", "Gemma2ForSequenceClassification"), + ("modernbert", "ModernBertForSequenceClassification"), ("glm", "GlmForSequenceClassification"), ("gpt-sw3", "GPT2ForSequenceClassification"), ("gpt2", "GPT2ForSequenceClassification"), @@ -1136,6 +1139,7 @@ ("funnel", "FunnelForTokenClassification"), ("gemma", "GemmaForTokenClassification"), ("gemma2", "Gemma2ForTokenClassification"), + ("modernbert", "ModernBertForTokenClassification"), ("glm", "GlmForTokenClassification"), ("gpt-sw3", "GPT2ForTokenClassification"), ("gpt2", "GPT2ForTokenClassification"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index e246bf3094c9cb..cce23de13dbcfa 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -203,6 +203,7 @@ "GemmaTokenizerFast" if is_tokenizers_available() else None, ), ), + ("modernbert", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)), ("git", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), ("glm", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), ("gpt-sw3", ("GPTSw3Tokenizer" if is_sentencepiece_available() else None, None)), diff --git a/src/transformers/models/modernbert/__init__.py b/src/transformers/models/modernbert/__init__.py new file mode 100644 index 00000000000000..40de068897fa00 --- /dev/null +++ b/src/transformers/models/modernbert/__init__.py @@ -0,0 +1,61 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import ( + OptionalDependencyNotAvailable, + _LazyModule, + is_torch_available, +) + + +_import_structure = { + "configuration_modernbert": ["ModernBertConfig"], +} + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_modernbert"] = [ + "ModernBertForCausalLM", + "ModernBertModel", + "ModernBertPreTrainedModel", + "ModernBertForSequenceClassification", + "ModernBertForTokenClassification", + ] + +if TYPE_CHECKING: + from .configuration_modernbert import ModernBertConfig + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_modernbert import ( + ModernBertForCausalLM, + ModernBertForSequenceClassification, + ModernBertForTokenClassification, + ModernBertModel, + ModernBertPreTrainedModel, + ) + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/transformers/models/modernbert/configuration_modernbert.py b/src/transformers/models/modernbert/configuration_modernbert.py new file mode 100644 index 00000000000000..f030f80c4d2403 --- /dev/null +++ b/src/transformers/models/modernbert/configuration_modernbert.py @@ -0,0 +1,184 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/modernbert/modular_modernbert.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_modernbert.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...configuration_utils import PretrainedConfig + + +class ModernBertConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`ModernBertModel`]. It is used to instantiate an ModernBert + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the ModernBert-7B. + e.g. [answerdotai/modernbert-base](https://huggingface.co/answerdotai/modernbert-base) + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 256000): + Vocabulary size of the ModernBert model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`ModernBertModel`] + hidden_size (`int`, *optional*, defaults to 2304): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 9216): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 26): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*, defaults to 4): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + head_dim (`int`, *optional*, defaults to 256): + The attention head dimension. + hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"` + if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function. + max_position_embeddings (`int`, *optional*, defaults to 8192): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*, defaults to 0): + Padding token id. + eos_token_id (`int`, *optional*, defaults to 1): + End of stream token id. + bos_token_id (`int`, *optional*, defaults to 2): + Beginning of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `True`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + query_pre_attn_scalar (`float`, *optional*, defaults to 256): scaling factor used on the attention scores + sliding_window (`int`, *optional*, defaults to 4096): in ModernBert, every other layer uses sliding window attention. This is the + size of the sliding window. + final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits. + attn_logit_softcapping (`float`, *optional*, defaults to 50.0): scaling factor when applying tanh softcapping on the attention scores. + cache_implementation (`str`, *optional*, defaults to `"hybrid"`): the cache type to be used with `generate`. + + ```python + >>> from transformers import ModernBertModel, ModernBertConfig + >>> # Initializing a ModernBert modernbert-7b style configuration + >>> configuration = ModernBertConfig() + >>> # Initializing a model from the modernbert-7b style configuration + >>> model = ModernBertModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "modernbert" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=50368, + hidden_size=768, + intermediate_size=1152, + num_hidden_layers=22, + num_attention_heads=12, + hidden_activation="gelu_python", + max_position_embeddings=8192, + initializer_range=0.02, + initalizer_cutoff_factor=2.0, + norm_eps=1e-5, + norm_bias=False, + pad_token_id=50283, + eos_token_id=50281, + bos_token_id=50282, + cls_token_id=50281, + sep_token_id=50282, + tie_word_embeddings=True, + global_rope_theta=160000.0, + attention_bias=False, + attention_dropout=0.0, + attn_out_dropout=0.1, + global_attn_every_n_layers=3, + local_attention=128, + local_rope_theta=10000.0, + skip_first_prenorm=True, + embedding_norm=True, + embedding_dropout=0.0, + mlp_bias=False, + mlp_dropout=0.0, + unpad_inputs=True, + unpad_no_grad=True, + decoder_bias=True, + classifier_dropout=0.0, + classifier_pooling="mean", + classifier_norm=True, + classifier_bias=True, + classifier_activation=None, + deterministic_flash_attn=False, + **kwargs, + ): + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + cls_token_id=cls_token_id, + sep_token_id=sep_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.initializer_range = initializer_range + self.initalizer_cutoff_factor = initalizer_cutoff_factor + self.norm_eps = norm_eps + self.norm_bias = norm_bias + self.global_rope_theta = global_rope_theta + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.attn_out_dropout = attn_out_dropout + self.hidden_activation = hidden_activation + self.global_attn_every_n_layers = global_attn_every_n_layers + self.local_attention = local_attention + self.local_rope_theta = local_rope_theta + self.skip_first_prenorm = skip_first_prenorm + self.embedding_norm = embedding_norm + self.embedding_dropout = embedding_dropout + self.mlp_bias = mlp_bias + self.mlp_dropout = mlp_dropout + self.unpad_inputs = unpad_inputs + self.unpad_no_grad = unpad_no_grad + self.decoder_bias = decoder_bias + self.classifier_dropout = classifier_dropout + self.classifier_pooling = classifier_pooling + self.classifier_bias = classifier_bias + self.classifier_norm = classifier_norm + self.classifier_activation = classifier_activation if classifier_activation is not None else hidden_activation + self.deterministic_flash_attn = deterministic_flash_attn diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py new file mode 100644 index 00000000000000..d5cbd8c0a94c87 --- /dev/null +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -0,0 +1,1099 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/modernbert/modular_modernbert.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_modernbert.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from enum import Enum +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ...activations import ACT2FN +from ...modeling_outputs import BaseModelOutput, MaskedLMOutput +from ...modeling_utils import PreTrainedModel +from ...utils import is_flash_attn_2_available, logging +from .configuration_modernbert import ModernBertConfig + + +if is_flash_attn_2_available(): + from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func + from flash_attn.layers.rotary import RotaryEmbedding + from flash_attn.ops.triton.rotary import apply_rotary + +logger = logging.get_logger(__name__) + + +class ModernBertModuleType(str, Enum): + in_module = "in" + out_module = "out" + embedding = "emb" + final_out = "final_out" + + +class ModernBertPoolingType(str, Enum): + cls = "cls" + mean = "mean" + max = "max" + + +class ApplyRotaryEmbUnpad(torch.autograd.Function): + @staticmethod + def forward( + ctx, + qkv, + cos, + sin, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + ): + # (total_nnz, 3, nheads, headdim) + total_nnz, three, nheads, headdim = qkv.shape + assert three == 3 + if qkv.is_contiguous(): + # Call 1 kernel instead of 2 kernels + # We need qkv to be contiguous so that when we reshape to combine (3, nheads) + # dimensions, we get the same tensor + # qk = rearrange(qkv[:, :2], "b_s t h d -> b_s (t h) d") + qk = qkv[:, :2].view(total_nnz, -1, headdim) + apply_rotary( + qk, + cos, + sin, + seqlen_offsets=0, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + interleaved=False, + inplace=True, + ) + else: + q, k = qkv[:, 0, :, :], qkv[:, 1, :, :] + apply_rotary( + q, + cos, + sin, + seqlen_offsets=0, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + interleaved=False, + inplace=True, + ) + apply_rotary( + k, + cos, + sin, + seqlen_offsets=0, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + interleaved=False, + inplace=True, + ) + + ctx.save_for_backward(cos, sin, cu_seqlens) + ctx.max_seqlen = max_seqlen + return qkv + + @staticmethod + def backward(ctx, do): + cos, sin, cu_seqlens = ctx.saved_tensors + if do.is_contiguous(): + total_nnz, three, nheads, headdim = do.shape + # Call 1 kernel instead of 2 kernels + # We need dqkv to be contiguous so that when we reshape to combine (3, nheads) dimensions, we get the same tensor + dqk = do[:, :2].view(total_nnz, -1, headdim) + apply_rotary( + dqk, + cos, + sin, + seqlen_offsets=0, + cu_seqlens=cu_seqlens, + max_seqlen=ctx.max_seqlen, + interleaved=False, + inplace=True, + conjugate=True, + ) + else: + dq, dk = do[:, 0, :, :], do[:, 1, :, :] + apply_rotary( + dq, + cos, + sin, + seqlen_offsets=0, + cu_seqlens=cu_seqlens, + max_seqlen=ctx.max_seqlen, + interleaved=False, + inplace=True, + conjugate=True, + ) + apply_rotary( + dk, + cos, + sin, + seqlen_offsets=0, + cu_seqlens=cu_seqlens, + max_seqlen=ctx.max_seqlen, + interleaved=False, + inplace=True, + conjugate=True, + ) + + return do, None, None, None, None, None, None + + +def apply_rotary_unpadded( + qkv, + cos, + sin, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, +): + """ + Arguments: + qkv: (total_nnz, 3, nheads, headdim) - input tensor for packed QKV. + cos, sin: (seqlen_rotary, rotary_dim / 2) + interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead + of 1st half and 2nd half (GPT-NeoX style). + inplace: if True, apply rotary embedding in-place. + seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount. + Most commonly used in inference when we have KV cache. + cu_seqlens: (batch + 1,) or None + max_seqlen: int + Return: + out: (total_nnz, dim) + rotary_dim must be <= headdim + Apply rotary embedding to the first rotary_dim of x. + """ + return ApplyRotaryEmbUnpad.apply(qkv, cos, sin, cu_seqlens, max_seqlen) + + +class ModernBertUnpaddedRotaryEmbedding(RotaryEmbedding): + """ + The rotary position embeddings applied directly to unpadded sequences. + """ + + def __init__( + self, + dim: int, + base: float = 10000.0, + max_seqlen: Optional[int] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + """ + max_seqlen: if max_seqlen, device, and dtype are provided, we precompute the cos_sin_cache + up to max_seqlen. If the max_seqlen, device, or dtype during training/inference differ, + the cos_sin_cache wll be recomputed during the forward pass. + """ + super().__init__(dim=dim, base=base, pos_idx_in_fp32=True, device=device, interleaved=False) + self.max_seqlen = max_seqlen + + if max_seqlen is not None and device is not None and dtype is not None: + self._update_cos_sin_cache(max_seqlen, device=device, dtype=dtype) + + def forward( + self, + qkv: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: Optional[int] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Apply rotary embedding *inplace* to qkv. + qkv: (total_nnz, 3, nheads, headdim) + cu_seqlens: (batch + 1,) cumulative sequence lengths + max_seqlen: int max seq length in the batch + """ + if max_seqlen is not None: + self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype) + + qkv = apply_rotary_unpadded( + qkv, + self._cos_cached, + self._sin_cached, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + + return qkv + + def extra_repr(self) -> str: + return f"dim={self.dim}, base={self.base}, scale_base={self.scale_base}" + + +# Copyright 2023 OLMo Authors +# License: Apache-2.0 + + +def _init_modernbert_weights( + config: ModernBertConfig, + module: Union[nn.Linear, nn.Embedding], + module_type: ModernBertModuleType, +) -> None: + """ + Initialize weights of a linear or embedding module. + + :param config: The model config. + :param module: The linear or embedding submodule to initialize. + """ + if module_type is None: + raise RuntimeError("When using the full megatron init, every module must have a type.") + + cutoff_factor = config.initalizer_cutoff_factor + if cutoff_factor is None: + cutoff_factor = 3 + + if module_type == ModernBertModuleType.in_module: + std = config.initializer_range # for att_proj (same as QKV), ff_proj + elif module_type == ModernBertModuleType.out_module: + std = config.initializer_range / math.sqrt(2.0 * config.num_hidden_layers) # for attn_out, ff_out + elif module_type == ModernBertModuleType.embedding: + std = config.initializer_range # token embeddings (wte) + elif module_type == ModernBertModuleType.final_out: + std = config.hidden_size**-0.5 # final output (ff_out) + else: + raise RuntimeError(f"Unknown module type '{module_type}'") + + nn.init.trunc_normal_( + module.weight, + mean=0.0, + std=std, + a=-cutoff_factor * std, + b=cutoff_factor * std, + ) + + if isinstance(module, nn.Linear): + if module.bias is not None: + nn.init.zeros_(module.bias) + + +class ModernBertEmbeddings(nn.Module): + """ + Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. + """ + + def __init__(self, config: ModernBertConfig): + super().__init__() + self.config = config + self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) + self.drop = nn.Dropout(config.embedding_dropout) if config.embedding_dropout > 0.0 else nn.Identity() + + def _init_weights(self, reset_params: bool = False): + _init_modernbert_weights(self.config, self.tok_embeddings, module_type=ModernBertModuleType.embedding) + if reset_params: + self.norm.reset_parameters() + + def forward(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None) -> torch.Tensor: + return self.drop(self.norm(self.tok_embeddings(input_ids))) + + +class ModernBertMLP(nn.Module): + """Applies the GLU at the end of each ModernBERT layer. + + Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate` + and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality. + """ + + def __init__(self, config: ModernBertConfig): + super().__init__() + self.config = config + self.Wi = nn.Linear(config.hidden_size, int(config.intermediate_size) * 2, bias=config.mlp_bias) + self.act = ACT2FN[config.hidden_activation] + self.drop = nn.Dropout(config.mlp_dropout) if config.mlp_dropout > 0.0 else nn.Identity() + self.Wo = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_bias) + + def _init_weights(self, reset_params: bool = False): + _init_modernbert_weights(self.config, self.Wi, module_type=ModernBertModuleType.in_module) + _init_modernbert_weights(self.config, self.Wo, module_type=ModernBertModuleType.out_module) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input, gate = self.Wi(hidden_states).chunk(2, dim=-1) + return self.Wo(self.drop(self.act(input) * gate)) + + +class ModernBertRotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) + self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) + + @torch.no_grad() + def forward(self, x, position_ids, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + self.inv_freq.to(x.device) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# def eager_attention_forward( +# config: ModernBertConfig, +# query: torch.Tensor, +# key: torch.Tensor, +# value: torch.Tensor, +# mask: Optional[torch.Tensor], +# **_kwargs, +# ) -> Tuple[torch.Tensor, torch.Tensor]: +# attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * config.scaling + +# if config.attn_logit_softcapping is not None: +# attn_weights = attn_weights / config.attn_logit_softcapping +# attn_weights = torch.tanh(attn_weights) +# attn_weights = attn_weights * config.attn_logit_softcapping +# if mask is not None: # no matter the length, we just slice it +# causal_mask = mask[:, :, :, : key_states.shape[-2]] +# attn_weights = attn_weights + causal_mask + +# # upcast attention to fp32 +# attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) +# attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=config.training) +# attn_output = torch.matmul(attn_weights, value_states) +# attn_output = attn_output.transpose(1, 2).contiguous() +# return attn_output, attn_weights + + +def flash_attention_forward( + config: "ModernBertAttention", + qkv: torch.Tensor, + rotary_emb: ModernBertUnpaddedRotaryEmbedding, + cu_seqlens: torch.Tensor, + max_seqlen: int, + local_attention: Tuple[int, int], + bs: int, + dim: int, + target_dtype: torch.dtype = torch.bfloat16, + **_kwargs, +) -> torch.Tensor: + # (total_seqlen, 3, nheads, headdim) + qkv = rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + + convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16) + if convert_dtype: + # FA2 implementation only supports fp16 and bf16. If FA2 is supported, + # bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported) + orig_dtype = qkv.dtype + qkv = qkv.to(target_dtype) + + attn = flash_attn_varlen_qkvpacked_func( + qkv, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + dropout_p=config.attention_dropout if config.training else 0.0, + deterministic=config.deterministic_flash_attn, + window_size=local_attention, + ) + attn = attn.to(orig_dtype) # type: ignore + else: + attn = flash_attn_varlen_qkvpacked_func( + qkv, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + dropout_p=config.attention_dropout if config.training else 0.0, + deterministic=config.deterministic_flash_attn, + window_size=local_attention, + ) + return attn.view(bs, dim) + + +# def flex_attention_forward( +# config: ModernBertConfig, +# query: torch.Tensor, +# key: torch.Tensor, +# value: torch.Tensor, +# mask: Optional[torch.Tensor], +# output_attentions: bool = False, +# **_kwargs, +# ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + +# attn_output = flex_attention( +# query, +# key, +# value, +# enable_gqa=True, +# scale=config.scaling, +# return_lse=output_attentions, +# ) +# if not output_attentions: +# attn_weights = None +# else: +# attn_output, attn_weights = attn_output + +# attn_output = attn_output.transpose(1, 2).contiguous() +# return attn_output, attn_weights + + +# def sdpa_attention_forward( +# config: ModernBertConfig, +# query: torch.Tensor, +# key: torch.Tensor, +# value: torch.Tensor, +# mask: Optional[torch.Tensor], +# **_kwargs, +# ) -> Tuple[torch.Tensor, None]: + +# causal_mask = mask +# if mask is not None: +# causal_mask = causal_mask[:, :, :, : key.shape[-2]] + +# # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, +# # Reference: https://github.com/pytorch/pytorch/issues/112577. +# if query.device.type == "cuda" and causal_mask is not None: +# query = query.contiguous() +# key = key.contiguous() +# value = value.contiguous() + +# attn_output = torch.nn.functional.scaled_dot_product_attention( +# query, +# key, +# value, +# attn_mask=causal_mask, +# dropout_p=config.attention_dropout if config.training else 0.0, +# is_causal=False, +# scale=config.scaling, +# ) +# attn_output = attn_output.transpose(1, 2).contiguous() +# return attn_output, None + + +MODERNBERT_ATTENTION_FUNCTION = { + "flash_attention_2": flash_attention_forward, + # "flex_attention": flex_attention_forward, + # "eager": eager_attention_forward, + # "sdpa": sdpa_attention_forward, +} + + +class ModernBertAttention(nn.Module): + """Performs multi-headed self attention on a batch of unpadded sequences. + + If Flash Attention 2 is installed, this module uses Flash Attention to improve throughput. + If Flash Attention 2 is not installed, the implementation will use PyTorch's SDPA kernel, + which requires padding and unpadding inputs, adding some overhead. + + See `forward` method for additional details. + """ + + def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): + super().__init__() + self.config = config + self.layer_id = layer_id + + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention heads ({config.num_attention_heads})" + ) + + self.attention_dropout = config.attention_dropout + self.deterministic_flash_attn = config.deterministic_flash_attn + self.num_heads = config.num_attention_heads + self.head_dim = config.hidden_size // config.num_attention_heads + self.Wqkv = nn.Linear(config.hidden_size, 3 * self.head_dim * self.num_heads, bias=config.attention_bias) + + if layer_id % config.global_attn_every_n_layers != 0: + self.local_attention = (config.local_attention // 2, config.local_attention // 2) + else: + self.local_attention = (-1, -1) + + rope_theta = config.global_rope_theta + max_position_embeddings = config.max_position_embeddings + if self.local_attention != (-1, -1): + if config.local_rope_theta is not None: + rope_theta = config.local_rope_theta + max_position_embeddings = config.local_attention + + if config._attn_implementation == "flash_attention_2": + self.rotary_emb = ModernBertUnpaddedRotaryEmbedding( + dim=self.head_dim, max_seqlen=max_position_embeddings, base=rope_theta + ) + else: + self.rotary_emb = ModernBertRotaryEmbedding( + dim=self.head_dim, max_position_embeddings=max_position_embeddings, base=rope_theta + ) + + self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias) + self.out_drop = nn.Dropout(config.attn_out_dropout) if config.attn_out_dropout > 0.0 else nn.Identity() + + def _init_weights(self, reset_params: bool = False): + _init_modernbert_weights(self.config, self.Wqkv, module_type=ModernBertModuleType.in_module) + _init_modernbert_weights(self.config, self.Wo, module_type=ModernBertModuleType.out_module) + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + **kwargs, + ) -> torch.Tensor: + """Perform self-attention. + + There are two attention implementations supported: PyTorch's SDPA attention and Flash Attention 2. + + The arguments are unpadded. The SDPA implementation of attention requires padded arguments while the + Flash Attention implementation does not. If using SDPA we first call `pad_input`. Once we compute + attention, we re-unpad our outputs for the other layers. The pad/unpad operations add overhead, but not + sending pad tokens through ffs saves compute. + + Args: + hidden_states: (total_nnz, dim) + cu_seqlens: (batch + 1,) + max_seqlen: int + indices: (total_nnz,) + attn_mask: (batch, max_seqlen) + + Returns: + attention: (total_nnz, dim) + """ + bs, dim = hidden_states.shape[:2] + qkv = self.Wqkv(hidden_states) + qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) + + attn = MODERNBERT_ATTENTION_FUNCTION[self.config._attn_implementation]( + self, + qkv=qkv, + rotary_emb=self.rotary_emb, + position_ids=position_ids, + attention_mask=attention_mask, + local_attention=self.local_attention, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + bs=bs, + dim=dim, + ) + + return self.out_drop(self.Wo(attn)) + + +class ModernBertFlashAttention2(ModernBertAttention): + def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): + super().__init__(config, layer_id) + self.config._attn_implementation = "flash_attention_2" + logger.warning_once( + "The `ModernBertFlashAttention2` class is deprecated in favor of simply modifying the `config._attn_implementation`" + "attribute of the `GemmaAttention` class! It will be removed in v4.48" + ) + + +class ModernBertSdpaAttention(ModernBertAttention): + def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): + super().__init__(config, layer_id) + self.config._attn_implementation = "sdpa" + logger.warning_once( + "The `ModernBertFlashAttention2` class is deprecated in favor of simply modifying the `config._attn_implementation`" + "attribute of the `GemmaAttention` class! It will be removed in v4.48" + ) + + +class ModernBertEncoderLayer(nn.Module): + def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): + super().__init__() + self.config = config + if config.skip_first_prenorm and config.embedding_norm and layer_id == 0: + self.attn_norm = nn.Identity() + else: + self.attn_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) + self.attn = ModernBertAttention(config=config, layer_id=layer_id) + self.mlp_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) + self.mlp = ModernBertMLP(config) + + def _init_weights(self, reset_params: bool = False): + if reset_params: + self.attn_norm.reset_parameters() + self.mlp_norm.reset_parameters() + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + ) -> torch.Tensor: + """Forward pass for a ModernBert layer, including both attention and MLP. + + Args: + hidden_states: (total_nnz, dim) + position_ids: (total_nnz,) + attention_mask: (batch, max_seqlen) + cu_seqlens: (batch + 1,) + max_seqlen: int + """ + attn_out = hidden_states + self.attn( + self.attn_norm(hidden_states), + position_ids=position_ids, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + attention_mask=attention_mask, + ) + return attn_out + self.mlp(self.mlp_norm(attn_out)) + + +class ModernBertPredictionHead(nn.Module): + def __init__(self, config: ModernBertConfig): + super().__init__() + self.config = config + self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.classifier_bias) + self.act = ACT2FN[config.classifier_activation] if config.classifier_activation else nn.Identity() + self.norm = ( + nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) + if config.classifier_norm + else nn.Identity() + ) + + def _init_weights(self, reset_params: bool = False): + if reset_params: + self.norm.reset_parameters() + _init_modernbert_weights(self.config, self.dense, module_type=ModernBertModuleType.in_module) + + def reset_parameters(self): + self._init_weights(reset_params=True) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.norm(self.act(self.dense(hidden_states))) + + +class ModernBertPoolingHead(nn.Module): + def __init__(self, config: ModernBertConfig): + super().__init__() + self.config = config + self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.classifier_bias) + self.act = ACT2FN[config.classifier_activation] if config.classifier_activation else nn.Identity() + self.norm = ( + nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) + if config.classifier_norm + else nn.Identity() + ) + self.drop = torch.nn.Dropout(config.classifier_dropout) if config.classifier_dropout > 0 else nn.Identity() + self.pooling_type = ModernBertPoolingType(config.classifier_pooling) + + def forward(self, hidden_states: torch.Tensor, pool: Optional[bool] = True) -> torch.Tensor: + if pool: + if self.pooling_type == ModernBertPoolingType.cls: + output = hidden_states[:, 0] + elif self.pooling_type == ModernBertPoolingType.mean: + output = hidden_states.mean(dim=1) + elif self.pooling_type == ModernBertPoolingType.max: + output = hidden_states.max(dim=1)[0] + else: + output = hidden_states + + return self.drop(self.norm(self.act(self.dense(output)))) + + def _init_weights(self, reset_params: bool = False): + _init_modernbert_weights(self.config, self.dense, module_type=ModernBertModuleType.out_module) + if reset_params and hasattr(self.norm, "reset_parameters"): + self.norm.reset_parameters() + + def reset_parameters(self): + self._init_weights(reset_params=True) + + +def _unpad_modernbert_input( + inputs: torch.Tensor, + attention_mask: torch.Tensor, + position_ids: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Remove padding from input sequences. + + Args: + inputs: (batch, seqlen, ...) or (batch, seqlen) + attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. + position_ids: (batch, seqlen), int, position ids + labels: (batch, seqlen), int, labels + + Returns: + unpadded_inputs: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask. + indices: (total_nnz) + cu_seqlens: (batch + 1), the cumulative sequence lengths + max_seqlen_in_batch: int + unpadded_position_ids: (total_nnz) or None + unpadded_labels: (total_nnz) or None + """ + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = int(seqlens_in_batch.max().item()) + cu_seqlens = torch.nn.functional.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + + if inputs.dim() == 2: + unpadded_inputs = inputs.flatten()[indices] + else: + batch, seqlen, *rest = inputs.shape + shape = batch * seqlen + unpadded_inputs = inputs.view(shape, *rest)[indices] + + unpadded_position_ids = position_ids.flatten()[indices] if position_ids is not None else None + unpadded_labels = labels.flatten()[indices] if labels is not None else None + + return unpadded_inputs, indices, cu_seqlens, max_seqlen_in_batch, unpadded_position_ids, unpadded_labels + + +def _pad_modernbert_output( + inputs: torch.Tensor, + indices: torch.Tensor, + batch: int, + seqlen: int, + labels: Optional[torch.Tensor] = None, + ignore_index: int = -100, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Add padding to sequences. + + Args: + inputs: (total_nnz, ...) or (total_nnz,), where total_nnz = number of tokens selected in attention_mask. + indices: (total_nnz) + batch: int, batch size + seqlen: int, max sequence length + position_ids: (total_nnz) or None + labels: (total_nnz) or None + + Returns: + padded_inputs: (batch, seqlen, ...) or (batch, seqlen) + padded_labels: (batch, seqlen) or None + """ + if inputs.dim() == 1: + output = torch.zeros(batch * seqlen, dtype=inputs.dtype, device=inputs.device) + output[indices] = inputs + padded_inputs = output.view(batch, seqlen) + else: + _, *rest = inputs.shape + output = torch.zeros(batch * seqlen, *rest, dtype=inputs.dtype, device=inputs.device) + output[indices] = inputs + padded_inputs = output.view(batch, seqlen, *rest) + + padded_labels = None + if labels is not None: + padded_labels = torch.full( + (batch * seqlen,), fill_value=ignore_index, dtype=labels.dtype, device=labels.device + ) + padded_labels[indices] = labels + padded_labels = padded_labels.view(batch, seqlen) + + return padded_inputs, padded_labels + + # Copyright (c) 2023, Tri Dao. + # License: Apache-2.0 + + # if is_flash_attn_2_available(): + + +class ModernBertPreTrainedModel(PreTrainedModel): + config_class = ModernBertConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["ModernBertEmbeddings", "ModernBertEncoderLayer"] + _supports_flash_attn_2 = True + _supports_sdpa = False # TODO: Enable SDPA + + def _init_weights( + self, + module: Union[ModernBertEncoderLayer, ModernBertAttention, ModernBertMLP, ModernBertEmbeddings], + reset_params: bool = False, + ): + module._init_weights(reset_params) + + @torch.no_grad() + def _unpad_inputs_no_grad( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + position_ids: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + ): + return self._unpad_inputs( + input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels + ) + + def _unpad_inputs( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + position_ids: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + ): + return _unpad_modernbert_input( + inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels + ) + + @torch.no_grad() + def _pad_outputs_no_grad( + self, + inputs: torch.Tensor, + indices: torch.Tensor, + batch_size: int, + seqlen: int, + labels: Optional[torch.Tensor] = None, + ignore_index: int = -100, + ): + return self._pad_outputs( + inputs=inputs, + indices=indices, + batch_size=batch_size, + seqlen=seqlen, + labels=labels, + ignore_index=ignore_index, + ) + + def _pad_outputs( + self, + inputs: torch.Tensor, + indices: torch.Tensor, + batch_size: int, + seqlen: int, + labels: Optional[torch.Tensor] = None, + ignore_index: int = -100, + ): + return _pad_modernbert_output( + inputs=inputs, indices=indices, batch=batch_size, seqlen=seqlen, labels=labels, ignore_index=ignore_index + ) + + +class ModernBertModel(ModernBertPreTrainedModel): + def __init__(self, config: ModernBertConfig): + super().__init__(config) + self.config = config + self.embeddings = ModernBertEmbeddings(config) + self.layers = nn.ModuleList( + [ModernBertEncoderLayer(config, layer_id) for layer_id in range(config.num_hidden_layers)] + ) + self.final_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) + self.gradient_checkpointing = False + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.tok_embeddings + + def set_input_embeddings(self, value): + self.embeddings.tok_embeddings = value + + def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None): + if module and hasattr(module, "_init_weights"): + super()._init_weights(module, reset_params) + elif isinstance(reset_params, bool): + self.embeddings._init_weights(reset_params=reset_params) + + if reset_params: + self.final_norm.reset_parameters() + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + return_dict: Optional[bool] = None, + indices: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + batch_size: Optional[int] = None, + seq_len: Optional[int] = None, + ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutput]: + if batch_size is None and seq_len is None: + batch_size, seq_len = input_ids.shape[:2] + + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_len), device=input_ids.device, dtype=torch.bool) + + repad = False + if self.config.unpad_inputs: + if indices is None and cu_seqlens is None and max_seqlen is None: + repad = True + if self.config.unpad_no_grad: + input_ids, indices, cu_seqlens, max_seqlen, *_ = self._unpad_inputs_no_grad( + input_ids, attention_mask + ) + else: + input_ids, indices, cu_seqlens, max_seqlen, *_ = self._unpad_inputs(input_ids, attention_mask) + elif position_ids is None: + position_ids = torch.arange(seq_len, device=input_ids.device) + + hidden_states = self.embeddings(input_ids) + + for encoder_layer in self.layers: + if self.gradient_checkpointing and self.training: + hidden_states = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + cu_seqlens, + max_seqlen, + ) + else: + hidden_states = encoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + + hidden_states = self.final_norm(hidden_states) + + if repad: + hidden_states = self._pad_outputs(hidden_states, indices, batch_size, seq_len) + + if not return_dict: + return hidden_states + return BaseModelOutput(last_hidden_state=hidden_states) + + +class ModernBertForMaskedLM(ModernBertPreTrainedModel): + def __init__( + self, + config: ModernBertConfig, + pad_logits: bool = True, + pad_logits_no_grad: Optional[bool] = None, + sparse_prediction: bool = True, + sparse_pred_ignore_index: int = -100, + ): + super().__init__(config) + self.config = config + self.bert = ModernBertModel(config) + self.head = ModernBertPredictionHead(config) + + if config.tie_word_embeddings: + decoder_weights = self.bert.embeddings.tok_embeddings.weight + else: + decoder_weights = nn.Linear(config.hidden_size, config.vocab_size, bias=False).weight + self.decoder = nn.Linear(decoder_weights.size(1), decoder_weights.size(0), bias=config.decoder_bias) + self.decoder.weight = decoder_weights + + self.pad_logits = pad_logits + self.pad_logits_no_grad = pad_logits_no_grad if pad_logits_no_grad is not None else self.config.unpad_no_grad + self.sparse_prediction = sparse_prediction + self.sparse_pred_ignore_index = sparse_pred_ignore_index + # Initialize weights and apply final processing + self._init_weights(reset_params=False) + + def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None): + assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified" + if module: + super()._init_weights(module) + else: + assert isinstance(reset_params, bool) + self.bert._init_weights(reset_params=reset_params) + self.head._init_weights(reset_params=reset_params) + + # Output weights. + if not self.config.tie_word_embeddings: + _init_modernbert_weights( + self.config, self.decoder, self.config.hidden_size, module_type=ModernBertModuleType.out_module + ) + + def get_output_embeddings(self): + return self.decoder + + def set_output_embeddings(self, new_embeddings: nn.Linear): + self.decoder = new_embeddings + + def forward( + self, + input_ids: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + return_dict: Optional[bool] = None, + indices: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + batch_size: Optional[int] = None, + seq_len: Optional[int] = None, + **kwargs, + ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.unpad_inputs: + if indices is None and cu_seqlens is None and max_seqlen is None: + batch_size, seq_len = input_ids.shape[:2] + if self.config.unpad_no_grad: + input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = self._unpad_inputs_no_grad( + input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels + ) + else: + input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = self._unpad_inputs( + input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels + ) + + output = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + batch_size=batch_size, + seq_len=seq_len, + ) + + if self.sparse_prediction and labels is not None: + # flatten labels and output first + labels = labels.view(-1) + output = output.view(labels.shape[0], -1) + + # then filter out the non-masked tokens + mask_tokens = labels != self.sparse_pred_ignore_index + output = output[mask_tokens] + labels = labels[mask_tokens] + + logits = self.decoder(self.head(output)) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size) + + if self.pad_logits: + if self.pad_logits_no_grad: + logits = self._pad_outputs_no_grad(logits, indices, batch_size, seq_len)[0] + else: + logits = self._pad_outputs(logits, indices, batch_size, seq_len)[0] + return MaskedLMOutput( + loss=loss, + logits=logits, + hidden_states=None, + attentions=None, + ) + else: + return MaskedLMOutput( + loss=loss, + logits=logits, + hidden_states=None, + attentions=None, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + batch_size=batch_size, + seq_len=seq_len, + labels=labels, + ) diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py new file mode 100644 index 00000000000000..dc9ba007ef71cd --- /dev/null +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -0,0 +1,1249 @@ +# coding=utf-8 +# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from enum import Enum +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint + +from ...activations import ACT2FN +from ...configuration_utils import PretrainedConfig +from ...modeling_outputs import ( + BaseModelOutput, + MaskedLMOutput, + MultipleChoiceModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from ...modeling_utils import PreTrainedModel +from ...utils import ( + is_flash_attn_2_available, + is_flash_attn_greater_or_equal, + is_torch_greater_or_equal, + logging, +) +from ..gemma.modeling_gemma import GemmaRotaryEmbedding, apply_rotary_pos_emb + + +if is_flash_attn_2_available(): + from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func + from flash_attn.layers.rotary import RotaryEmbedding + from flash_attn.ops.triton.rotary import apply_rotary + +if is_torch_greater_or_equal("2.5"): + from torch.nn.attention.flex_attention import flex_attention + + +_CHECKPOINT_FOR_DOC = "answerdotai/modernbert-base" + +logger = logging.get_logger(__name__) + + +class ModernBertConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`ModernBertModel`]. It is used to instantiate an ModernBert + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the ModernBert-7B. + e.g. [answerdotai/modernbert-base](https://huggingface.co/answerdotai/modernbert-base) + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 256000): + Vocabulary size of the ModernBert model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`ModernBertModel`] + hidden_size (`int`, *optional*, defaults to 2304): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 9216): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 26): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*, defaults to 4): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + head_dim (`int`, *optional*, defaults to 256): + The attention head dimension. + hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"` + if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function. + max_position_embeddings (`int`, *optional*, defaults to 8192): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*, defaults to 0): + Padding token id. + eos_token_id (`int`, *optional*, defaults to 1): + End of stream token id. + bos_token_id (`int`, *optional*, defaults to 2): + Beginning of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `True`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + query_pre_attn_scalar (`float`, *optional*, defaults to 256): scaling factor used on the attention scores + sliding_window (`int`, *optional*, defaults to 4096): in ModernBert, every other layer uses sliding window attention. This is the + size of the sliding window. + final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits. + attn_logit_softcapping (`float`, *optional*, defaults to 50.0): scaling factor when applying tanh softcapping on the attention scores. + cache_implementation (`str`, *optional*, defaults to `"hybrid"`): the cache type to be used with `generate`. + + ```python + >>> from transformers import ModernBertModel, ModernBertConfig + >>> # Initializing a ModernBert modernbert-7b style configuration + >>> configuration = ModernBertConfig() + >>> # Initializing a model from the modernbert-7b style configuration + >>> model = ModernBertModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "modernbert" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=50368, + hidden_size=768, + intermediate_size=1152, + num_hidden_layers=22, + num_attention_heads=12, + hidden_activation="gelu_python", + max_position_embeddings=8192, + initializer_range=0.02, + initalizer_cutoff_factor=2.0, + norm_eps=1e-5, + norm_bias=False, + pad_token_id=50283, + eos_token_id=50281, + bos_token_id=50282, + cls_token_id=50281, + sep_token_id=50282, + tie_word_embeddings=True, + global_rope_theta=160000.0, + attention_bias=False, + attention_dropout=0.0, + attn_out_dropout=0.1, + global_attn_every_n_layers=3, + local_attention=128, + local_rope_theta=10000.0, + skip_first_prenorm=True, + embedding_norm=True, + embedding_dropout=0.0, + mlp_bias=False, + mlp_dropout=0.0, + unpad_inputs=True, + unpad_no_grad=True, + decoder_bias=True, + classifier_dropout=0.0, + classifier_pooling="mean", + classifier_norm=True, + classifier_bias=True, + classifier_activation=None, + deterministic_flash_attn=False, + **kwargs, + ): + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + cls_token_id=cls_token_id, + sep_token_id=sep_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.initializer_range = initializer_range + self.initalizer_cutoff_factor = initalizer_cutoff_factor + self.norm_eps = norm_eps + self.norm_bias = norm_bias + self.global_rope_theta = global_rope_theta + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.attn_out_dropout = attn_out_dropout + self.hidden_activation = hidden_activation + self.global_attn_every_n_layers = global_attn_every_n_layers + self.local_attention = local_attention + self.local_rope_theta = local_rope_theta + self.skip_first_prenorm = skip_first_prenorm + self.embedding_norm = embedding_norm + self.embedding_dropout = embedding_dropout + self.mlp_bias = mlp_bias + self.mlp_dropout = mlp_dropout + self.unpad_inputs = unpad_inputs + self.unpad_no_grad = unpad_no_grad + self.decoder_bias = decoder_bias + self.classifier_dropout = classifier_dropout + self.classifier_pooling = classifier_pooling + self.classifier_bias = classifier_bias + self.classifier_norm = classifier_norm + self.classifier_activation = classifier_activation if classifier_activation is not None else hidden_activation + self.deterministic_flash_attn = deterministic_flash_attn + + +class ModernBertModuleType(str, Enum): + in_module = "in" + out_module = "out" + embedding = "emb" + final_out = "final_out" + + +class ModernBertPoolingType(str, Enum): + cls = "cls" + mean = "mean" + max = "max" + + +# Copyright 2023 OLMo Authors +# License: Apache-2.0 + + +def _init_modernbert_weights( + config: ModernBertConfig, + module: Union[nn.Linear, nn.Embedding], + module_type: ModernBertModuleType, +) -> None: + """ + Initialize weights of a linear or embedding module. + + :param config: The model config. + :param module: The linear or embedding submodule to initialize. + """ + if module_type is None: + raise RuntimeError("When using the full megatron init, every module must have a type.") + + cutoff_factor = config.initalizer_cutoff_factor + if cutoff_factor is None: + cutoff_factor = 3 + + if module_type == ModernBertModuleType.in_module: + std = config.initializer_range # for att_proj (same as QKV), ff_proj + elif module_type == ModernBertModuleType.out_module: + std = config.initializer_range / math.sqrt(2.0 * config.num_hidden_layers) # for attn_out, ff_out + elif module_type == ModernBertModuleType.embedding: + std = config.initializer_range # token embeddings (wte) + elif module_type == ModernBertModuleType.final_out: + std = config.hidden_size**-0.5 # final output (ff_out) + else: + raise RuntimeError(f"Unknown module type '{module_type}'") + + nn.init.trunc_normal_( + module.weight, + mean=0.0, + std=std, + a=-cutoff_factor * std, + b=cutoff_factor * std, + ) + + if isinstance(module, nn.Linear): + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def _unpad_modernbert_input( + inputs: torch.Tensor, + attention_mask: torch.Tensor, + position_ids: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Remove padding from input sequences. + + Args: + inputs: (batch, seqlen, ...) or (batch, seqlen) + attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. + position_ids: (batch, seqlen), int, position ids + labels: (batch, seqlen), int, labels + + Returns: + unpadded_inputs: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask. + indices: (total_nnz) + cu_seqlens: (batch + 1), the cumulative sequence lengths + max_seqlen_in_batch: int + unpadded_position_ids: (total_nnz) or None + unpadded_labels: (total_nnz) or None + """ + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = int(seqlens_in_batch.max().item()) + cu_seqlens = torch.nn.functional.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + + if inputs.dim() == 2: + unpadded_inputs = inputs.flatten()[indices] + else: + batch, seqlen, *rest = inputs.shape + shape = batch * seqlen + unpadded_inputs = inputs.view(shape, *rest)[indices] + + unpadded_position_ids = position_ids.flatten()[indices] if position_ids is not None else None + unpadded_labels = labels.flatten()[indices] if labels is not None else None + + return unpadded_inputs, indices, cu_seqlens, max_seqlen_in_batch, unpadded_position_ids, unpadded_labels + + +def _pad_modernbert_output( + inputs: torch.Tensor, + indices: torch.Tensor, + batch: int, + seqlen: int, + labels: Optional[torch.Tensor] = None, + ignore_index: int = -100, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Add padding to sequences. + + Args: + inputs: (total_nnz, ...) or (total_nnz,), where total_nnz = number of tokens selected in attention_mask. + indices: (total_nnz) + batch: int, batch size + seqlen: int, max sequence length + position_ids: (total_nnz) or None + labels: (total_nnz) or None + + Returns: + padded_inputs: (batch, seqlen, ...) or (batch, seqlen) + padded_labels: (batch, seqlen) or None + """ + if inputs.dim() == 1: + output = torch.zeros(batch * seqlen, dtype=inputs.dtype, device=inputs.device) + output[indices] = inputs + padded_inputs = output.view(batch, seqlen) + else: + _, *rest = inputs.shape + output = torch.zeros(batch * seqlen, *rest, dtype=inputs.dtype, device=inputs.device) + output[indices] = inputs + padded_inputs = output.view(batch, seqlen, *rest) + + padded_labels = None + if labels is not None: + padded_labels = torch.full( + (batch * seqlen,), fill_value=ignore_index, dtype=labels.dtype, device=labels.device + ) + padded_labels[indices] = labels + padded_labels = padded_labels.view(batch, seqlen) + + return padded_inputs, padded_labels + + # Copyright (c) 2023, Tri Dao. + # License: Apache-2.0 + + # if is_flash_attn_2_available(): + + +class ApplyRotaryEmbUnpad(torch.autograd.Function): + @staticmethod + def forward( + ctx, + qkv, + cos, + sin, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + ): + # (total_nnz, 3, nheads, headdim) + total_nnz, three, nheads, headdim = qkv.shape + assert three == 3 + if qkv.is_contiguous(): + # Call 1 kernel instead of 2 kernels + # We need qkv to be contiguous so that when we reshape to combine (3, nheads) + # dimensions, we get the same tensor + # qk = rearrange(qkv[:, :2], "b_s t h d -> b_s (t h) d") + qk = qkv[:, :2].view(total_nnz, -1, headdim) + apply_rotary( + qk, + cos, + sin, + seqlen_offsets=0, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + interleaved=False, + inplace=True, + ) + else: + q, k = qkv[:, 0, :, :], qkv[:, 1, :, :] + apply_rotary( + q, + cos, + sin, + seqlen_offsets=0, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + interleaved=False, + inplace=True, + ) + apply_rotary( + k, + cos, + sin, + seqlen_offsets=0, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + interleaved=False, + inplace=True, + ) + + ctx.save_for_backward(cos, sin, cu_seqlens) + ctx.max_seqlen = max_seqlen + return qkv + + @staticmethod + def backward(ctx, do): + cos, sin, cu_seqlens = ctx.saved_tensors + if do.is_contiguous(): + total_nnz, three, nheads, headdim = do.shape + # Call 1 kernel instead of 2 kernels + # We need dqkv to be contiguous so that when we reshape to combine (3, nheads) dimensions, we get the same tensor + dqk = do[:, :2].view(total_nnz, -1, headdim) + apply_rotary( + dqk, + cos, + sin, + seqlen_offsets=0, + cu_seqlens=cu_seqlens, + max_seqlen=ctx.max_seqlen, + interleaved=False, + inplace=True, + conjugate=True, + ) + else: + dq, dk = do[:, 0, :, :], do[:, 1, :, :] + apply_rotary( + dq, + cos, + sin, + seqlen_offsets=0, + cu_seqlens=cu_seqlens, + max_seqlen=ctx.max_seqlen, + interleaved=False, + inplace=True, + conjugate=True, + ) + apply_rotary( + dk, + cos, + sin, + seqlen_offsets=0, + cu_seqlens=cu_seqlens, + max_seqlen=ctx.max_seqlen, + interleaved=False, + inplace=True, + conjugate=True, + ) + + return do, None, None, None, None, None, None + + +def apply_rotary_unpadded( + qkv, + cos, + sin, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, +): + """ + Arguments: + qkv: (total_nnz, 3, nheads, headdim) - input tensor for packed QKV. + cos, sin: (seqlen_rotary, rotary_dim / 2) + interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead + of 1st half and 2nd half (GPT-NeoX style). + inplace: if True, apply rotary embedding in-place. + seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount. + Most commonly used in inference when we have KV cache. + cu_seqlens: (batch + 1,) or None + max_seqlen: int + Return: + out: (total_nnz, dim) + rotary_dim must be <= headdim + Apply rotary embedding to the first rotary_dim of x. + """ + return ApplyRotaryEmbUnpad.apply(qkv, cos, sin, cu_seqlens, max_seqlen) + + +class ModernBertUnpaddedRotaryEmbedding(RotaryEmbedding): + """ + The rotary position embeddings applied directly to unpadded sequences. + """ + + def __init__( + self, + dim: int, + base: float = 10000.0, + max_seqlen: Optional[int] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + """ + max_seqlen: if max_seqlen, device, and dtype are provided, we precompute the cos_sin_cache + up to max_seqlen. If the max_seqlen, device, or dtype during training/inference differ, + the cos_sin_cache wll be recomputed during the forward pass. + """ + super().__init__(dim=dim, base=base, pos_idx_in_fp32=True, device=device, interleaved=False) + self.max_seqlen = max_seqlen + + if max_seqlen is not None and device is not None and dtype is not None: + self._update_cos_sin_cache(max_seqlen, device=device, dtype=dtype) + + def forward( + self, + qkv: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: Optional[int] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Apply rotary embedding *inplace* to qkv. + qkv: (total_nnz, 3, nheads, headdim) + cu_seqlens: (batch + 1,) cumulative sequence lengths + max_seqlen: int max seq length in the batch + """ + if max_seqlen is not None: + self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype) + + qkv = apply_rotary_unpadded( + qkv, + self._cos_cached, + self._sin_cached, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + + return qkv + + def extra_repr(self) -> str: + return f"dim={self.dim}, base={self.base}, scale_base={self.scale_base}" + + +class ModernBertEmbeddings(nn.Module): + """ + Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. + """ + + def __init__(self, config: ModernBertConfig): + super().__init__() + self.config = config + self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) + self.drop = nn.Dropout(config.embedding_dropout) if config.embedding_dropout > 0.0 else nn.Identity() + + def _init_weights(self, reset_params: bool = False): + _init_modernbert_weights(self.config, self.tok_embeddings, module_type=ModernBertModuleType.embedding) + if reset_params: + self.norm.reset_parameters() + + def forward(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None) -> torch.Tensor: + return self.drop(self.norm(self.tok_embeddings(input_ids))) + + +class ModernBertMLP(nn.Module): + """Applies the GLU at the end of each ModernBERT layer. + + Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate` + and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality. + """ + + def __init__(self, config: ModernBertConfig): + super().__init__() + self.config = config + self.Wi = nn.Linear(config.hidden_size, int(config.intermediate_size) * 2, bias=config.mlp_bias) + self.act = ACT2FN[config.hidden_activation] + self.drop = nn.Dropout(config.mlp_dropout) if config.mlp_dropout > 0.0 else nn.Identity() + self.Wo = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_bias) + + def _init_weights(self, reset_params: bool = False): + _init_modernbert_weights(self.config, self.Wi, module_type=ModernBertModuleType.in_module) + _init_modernbert_weights(self.config, self.Wo, module_type=ModernBertModuleType.out_module) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input, gate = self.Wi(hidden_states).chunk(2, dim=-1) + return self.Wo(self.drop(self.act(input) * gate)) + + +class ModernBertRotaryEmbedding(GemmaRotaryEmbedding): + pass + + +# def eager_attention_forward( +# config: ModernBertConfig, +# query: torch.Tensor, +# key: torch.Tensor, +# value: torch.Tensor, +# mask: Optional[torch.Tensor], +# **_kwargs, +# ) -> Tuple[torch.Tensor, torch.Tensor]: +# attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * config.scaling + +# if config.attn_logit_softcapping is not None: +# attn_weights = attn_weights / config.attn_logit_softcapping +# attn_weights = torch.tanh(attn_weights) +# attn_weights = attn_weights * config.attn_logit_softcapping +# if mask is not None: # no matter the length, we just slice it +# causal_mask = mask[:, :, :, : key_states.shape[-2]] +# attn_weights = attn_weights + causal_mask + +# # upcast attention to fp32 +# attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) +# attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=config.training) +# attn_output = torch.matmul(attn_weights, value_states) +# attn_output = attn_output.transpose(1, 2).contiguous() +# return attn_output, attn_weights + + +def flash_attention_forward( + config: "ModernBertAttention", + qkv: torch.Tensor, + rotary_emb: ModernBertUnpaddedRotaryEmbedding, + cu_seqlens: torch.Tensor, + max_seqlen: int, + local_attention: Tuple[int, int], + bs: int, + dim: int, + target_dtype: torch.dtype = torch.bfloat16, + **_kwargs, +) -> torch.Tensor: + # (total_seqlen, 3, nheads, headdim) + qkv = rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + + convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16) + if convert_dtype: + # FA2 implementation only supports fp16 and bf16. If FA2 is supported, + # bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported) + orig_dtype = qkv.dtype + qkv = qkv.to(target_dtype) + + attn = flash_attn_varlen_qkvpacked_func( + qkv, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + dropout_p=config.attention_dropout if config.training else 0.0, + deterministic=config.deterministic_flash_attn, + window_size=local_attention, + ) + attn = attn.to(orig_dtype) # type: ignore + else: + attn = flash_attn_varlen_qkvpacked_func( + qkv, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + dropout_p=config.attention_dropout if config.training else 0.0, + deterministic=config.deterministic_flash_attn, + window_size=local_attention, + ) + return attn.view(bs, dim) + + +# def flex_attention_forward( +# config: ModernBertConfig, +# query: torch.Tensor, +# key: torch.Tensor, +# value: torch.Tensor, +# mask: Optional[torch.Tensor], +# output_attentions: bool = False, +# **_kwargs, +# ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + +# attn_output = flex_attention( +# query, +# key, +# value, +# enable_gqa=True, +# scale=config.scaling, +# return_lse=output_attentions, +# ) +# if not output_attentions: +# attn_weights = None +# else: +# attn_output, attn_weights = attn_output + +# attn_output = attn_output.transpose(1, 2).contiguous() +# return attn_output, attn_weights + + +# def sdpa_attention_forward( +# config: ModernBertConfig, +# query: torch.Tensor, +# key: torch.Tensor, +# value: torch.Tensor, +# mask: Optional[torch.Tensor], +# **_kwargs, +# ) -> Tuple[torch.Tensor, None]: + +# causal_mask = mask +# if mask is not None: +# causal_mask = causal_mask[:, :, :, : key.shape[-2]] + +# # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, +# # Reference: https://github.com/pytorch/pytorch/issues/112577. +# if query.device.type == "cuda" and causal_mask is not None: +# query = query.contiguous() +# key = key.contiguous() +# value = value.contiguous() + +# attn_output = torch.nn.functional.scaled_dot_product_attention( +# query, +# key, +# value, +# attn_mask=causal_mask, +# dropout_p=config.attention_dropout if config.training else 0.0, +# is_causal=False, +# scale=config.scaling, +# ) +# attn_output = attn_output.transpose(1, 2).contiguous() +# return attn_output, None + + +MODERNBERT_ATTENTION_FUNCTION = { + "flash_attention_2": flash_attention_forward, + # "flex_attention": flex_attention_forward, + # "eager": eager_attention_forward, + # "sdpa": sdpa_attention_forward, +} + + +class ModernBertAttention(nn.Module): + """Performs multi-headed self attention on a batch of unpadded sequences. + + If Flash Attention 2 is installed, this module uses Flash Attention to improve throughput. + If Flash Attention 2 is not installed, the implementation will use PyTorch's SDPA kernel, + which requires padding and unpadding inputs, adding some overhead. + + See `forward` method for additional details. + """ + + def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): + super().__init__() + self.config = config + self.layer_id = layer_id + + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention heads ({config.num_attention_heads})" + ) + + self.attention_dropout = config.attention_dropout + self.deterministic_flash_attn = config.deterministic_flash_attn + self.num_heads = config.num_attention_heads + self.head_dim = config.hidden_size // config.num_attention_heads + self.Wqkv = nn.Linear(config.hidden_size, 3 * self.head_dim * self.num_heads, bias=config.attention_bias) + + if layer_id % config.global_attn_every_n_layers != 0: + self.local_attention = (config.local_attention // 2, config.local_attention // 2) + else: + self.local_attention = (-1, -1) + + rope_theta = config.global_rope_theta + max_position_embeddings = config.max_position_embeddings + if self.local_attention != (-1, -1): + if config.local_rope_theta is not None: + rope_theta = config.local_rope_theta + max_position_embeddings = config.local_attention + + if config._attn_implementation == "flash_attention_2": + self.rotary_emb = ModernBertUnpaddedRotaryEmbedding( + dim=self.head_dim, max_seqlen=max_position_embeddings, base=rope_theta + ) + else: + self.rotary_emb = ModernBertRotaryEmbedding( + dim=self.head_dim, max_position_embeddings=max_position_embeddings, base=rope_theta + ) + + self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias) + self.out_drop = nn.Dropout(config.attn_out_dropout) if config.attn_out_dropout > 0.0 else nn.Identity() + + def _init_weights(self, reset_params: bool = False): + _init_modernbert_weights(self.config, self.Wqkv, module_type=ModernBertModuleType.in_module) + _init_modernbert_weights(self.config, self.Wo, module_type=ModernBertModuleType.out_module) + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + **kwargs, + ) -> torch.Tensor: + """Perform self-attention. + + There are two attention implementations supported: PyTorch's SDPA attention and Flash Attention 2. + + The arguments are unpadded. The SDPA implementation of attention requires padded arguments while the + Flash Attention implementation does not. If using SDPA we first call `pad_input`. Once we compute + attention, we re-unpad our outputs for the other layers. The pad/unpad operations add overhead, but not + sending pad tokens through ffs saves compute. + + Args: + hidden_states: (total_nnz, dim) + cu_seqlens: (batch + 1,) + max_seqlen: int + indices: (total_nnz,) + attn_mask: (batch, max_seqlen) + + Returns: + attention: (total_nnz, dim) + """ + bs, dim = hidden_states.shape[:2] + qkv = self.Wqkv(hidden_states) + qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) + + attn = MODERNBERT_ATTENTION_FUNCTION[self.config._attn_implementation]( + self, + qkv=qkv, + rotary_emb=self.rotary_emb, + position_ids=position_ids, + attention_mask=attention_mask, + local_attention=self.local_attention, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + bs=bs, + dim=dim, + ) + + return self.out_drop(self.Wo(attn)) + + +class ModernBertFlashAttention2(ModernBertAttention): + def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): + super().__init__(config, layer_id) + self.config._attn_implementation = "flash_attention_2" + logger.warning_once( + "The `ModernBertFlashAttention2` class is deprecated in favor of simply modifying the `config._attn_implementation`" + "attribute of the `GemmaAttention` class! It will be removed in v4.48" + ) + + +class ModernBertSdpaAttention(ModernBertAttention): + def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): + super().__init__(config, layer_id) + self.config._attn_implementation = "sdpa" + logger.warning_once( + "The `ModernBertFlashAttention2` class is deprecated in favor of simply modifying the `config._attn_implementation`" + "attribute of the `GemmaAttention` class! It will be removed in v4.48" + ) + + +class ModernBertEncoderLayer(nn.Module): + def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): + super().__init__() + self.config = config + if config.skip_first_prenorm and config.embedding_norm and layer_id == 0: + self.attn_norm = nn.Identity() + else: + self.attn_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) + self.attn = ModernBertAttention(config=config, layer_id=layer_id) + self.mlp_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) + self.mlp = ModernBertMLP(config) + + def _init_weights(self, reset_params: bool = False): + if reset_params: + self.attn_norm.reset_parameters() + self.mlp_norm.reset_parameters() + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + ) -> torch.Tensor: + """Forward pass for a ModernBert layer, including both attention and MLP. + + Args: + hidden_states: (total_nnz, dim) + position_ids: (total_nnz,) + attention_mask: (batch, max_seqlen) + cu_seqlens: (batch + 1,) + max_seqlen: int + """ + attn_out = hidden_states + self.attn( + self.attn_norm(hidden_states), + position_ids=position_ids, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + attention_mask=attention_mask, + ) + return attn_out + self.mlp(self.mlp_norm(attn_out)) + + +class ModernBertPredictionHead(nn.Module): + def __init__(self, config: ModernBertConfig): + super().__init__() + self.config = config + self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.classifier_bias) + self.act = ACT2FN[config.classifier_activation] if config.classifier_activation else nn.Identity() + self.norm = ( + nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) + if config.classifier_norm + else nn.Identity() + ) + + def _init_weights(self, reset_params: bool = False): + if reset_params: + self.norm.reset_parameters() + _init_modernbert_weights(self.config, self.dense, module_type=ModernBertModuleType.in_module) + + def reset_parameters(self): + self._init_weights(reset_params=True) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.norm(self.act(self.dense(hidden_states))) + + +class ModernBertPoolingHead(nn.Module): + def __init__(self, config: ModernBertConfig): + super().__init__() + self.config = config + self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.classifier_bias) + self.act = ACT2FN[config.classifier_activation] if config.classifier_activation else nn.Identity() + self.norm = ( + nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) + if config.classifier_norm + else nn.Identity() + ) + self.drop = torch.nn.Dropout(config.classifier_dropout) if config.classifier_dropout > 0 else nn.Identity() + self.pooling_type = ModernBertPoolingType(config.classifier_pooling) + + def forward(self, hidden_states: torch.Tensor, pool: Optional[bool] = True) -> torch.Tensor: + if pool: + if self.pooling_type == ModernBertPoolingType.cls: + output = hidden_states[:, 0] + elif self.pooling_type == ModernBertPoolingType.mean: + output = hidden_states.mean(dim=1) + elif self.pooling_type == ModernBertPoolingType.max: + output = hidden_states.max(dim=1)[0] + else: + output = hidden_states + + return self.drop(self.norm(self.act(self.dense(output)))) + + def _init_weights(self, reset_params: bool = False): + _init_modernbert_weights(self.config, self.dense, module_type=ModernBertModuleType.out_module) + if reset_params and hasattr(self.norm, "reset_parameters"): + self.norm.reset_parameters() + + def reset_parameters(self): + self._init_weights(reset_params=True) + + +class ModernBertPreTrainedModel(PreTrainedModel): + config_class = ModernBertConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["ModernBertEmbeddings", "ModernBertEncoderLayer"] + _supports_flash_attn_2 = True + _supports_sdpa = False # TODO: Enable SDPA + + def _init_weights( + self, + module: Union[ModernBertEncoderLayer, ModernBertAttention, ModernBertMLP, ModernBertEmbeddings], + reset_params: bool = False, + ): + module._init_weights(reset_params) + + @torch.no_grad() + def _unpad_inputs_no_grad( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + position_ids: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + ): + return self._unpad_inputs( + input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels + ) + + def _unpad_inputs( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + position_ids: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + ): + return _unpad_modernbert_input( + inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels + ) + + @torch.no_grad() + def _pad_outputs_no_grad( + self, + inputs: torch.Tensor, + indices: torch.Tensor, + batch_size: int, + seqlen: int, + labels: Optional[torch.Tensor] = None, + ignore_index: int = -100, + ): + return self._pad_outputs( + inputs=inputs, + indices=indices, + batch_size=batch_size, + seqlen=seqlen, + labels=labels, + ignore_index=ignore_index, + ) + + def _pad_outputs( + self, + inputs: torch.Tensor, + indices: torch.Tensor, + batch_size: int, + seqlen: int, + labels: Optional[torch.Tensor] = None, + ignore_index: int = -100, + ): + return _pad_modernbert_output( + inputs=inputs, indices=indices, batch=batch_size, seqlen=seqlen, labels=labels, ignore_index=ignore_index + ) + + +class ModernBertModel(ModernBertPreTrainedModel): + def __init__(self, config: ModernBertConfig): + super().__init__(config) + self.config = config + self.embeddings = ModernBertEmbeddings(config) + self.layers = nn.ModuleList( + [ModernBertEncoderLayer(config, layer_id) for layer_id in range(config.num_hidden_layers)] + ) + self.final_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) + self.gradient_checkpointing = False + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.tok_embeddings + + def set_input_embeddings(self, value): + self.embeddings.tok_embeddings = value + + def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None): + if module and hasattr(module, "_init_weights"): + super()._init_weights(module, reset_params) + elif isinstance(reset_params, bool): + self.embeddings._init_weights(reset_params=reset_params) + + if reset_params: + self.final_norm.reset_parameters() + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + return_dict: Optional[bool] = None, + indices: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + batch_size: Optional[int] = None, + seq_len: Optional[int] = None, + ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutput]: + if batch_size is None and seq_len is None: + batch_size, seq_len = input_ids.shape[:2] + + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_len), device=input_ids.device, dtype=torch.bool) + + repad = False + if self.config.unpad_inputs: + if indices is None and cu_seqlens is None and max_seqlen is None: + repad = True + if self.config.unpad_no_grad: + input_ids, indices, cu_seqlens, max_seqlen, *_ = self._unpad_inputs_no_grad( + input_ids, attention_mask + ) + else: + input_ids, indices, cu_seqlens, max_seqlen, *_ = self._unpad_inputs(input_ids, attention_mask) + elif position_ids is None: + position_ids = torch.arange(seq_len, device=input_ids.device) + + hidden_states = self.embeddings(input_ids) + + for encoder_layer in self.layers: + if self.gradient_checkpointing and self.training: + hidden_states = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + cu_seqlens, + max_seqlen, + ) + else: + hidden_states = encoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + + hidden_states = self.final_norm(hidden_states) + + if repad: + hidden_states = self._pad_outputs(hidden_states, indices, batch_size, seq_len) + + if not return_dict: + return hidden_states + return BaseModelOutput(last_hidden_state=hidden_states) + + +class ModernBertForMaskedLM(ModernBertPreTrainedModel): + def __init__( + self, + config: ModernBertConfig, + pad_logits: bool = True, + pad_logits_no_grad: Optional[bool] = None, + sparse_prediction: bool = True, + sparse_pred_ignore_index: int = -100, + ): + super().__init__(config) + self.config = config + self.bert = ModernBertModel(config) + self.head = ModernBertPredictionHead(config) + + if config.tie_word_embeddings: + decoder_weights = self.bert.embeddings.tok_embeddings.weight + else: + decoder_weights = nn.Linear(config.hidden_size, config.vocab_size, bias=False).weight + self.decoder = nn.Linear(decoder_weights.size(1), decoder_weights.size(0), bias=config.decoder_bias) + self.decoder.weight = decoder_weights + + self.pad_logits = pad_logits + self.pad_logits_no_grad = pad_logits_no_grad if pad_logits_no_grad is not None else self.config.unpad_no_grad + self.sparse_prediction = sparse_prediction + self.sparse_pred_ignore_index = sparse_pred_ignore_index + # Initialize weights and apply final processing + self._init_weights(reset_params=False) + + def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None): + assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified" + if module: + super()._init_weights(module) + else: + assert isinstance(reset_params, bool) + self.bert._init_weights(reset_params=reset_params) + self.head._init_weights(reset_params=reset_params) + + # Output weights. + if not self.config.tie_word_embeddings: + _init_modernbert_weights( + self.config, self.decoder, self.config.hidden_size, module_type=ModernBertModuleType.out_module + ) + + def get_output_embeddings(self): + return self.decoder + + def set_output_embeddings(self, new_embeddings: nn.Linear): + self.decoder = new_embeddings + + def forward( + self, + input_ids: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + return_dict: Optional[bool] = None, + indices: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + batch_size: Optional[int] = None, + seq_len: Optional[int] = None, + **kwargs, + ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.unpad_inputs: + if indices is None and cu_seqlens is None and max_seqlen is None: + batch_size, seq_len = input_ids.shape[:2] + if self.config.unpad_no_grad: + input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = self._unpad_inputs_no_grad( + input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels + ) + else: + input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = self._unpad_inputs( + input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels + ) + + output = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + batch_size=batch_size, + seq_len=seq_len, + ) + + if self.sparse_prediction and labels is not None: + # flatten labels and output first + labels = labels.view(-1) + output = output.view(labels.shape[0], -1) + + # then filter out the non-masked tokens + mask_tokens = labels != self.sparse_pred_ignore_index + output = output[mask_tokens] + labels = labels[mask_tokens] + + logits = self.decoder(self.head(output)) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size) + + if self.pad_logits: + if self.pad_logits_no_grad: + logits = self._pad_outputs_no_grad(logits, indices, batch_size, seq_len)[0] + else: + logits = self._pad_outputs(logits, indices, batch_size, seq_len)[0] + return MaskedLMOutput( + loss=loss, + logits=logits, + hidden_states=None, + attentions=None, + ) + else: + return MaskedLMOutput( + loss=loss, + logits=logits, + hidden_states=None, + attentions=None, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + batch_size=batch_size, + seq_len=seq_len, + labels=labels, + ) diff --git a/tests/models/modernbert/__init__.py b/tests/models/modernbert/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 From dafb203c2b5c469f6effe09e2e912d6d8b102eb5 Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Tue, 10 Dec 2024 16:26:50 -0600 Subject: [PATCH 02/88] small bug fixes --- src/transformers/models/auto/tokenization_auto.py | 2 +- src/transformers/models/modernbert/modeling_modernbert.py | 6 ++++-- src/transformers/models/modernbert/modular_modernbert.py | 6 ++++-- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index cce23de13dbcfa..c12948db48555d 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -203,7 +203,7 @@ "GemmaTokenizerFast" if is_tokenizers_available() else None, ), ), - ("modernbert", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)), + ("modernbert", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), ("git", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), ("glm", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), ("gpt-sw3", ("GPTSw3Tokenizer" if is_sentencepiece_available() else None, None)), diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index d5cbd8c0a94c87..0adb149d1b13d9 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -919,6 +919,8 @@ def forward( batch_size: Optional[int] = None, seq_len: Optional[int] = None, ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if batch_size is None and seq_len is None: batch_size, seq_len = input_ids.shape[:2] @@ -1075,9 +1077,9 @@ def forward( if self.pad_logits: if self.pad_logits_no_grad: - logits = self._pad_outputs_no_grad(logits, indices, batch_size, seq_len)[0] + logits, _ = self._pad_outputs_no_grad(logits, indices, batch_size, seq_len) else: - logits = self._pad_outputs(logits, indices, batch_size, seq_len)[0] + logits, _ = self._pad_outputs(logits, indices, batch_size, seq_len) return MaskedLMOutput( loss=loss, logits=logits, diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index dc9ba007ef71cd..8cc6b887f191c8 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -1069,6 +1069,8 @@ def forward( batch_size: Optional[int] = None, seq_len: Optional[int] = None, ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if batch_size is None and seq_len is None: batch_size, seq_len = input_ids.shape[:2] @@ -1225,9 +1227,9 @@ def forward( if self.pad_logits: if self.pad_logits_no_grad: - logits = self._pad_outputs_no_grad(logits, indices, batch_size, seq_len)[0] + logits, _ = self._pad_outputs_no_grad(logits, indices, batch_size, seq_len) else: - logits = self._pad_outputs(logits, indices, batch_size, seq_len)[0] + logits, _ = self._pad_outputs(logits, indices, batch_size, seq_len) return MaskedLMOutput( loss=loss, logits=logits, From df13def437cf4ebf21125fa1e14c8d4a2d8c1634 Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Tue, 10 Dec 2024 21:48:17 -0600 Subject: [PATCH 03/88] fixes --- .../modernbert/configuration_modernbert.py | 2 +- .../models/modernbert/modeling_modernbert.py | 36 ++++--------------- .../models/modernbert/modular_modernbert.py | 34 ++++-------------- 3 files changed, 15 insertions(+), 57 deletions(-) diff --git a/src/transformers/models/modernbert/configuration_modernbert.py b/src/transformers/models/modernbert/configuration_modernbert.py index f030f80c4d2403..296bae3da104a0 100644 --- a/src/transformers/models/modernbert/configuration_modernbert.py +++ b/src/transformers/models/modernbert/configuration_modernbert.py @@ -106,7 +106,7 @@ def __init__( intermediate_size=1152, num_hidden_layers=22, num_attention_heads=12, - hidden_activation="gelu_python", + hidden_activation="gelu", max_position_embeddings=8192, initializer_range=0.02, initalizer_cutoff_factor=2.0, diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index 0adb149d1b13d9..2543f9bff500a0 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -30,7 +30,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, MaskedLMOutput from ...modeling_utils import PreTrainedModel -from ...utils import is_flash_attn_2_available, logging +from ...utils import is_flash_attn_2_available from .configuration_modernbert import ModernBertConfig @@ -39,8 +39,6 @@ from flash_attn.layers.rotary import RotaryEmbedding from flash_attn.ops.triton.rotary import apply_rotary -logger = logging.get_logger(__name__) - class ModernBertModuleType(str, Enum): in_module = "in" @@ -383,7 +381,7 @@ def forward(self, x, position_ids, seq_len=None): def flash_attention_forward( - config: "ModernBertAttention", + self: "ModernBertAttention", qkv: torch.Tensor, rotary_emb: ModernBertUnpaddedRotaryEmbedding, cu_seqlens: torch.Tensor, @@ -408,8 +406,8 @@ def flash_attention_forward( qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, - dropout_p=config.attention_dropout if config.training else 0.0, - deterministic=config.deterministic_flash_attn, + dropout_p=self.attention_dropout if self.training else 0.0, + deterministic=self.deterministic_flash_attn, window_size=local_attention, ) attn = attn.to(orig_dtype) # type: ignore @@ -418,8 +416,8 @@ def flash_attention_forward( qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, - dropout_p=config.attention_dropout if config.training else 0.0, - deterministic=config.deterministic_flash_attn, + dropout_p=self.attention_dropout if self.training else 0.0, + deterministic=self.deterministic_flash_attn, window_size=local_attention, ) return attn.view(bs, dim) @@ -595,26 +593,6 @@ def forward( return self.out_drop(self.Wo(attn)) -class ModernBertFlashAttention2(ModernBertAttention): - def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): - super().__init__(config, layer_id) - self.config._attn_implementation = "flash_attention_2" - logger.warning_once( - "The `ModernBertFlashAttention2` class is deprecated in favor of simply modifying the `config._attn_implementation`" - "attribute of the `GemmaAttention` class! It will be removed in v4.48" - ) - - -class ModernBertSdpaAttention(ModernBertAttention): - def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): - super().__init__(config, layer_id) - self.config._attn_implementation = "sdpa" - logger.warning_once( - "The `ModernBertFlashAttention2` class is deprecated in favor of simply modifying the `config._attn_implementation`" - "attribute of the `GemmaAttention` class! It will be removed in v4.48" - ) - - class ModernBertEncoderLayer(nn.Module): def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): super().__init__() @@ -964,7 +942,7 @@ def forward( hidden_states = self.final_norm(hidden_states) if repad: - hidden_states = self._pad_outputs(hidden_states, indices, batch_size, seq_len) + hidden_states, _ = self._pad_outputs(hidden_states, indices, batch_size, seq_len) if not return_dict: return hidden_states diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index 8cc6b887f191c8..db1ddc733b1cde 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -138,7 +138,7 @@ def __init__( intermediate_size=1152, num_hidden_layers=22, num_attention_heads=12, - hidden_activation="gelu_python", + hidden_activation="gelu", max_position_embeddings=8192, initializer_range=0.02, initalizer_cutoff_factor=2.0, @@ -623,7 +623,7 @@ class ModernBertRotaryEmbedding(GemmaRotaryEmbedding): def flash_attention_forward( - config: "ModernBertAttention", + self: "ModernBertAttention", qkv: torch.Tensor, rotary_emb: ModernBertUnpaddedRotaryEmbedding, cu_seqlens: torch.Tensor, @@ -648,8 +648,8 @@ def flash_attention_forward( qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, - dropout_p=config.attention_dropout if config.training else 0.0, - deterministic=config.deterministic_flash_attn, + dropout_p=self.attention_dropout if self.training else 0.0, + deterministic=self.deterministic_flash_attn, window_size=local_attention, ) attn = attn.to(orig_dtype) # type: ignore @@ -658,8 +658,8 @@ def flash_attention_forward( qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, - dropout_p=config.attention_dropout if config.training else 0.0, - deterministic=config.deterministic_flash_attn, + dropout_p=self.attention_dropout if self.training else 0.0, + deterministic=self.deterministic_flash_attn, window_size=local_attention, ) return attn.view(bs, dim) @@ -835,26 +835,6 @@ def forward( return self.out_drop(self.Wo(attn)) -class ModernBertFlashAttention2(ModernBertAttention): - def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): - super().__init__(config, layer_id) - self.config._attn_implementation = "flash_attention_2" - logger.warning_once( - "The `ModernBertFlashAttention2` class is deprecated in favor of simply modifying the `config._attn_implementation`" - "attribute of the `GemmaAttention` class! It will be removed in v4.48" - ) - - -class ModernBertSdpaAttention(ModernBertAttention): - def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): - super().__init__(config, layer_id) - self.config._attn_implementation = "sdpa" - logger.warning_once( - "The `ModernBertFlashAttention2` class is deprecated in favor of simply modifying the `config._attn_implementation`" - "attribute of the `GemmaAttention` class! It will be removed in v4.48" - ) - - class ModernBertEncoderLayer(nn.Module): def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): super().__init__() @@ -1114,7 +1094,7 @@ def forward( hidden_states = self.final_norm(hidden_states) if repad: - hidden_states = self._pad_outputs(hidden_states, indices, batch_size, seq_len) + hidden_states, _ = self._pad_outputs(hidden_states, indices, batch_size, seq_len) if not return_dict: return hidden_states From d09eabf3b4d12b25f9cb00c128e1be820d5673be Mon Sep 17 00:00:00 2001 From: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com> Date: Wed, 11 Dec 2024 13:18:46 +0100 Subject: [PATCH 04/88] Update import --- src/transformers/models/modernbert/modular_modernbert.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index db1ddc733b1cde..53bba7c1aec21c 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -19,7 +19,7 @@ from typing import Optional, Tuple, Union import torch -import torch.nn as nn +from torch import nn import torch.utils.checkpoint from ...activations import ACT2FN From 8c3afea0db04cf078825ad4f980336c0826b480a Mon Sep 17 00:00:00 2001 From: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com> Date: Wed, 11 Dec 2024 13:19:11 +0100 Subject: [PATCH 05/88] Use compiled mlp->mlp_norm to match research implementation --- src/transformers/models/modernbert/modular_modernbert.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index 53bba7c1aec21c..9a7dd06beae024 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -852,6 +852,10 @@ def _init_weights(self, reset_params: bool = False): self.attn_norm.reset_parameters() self.mlp_norm.reset_parameters() + @torch.compile(dynamic=True) + def compiled_mlp(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.mlp(self.mlp_norm(hidden_states)) + def forward( self, hidden_states: torch.Tensor, @@ -876,7 +880,7 @@ def forward( max_seqlen=max_seqlen, attention_mask=attention_mask, ) - return attn_out + self.mlp(self.mlp_norm(attn_out)) + return attn_out + self.compiled_mlp(attn_out) class ModernBertPredictionHead(nn.Module): From a40aaa9203a557153644a3d9b12b9024739d3565 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Wed, 11 Dec 2024 13:20:44 +0100 Subject: [PATCH 06/88] Propagate changes in modular to modeling --- src/transformers/models/modernbert/modeling_modernbert.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index 2543f9bff500a0..a4b38c9eae3733 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -25,7 +25,7 @@ from typing import Optional, Tuple, Union import torch -import torch.nn as nn +from torch import nn from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, MaskedLMOutput @@ -610,6 +610,10 @@ def _init_weights(self, reset_params: bool = False): self.attn_norm.reset_parameters() self.mlp_norm.reset_parameters() + @torch.compile(dynamic=True) + def compiled_mlp(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.mlp(self.mlp_norm(hidden_states)) + def forward( self, hidden_states: torch.Tensor, @@ -634,7 +638,7 @@ def forward( max_seqlen=max_seqlen, attention_mask=attention_mask, ) - return attn_out + self.mlp(self.mlp_norm(attn_out)) + return attn_out + self.compiled_mlp(attn_out) class ModernBertPredictionHead(nn.Module): From 9f0b8ca8d6fe2a9a6bd2c2e3c268fec707dea36e Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Wed, 11 Dec 2024 13:30:31 +0100 Subject: [PATCH 07/88] Replace duplicate attn_out_dropout in favor of attention_dropout cc @warner-benjamin let me know if the two should remain separate! --- .../models/modernbert/configuration_modernbert.py | 2 -- src/transformers/models/modernbert/modeling_modernbert.py | 2 +- src/transformers/models/modernbert/modular_modernbert.py | 4 +--- 3 files changed, 2 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/modernbert/configuration_modernbert.py b/src/transformers/models/modernbert/configuration_modernbert.py index 296bae3da104a0..1e16cb9e45c86d 100644 --- a/src/transformers/models/modernbert/configuration_modernbert.py +++ b/src/transformers/models/modernbert/configuration_modernbert.py @@ -121,7 +121,6 @@ def __init__( global_rope_theta=160000.0, attention_bias=False, attention_dropout=0.0, - attn_out_dropout=0.1, global_attn_every_n_layers=3, local_attention=128, local_rope_theta=10000.0, @@ -163,7 +162,6 @@ def __init__( self.global_rope_theta = global_rope_theta self.attention_bias = attention_bias self.attention_dropout = attention_dropout - self.attn_out_dropout = attn_out_dropout self.hidden_activation = hidden_activation self.global_attn_every_n_layers = global_attn_every_n_layers self.local_attention = local_attention diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index a4b38c9eae3733..9d3f4a7e7a77b5 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -539,7 +539,7 @@ def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): ) self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias) - self.out_drop = nn.Dropout(config.attn_out_dropout) if config.attn_out_dropout > 0.0 else nn.Identity() + self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity() def _init_weights(self, reset_params: bool = False): _init_modernbert_weights(self.config, self.Wqkv, module_type=ModernBertModuleType.in_module) diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index 9a7dd06beae024..78bceafd5015cb 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -153,7 +153,6 @@ def __init__( global_rope_theta=160000.0, attention_bias=False, attention_dropout=0.0, - attn_out_dropout=0.1, global_attn_every_n_layers=3, local_attention=128, local_rope_theta=10000.0, @@ -195,7 +194,6 @@ def __init__( self.global_rope_theta = global_rope_theta self.attention_bias = attention_bias self.attention_dropout = attention_dropout - self.attn_out_dropout = attn_out_dropout self.hidden_activation = hidden_activation self.global_attn_every_n_layers = global_attn_every_n_layers self.local_attention = local_attention @@ -781,7 +779,7 @@ def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): ) self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias) - self.out_drop = nn.Dropout(config.attn_out_dropout) if config.attn_out_dropout > 0.0 else nn.Identity() + self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity() def _init_weights(self, reset_params: bool = False): _init_modernbert_weights(self.config, self.Wqkv, module_type=ModernBertModuleType.in_module) From 900d8ec8ff51c5420e5c001692962b2768a4e016 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Wed, 11 Dec 2024 13:42:10 +0100 Subject: [PATCH 08/88] Update BOS to CLS and EOS to SEP Please confirm @warner-benjamin --- .../models/modernbert/configuration_modernbert.py | 4 ++-- src/transformers/models/modernbert/modular_modernbert.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/modernbert/configuration_modernbert.py b/src/transformers/models/modernbert/configuration_modernbert.py index 1e16cb9e45c86d..23e21cb61f98d4 100644 --- a/src/transformers/models/modernbert/configuration_modernbert.py +++ b/src/transformers/models/modernbert/configuration_modernbert.py @@ -113,8 +113,8 @@ def __init__( norm_eps=1e-5, norm_bias=False, pad_token_id=50283, - eos_token_id=50281, - bos_token_id=50282, + eos_token_id=50282, + bos_token_id=50281, cls_token_id=50281, sep_token_id=50282, tie_word_embeddings=True, diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index 78bceafd5015cb..6b378665623cd0 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -145,8 +145,8 @@ def __init__( norm_eps=1e-5, norm_bias=False, pad_token_id=50283, - eos_token_id=50281, - bos_token_id=50282, + eos_token_id=50282, + bos_token_id=50281, cls_token_id=50281, sep_token_id=50282, tie_word_embeddings=True, From caf8901824677fda1aceba7fd5bc23feaeff9035 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Wed, 11 Dec 2024 14:07:08 +0100 Subject: [PATCH 09/88] Set default classifier bias to False, matching research repo --- src/transformers/models/modernbert/configuration_modernbert.py | 2 +- src/transformers/models/modernbert/modular_modernbert.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/modernbert/configuration_modernbert.py b/src/transformers/models/modernbert/configuration_modernbert.py index 23e21cb61f98d4..ab01c53ccf9f1c 100644 --- a/src/transformers/models/modernbert/configuration_modernbert.py +++ b/src/transformers/models/modernbert/configuration_modernbert.py @@ -135,7 +135,7 @@ def __init__( classifier_dropout=0.0, classifier_pooling="mean", classifier_norm=True, - classifier_bias=True, + classifier_bias=False, classifier_activation=None, deterministic_flash_attn=False, **kwargs, diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index 6b378665623cd0..db33389538d2e3 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -167,7 +167,7 @@ def __init__( classifier_dropout=0.0, classifier_pooling="mean", classifier_norm=True, - classifier_bias=True, + classifier_bias=False, classifier_activation=None, deterministic_flash_attn=False, **kwargs, From 82766022df7571e9bf72feec7020b7fd3fac5137 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Wed, 11 Dec 2024 14:14:32 +0100 Subject: [PATCH 10/88] Update tie_word_embeddings description --- src/transformers/models/modernbert/configuration_modernbert.py | 3 ++- src/transformers/models/modernbert/modular_modernbert.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/modernbert/configuration_modernbert.py b/src/transformers/models/modernbert/configuration_modernbert.py index ab01c53ccf9f1c..3a3e5e2fda691e 100644 --- a/src/transformers/models/modernbert/configuration_modernbert.py +++ b/src/transformers/models/modernbert/configuration_modernbert.py @@ -72,7 +72,8 @@ class ModernBertConfig(PretrainedConfig): bos_token_id (`int`, *optional*, defaults to 2): Beginning of stream token id. tie_word_embeddings (`bool`, *optional*, defaults to `True`): - Whether to tie weight embeddings + Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the + model has a output word embedding layer. rope_theta (`float`, *optional*, defaults to 10000.0): The base period of the RoPE embeddings. attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index db33389538d2e3..b4b0cbbb383cdd 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -104,7 +104,8 @@ class ModernBertConfig(PretrainedConfig): bos_token_id (`int`, *optional*, defaults to 2): Beginning of stream token id. tie_word_embeddings (`bool`, *optional*, defaults to `True`): - Whether to tie weight embeddings + Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the + model has a output word embedding layer. rope_theta (`float`, *optional*, defaults to 10000.0): The base period of the RoPE embeddings. attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): From 79e4bbb8bed8e49b9c46d6bb35c3d6faacb85d2c Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Wed, 11 Dec 2024 14:15:31 +0100 Subject: [PATCH 11/88] Fix _init_weights for ForMaskedLM --- src/transformers/models/modernbert/modeling_modernbert.py | 4 +--- src/transformers/models/modernbert/modular_modernbert.py | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index 9d3f4a7e7a77b5..cf1410638a5960 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -992,9 +992,7 @@ def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Option # Output weights. if not self.config.tie_word_embeddings: - _init_modernbert_weights( - self.config, self.decoder, self.config.hidden_size, module_type=ModernBertModuleType.out_module - ) + _init_modernbert_weights(self.config, self.decoder, module_type=ModernBertModuleType.out_module) def get_output_embeddings(self): return self.decoder diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index b4b0cbbb383cdd..72eb8bd158aff0 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -1143,9 +1143,7 @@ def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Option # Output weights. if not self.config.tie_word_embeddings: - _init_modernbert_weights( - self.config, self.decoder, self.config.hidden_size, module_type=ModernBertModuleType.out_module - ) + _init_modernbert_weights(self.config, self.decoder, module_type=ModernBertModuleType.out_module) def get_output_embeddings(self): return self.decoder From b59bad978f1a0ab5faf96a7d2d45c407080d9f37 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Wed, 11 Dec 2024 15:43:12 +0100 Subject: [PATCH 12/88] Match base_model_prefix --- src/transformers/models/modernbert/modeling_modernbert.py | 8 ++++---- src/transformers/models/modernbert/modular_modernbert.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index cf1410638a5960..8548446ba23af7 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -964,11 +964,11 @@ def __init__( ): super().__init__(config) self.config = config - self.bert = ModernBertModel(config) + self.model = ModernBertModel(config) self.head = ModernBertPredictionHead(config) if config.tie_word_embeddings: - decoder_weights = self.bert.embeddings.tok_embeddings.weight + decoder_weights = self.model.embeddings.tok_embeddings.weight else: decoder_weights = nn.Linear(config.hidden_size, config.vocab_size, bias=False).weight self.decoder = nn.Linear(decoder_weights.size(1), decoder_weights.size(0), bias=config.decoder_bias) @@ -987,7 +987,7 @@ def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Option super()._init_weights(module) else: assert isinstance(reset_params, bool) - self.bert._init_weights(reset_params=reset_params) + self.model._init_weights(reset_params=reset_params) self.head._init_weights(reset_params=reset_params) # Output weights. @@ -1028,7 +1028,7 @@ def forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels ) - output = self.bert( + output = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index 72eb8bd158aff0..4074dcbaae4c24 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -1115,11 +1115,11 @@ def __init__( ): super().__init__(config) self.config = config - self.bert = ModernBertModel(config) + self.model = ModernBertModel(config) self.head = ModernBertPredictionHead(config) if config.tie_word_embeddings: - decoder_weights = self.bert.embeddings.tok_embeddings.weight + decoder_weights = self.model.embeddings.tok_embeddings.weight else: decoder_weights = nn.Linear(config.hidden_size, config.vocab_size, bias=False).weight self.decoder = nn.Linear(decoder_weights.size(1), decoder_weights.size(0), bias=config.decoder_bias) @@ -1138,7 +1138,7 @@ def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Option super()._init_weights(module) else: assert isinstance(reset_params, bool) - self.bert._init_weights(reset_params=reset_params) + self.model._init_weights(reset_params=reset_params) self.head._init_weights(reset_params=reset_params) # Output weights. @@ -1179,7 +1179,7 @@ def forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels ) - output = self.bert( + output = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, From e7bef536776680cb93fef66321fd30b0d68b976b Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Wed, 11 Dec 2024 15:44:36 +0100 Subject: [PATCH 13/88] Add compiled_head to match research repo outputs --- src/transformers/models/modernbert/modeling_modernbert.py | 8 +++++++- src/transformers/models/modernbert/modular_modernbert.py | 8 +++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index 8548446ba23af7..c926603aa2aa7a 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -1000,6 +1000,10 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings: nn.Linear): self.decoder = new_embeddings + @torch.compile(dynamic=True) + def compiled_head(self, output: torch.Tensor) -> torch.Tensor: + return self.decoder(self.head(output)) + def forward( self, input_ids: Optional[torch.Tensor], @@ -1037,7 +1041,9 @@ def forward( max_seqlen=max_seqlen, batch_size=batch_size, seq_len=seq_len, + return_dict=return_dict, ) + output = output[0] if self.sparse_prediction and labels is not None: # flatten labels and output first @@ -1049,7 +1055,7 @@ def forward( output = output[mask_tokens] labels = labels[mask_tokens] - logits = self.decoder(self.head(output)) + logits = self.compiled_head(output) loss = None if labels is not None: diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index 4074dcbaae4c24..cb921c4dab9cfb 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -1151,6 +1151,10 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings: nn.Linear): self.decoder = new_embeddings + @torch.compile(dynamic=True) + def compiled_head(self, output: torch.Tensor) -> torch.Tensor: + return self.decoder(self.head(output)) + def forward( self, input_ids: Optional[torch.Tensor], @@ -1188,7 +1192,9 @@ def forward( max_seqlen=max_seqlen, batch_size=batch_size, seq_len=seq_len, + return_dict=return_dict, ) + output = output[0] if self.sparse_prediction and labels is not None: # flatten labels and output first @@ -1200,7 +1206,7 @@ def forward( output = output[mask_tokens] labels = labels[mask_tokens] - logits = self.decoder(self.head(output)) + logits = self.compiled_head(output) loss = None if labels is not None: From 120578ba9146d40ef417f20a64c8d1df71288130 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Wed, 11 Dec 2024 15:51:41 +0100 Subject: [PATCH 14/88] Fix imports for ModernBertForMaskedLM --- src/transformers/__init__.py | 1 + src/transformers/models/auto/modeling_auto.py | 1 + src/transformers/models/modernbert/__init__.py | 8 ++------ 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index b70544f6788774..fe7197694be8d0 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -2302,6 +2302,7 @@ ) _import_structure["models.modernbert"].extend( [ + "ModernBertForMaskedLM", "ModernBertForCausalLM", "ModernBertForSequenceClassification", "ModernBertForTokenClassification", diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index fb08610df495bd..f36ae22755b67b 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -818,6 +818,7 @@ ("mega", "MegaForMaskedLM"), ("megatron-bert", "MegatronBertForMaskedLM"), ("mobilebert", "MobileBertForMaskedLM"), + ("modernbert", "ModernBertForMaskedLM"), ("mpnet", "MPNetForMaskedLM"), ("mra", "MraForMaskedLM"), ("mvp", "MvpForConditionalGeneration"), diff --git a/src/transformers/models/modernbert/__init__.py b/src/transformers/models/modernbert/__init__.py index 40de068897fa00..91e19878f84f30 100644 --- a/src/transformers/models/modernbert/__init__.py +++ b/src/transformers/models/modernbert/__init__.py @@ -31,11 +31,9 @@ pass else: _import_structure["modeling_modernbert"] = [ - "ModernBertForCausalLM", + "ModernBertForMaskedLM", "ModernBertModel", "ModernBertPreTrainedModel", - "ModernBertForSequenceClassification", - "ModernBertForTokenClassification", ] if TYPE_CHECKING: @@ -48,9 +46,7 @@ pass else: from .modeling_modernbert import ( - ModernBertForCausalLM, - ModernBertForSequenceClassification, - ModernBertForTokenClassification, + ModernBertForMaskedLM, ModernBertModel, ModernBertPreTrainedModel, ) From 142ff11ea201842b609acd23cbf58c3ac70bbfbb Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Wed, 11 Dec 2024 15:53:29 +0100 Subject: [PATCH 15/88] Just use "gelu" default outright for classifier --- .../models/modernbert/configuration_modernbert.py | 10 +++++----- .../models/modernbert/modular_modernbert.py | 10 +++++----- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/modernbert/configuration_modernbert.py b/src/transformers/models/modernbert/configuration_modernbert.py index 3a3e5e2fda691e..5a6337bc6fba08 100644 --- a/src/transformers/models/modernbert/configuration_modernbert.py +++ b/src/transformers/models/modernbert/configuration_modernbert.py @@ -53,9 +53,9 @@ class ModernBertConfig(PretrainedConfig): `num_attention_heads`. head_dim (`int`, *optional*, defaults to 256): The attention head dimension. - hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): - The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"` - if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function. + hidden_activation (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the decoder. Will default to `"gelu"` + if not specified. max_position_embeddings (`int`, *optional*, defaults to 8192): The maximum sequence length that this model might ever be used with. initializer_range (`float`, *optional*, defaults to 0.02): @@ -137,7 +137,7 @@ def __init__( classifier_pooling="mean", classifier_norm=True, classifier_bias=False, - classifier_activation=None, + classifier_activation="gelu", deterministic_flash_attn=False, **kwargs, ): @@ -179,5 +179,5 @@ def __init__( self.classifier_pooling = classifier_pooling self.classifier_bias = classifier_bias self.classifier_norm = classifier_norm - self.classifier_activation = classifier_activation if classifier_activation is not None else hidden_activation + self.classifier_activation = classifier_activation self.deterministic_flash_attn = deterministic_flash_attn diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index cb921c4dab9cfb..b2ba902d07bbbb 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -85,9 +85,9 @@ class ModernBertConfig(PretrainedConfig): `num_attention_heads`. head_dim (`int`, *optional*, defaults to 256): The attention head dimension. - hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): - The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"` - if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function. + hidden_activation (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the decoder. Will default to `"gelu"` + if not specified. max_position_embeddings (`int`, *optional*, defaults to 8192): The maximum sequence length that this model might ever be used with. initializer_range (`float`, *optional*, defaults to 0.02): @@ -169,7 +169,7 @@ def __init__( classifier_pooling="mean", classifier_norm=True, classifier_bias=False, - classifier_activation=None, + classifier_activation="gelu", deterministic_flash_attn=False, **kwargs, ): @@ -211,7 +211,7 @@ def __init__( self.classifier_pooling = classifier_pooling self.classifier_bias = classifier_bias self.classifier_norm = classifier_norm - self.classifier_activation = classifier_activation if classifier_activation is not None else hidden_activation + self.classifier_activation = classifier_activation self.deterministic_flash_attn = deterministic_flash_attn From b44abdc45165b4f2c3c0cedf73d9f0e7814f9f43 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Wed, 11 Dec 2024 16:14:50 +0100 Subject: [PATCH 16/88] Fix config name typo: initalizer -> initializer --- .../models/modernbert/configuration_modernbert.py | 4 ++-- src/transformers/models/modernbert/modeling_modernbert.py | 2 +- src/transformers/models/modernbert/modular_modernbert.py | 6 +++--- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/modernbert/configuration_modernbert.py b/src/transformers/models/modernbert/configuration_modernbert.py index 5a6337bc6fba08..b5ea720e3e3825 100644 --- a/src/transformers/models/modernbert/configuration_modernbert.py +++ b/src/transformers/models/modernbert/configuration_modernbert.py @@ -110,7 +110,7 @@ def __init__( hidden_activation="gelu", max_position_embeddings=8192, initializer_range=0.02, - initalizer_cutoff_factor=2.0, + initializer_cutoff_factor=2.0, norm_eps=1e-5, norm_bias=False, pad_token_id=50283, @@ -157,7 +157,7 @@ def __init__( self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.initializer_range = initializer_range - self.initalizer_cutoff_factor = initalizer_cutoff_factor + self.initializer_cutoff_factor = initializer_cutoff_factor self.norm_eps = norm_eps self.norm_bias = norm_bias self.global_rope_theta = global_rope_theta diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index c926603aa2aa7a..e1085ba138240e 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -253,7 +253,7 @@ def _init_modernbert_weights( if module_type is None: raise RuntimeError("When using the full megatron init, every module must have a type.") - cutoff_factor = config.initalizer_cutoff_factor + cutoff_factor = config.initializer_cutoff_factor if cutoff_factor is None: cutoff_factor = 3 diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index b2ba902d07bbbb..b36a6849cd620d 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -142,7 +142,7 @@ def __init__( hidden_activation="gelu", max_position_embeddings=8192, initializer_range=0.02, - initalizer_cutoff_factor=2.0, + initializer_cutoff_factor=2.0, norm_eps=1e-5, norm_bias=False, pad_token_id=50283, @@ -189,7 +189,7 @@ def __init__( self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.initializer_range = initializer_range - self.initalizer_cutoff_factor = initalizer_cutoff_factor + self.initializer_cutoff_factor = initializer_cutoff_factor self.norm_eps = norm_eps self.norm_bias = norm_bias self.global_rope_theta = global_rope_theta @@ -246,7 +246,7 @@ def _init_modernbert_weights( if module_type is None: raise RuntimeError("When using the full megatron init, every module must have a type.") - cutoff_factor = config.initalizer_cutoff_factor + cutoff_factor = config.initializer_cutoff_factor if cutoff_factor is None: cutoff_factor = 3 From 3de8ebfb121399c4873442e11e3afbdeec6e76fc Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Wed, 11 Dec 2024 16:26:22 +0100 Subject: [PATCH 17/88] Remove some unused parameters in docstring. Still lots to edit there! --- .../models/modernbert/configuration_modernbert.py | 5 ----- src/transformers/models/modernbert/modular_modernbert.py | 5 ----- 2 files changed, 10 deletions(-) diff --git a/src/transformers/models/modernbert/configuration_modernbert.py b/src/transformers/models/modernbert/configuration_modernbert.py index b5ea720e3e3825..7e034d3c3ebe5e 100644 --- a/src/transformers/models/modernbert/configuration_modernbert.py +++ b/src/transformers/models/modernbert/configuration_modernbert.py @@ -62,9 +62,6 @@ class ModernBertConfig(PretrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. rms_norm_eps (`float`, *optional*, defaults to 1e-06): The epsilon used by the rms normalization layers. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should return the last key/values attentions (not used by all models). Only - relevant if `config.is_decoder=True`. pad_token_id (`int`, *optional*, defaults to 0): Padding token id. eos_token_id (`int`, *optional*, defaults to 1): @@ -81,8 +78,6 @@ class ModernBertConfig(PretrainedConfig): attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. query_pre_attn_scalar (`float`, *optional*, defaults to 256): scaling factor used on the attention scores - sliding_window (`int`, *optional*, defaults to 4096): in ModernBert, every other layer uses sliding window attention. This is the - size of the sliding window. final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits. attn_logit_softcapping (`float`, *optional*, defaults to 50.0): scaling factor when applying tanh softcapping on the attention scores. cache_implementation (`str`, *optional*, defaults to `"hybrid"`): the cache type to be used with `generate`. diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index b36a6849cd620d..541f2554e7fc62 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -94,9 +94,6 @@ class ModernBertConfig(PretrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. rms_norm_eps (`float`, *optional*, defaults to 1e-06): The epsilon used by the rms normalization layers. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should return the last key/values attentions (not used by all models). Only - relevant if `config.is_decoder=True`. pad_token_id (`int`, *optional*, defaults to 0): Padding token id. eos_token_id (`int`, *optional*, defaults to 1): @@ -113,8 +110,6 @@ class ModernBertConfig(PretrainedConfig): attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. query_pre_attn_scalar (`float`, *optional*, defaults to 256): scaling factor used on the attention scores - sliding_window (`int`, *optional*, defaults to 4096): in ModernBert, every other layer uses sliding window attention. This is the - size of the sliding window. final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits. attn_logit_softcapping (`float`, *optional*, defaults to 50.0): scaling factor when applying tanh softcapping on the attention scores. cache_implementation (`str`, *optional*, defaults to `"hybrid"`): the cache type to be used with `generate`. From 7a05b3f109eb853ef78e339e8c7555f2f17c3129 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Thu, 12 Dec 2024 10:22:22 +0100 Subject: [PATCH 18/88] Compile the embeddings forward Not having this resulted in very slight differences - so small it wasn't even noticed for the base model, only for the large model. But the tiny difference for large propagated at the embedding layer through the rest of the model, leading to notable differences of ~0.0084 average per value, up to 0.2343 for the worst case. --- src/transformers/models/modernbert/modeling_modernbert.py | 1 + src/transformers/models/modernbert/modular_modernbert.py | 1 + 2 files changed, 2 insertions(+) diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index e1085ba138240e..82d1e1d7d778bc 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -298,6 +298,7 @@ def _init_weights(self, reset_params: bool = False): if reset_params: self.norm.reset_parameters() + @torch.compile(dynamic=True) def forward(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None) -> torch.Tensor: return self.drop(self.norm(self.tok_embeddings(input_ids))) diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index 541f2554e7fc62..4dc31e117f5c18 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -558,6 +558,7 @@ def _init_weights(self, reset_params: bool = False): if reset_params: self.norm.reset_parameters() + @torch.compile(dynamic=True) def forward(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None) -> torch.Tensor: return self.drop(self.norm(self.tok_embeddings(input_ids))) From 88b0ecfb94cc5a546e2d80c1736f079f58ec8943 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Thu, 12 Dec 2024 13:49:12 +0100 Subject: [PATCH 19/88] Add drafts for ForSequenceClassification/ForTokenClassification --- docs/source/en/index.md | 1 + src/transformers/__init__.py | 15 +- src/transformers/models/__init__.py | 2 +- .../models/auto/configuration_auto.py | 4 +- src/transformers/models/auto/modeling_auto.py | 8 +- .../models/auto/tokenization_auto.py | 2 +- .../models/modernbert/__init__.py | 4 + .../models/modernbert/modeling_modernbert.py | 178 +++++++++++++++++- .../models/modernbert/modular_modernbert.py | 176 ++++++++++++++++- src/transformers/utils/dummy_pt_objects.py | 28 +++ 10 files changed, 395 insertions(+), 23 deletions(-) diff --git a/docs/source/en/index.md b/docs/source/en/index.md index 8a9ccf45b69c26..817e0668abc227 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -224,6 +224,7 @@ Flax), PyTorch, and/or TensorFlow. | [MobileNetV2](model_doc/mobilenet_v2) | βœ… | ❌ | ❌ | | [MobileViT](model_doc/mobilevit) | βœ… | βœ… | ❌ | | [MobileViTV2](model_doc/mobilevitv2) | βœ… | ❌ | ❌ | +| [ModernBERT](model_doc/modernbert) | βœ… | ❌ | ❌ | | [Moshi](model_doc/moshi) | βœ… | ❌ | ❌ | | [MPNet](model_doc/mpnet) | βœ… | βœ… | ❌ | | [MPT](model_doc/mpt) | βœ… | ❌ | ❌ | diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index fe7197694be8d0..2ef6eaf73c5155 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -5339,7 +5339,6 @@ from .models.fuyu import FuyuConfig from .models.gemma import GemmaConfig from .models.gemma2 import Gemma2Config - from .models.modernbert import ModernBertConfig from .models.git import ( GitConfig, GitProcessor, @@ -5503,6 +5502,7 @@ from .models.mobilevitv2 import ( MobileViTV2Config, ) + from .models.modernbert import ModernBertConfig from .models.moshi import ( MoshiConfig, MoshiDepthConfig, @@ -7068,13 +7068,6 @@ Gemma2Model, Gemma2PreTrainedModel, ) - from .models.modernbert import ( - ModernBertForCausalLM, - ModernBertForSequenceClassification, - ModernBertForTokenClassification, - ModernBertModel, - ModernBertPreTrainedModel, - ) from .models.git import ( GitForCausalLM, GitModel, @@ -7466,6 +7459,12 @@ MobileViTV2Model, MobileViTV2PreTrainedModel, ) + from .models.modernbert import ( + ModernBertForSequenceClassification, + ModernBertForTokenClassification, + ModernBertModel, + ModernBertPreTrainedModel, + ) from .models.moshi import ( MoshiForCausalLM, MoshiForConditionalGeneration, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 2d295287b39de0..a457b2ef0d8c0c 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -96,7 +96,6 @@ fuyu, gemma, gemma2, - modernbert, git, glm, glpn, @@ -163,6 +162,7 @@ mobilenet_v2, mobilevit, mobilevitv2, + modernbert, moshi, mpnet, mpt, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 9c1374c5b9ccb6..a809e0c84ab795 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -113,7 +113,6 @@ ("fuyu", "FuyuConfig"), ("gemma", "GemmaConfig"), ("gemma2", "Gemma2Config"), - ("modernbert", "ModernBertConfig"), ("git", "GitConfig"), ("glm", "GlmConfig"), ("glpn", "GLPNConfig"), @@ -181,6 +180,7 @@ ("mobilenet_v2", "MobileNetV2Config"), ("mobilevit", "MobileViTConfig"), ("mobilevitv2", "MobileViTV2Config"), + ("modernbert", "ModernBertConfig"), ("moshi", "MoshiConfig"), ("mpnet", "MPNetConfig"), ("mpt", "MptConfig"), @@ -418,7 +418,6 @@ ("fuyu", "Fuyu"), ("gemma", "Gemma"), ("gemma2", "Gemma2"), - ("modernbert", "ModernBERT"), ("git", "GIT"), ("glm", "GLM"), ("glpn", "GLPN"), @@ -496,6 +495,7 @@ ("mobilenet_v2", "MobileNetV2"), ("mobilevit", "MobileViT"), ("mobilevitv2", "MobileViTV2"), + ("modernbert", "ModernBERT"), ("moshi", "Moshi"), ("mpnet", "MPNet"), ("mpt", "MPT"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index f36ae22755b67b..0a773017e2bfd0 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -110,7 +110,6 @@ ("funnel", ("FunnelModel", "FunnelBaseModel")), ("gemma", "GemmaModel"), ("gemma2", "Gemma2Model"), - ("modernbert", "ModernBertModel"), ("git", "GitModel"), ("glm", "GlmModel"), ("glpn", "GLPNModel"), @@ -171,6 +170,7 @@ ("mobilenet_v2", "MobileNetV2Model"), ("mobilevit", "MobileViTModel"), ("mobilevitv2", "MobileViTV2Model"), + ("modernbert", "ModernBertModel"), ("moshi", "MoshiModel"), ("mpnet", "MPNetModel"), ("mpt", "MptModel"), @@ -488,7 +488,6 @@ ("fuyu", "FuyuForCausalLM"), ("gemma", "GemmaForCausalLM"), ("gemma2", "Gemma2ForCausalLM"), - ("modernbert", "ModernBertForCausalLM"), ("git", "GitForCausalLM"), ("glm", "GlmForCausalLM"), ("gpt-sw3", "GPT2LMHeadModel"), @@ -512,6 +511,7 @@ ("mistral", "MistralForCausalLM"), ("mixtral", "MixtralForCausalLM"), ("mllama", "MllamaForCausalLM"), + ("modernbert", "ModernBertForCausalLM"), ("moshi", "MoshiForCausalLM"), ("mpt", "MptForCausalLM"), ("musicgen", "MusicgenForCausalLM"), @@ -948,7 +948,6 @@ ("funnel", "FunnelForSequenceClassification"), ("gemma", "GemmaForSequenceClassification"), ("gemma2", "Gemma2ForSequenceClassification"), - ("modernbert", "ModernBertForSequenceClassification"), ("glm", "GlmForSequenceClassification"), ("gpt-sw3", "GPT2ForSequenceClassification"), ("gpt2", "GPT2ForSequenceClassification"), @@ -974,6 +973,7 @@ ("mistral", "MistralForSequenceClassification"), ("mixtral", "MixtralForSequenceClassification"), ("mobilebert", "MobileBertForSequenceClassification"), + ("modernbert", "ModernBertForSequenceClassification"), ("mpnet", "MPNetForSequenceClassification"), ("mpt", "MptForSequenceClassification"), ("mra", "MraForSequenceClassification"), @@ -1140,7 +1140,6 @@ ("funnel", "FunnelForTokenClassification"), ("gemma", "GemmaForTokenClassification"), ("gemma2", "Gemma2ForTokenClassification"), - ("modernbert", "ModernBertForTokenClassification"), ("glm", "GlmForTokenClassification"), ("gpt-sw3", "GPT2ForTokenClassification"), ("gpt2", "GPT2ForTokenClassification"), @@ -1161,6 +1160,7 @@ ("mistral", "MistralForTokenClassification"), ("mixtral", "MixtralForTokenClassification"), ("mobilebert", "MobileBertForTokenClassification"), + ("modernbert", "ModernBertForTokenClassification"), ("mpnet", "MPNetForTokenClassification"), ("mpt", "MptForTokenClassification"), ("mra", "MraForTokenClassification"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index c12948db48555d..b5e2dbc6edf017 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -203,7 +203,6 @@ "GemmaTokenizerFast" if is_tokenizers_available() else None, ), ), - ("modernbert", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), ("git", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), ("glm", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), ("gpt-sw3", ("GPTSw3Tokenizer" if is_sentencepiece_available() else None, None)), @@ -311,6 +310,7 @@ ("mllama", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)), ("mluke", ("MLukeTokenizer" if is_sentencepiece_available() else None, None)), ("mobilebert", ("MobileBertTokenizer", "MobileBertTokenizerFast" if is_tokenizers_available() else None)), + ("modernbert", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), ("moshi", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), ("mpnet", ("MPNetTokenizer", "MPNetTokenizerFast" if is_tokenizers_available() else None)), ("mpt", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)), diff --git a/src/transformers/models/modernbert/__init__.py b/src/transformers/models/modernbert/__init__.py index 91e19878f84f30..9a586d4ca54553 100644 --- a/src/transformers/models/modernbert/__init__.py +++ b/src/transformers/models/modernbert/__init__.py @@ -34,6 +34,8 @@ "ModernBertForMaskedLM", "ModernBertModel", "ModernBertPreTrainedModel", + "ModernBertForSequenceClassification", + "ModernBertForTokenClassification", ] if TYPE_CHECKING: @@ -47,6 +49,8 @@ else: from .modeling_modernbert import ( ModernBertForMaskedLM, + ModernBertForSequenceClassification, + ModernBertForTokenClassification, ModernBertModel, ModernBertPreTrainedModel, ) diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index 82d1e1d7d778bc..b70f92995806dd 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -26,9 +26,10 @@ import torch from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...modeling_outputs import BaseModelOutput, MaskedLMOutput +from ...modeling_outputs import BaseModelOutput, MaskedLMOutput, SequenceClassifierOutput, TokenClassifierOutput from ...modeling_utils import PreTrainedModel from ...utils import is_flash_attn_2_available from .configuration_modernbert import ModernBertConfig @@ -825,7 +826,7 @@ def _unpad_inputs( attention_mask: torch.Tensor, position_ids: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, - ): + ) -> Tuple[torch.Tensor | int | None]: return _unpad_modernbert_input( inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels ) @@ -883,7 +884,7 @@ def set_input_embeddings(self, value): def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None): if module and hasattr(module, "_init_weights"): - super()._init_weights(module, reset_params) + super()._init_weights(module, reset_params=reset_params) elif isinstance(reset_params, bool): self.embeddings._init_weights(reset_params=reset_params) @@ -985,7 +986,7 @@ def __init__( def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None): assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified" if module: - super()._init_weights(module) + super()._init_weights(module, reset_params=reset_params) else: assert isinstance(reset_params, bool) self.model._init_weights(reset_params=reset_params) @@ -1086,3 +1087,172 @@ def forward( seq_len=seq_len, labels=labels, ) + + +class ModernBertForSequenceClassification(ModernBertPreTrainedModel): + def __init__(self, config: ModernBertConfig): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + + self.model = ModernBertModel(config) + self.head = ModernBertPoolingHead(config) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self._init_weights(reset_params=False) + + def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None): + assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified" + if module: + super()._init_weights(module, reset_params=reset_params) + else: + assert isinstance(reset_params, bool) + self.model._init_weights(reset_params=reset_params) + self.head._init_weights(reset_params=reset_params) + _init_modernbert_weights(self.config, self.classifier, type_of_module=ModernBertModuleType.final_out) + + def forward( + self, + input_ids: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + return_dict: Optional[bool] = None, + indices: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + batch_size: Optional[int] = None, + seq_len: Optional[int] = None, + **kwargs, + ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + batch_size=batch_size, + seq_len=seq_len, + return_dict=return_dict, + ) + last_hidden_state = outputs[0] + + pooled_output = self.head(last_hidden_state) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=None, + attentions=None, + ) + + +class ModernBertForTokenClassification(ModernBertPreTrainedModel): + def __init__(self, config: ModernBertConfig): + super().__init__(config) + self.num_labels = config.num_labels + + self.model = ModernBertModel(config) + self.drop = nn.Dropout(config.classifier_dropout) if config.classifier_dropout > 0 else nn.Identity() + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self._init_weights(reset_params=False) + + def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None): + assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified" + if module: + super()._init_weights(module, reset_params=reset_params) + else: + assert isinstance(reset_params, bool) + self.model._init_weights(reset_params=reset_params) + _init_modernbert_weights(self.config, self.classifier, type_of_module=ModernBertModuleType.final_out) + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + last_hidden_state = outputs[0] + + last_hidden_state = self.drop(last_hidden_state) + logits = self.classifier(last_hidden_state) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=None, + attentions=None, + ) diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index 4dc31e117f5c18..7ec498ad448991 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -21,6 +21,7 @@ import torch from torch import nn import torch.utils.checkpoint +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...configuration_utils import PretrainedConfig @@ -971,7 +972,7 @@ def _unpad_inputs( attention_mask: torch.Tensor, position_ids: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, - ): + ) -> Tuple[torch.Tensor | int | None]: return _unpad_modernbert_input( inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels ) @@ -1029,7 +1030,7 @@ def set_input_embeddings(self, value): def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None): if module and hasattr(module, "_init_weights"): - super()._init_weights(module, reset_params) + super()._init_weights(module, reset_params=reset_params) elif isinstance(reset_params, bool): self.embeddings._init_weights(reset_params=reset_params) @@ -1131,7 +1132,7 @@ def __init__( def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None): assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified" if module: - super()._init_weights(module) + super()._init_weights(module, reset_params=reset_params) else: assert isinstance(reset_params, bool) self.model._init_weights(reset_params=reset_params) @@ -1232,3 +1233,172 @@ def forward( seq_len=seq_len, labels=labels, ) + + +class ModernBertForSequenceClassification(ModernBertPreTrainedModel): + def __init__(self, config: ModernBertConfig): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + + self.model = ModernBertModel(config) + self.head = ModernBertPoolingHead(config) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self._init_weights(reset_params=False) + + def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None): + assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified" + if module: + super()._init_weights(module, reset_params=reset_params) + else: + assert isinstance(reset_params, bool) + self.model._init_weights(reset_params=reset_params) + self.head._init_weights(reset_params=reset_params) + _init_modernbert_weights(self.config, self.classifier, type_of_module=ModernBertModuleType.final_out) + + def forward( + self, + input_ids: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + return_dict: Optional[bool] = None, + indices: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + batch_size: Optional[int] = None, + seq_len: Optional[int] = None, + **kwargs, + ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + batch_size=batch_size, + seq_len=seq_len, + return_dict=return_dict, + ) + last_hidden_state = outputs[0] + + pooled_output = self.head(last_hidden_state) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=None, + attentions=None, + ) + + +class ModernBertForTokenClassification(ModernBertPreTrainedModel): + def __init__(self, config: ModernBertConfig): + super().__init__(config) + self.num_labels = config.num_labels + + self.model = ModernBertModel(config) + self.drop = nn.Dropout(config.classifier_dropout) if config.classifier_dropout > 0 else nn.Identity() + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self._init_weights(reset_params=False) + + def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None): + assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified" + if module: + super()._init_weights(module, reset_params=reset_params) + else: + assert isinstance(reset_params, bool) + self.model._init_weights(reset_params=reset_params) + _init_modernbert_weights(self.config, self.classifier, type_of_module=ModernBertModuleType.final_out) + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + last_hidden_state = outputs[0] + + last_hidden_state = self.drop(last_hidden_state) + logits = self.classifier(last_hidden_state) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=None, + attentions=None, + ) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 1238f058783c18..6d11147391d97d 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -4375,6 +4375,34 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class ModernBertForSequenceClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class ModernBertForTokenClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class ModernBertModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class ModernBertPreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class GitForCausalLM(metaclass=DummyObject): _backends = ["torch"] From 5e3d61d7180540003f5753c3f4186b456d160291 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Thu, 12 Dec 2024 15:13:32 +0100 Subject: [PATCH 20/88] Add initial SDPA support (not exactly equivalent to FA2 yet!) During testing, FA2 and SDPA still differ by about 0.0098 per value in the token embeddings. It still predicts the correct mask fills, but I'd like to get it fully 1-1 if possible. --- .../modernbert/configuration_modernbert.py | 5 +- .../models/modernbert/modeling_modernbert.py | 173 ++++++++++------- .../models/modernbert/modular_modernbert.py | 174 +++++++++++------- 3 files changed, 224 insertions(+), 128 deletions(-) diff --git a/src/transformers/models/modernbert/configuration_modernbert.py b/src/transformers/models/modernbert/configuration_modernbert.py index 7e034d3c3ebe5e..da08d860167b61 100644 --- a/src/transformers/models/modernbert/configuration_modernbert.py +++ b/src/transformers/models/modernbert/configuration_modernbert.py @@ -125,7 +125,7 @@ def __init__( embedding_dropout=0.0, mlp_bias=False, mlp_dropout=0.0, - unpad_inputs=True, + unpad_inputs=None, unpad_no_grad=True, decoder_bias=True, classifier_dropout=0.0, @@ -176,3 +176,6 @@ def __init__( self.classifier_norm = classifier_norm self.classifier_activation = classifier_activation self.deterministic_flash_attn = deterministic_flash_attn + + if unpad_inputs is None: + self.unpad_inputs = self._attn_implementation == "flash_attention_2" diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index b70f92995806dd..c90c4a9b7ad6a1 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -25,10 +25,12 @@ from typing import Optional, Tuple, Union import torch +import torch.nn.functional as F from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_outputs import BaseModelOutput, MaskedLMOutput, SequenceClassifierOutput, TokenClassifierOutput from ...modeling_utils import PreTrainedModel from ...utils import is_flash_attn_2_available @@ -39,6 +41,8 @@ from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func from flash_attn.layers.rotary import RotaryEmbedding from flash_attn.ops.triton.rotary import apply_rotary +else: + RotaryEmbedding = None class ModernBertModuleType(str, Enum): @@ -356,6 +360,40 @@ def forward(self, x, position_ids, seq_len=None): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + # def eager_attention_forward( # config: ModernBertConfig, # query: torch.Tensor, @@ -452,44 +490,41 @@ def flash_attention_forward( # return attn_output, attn_weights -# def sdpa_attention_forward( -# config: ModernBertConfig, -# query: torch.Tensor, -# key: torch.Tensor, -# value: torch.Tensor, -# mask: Optional[torch.Tensor], -# **_kwargs, -# ) -> Tuple[torch.Tensor, None]: - -# causal_mask = mask -# if mask is not None: -# causal_mask = causal_mask[:, :, :, : key.shape[-2]] - -# # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, -# # Reference: https://github.com/pytorch/pytorch/issues/112577. -# if query.device.type == "cuda" and causal_mask is not None: -# query = query.contiguous() -# key = key.contiguous() -# value = value.contiguous() - -# attn_output = torch.nn.functional.scaled_dot_product_attention( -# query, -# key, -# value, -# attn_mask=causal_mask, -# dropout_p=config.attention_dropout if config.training else 0.0, -# is_causal=False, -# scale=config.scaling, -# ) -# attn_output = attn_output.transpose(1, 2).contiguous() -# return attn_output, None +def sdpa_attention_forward( + self: "ModernBertAttention", + qkv: torch.Tensor, + position_ids: Optional[torch.LongTensor], + attention_mask: torch.Tensor, + bs: int, + seqlen: int, + dim: int, + **_kwargs, +) -> Tuple[torch.Tensor, None]: + # qkv: [batch_size, seqlen, 3, nheads, headdim] + cos, sin = self.rotary_emb(qkv, position_ids=position_ids) + query, key, value = qkv.transpose(3, 1).unbind(dim=2) + # query, key, value: [batch_size, heads, seq_len, head_dim] + query, key = apply_rotary_pos_emb(query, key, cos, sin) + + if attention_mask is not None: + attention_mask = attention_mask[:, :, :, : key.shape[-2]] + + attn_output = F.scaled_dot_product_attention( + query, + key, + value, + dropout_p=self.attention_dropout, + attn_mask=attention_mask, + ).transpose(1, 2) + attn_output = attn_output.view(bs, seqlen, dim) + return attn_output MODERNBERT_ATTENTION_FUNCTION = { "flash_attention_2": flash_attention_forward, # "flex_attention": flex_attention_forward, # "eager": eager_attention_forward, - # "sdpa": sdpa_attention_forward, + "sdpa": sdpa_attention_forward, } @@ -575,21 +610,41 @@ def forward( Returns: attention: (total_nnz, dim) """ - bs, dim = hidden_states.shape[:2] qkv = self.Wqkv(hidden_states) - qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) + + attn_kwargs = {} + if self.config._attn_implementation == "flash_attention_2": + bs, dim = hidden_states.shape[:2] + qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) + + attn_kwargs.update( + { + "local_attention": self.local_attention, + "cu_seqlens": cu_seqlens, + "max_seqlen": max_seqlen, + "bs": bs, + "dim": dim, + } + ) + else: + bs, seqlen, dim = hidden_states.shape + qkv = qkv.view(bs, seqlen, 3, self.num_heads, self.head_dim) + + attn_kwargs.update( + { + "position_ids": position_ids, + "attention_mask": attention_mask, + "bs": bs, + "seqlen": seqlen, + "dim": dim, + } + ) attn = MODERNBERT_ATTENTION_FUNCTION[self.config._attn_implementation]( self, qkv=qkv, rotary_emb=self.rotary_emb, - position_ids=position_ids, - attention_mask=attention_mask, - local_attention=self.local_attention, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - bs=bs, - dim=dim, + **attn_kwargs, ) return self.out_drop(self.Wo(attn)) @@ -799,7 +854,7 @@ class ModernBertPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["ModernBertEmbeddings", "ModernBertEncoderLayer"] _supports_flash_attn_2 = True - _supports_sdpa = False # TODO: Enable SDPA + _supports_sdpa = True def _init_weights( self, @@ -922,10 +977,15 @@ def forward( else: input_ids, indices, cu_seqlens, max_seqlen, *_ = self._unpad_inputs(input_ids, attention_mask) elif position_ids is None: - position_ids = torch.arange(seq_len, device=input_ids.device) + position_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0) hidden_states = self.embeddings(input_ids) + # expand attention_mask + if self.config._attn_implementation != "flash_attention_2" and attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) + for encoder_layer in self.layers: if self.gradient_checkpointing and self.training: hidden_states = self._gradient_checkpointing_func( @@ -1063,30 +1123,17 @@ def forward( if labels is not None: loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size) - if self.pad_logits: + if self.config.unpad_inputs: if self.pad_logits_no_grad: logits, _ = self._pad_outputs_no_grad(logits, indices, batch_size, seq_len) else: logits, _ = self._pad_outputs(logits, indices, batch_size, seq_len) - return MaskedLMOutput( - loss=loss, - logits=logits, - hidden_states=None, - attentions=None, - ) - else: - return MaskedLMOutput( - loss=loss, - logits=logits, - hidden_states=None, - attentions=None, - indices=indices, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - batch_size=batch_size, - seq_len=seq_len, - labels=labels, - ) + return MaskedLMOutput( + loss=loss, + logits=logits, + hidden_states=None, + attentions=None, + ) class ModernBertForSequenceClassification(ModernBertPreTrainedModel): diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index 7ec498ad448991..738161b3322a09 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -22,9 +22,11 @@ from torch import nn import torch.utils.checkpoint from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +import torch.nn.functional as F from ...activations import ACT2FN from ...configuration_utils import PretrainedConfig +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_outputs import ( BaseModelOutput, MaskedLMOutput, @@ -46,6 +48,8 @@ from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func from flash_attn.layers.rotary import RotaryEmbedding from flash_attn.ops.triton.rotary import apply_rotary +else: + RotaryEmbedding = None if is_torch_greater_or_equal("2.5"): from torch.nn.attention.flex_attention import flex_attention @@ -158,7 +162,7 @@ def __init__( embedding_dropout=0.0, mlp_bias=False, mlp_dropout=0.0, - unpad_inputs=True, + unpad_inputs=None, unpad_no_grad=True, decoder_bias=True, classifier_dropout=0.0, @@ -210,6 +214,9 @@ def __init__( self.classifier_activation = classifier_activation self.deterministic_flash_attn = deterministic_flash_attn + if unpad_inputs is None: + self.unpad_inputs = self._attn_implementation == "flash_attention_2" + class ModernBertModuleType(str, Enum): in_module = "in" @@ -592,6 +599,40 @@ class ModernBertRotaryEmbedding(GemmaRotaryEmbedding): pass +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + # def eager_attention_forward( # config: ModernBertConfig, # query: torch.Tensor, @@ -688,44 +729,41 @@ def flash_attention_forward( # return attn_output, attn_weights -# def sdpa_attention_forward( -# config: ModernBertConfig, -# query: torch.Tensor, -# key: torch.Tensor, -# value: torch.Tensor, -# mask: Optional[torch.Tensor], -# **_kwargs, -# ) -> Tuple[torch.Tensor, None]: - -# causal_mask = mask -# if mask is not None: -# causal_mask = causal_mask[:, :, :, : key.shape[-2]] - -# # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, -# # Reference: https://github.com/pytorch/pytorch/issues/112577. -# if query.device.type == "cuda" and causal_mask is not None: -# query = query.contiguous() -# key = key.contiguous() -# value = value.contiguous() - -# attn_output = torch.nn.functional.scaled_dot_product_attention( -# query, -# key, -# value, -# attn_mask=causal_mask, -# dropout_p=config.attention_dropout if config.training else 0.0, -# is_causal=False, -# scale=config.scaling, -# ) -# attn_output = attn_output.transpose(1, 2).contiguous() -# return attn_output, None +def sdpa_attention_forward( + self: "ModernBertAttention", + qkv: torch.Tensor, + position_ids: Optional[torch.LongTensor], + attention_mask: torch.Tensor, + bs: int, + seqlen: int, + dim: int, + **_kwargs, +) -> Tuple[torch.Tensor, None]: + # qkv: [batch_size, seqlen, 3, nheads, headdim] + cos, sin = self.rotary_emb(qkv, position_ids=position_ids) + query, key, value = qkv.transpose(3, 1).unbind(dim=2) + # query, key, value: [batch_size, heads, seq_len, head_dim] + query, key = apply_rotary_pos_emb(query, key, cos, sin) + + if attention_mask is not None: + attention_mask = attention_mask[:, :, :, : key.shape[-2]] + + attn_output = F.scaled_dot_product_attention( + query, + key, + value, + dropout_p=self.attention_dropout, + attn_mask=attention_mask, + ).transpose(1, 2) + attn_output = attn_output.view(bs, seqlen, dim) + return attn_output MODERNBERT_ATTENTION_FUNCTION = { "flash_attention_2": flash_attention_forward, # "flex_attention": flex_attention_forward, # "eager": eager_attention_forward, - # "sdpa": sdpa_attention_forward, + "sdpa": sdpa_attention_forward, } @@ -811,21 +849,37 @@ def forward( Returns: attention: (total_nnz, dim) """ - bs, dim = hidden_states.shape[:2] qkv = self.Wqkv(hidden_states) - qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) + + attn_kwargs = {} + if self.config._attn_implementation == "flash_attention_2": + bs, dim = hidden_states.shape[:2] + qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) + + attn_kwargs.update({ + "local_attention": self.local_attention, + "cu_seqlens": cu_seqlens, + "max_seqlen": max_seqlen, + "bs": bs, + "dim": dim, + }) + else: + bs, seqlen, dim = hidden_states.shape + qkv = qkv.view(bs, seqlen, 3, self.num_heads, self.head_dim) + + attn_kwargs.update({ + "position_ids": position_ids, + "attention_mask": attention_mask, + "bs": bs, + "seqlen": seqlen, + "dim": dim, + }) attn = MODERNBERT_ATTENTION_FUNCTION[self.config._attn_implementation]( self, qkv=qkv, rotary_emb=self.rotary_emb, - position_ids=position_ids, - attention_mask=attention_mask, - local_attention=self.local_attention, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - bs=bs, - dim=dim, + **attn_kwargs, ) return self.out_drop(self.Wo(attn)) @@ -945,7 +999,7 @@ class ModernBertPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["ModernBertEmbeddings", "ModernBertEncoderLayer"] _supports_flash_attn_2 = True - _supports_sdpa = False # TODO: Enable SDPA + _supports_sdpa = True def _init_weights( self, @@ -1068,10 +1122,15 @@ def forward( else: input_ids, indices, cu_seqlens, max_seqlen, *_ = self._unpad_inputs(input_ids, attention_mask) elif position_ids is None: - position_ids = torch.arange(seq_len, device=input_ids.device) + position_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0) hidden_states = self.embeddings(input_ids) + # expand attention_mask + if self.config._attn_implementation != "flash_attention_2" and attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) + for encoder_layer in self.layers: if self.gradient_checkpointing and self.training: hidden_states = self._gradient_checkpointing_func( @@ -1209,30 +1268,17 @@ def forward( if labels is not None: loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size) - if self.pad_logits: + if self.config.unpad_inputs: if self.pad_logits_no_grad: logits, _ = self._pad_outputs_no_grad(logits, indices, batch_size, seq_len) else: logits, _ = self._pad_outputs(logits, indices, batch_size, seq_len) - return MaskedLMOutput( - loss=loss, - logits=logits, - hidden_states=None, - attentions=None, - ) - else: - return MaskedLMOutput( - loss=loss, - logits=logits, - hidden_states=None, - attentions=None, - indices=indices, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - batch_size=batch_size, - seq_len=seq_len, - labels=labels, - ) + return MaskedLMOutput( + loss=loss, + logits=logits, + hidden_states=None, + attentions=None, + ) class ModernBertForSequenceClassification(ModernBertPreTrainedModel): From 2a3d3783865cec4ff10838e98e0e7d6992e6d89e Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Thu, 12 Dec 2024 15:22:10 +0100 Subject: [PATCH 21/88] Only use attention dropout if training --- src/transformers/models/modernbert/modeling_modernbert.py | 2 +- src/transformers/models/modernbert/modular_modernbert.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index c90c4a9b7ad6a1..a1d3e218272efa 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -513,7 +513,7 @@ def sdpa_attention_forward( query, key, value, - dropout_p=self.attention_dropout, + dropout_p=self.attention_dropout if self.training else 0.0, attn_mask=attention_mask, ).transpose(1, 2) attn_output = attn_output.view(bs, seqlen, dim) diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index 738161b3322a09..38fed002b65b1d 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -752,7 +752,7 @@ def sdpa_attention_forward( query, key, value, - dropout_p=self.attention_dropout, + dropout_p=self.attention_dropout if self.training else 0.0, attn_mask=attention_mask, ).transpose(1, 2) attn_output = attn_output.view(bs, seqlen, dim) From a2051d6b4a0c771ecd36ba05603fe5823110acb5 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Thu, 12 Dec 2024 17:02:16 +0100 Subject: [PATCH 22/88] Add initial eager attention support (also not equivalent to FA2 yet!) Frustratingly, I also can't get eager to be equivalent to FA2 (or sdpa), but it does get really close, i.e. avg ~0.010 difference per value. Especially if I use fp32 for both FA2&eager, avg ~0.0029 difference per value The fill-mask results are good with eager. --- .../models/modernbert/modeling_modernbert.py | 57 +++++++++++-------- .../models/modernbert/modular_modernbert.py | 57 +++++++++++-------- 2 files changed, 64 insertions(+), 50 deletions(-) diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index a1d3e218272efa..7276ef4aede93e 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -394,30 +394,37 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed, k_embed -# def eager_attention_forward( -# config: ModernBertConfig, -# query: torch.Tensor, -# key: torch.Tensor, -# value: torch.Tensor, -# mask: Optional[torch.Tensor], -# **_kwargs, -# ) -> Tuple[torch.Tensor, torch.Tensor]: -# attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * config.scaling - -# if config.attn_logit_softcapping is not None: -# attn_weights = attn_weights / config.attn_logit_softcapping -# attn_weights = torch.tanh(attn_weights) -# attn_weights = attn_weights * config.attn_logit_softcapping -# if mask is not None: # no matter the length, we just slice it -# causal_mask = mask[:, :, :, : key_states.shape[-2]] -# attn_weights = attn_weights + causal_mask - -# # upcast attention to fp32 -# attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) -# attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=config.training) -# attn_output = torch.matmul(attn_weights, value_states) -# attn_output = attn_output.transpose(1, 2).contiguous() -# return attn_output, attn_weights +def eager_attention_forward( + self: "ModernBertAttention", + qkv: torch.Tensor, + position_ids: Optional[torch.LongTensor], + attention_mask: torch.Tensor, + bs: int, + seqlen: int, + dim: int, + **_kwargs, +) -> Tuple[torch.Tensor, torch.Tensor]: + # qkv: [batch_size, seqlen, 3, nheads, headdim] + cos, sin = self.rotary_emb(qkv, position_ids=position_ids) + query, key, value = qkv.transpose(3, 1).unbind(dim=2) + # query, key, value: [batch_size, heads, seq_len, head_dim] + query, key = apply_rotary_pos_emb(query, key, cos, sin) + + scale = self.head_dim**-0.5 + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scale + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + if self.training: + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bs, seqlen, dim) + return attn_output def flash_attention_forward( @@ -523,7 +530,7 @@ def sdpa_attention_forward( MODERNBERT_ATTENTION_FUNCTION = { "flash_attention_2": flash_attention_forward, # "flex_attention": flex_attention_forward, - # "eager": eager_attention_forward, + "eager": eager_attention_forward, "sdpa": sdpa_attention_forward, } diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index 38fed002b65b1d..36991801b15907 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -633,30 +633,37 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed, k_embed -# def eager_attention_forward( -# config: ModernBertConfig, -# query: torch.Tensor, -# key: torch.Tensor, -# value: torch.Tensor, -# mask: Optional[torch.Tensor], -# **_kwargs, -# ) -> Tuple[torch.Tensor, torch.Tensor]: -# attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * config.scaling - -# if config.attn_logit_softcapping is not None: -# attn_weights = attn_weights / config.attn_logit_softcapping -# attn_weights = torch.tanh(attn_weights) -# attn_weights = attn_weights * config.attn_logit_softcapping -# if mask is not None: # no matter the length, we just slice it -# causal_mask = mask[:, :, :, : key_states.shape[-2]] -# attn_weights = attn_weights + causal_mask - -# # upcast attention to fp32 -# attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) -# attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=config.training) -# attn_output = torch.matmul(attn_weights, value_states) -# attn_output = attn_output.transpose(1, 2).contiguous() -# return attn_output, attn_weights +def eager_attention_forward( + self: "ModernBertAttention", + qkv: torch.Tensor, + position_ids: Optional[torch.LongTensor], + attention_mask: torch.Tensor, + bs: int, + seqlen: int, + dim: int, + **_kwargs, +) -> Tuple[torch.Tensor, torch.Tensor]: + # qkv: [batch_size, seqlen, 3, nheads, headdim] + cos, sin = self.rotary_emb(qkv, position_ids=position_ids) + query, key, value = qkv.transpose(3, 1).unbind(dim=2) + # query, key, value: [batch_size, heads, seq_len, head_dim] + query, key = apply_rotary_pos_emb(query, key, cos, sin) + + scale = self.head_dim**-0.5 + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scale + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + if self.training: + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bs, seqlen, dim) + return attn_output def flash_attention_forward( @@ -762,7 +769,7 @@ def sdpa_attention_forward( MODERNBERT_ATTENTION_FUNCTION = { "flash_attention_2": flash_attention_forward, # "flex_attention": flex_attention_forward, - # "eager": eager_attention_forward, + "eager": eager_attention_forward, "sdpa": sdpa_attention_forward, } From 124f1fd42fe7c68c728e998e0c69060e011681a7 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Fri, 13 Dec 2024 12:55:03 +0100 Subject: [PATCH 23/88] Add initial tests, output_attentions, output_hidden_states, prune_heads Tests are based on BERT, not all tests pass yet: 23 failed, 79 passed, 100 skipped --- src/transformers/__init__.py | 1 + .../models/modernbert/modeling_modernbert.py | 185 +++++++--- .../models/modernbert/modular_modernbert.py | 211 +++++++++--- src/transformers/pytorch_utils.py | 40 +++ .../modernbert/test_modeling_modernbert.py | 320 ++++++++++++++++++ 5 files changed, 656 insertions(+), 101 deletions(-) create mode 100644 tests/models/modernbert/test_modeling_modernbert.py diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 2ef6eaf73c5155..4c5533bfa2c102 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -7460,6 +7460,7 @@ MobileViTV2PreTrainedModel, ) from .models.modernbert import ( + ModernBertForMaskedLM, ModernBertForSequenceClassification, ModernBertForTokenClassification, ModernBertModel, diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index 7276ef4aede93e..35a3063c44aba0 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -33,7 +33,8 @@ from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_outputs import BaseModelOutput, MaskedLMOutput, SequenceClassifierOutput, TokenClassifierOutput from ...modeling_utils import PreTrainedModel -from ...utils import is_flash_attn_2_available +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer, prune_qkv_linear_layer +from ...utils import is_flash_attn_2_available, logging from .configuration_modernbert import ModernBertConfig @@ -44,6 +45,8 @@ else: RotaryEmbedding = None +logger = logging.get_logger(__name__) + class ModernBertModuleType(str, Enum): in_module = "in" @@ -402,8 +405,9 @@ def eager_attention_forward( bs: int, seqlen: int, dim: int, + output_attentions: Optional[bool] = False, **_kwargs, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor] | Tuple[torch.Tensor]: # qkv: [batch_size, seqlen, 3, nheads, headdim] cos, sin = self.rotary_emb(qkv, position_ids=position_ids) query, key, value = qkv.transpose(3, 1).unbind(dim=2) @@ -424,7 +428,9 @@ def eager_attention_forward( attn_output = torch.matmul(attn_weights, value) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bs, seqlen, dim) - return attn_output + if output_attentions: + return (attn_output, attn_weights) + return (attn_output,) def flash_attention_forward( @@ -438,7 +444,7 @@ def flash_attention_forward( dim: int, target_dtype: torch.dtype = torch.bfloat16, **_kwargs, -) -> torch.Tensor: +) -> Tuple[torch.Tensor]: # (total_seqlen, 3, nheads, headdim) qkv = rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) @@ -467,7 +473,7 @@ def flash_attention_forward( deterministic=self.deterministic_flash_attn, window_size=local_attention, ) - return attn.view(bs, dim) + return (attn.view(bs, dim),) # def flex_attention_forward( @@ -506,7 +512,7 @@ def sdpa_attention_forward( seqlen: int, dim: int, **_kwargs, -) -> Tuple[torch.Tensor, None]: +) -> Tuple[torch.Tensor]: # qkv: [batch_size, seqlen, 3, nheads, headdim] cos, sin = self.rotary_emb(qkv, position_ids=position_ids) query, key, value = qkv.transpose(3, 1).unbind(dim=2) @@ -524,7 +530,7 @@ def sdpa_attention_forward( attn_mask=attention_mask, ).transpose(1, 2) attn_output = attn_output.view(bs, seqlen, dim) - return attn_output + return (attn_output,) MODERNBERT_ATTENTION_FUNCTION = { @@ -559,7 +565,8 @@ def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): self.deterministic_flash_attn = config.deterministic_flash_attn self.num_heads = config.num_attention_heads self.head_dim = config.hidden_size // config.num_attention_heads - self.Wqkv = nn.Linear(config.hidden_size, 3 * self.head_dim * self.num_heads, bias=config.attention_bias) + self.all_head_size = self.head_dim * self.num_heads + self.Wqkv = nn.Linear(config.hidden_size, 3 * self.all_head_size, bias=config.attention_bias) if layer_id % config.global_attn_every_n_layers != 0: self.local_attention = (config.local_attention // 2, config.local_attention // 2) @@ -584,6 +591,21 @@ def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias) self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity() + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads) + + # Prune linear layers + self.Wqkv = prune_qkv_linear_layer(self.Wqkv, index) + self.Wo = prune_linear_layer(self.Wo, index, dim=1) + + # Update hyper params and store pruned heads + self.num_heads = self.num_heads - len(heads) + self.all_head_size = self.head_dim * self.num_heads + self.pruned_heads = self.pruned_heads.union(heads) def _init_weights(self, reset_params: bool = False): _init_modernbert_weights(self.config, self.Wqkv, module_type=ModernBertModuleType.in_module) @@ -596,6 +618,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None, + output_attentions: Optional[bool] = False, **kwargs, ) -> torch.Tensor: """Perform self-attention. @@ -619,9 +642,12 @@ def forward( """ qkv = self.Wqkv(hidden_states) - attn_kwargs = {} + attn_kwargs = { + "output_attentions": output_attentions, + "dim": self.all_head_size, + } if self.config._attn_implementation == "flash_attention_2": - bs, dim = hidden_states.shape[:2] + bs = hidden_states.shape[0] qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) attn_kwargs.update( @@ -630,11 +656,10 @@ def forward( "cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen, "bs": bs, - "dim": dim, } ) else: - bs, seqlen, dim = hidden_states.shape + bs, seqlen = hidden_states.shape[:2] qkv = qkv.view(bs, seqlen, 3, self.num_heads, self.head_dim) attn_kwargs.update( @@ -643,18 +668,32 @@ def forward( "attention_mask": attention_mask, "bs": bs, "seqlen": seqlen, - "dim": dim, } ) - attn = MODERNBERT_ATTENTION_FUNCTION[self.config._attn_implementation]( + if output_attentions: + if self.config._attn_implementation == "sdpa": + logger.warning_once( + "Outputting attentions is only supported with the 'eager' attention implementation, " + 'not with "sdpa". Falling back to `attn_implementation="eager"`.' + ) + self.config._attn_implementation = "eager" + elif self.config._attn_implementation != "eager": + logger.warning_once( + "Outputting attentions is only supported with the eager attention implementation, " + f'not with {self.config._attn_implementation}. Consider setting `attn_implementation="eager"`.' + " Setting `output_attentions=False`." + ) + + attn_outputs = MODERNBERT_ATTENTION_FUNCTION[self.config._attn_implementation]( self, qkv=qkv, rotary_emb=self.rotary_emb, **attn_kwargs, ) + hidden_states = attn_outputs[0] - return self.out_drop(self.Wo(attn)) + return (self.out_drop(self.Wo(hidden_states)),) + attn_outputs[1:] # add attentions if outputted class ModernBertEncoderLayer(nn.Module): @@ -685,6 +724,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None, + output_attentions: Optional[bool] = False, ) -> torch.Tensor: """Forward pass for a ModernBert layer, including both attention and MLP. @@ -695,14 +735,18 @@ def forward( cu_seqlens: (batch + 1,) max_seqlen: int """ - attn_out = hidden_states + self.attn( + attn_outputs = self.attn( self.attn_norm(hidden_states), position_ids=position_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, attention_mask=attention_mask, + output_attentions=output_attentions, ) - return attn_out + self.compiled_mlp(attn_out) + hidden_states = hidden_states + attn_outputs[0] + hidden_states = hidden_states + self.compiled_mlp(hidden_states) + + return (hidden_states,) + attn_outputs[1:] # add attentions if outputted class ModernBertPredictionHead(nn.Module): @@ -868,7 +912,8 @@ def _init_weights( module: Union[ModernBertEncoderLayer, ModernBertAttention, ModernBertMLP, ModernBertEmbeddings], reset_params: bool = False, ): - module._init_weights(reset_params) + if isinstance(module, (ModernBertEncoderLayer, ModernBertAttention, ModernBertMLP, ModernBertEmbeddings)): + module._init_weights(reset_params=reset_params) @torch.no_grad() def _unpad_inputs_no_grad( @@ -944,6 +989,14 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.embeddings.tok_embeddings = value + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.layers[layer].attn.prune_heads(heads) + def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None): if module and hasattr(module, "_init_weights"): super()._init_weights(module, reset_params=reset_params) @@ -958,15 +1011,24 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - return_dict: Optional[bool] = None, indices: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None, batch_size: Optional[int] = None, seq_len: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + if batch_size is None and seq_len is None: batch_size, seq_len = input_ids.shape[:2] @@ -994,23 +1056,34 @@ def forward( attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) for encoder_layer in self.layers: + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + if self.gradient_checkpointing and self.training: - hidden_states = self._gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, hidden_states, attention_mask, position_ids, cu_seqlens, max_seqlen, + output_attentions, ) else: - hidden_states = encoder_layer( + layer_outputs = encoder_layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, + output_attentions=output_attentions, ) + hidden_states = layer_outputs[0] + if output_attentions and len(layer_outputs) > 1: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) hidden_states = self.final_norm(hidden_states) @@ -1018,11 +1091,17 @@ def forward( hidden_states, _ = self._pad_outputs(hidden_states, indices, batch_size, seq_len) if not return_dict: - return hidden_states - return BaseModelOutput(last_hidden_state=hidden_states) + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) class ModernBertForMaskedLM(ModernBertPreTrainedModel): + _tied_weights_keys = ["decoder.weight"] + def __init__( self, config: ModernBertConfig, @@ -1079,12 +1158,14 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, - return_dict: Optional[bool] = None, indices: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None, batch_size: Optional[int] = None, seq_len: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, **kwargs, ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -1101,7 +1182,7 @@ def forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels ) - output = self.model( + outputs = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1110,21 +1191,23 @@ def forward( max_seqlen=max_seqlen, batch_size=batch_size, seq_len=seq_len, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, return_dict=return_dict, ) - output = output[0] + last_hidden_state = outputs[0] if self.sparse_prediction and labels is not None: # flatten labels and output first labels = labels.view(-1) - output = output.view(labels.shape[0], -1) + last_hidden_state = last_hidden_state.view(labels.shape[0], -1) # then filter out the non-masked tokens mask_tokens = labels != self.sparse_pred_ignore_index - output = output[mask_tokens] + last_hidden_state = last_hidden_state[mask_tokens] labels = labels[mask_tokens] - logits = self.compiled_head(output) + logits = self.compiled_head(last_hidden_state) loss = None if labels is not None: @@ -1135,11 +1218,15 @@ def forward( logits, _ = self._pad_outputs_no_grad(logits, indices, batch_size, seq_len) else: logits, _ = self._pad_outputs(logits, indices, batch_size, seq_len) + if not return_dict: + output = (logits,) + return ((loss,) + output) if loss is not None else output + return MaskedLMOutput( loss=loss, logits=logits, - hidden_states=None, - attentions=None, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, ) @@ -1164,7 +1251,7 @@ def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Option assert isinstance(reset_params, bool) self.model._init_weights(reset_params=reset_params) self.head._init_weights(reset_params=reset_params) - _init_modernbert_weights(self.config, self.classifier, type_of_module=ModernBertModuleType.final_out) + _init_modernbert_weights(self.config, self.classifier, module_type=ModernBertModuleType.final_out) def forward( self, @@ -1172,12 +1259,14 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, - return_dict: Optional[bool] = None, indices: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None, batch_size: Optional[int] = None, seq_len: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, **kwargs, ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: r""" @@ -1197,6 +1286,8 @@ def forward( max_seqlen=max_seqlen, batch_size=batch_size, seq_len=seq_len, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, return_dict=return_dict, ) last_hidden_state = outputs[0] @@ -1234,8 +1325,8 @@ def forward( return SequenceClassifierOutput( loss=loss, logits=logits, - hidden_states=None, - attentions=None, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, ) @@ -1258,17 +1349,19 @@ def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Option else: assert isinstance(reset_params, bool) self.model._init_weights(reset_params=reset_params) - _init_modernbert_weights(self.config, self.classifier, type_of_module=ModernBertModuleType.final_out) + _init_modernbert_weights(self.config, self.classifier, module_type=ModernBertModuleType.final_out) def forward( self, - input_ids: Optional[torch.Tensor] = None, + input_ids: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, + indices: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + batch_size: Optional[int] = None, + seq_len: Optional[int] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, @@ -1282,10 +1375,12 @@ def forward( outputs = self.model( input_ids, attention_mask=attention_mask, - token_type_ids=token_type_ids, position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + batch_size=batch_size, + seq_len=seq_len, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, @@ -1307,6 +1402,6 @@ def forward( return TokenClassifierOutput( loss=loss, logits=logits, - hidden_states=None, - attentions=None, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, ) diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index 36991801b15907..368281cedc82a7 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -41,6 +41,7 @@ is_torch_greater_or_equal, logging, ) +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer, prune_qkv_linear_layer from ..gemma.modeling_gemma import GemmaRotaryEmbedding, apply_rotary_pos_emb @@ -641,8 +642,9 @@ def eager_attention_forward( bs: int, seqlen: int, dim: int, + output_attentions: Optional[bool] = False, **_kwargs, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor] | Tuple[torch.Tensor]: # qkv: [batch_size, seqlen, 3, nheads, headdim] cos, sin = self.rotary_emb(qkv, position_ids=position_ids) query, key, value = qkv.transpose(3, 1).unbind(dim=2) @@ -663,7 +665,9 @@ def eager_attention_forward( attn_output = torch.matmul(attn_weights, value) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bs, seqlen, dim) - return attn_output + if output_attentions: + return (attn_output, attn_weights) + return (attn_output,) def flash_attention_forward( @@ -677,7 +681,7 @@ def flash_attention_forward( dim: int, target_dtype: torch.dtype = torch.bfloat16, **_kwargs, -) -> torch.Tensor: +) -> Tuple[torch.Tensor]: # (total_seqlen, 3, nheads, headdim) qkv = rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) @@ -706,7 +710,7 @@ def flash_attention_forward( deterministic=self.deterministic_flash_attn, window_size=local_attention, ) - return attn.view(bs, dim) + return (attn.view(bs, dim),) # def flex_attention_forward( @@ -745,7 +749,7 @@ def sdpa_attention_forward( seqlen: int, dim: int, **_kwargs, -) -> Tuple[torch.Tensor, None]: +) -> Tuple[torch.Tensor]: # qkv: [batch_size, seqlen, 3, nheads, headdim] cos, sin = self.rotary_emb(qkv, position_ids=position_ids) query, key, value = qkv.transpose(3, 1).unbind(dim=2) @@ -763,7 +767,7 @@ def sdpa_attention_forward( attn_mask=attention_mask, ).transpose(1, 2) attn_output = attn_output.view(bs, seqlen, dim) - return attn_output + return (attn_output,) MODERNBERT_ATTENTION_FUNCTION = { @@ -798,7 +802,8 @@ def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): self.deterministic_flash_attn = config.deterministic_flash_attn self.num_heads = config.num_attention_heads self.head_dim = config.hidden_size // config.num_attention_heads - self.Wqkv = nn.Linear(config.hidden_size, 3 * self.head_dim * self.num_heads, bias=config.attention_bias) + self.all_head_size = self.head_dim * self.num_heads + self.Wqkv = nn.Linear(config.hidden_size, 3 * self.all_head_size, bias=config.attention_bias) if layer_id % config.global_attn_every_n_layers != 0: self.local_attention = (config.local_attention // 2, config.local_attention // 2) @@ -823,6 +828,23 @@ def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias) self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity() + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.num_heads, self.head_dim, self.pruned_heads + ) + + # Prune linear layers + self.Wqkv = prune_qkv_linear_layer(self.Wqkv, index) + self.Wo = prune_linear_layer(self.Wo, index, dim=1) + + # Update hyper params and store pruned heads + self.num_heads = self.num_heads - len(heads) + self.all_head_size = self.head_dim * self.num_heads + self.pruned_heads = self.pruned_heads.union(heads) def _init_weights(self, reset_params: bool = False): _init_modernbert_weights(self.config, self.Wqkv, module_type=ModernBertModuleType.in_module) @@ -835,6 +857,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None, + output_attentions: Optional[bool] = False, **kwargs, ) -> torch.Tensor: """Perform self-attention. @@ -858,38 +881,58 @@ def forward( """ qkv = self.Wqkv(hidden_states) - attn_kwargs = {} + attn_kwargs = { + "output_attentions": output_attentions, + "dim": self.all_head_size, + } if self.config._attn_implementation == "flash_attention_2": - bs, dim = hidden_states.shape[:2] + bs = hidden_states.shape[0] qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) - attn_kwargs.update({ - "local_attention": self.local_attention, - "cu_seqlens": cu_seqlens, - "max_seqlen": max_seqlen, - "bs": bs, - "dim": dim, - }) + attn_kwargs.update( + { + "local_attention": self.local_attention, + "cu_seqlens": cu_seqlens, + "max_seqlen": max_seqlen, + "bs": bs, + } + ) else: - bs, seqlen, dim = hidden_states.shape + bs, seqlen = hidden_states.shape[:2] qkv = qkv.view(bs, seqlen, 3, self.num_heads, self.head_dim) - attn_kwargs.update({ - "position_ids": position_ids, - "attention_mask": attention_mask, - "bs": bs, - "seqlen": seqlen, - "dim": dim, - }) + attn_kwargs.update( + { + "position_ids": position_ids, + "attention_mask": attention_mask, + "bs": bs, + "seqlen": seqlen, + } + ) + + if output_attentions: + if self.config._attn_implementation == "sdpa": + logger.warning_once( + "Outputting attentions is only supported with the 'eager' attention implementation, " + 'not with "sdpa". Falling back to `attn_implementation="eager"`.' + ) + self.config._attn_implementation = "eager" + elif self.config._attn_implementation != "eager": + logger.warning_once( + "Outputting attentions is only supported with the eager attention implementation, " + f'not with {self.config._attn_implementation}. Consider setting `attn_implementation="eager"`.' + " Setting `output_attentions=False`." + ) - attn = MODERNBERT_ATTENTION_FUNCTION[self.config._attn_implementation]( + attn_outputs = MODERNBERT_ATTENTION_FUNCTION[self.config._attn_implementation]( self, qkv=qkv, rotary_emb=self.rotary_emb, **attn_kwargs, ) + hidden_states = attn_outputs[0] - return self.out_drop(self.Wo(attn)) + return (self.out_drop(self.Wo(hidden_states)),) + attn_outputs[1:] # add attentions if outputted class ModernBertEncoderLayer(nn.Module): @@ -920,6 +963,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None, + output_attentions: Optional[bool] = False, ) -> torch.Tensor: """Forward pass for a ModernBert layer, including both attention and MLP. @@ -930,14 +974,18 @@ def forward( cu_seqlens: (batch + 1,) max_seqlen: int """ - attn_out = hidden_states + self.attn( + attn_outputs = self.attn( self.attn_norm(hidden_states), position_ids=position_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, attention_mask=attention_mask, + output_attentions=output_attentions, ) - return attn_out + self.compiled_mlp(attn_out) + hidden_states = hidden_states + attn_outputs[0] + hidden_states = hidden_states + self.compiled_mlp(hidden_states) + + return (hidden_states,) + attn_outputs[1:] # add attentions if outputted class ModernBertPredictionHead(nn.Module): @@ -1013,7 +1061,8 @@ def _init_weights( module: Union[ModernBertEncoderLayer, ModernBertAttention, ModernBertMLP, ModernBertEmbeddings], reset_params: bool = False, ): - module._init_weights(reset_params) + if isinstance(module, (ModernBertEncoderLayer, ModernBertAttention, ModernBertMLP, ModernBertEmbeddings)): + module._init_weights(reset_params=reset_params) @torch.no_grad() def _unpad_inputs_no_grad( @@ -1089,6 +1138,14 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.embeddings.tok_embeddings = value + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.layers[layer].attn.prune_heads(heads) + def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None): if module and hasattr(module, "_init_weights"): super()._init_weights(module, reset_params=reset_params) @@ -1103,15 +1160,24 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - return_dict: Optional[bool] = None, indices: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None, batch_size: Optional[int] = None, seq_len: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + if batch_size is None and seq_len is None: batch_size, seq_len = input_ids.shape[:2] @@ -1139,23 +1205,34 @@ def forward( attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) for encoder_layer in self.layers: + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + if self.gradient_checkpointing and self.training: - hidden_states = self._gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, hidden_states, attention_mask, position_ids, cu_seqlens, max_seqlen, + output_attentions, ) else: - hidden_states = encoder_layer( + layer_outputs = encoder_layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, + output_attentions=output_attentions, ) + hidden_states = layer_outputs[0] + if output_attentions and len(layer_outputs) > 1: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) hidden_states = self.final_norm(hidden_states) @@ -1163,11 +1240,17 @@ def forward( hidden_states, _ = self._pad_outputs(hidden_states, indices, batch_size, seq_len) if not return_dict: - return hidden_states - return BaseModelOutput(last_hidden_state=hidden_states) + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) class ModernBertForMaskedLM(ModernBertPreTrainedModel): + _tied_weights_keys = ["decoder.weight"] + def __init__( self, config: ModernBertConfig, @@ -1224,12 +1307,14 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, - return_dict: Optional[bool] = None, indices: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None, batch_size: Optional[int] = None, seq_len: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, **kwargs, ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -1246,7 +1331,7 @@ def forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels ) - output = self.model( + outputs = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1255,21 +1340,23 @@ def forward( max_seqlen=max_seqlen, batch_size=batch_size, seq_len=seq_len, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, return_dict=return_dict, ) - output = output[0] + last_hidden_state = outputs[0] if self.sparse_prediction and labels is not None: # flatten labels and output first labels = labels.view(-1) - output = output.view(labels.shape[0], -1) + last_hidden_state = last_hidden_state.view(labels.shape[0], -1) # then filter out the non-masked tokens mask_tokens = labels != self.sparse_pred_ignore_index - output = output[mask_tokens] + last_hidden_state = last_hidden_state[mask_tokens] labels = labels[mask_tokens] - logits = self.compiled_head(output) + logits = self.compiled_head(last_hidden_state) loss = None if labels is not None: @@ -1280,11 +1367,15 @@ def forward( logits, _ = self._pad_outputs_no_grad(logits, indices, batch_size, seq_len) else: logits, _ = self._pad_outputs(logits, indices, batch_size, seq_len) + if not return_dict: + output = (logits,) + return ((loss,) + output) if loss is not None else output + return MaskedLMOutput( loss=loss, logits=logits, - hidden_states=None, - attentions=None, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, ) @@ -1309,7 +1400,7 @@ def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Option assert isinstance(reset_params, bool) self.model._init_weights(reset_params=reset_params) self.head._init_weights(reset_params=reset_params) - _init_modernbert_weights(self.config, self.classifier, type_of_module=ModernBertModuleType.final_out) + _init_modernbert_weights(self.config, self.classifier, module_type=ModernBertModuleType.final_out) def forward( self, @@ -1317,12 +1408,14 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, - return_dict: Optional[bool] = None, indices: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None, batch_size: Optional[int] = None, seq_len: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, **kwargs, ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: r""" @@ -1342,6 +1435,8 @@ def forward( max_seqlen=max_seqlen, batch_size=batch_size, seq_len=seq_len, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, return_dict=return_dict, ) last_hidden_state = outputs[0] @@ -1379,8 +1474,8 @@ def forward( return SequenceClassifierOutput( loss=loss, logits=logits, - hidden_states=None, - attentions=None, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, ) @@ -1403,17 +1498,19 @@ def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Option else: assert isinstance(reset_params, bool) self.model._init_weights(reset_params=reset_params) - _init_modernbert_weights(self.config, self.classifier, type_of_module=ModernBertModuleType.final_out) + _init_modernbert_weights(self.config, self.classifier, module_type=ModernBertModuleType.final_out) def forward( self, - input_ids: Optional[torch.Tensor] = None, + input_ids: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, + indices: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + batch_size: Optional[int] = None, + seq_len: Optional[int] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, @@ -1427,10 +1524,12 @@ def forward( outputs = self.model( input_ids, attention_mask=attention_mask, - token_type_ids=token_type_ids, position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + batch_size=batch_size, + seq_len=seq_len, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, @@ -1452,6 +1551,6 @@ def forward( return TokenClassifierOutput( loss=loss, logits=logits, - hidden_states=None, - attentions=None, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, ) diff --git a/src/transformers/pytorch_utils.py b/src/transformers/pytorch_utils.py index 5bdf8a355ddfaa..587b91feabe414 100644 --- a/src/transformers/pytorch_utils.py +++ b/src/transformers/pytorch_utils.py @@ -92,6 +92,46 @@ def prune_linear_layer(layer: nn.Linear, index: torch.LongTensor, dim: int = 0) return new_layer +def prune_qkv_linear_layer(layer: nn.Linear, index: torch.LongTensor, dim: int = 0) -> nn.Linear: + """ + Prune a QKV linear layer to keep only entries in index. + + Used to remove heads. + + Args: + layer (`torch.nn.Linear`): The layer to prune. + index (`torch.LongTensor`): The indices to keep in the layer. + dim (`int`, *optional*, defaults to 0): The dimension on which to keep the indices. + + Returns: + `torch.nn.Linear`: The pruned QKV layer as a new layer with `requires_grad=True`. + """ + assert layer.out_features % 3 == 0, "The output features of the linear layer should be divisible by 3" + index = torch.cat([ + index, + index + layer.out_features // 3, + index + 2 * layer.out_features // 3, + ]) + index = index.to(layer.weight.device) + W = layer.weight.index_select(dim, index).clone().detach() + if layer.bias is not None: + if dim == 1: + b = layer.bias.clone().detach() + else: + b = layer.bias[index].clone().detach() + new_size = list(layer.weight.size()) + new_size[dim] = len(index) + new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device) + new_layer.weight.requires_grad = False + new_layer.weight.copy_(W.contiguous()) + new_layer.weight.requires_grad = True + if layer.bias is not None: + new_layer.bias.requires_grad = False + new_layer.bias.copy_(b.contiguous()) + new_layer.bias.requires_grad = True + return new_layer + + class Conv1D(nn.Module): """ 1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2). diff --git a/tests/models/modernbert/test_modeling_modernbert.py b/tests/models/modernbert/test_modeling_modernbert.py new file mode 100644 index 00000000000000..2efb465a348f1c --- /dev/null +++ b/tests/models/modernbert/test_modeling_modernbert.py @@ -0,0 +1,320 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import tempfile +import unittest + +from packaging import version + +from transformers import AutoTokenizer, ModernBertConfig, is_torch_available +from transformers.models.auto import get_values +from transformers.testing_utils import ( + CaptureLogger, + require_torch, + require_torch_accelerator, + slow, + torch_device, +) + +from ...generation.test_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_torch_available(): + import torch + + from transformers import ( + MODEL_FOR_PRETRAINING_MAPPING, + ModernBertForMaskedLM, + ModernBertForSequenceClassification, + ModernBertForTokenClassification, + ModernBertModel, + logging, + ) + + +class ModernBertModelTester: + def __init__( + self, + parent, + batch_size=13, + seq_length=7, + is_training=True, + use_input_mask=True, + use_labels=True, + vocab_size=99, + pad_token_id=0, + hidden_size=32, + num_hidden_layers=2, + num_attention_heads=4, + intermediate_size=37, + hidden_activation="gelu", + mlp_dropout=0.0, + attention_dropout=0.0, + embedding_dropout=0.0, + classifier_dropout=0.0, + max_position_embeddings=512, + type_vocab_size=16, + type_sequence_label_size=2, + initializer_range=0.02, + num_labels=3, + num_choices=4, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_input_mask = use_input_mask + self.use_labels = use_labels + self.vocab_size = vocab_size + self.pad_token_id = pad_token_id + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_activation = hidden_activation + self.mlp_dropout = mlp_dropout + self.attention_dropout = attention_dropout + self.embedding_dropout = embedding_dropout + self.classifier_dropout = classifier_dropout + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + self.num_labels = num_labels + self.num_choices = num_choices + self.scope = scope + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + input_mask = None + if self.use_input_mask: + input_mask = random_attention_mask([self.batch_size, self.seq_length]) + + sequence_labels = None + token_labels = None + choice_labels = None + if self.use_labels: + sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) + choice_labels = ids_tensor([self.batch_size], self.num_choices) + + config = self.get_config() + + return config, input_ids, input_mask, sequence_labels, token_labels, choice_labels + + def get_config(self): + """ + Returns a tiny configuration by default. + """ + return ModernBertConfig( + vocab_size=self.vocab_size, + pad_token_id=self.pad_token_id, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + hidden_activation=self.hidden_activation, + mlp_dropout=self.mlp_dropout, + attention_dropout=self.attention_dropout, + embedding_dropout=self.embedding_dropout, + classifier_dropout=self.classifier_dropout, + max_position_embeddings=self.max_position_embeddings, + type_vocab_size=self.type_vocab_size, + is_decoder=False, + initializer_range=self.initializer_range, + ) + + def create_and_check_model( + self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = ModernBertModel(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask) + result = model(input_ids) + result = model(input_ids) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size)) + + def create_and_check_for_masked_lm( + self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = ModernBertForMaskedLM(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask, labels=token_labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + + def create_and_check_for_sequence_classification( + self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + config.num_labels = self.num_labels + model = ModernBertForSequenceClassification(config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask, labels=sequence_labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels)) + + def create_and_check_for_token_classification( + self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + config.num_labels = self.num_labels + model = ModernBertForTokenClassification(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask, labels=token_labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = config_and_inputs + inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} + return config, inputs_dict + + +@require_torch +class ModernBertModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = ( + ( + ModernBertModel, + ModernBertForMaskedLM, + ModernBertForSequenceClassification, + ModernBertForTokenClassification, + ) + if is_torch_available() + else () + ) + all_generative_model_classes = () + pipeline_model_mapping = ( + { + "feature-extraction": ModernBertModel, + "fill-mask": ModernBertForMaskedLM, + "text-classification": ModernBertForSequenceClassification, + "token-classification": ModernBertForTokenClassification, + "zero-shot": ModernBertForSequenceClassification, + } + if is_torch_available() + else {} + ) + fx_compatible = True + model_split_percents = [0.5, 0.8, 0.9] + + # special case for ForPreTraining model + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) + + if inputs_dict.get("output_attentions", False): + inputs_dict["output_attentions"] = True + + if return_labels: + if model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING): + inputs_dict["labels"] = torch.zeros( + (self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device + ) + inputs_dict["next_sentence_label"] = torch.zeros( + self.model_tester.batch_size, dtype=torch.long, device=torch_device + ) + return inputs_dict + + def setUp(self): + self.model_tester = ModernBertModelTester(self) + self.config_tester = ConfigTester(self, config_class=ModernBertConfig, hidden_size=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_model_various_embeddings(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + for type in ["absolute", "relative_key", "relative_key_query"]: + config_and_inputs[0].position_embedding_type = type + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_model_3d_mask_shapes(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + # manipulate input_mask + config_and_inputs = list(config_and_inputs) + batch_size, seq_length = config_and_inputs[3].shape + config_and_inputs[3] = random_attention_mask([batch_size, seq_length, seq_length]) + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_for_masked_lm(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_masked_lm(*config_and_inputs) + + def test_for_pretraining(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_pretraining(*config_and_inputs) + + def test_for_sequence_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs) + + def test_for_token_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_token_classification(*config_and_inputs) + + def test_for_warning_if_padding_and_no_attention_mask(self): + ( + config, + input_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = self.model_tester.prepare_config_and_inputs() + + # Set pad tokens in the input_ids + input_ids[0, 0] = config.pad_token_id + + # Check for warnings if the attention_mask is missing. + logger = logging.get_logger("transformers.modeling_utils") + # clear cache so we can test the warning is emitted (from `warning_once`). + logger.warning_once.cache_clear() + + with CaptureLogger(logger) as cl: + model = ModernBertModel(config=config) + model.to(torch_device) + model.eval() + model(input_ids, attention_mask=None) + self.assertIn("We strongly recommend passing in an `attention_mask`", cl.out) + + @slow + def test_model_from_pretrained(self): + model_name = "google-bert/bert-base-uncased" + model = ModernBertModel.from_pretrained(model_name) + self.assertIsNotNone(model) + + +@require_torch +class ModernBertModelIntegrationTest(unittest.TestCase): + """ + These still need to be written, once public models are available. + """ \ No newline at end of file From 38f959bf096bd05e1174ed4495daac004bf115fa Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Fri, 13 Dec 2024 13:27:11 +0100 Subject: [PATCH 24/88] Remove kwargs from ModernBertForMaskedLM Disable sparse_prediction by default to match the normal HF, can be enabled via config --- .../modernbert/configuration_modernbert.py | 4 ++++ .../models/modernbert/modeling_modernbert.py | 17 ++++----------- .../models/modernbert/modular_modernbert.py | 21 +++++++------------ 3 files changed, 16 insertions(+), 26 deletions(-) diff --git a/src/transformers/models/modernbert/configuration_modernbert.py b/src/transformers/models/modernbert/configuration_modernbert.py index da08d860167b61..ccc47280453e79 100644 --- a/src/transformers/models/modernbert/configuration_modernbert.py +++ b/src/transformers/models/modernbert/configuration_modernbert.py @@ -134,6 +134,8 @@ def __init__( classifier_bias=False, classifier_activation="gelu", deterministic_flash_attn=False, + sparse_prediction=False, + sparse_pred_ignore_index=-100, **kwargs, ): super().__init__( @@ -176,6 +178,8 @@ def __init__( self.classifier_norm = classifier_norm self.classifier_activation = classifier_activation self.deterministic_flash_attn = deterministic_flash_attn + self.sparse_prediction = sparse_prediction + self.sparse_pred_ignore_index = sparse_pred_ignore_index if unpad_inputs is None: self.unpad_inputs = self._attn_implementation == "flash_attention_2" diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index 35a3063c44aba0..1b4a8fe3558368 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -1102,14 +1102,7 @@ def forward( class ModernBertForMaskedLM(ModernBertPreTrainedModel): _tied_weights_keys = ["decoder.weight"] - def __init__( - self, - config: ModernBertConfig, - pad_logits: bool = True, - pad_logits_no_grad: Optional[bool] = None, - sparse_prediction: bool = True, - sparse_pred_ignore_index: int = -100, - ): + def __init__(self, config: ModernBertConfig): super().__init__(config) self.config = config self.model = ModernBertModel(config) @@ -1122,10 +1115,8 @@ def __init__( self.decoder = nn.Linear(decoder_weights.size(1), decoder_weights.size(0), bias=config.decoder_bias) self.decoder.weight = decoder_weights - self.pad_logits = pad_logits - self.pad_logits_no_grad = pad_logits_no_grad if pad_logits_no_grad is not None else self.config.unpad_no_grad - self.sparse_prediction = sparse_prediction - self.sparse_pred_ignore_index = sparse_pred_ignore_index + self.sparse_prediction = self.config.sparse_prediction + self.sparse_pred_ignore_index = self.config.sparse_pred_ignore_index # Initialize weights and apply final processing self._init_weights(reset_params=False) @@ -1214,7 +1205,7 @@ def forward( loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size) if self.config.unpad_inputs: - if self.pad_logits_no_grad: + if self.config.unpad_no_grad: logits, _ = self._pad_outputs_no_grad(logits, indices, batch_size, seq_len) else: logits, _ = self._pad_outputs(logits, indices, batch_size, seq_len) diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index 368281cedc82a7..e10fc9f12e7878 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -172,6 +172,8 @@ def __init__( classifier_bias=False, classifier_activation="gelu", deterministic_flash_attn=False, + sparse_prediction=False, + sparse_pred_ignore_index=-100, **kwargs, ): super().__init__( @@ -214,6 +216,8 @@ def __init__( self.classifier_norm = classifier_norm self.classifier_activation = classifier_activation self.deterministic_flash_attn = deterministic_flash_attn + self.sparse_prediction = sparse_prediction + self.sparse_pred_ignore_index = sparse_pred_ignore_index if unpad_inputs is None: self.unpad_inputs = self._attn_implementation == "flash_attention_2" @@ -1251,14 +1255,7 @@ def forward( class ModernBertForMaskedLM(ModernBertPreTrainedModel): _tied_weights_keys = ["decoder.weight"] - def __init__( - self, - config: ModernBertConfig, - pad_logits: bool = True, - pad_logits_no_grad: Optional[bool] = None, - sparse_prediction: bool = True, - sparse_pred_ignore_index: int = -100, - ): + def __init__(self, config: ModernBertConfig): super().__init__(config) self.config = config self.model = ModernBertModel(config) @@ -1271,10 +1268,8 @@ def __init__( self.decoder = nn.Linear(decoder_weights.size(1), decoder_weights.size(0), bias=config.decoder_bias) self.decoder.weight = decoder_weights - self.pad_logits = pad_logits - self.pad_logits_no_grad = pad_logits_no_grad if pad_logits_no_grad is not None else self.config.unpad_no_grad - self.sparse_prediction = sparse_prediction - self.sparse_pred_ignore_index = sparse_pred_ignore_index + self.sparse_prediction = self.config.sparse_prediction + self.sparse_pred_ignore_index = self.config.sparse_pred_ignore_index # Initialize weights and apply final processing self._init_weights(reset_params=False) @@ -1363,7 +1358,7 @@ def forward( loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size) if self.config.unpad_inputs: - if self.pad_logits_no_grad: + if self.config.unpad_no_grad: logits, _ = self._pad_outputs_no_grad(logits, indices, batch_size, seq_len) else: logits, _ = self._pad_outputs(logits, indices, batch_size, seq_len) From f7169434facc8e02c67bbb9aa066ba80eb845018 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Fri, 13 Dec 2024 14:50:11 +0100 Subject: [PATCH 25/88] Remove/adjust/skip improper tests; warn if padding but no attn mask --- .../models/modernbert/modeling_modernbert.py | 2 ++ .../models/modernbert/modular_modernbert.py | 2 ++ .../modernbert/test_modeling_modernbert.py | 16 ++++++++++------ 3 files changed, 14 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index 1b4a8fe3558368..038607dfd7ebbb 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -1029,6 +1029,8 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + if batch_size is None and seq_len is None: batch_size, seq_len = input_ids.shape[:2] diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index e10fc9f12e7878..1a8abeb6f60ca8 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -1182,6 +1182,8 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + if batch_size is None and seq_len is None: batch_size, seq_len = input_ids.shape[:2] diff --git a/tests/models/modernbert/test_modeling_modernbert.py b/tests/models/modernbert/test_modeling_modernbert.py index 2efb465a348f1c..a6fc1ce2d86dcd 100644 --- a/tests/models/modernbert/test_modeling_modernbert.py +++ b/tests/models/modernbert/test_modeling_modernbert.py @@ -151,7 +151,6 @@ def create_and_check_model( result = model(input_ids) result = model(input_ids) self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) - self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size)) def create_and_check_for_masked_lm( self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels @@ -220,7 +219,8 @@ class ModernBertModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste if is_torch_available() else {} ) - fx_compatible = True + fx_compatible = False + test_head_masking = False model_split_percents = [0.5, 0.8, 0.9] # special case for ForPreTraining model @@ -265,14 +265,14 @@ def test_model_3d_mask_shapes(self): config_and_inputs[3] = random_attention_mask([batch_size, seq_length, seq_length]) self.model_tester.create_and_check_model(*config_and_inputs) + @unittest.skip("ModernBert doesn't use `inputs_embeds` as input.") + def test_inputs_embeds(self): + pass + def test_for_masked_lm(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_masked_lm(*config_and_inputs) - def test_for_pretraining(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_for_pretraining(*config_and_inputs) - def test_for_sequence_classification(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs) @@ -306,6 +306,10 @@ def test_for_warning_if_padding_and_no_attention_mask(self): model(input_ids, attention_mask=None) self.assertIn("We strongly recommend passing in an `attention_mask`", cl.out) + @unittest.skip("ModernBert doesn't use separate classes for SDPA, but a function instead.") + def test_sdpa_can_dispatch_non_composite_models(self): + pass + @slow def test_model_from_pretrained(self): model_name = "google-bert/bert-base-uncased" From f41adaa6b71b3edae5b6f91cac1e287054d07ec9 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Fri, 13 Dec 2024 15:02:08 +0100 Subject: [PATCH 26/88] Run formatting etc. --- src/transformers/__init__.py | 1 - src/transformers/models/auto/modeling_auto.py | 1 - .../models/modernbert/modular_modernbert.py | 48 ++------------ src/transformers/pytorch_utils.py | 12 ++-- src/transformers/utils/dummy_pt_objects.py | 63 ++++++++++--------- .../modernbert/test_modeling_modernbert.py | 15 ++--- 6 files changed, 51 insertions(+), 89 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 4c5533bfa2c102..cf607b572d7e3c 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -2303,7 +2303,6 @@ _import_structure["models.modernbert"].extend( [ "ModernBertForMaskedLM", - "ModernBertForCausalLM", "ModernBertForSequenceClassification", "ModernBertForTokenClassification", "ModernBertModel", diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 0a773017e2bfd0..13b9983220cfbc 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -511,7 +511,6 @@ ("mistral", "MistralForCausalLM"), ("mixtral", "MixtralForCausalLM"), ("mllama", "MllamaForCausalLM"), - ("modernbert", "ModernBertForCausalLM"), ("moshi", "MoshiForCausalLM"), ("mpt", "MptForCausalLM"), ("musicgen", "MusicgenForCausalLM"), diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index 1a8abeb6f60ca8..66d05a76e0bc75 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -19,10 +19,10 @@ from typing import Optional, Tuple, Union import torch -from torch import nn +import torch.nn.functional as F import torch.utils.checkpoint +from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -import torch.nn.functional as F from ...activations import ACT2FN from ...configuration_utils import PretrainedConfig @@ -30,18 +30,16 @@ from ...modeling_outputs import ( BaseModelOutput, MaskedLMOutput, - MultipleChoiceModelOutput, SequenceClassifierOutput, TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer, prune_qkv_linear_layer from ...utils import ( is_flash_attn_2_available, - is_flash_attn_greater_or_equal, is_torch_greater_or_equal, logging, ) -from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer, prune_qkv_linear_layer from ..gemma.modeling_gemma import GemmaRotaryEmbedding, apply_rotary_pos_emb @@ -53,7 +51,7 @@ RotaryEmbedding = None if is_torch_greater_or_equal("2.5"): - from torch.nn.attention.flex_attention import flex_attention + pass _CHECKPOINT_FOR_DOC = "answerdotai/modernbert-base" @@ -604,40 +602,6 @@ class ModernBertRotaryEmbedding(GemmaRotaryEmbedding): pass -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - def eager_attention_forward( self: "ModernBertAttention", qkv: torch.Tensor, @@ -837,9 +801,7 @@ def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): def prune_heads(self, heads): if len(heads) == 0: return - heads, index = find_pruneable_heads_and_indices( - heads, self.num_heads, self.head_dim, self.pruned_heads - ) + heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads) # Prune linear layers self.Wqkv = prune_qkv_linear_layer(self.Wqkv, index) diff --git a/src/transformers/pytorch_utils.py b/src/transformers/pytorch_utils.py index 587b91feabe414..b2461fc1e73dba 100644 --- a/src/transformers/pytorch_utils.py +++ b/src/transformers/pytorch_utils.py @@ -107,11 +107,13 @@ def prune_qkv_linear_layer(layer: nn.Linear, index: torch.LongTensor, dim: int = `torch.nn.Linear`: The pruned QKV layer as a new layer with `requires_grad=True`. """ assert layer.out_features % 3 == 0, "The output features of the linear layer should be divisible by 3" - index = torch.cat([ - index, - index + layer.out_features // 3, - index + 2 * layer.out_features // 3, - ]) + index = torch.cat( + [ + index, + index + layer.out_features // 3, + index + 2 * layer.out_features // 3, + ] + ) index = index.to(layer.weight.device) W = layer.weight.index_select(dim, index).clone().detach() if layer.bias is not None: diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 6d11147391d97d..98a9911b43c343 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -4375,34 +4375,6 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) -class ModernBertForSequenceClassification(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - -class ModernBertForTokenClassification(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - -class ModernBertModel(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - -class ModernBertPreTrainedModel(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - class GitForCausalLM(metaclass=DummyObject): _backends = ["torch"] @@ -6317,6 +6289,41 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class ModernBertForMaskedLM(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class ModernBertForSequenceClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class ModernBertForTokenClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class ModernBertModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class ModernBertPreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class MoshiForCausalLM(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/models/modernbert/test_modeling_modernbert.py b/tests/models/modernbert/test_modeling_modernbert.py index a6fc1ce2d86dcd..db97595ae2fdfc 100644 --- a/tests/models/modernbert/test_modeling_modernbert.py +++ b/tests/models/modernbert/test_modeling_modernbert.py @@ -12,25 +12,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os -import tempfile import unittest -from packaging import version - -from transformers import AutoTokenizer, ModernBertConfig, is_torch_available +from transformers import ModernBertConfig, is_torch_available from transformers.models.auto import get_values from transformers.testing_utils import ( CaptureLogger, require_torch, - require_torch_accelerator, slow, torch_device, ) from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester -from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask +from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask from ...test_pipeline_mixin import PipelineTesterMixin @@ -141,9 +136,7 @@ def get_config(self): initializer_range=self.initializer_range, ) - def create_and_check_model( - self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels - ): + def create_and_check_model(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels): model = ModernBertModel(config=config) model.to(torch_device) model.eval() @@ -321,4 +314,4 @@ def test_model_from_pretrained(self): class ModernBertModelIntegrationTest(unittest.TestCase): """ These still need to be written, once public models are available. - """ \ No newline at end of file + """ From d06654a17d701803b4c08bb93989c257f4b3476e Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Sat, 14 Dec 2024 09:37:18 +0100 Subject: [PATCH 27/88] Run python utils/custom_init_isort.py --- src/transformers/__init__.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index cf607b572d7e3c..fc7d6de025dcba 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -450,7 +450,6 @@ "models.fuyu": ["FuyuConfig"], "models.gemma": ["GemmaConfig"], "models.gemma2": ["Gemma2Config"], - "models.modernbert": ["ModernBertConfig"], "models.git": [ "GitConfig", "GitProcessor", @@ -595,6 +594,7 @@ "models.mobilenet_v2": ["MobileNetV2Config"], "models.mobilevit": ["MobileViTConfig"], "models.mobilevitv2": ["MobileViTV2Config"], + "models.modernbert": ["ModernBertConfig"], "models.moshi": [ "MoshiConfig", "MoshiDepthConfig", @@ -2300,15 +2300,6 @@ "Gemma2PreTrainedModel", ] ) - _import_structure["models.modernbert"].extend( - [ - "ModernBertForMaskedLM", - "ModernBertForSequenceClassification", - "ModernBertForTokenClassification", - "ModernBertModel", - "ModernBertPreTrainedModel", - ] - ) _import_structure["models.git"].extend( [ "GitForCausalLM", @@ -2820,6 +2811,15 @@ "MobileViTV2PreTrainedModel", ] ) + _import_structure["models.modernbert"].extend( + [ + "ModernBertForMaskedLM", + "ModernBertForSequenceClassification", + "ModernBertForTokenClassification", + "ModernBertModel", + "ModernBertPreTrainedModel", + ] + ) _import_structure["models.moshi"].extend( [ "MoshiForCausalLM", From f9301f4cf9886aa757c93f01d8c76e1421406722 Mon Sep 17 00:00:00 2001 From: Said Taghadouini Date: Sun, 15 Dec 2024 22:45:31 +0000 Subject: [PATCH 28/88] FlexAttention with unpadded sequences(matches FA2 within bf16 numerics) --- .../modernbert/configuration_modernbert.py | 8 +- .../models/modernbert/modeling_modernbert.py | 130 +++++++++++++---- .../models/modernbert/modular_modernbert.py | 136 +++++++++++++----- 3 files changed, 209 insertions(+), 65 deletions(-) diff --git a/src/transformers/models/modernbert/configuration_modernbert.py b/src/transformers/models/modernbert/configuration_modernbert.py index ccc47280453e79..9243967e08d299 100644 --- a/src/transformers/models/modernbert/configuration_modernbert.py +++ b/src/transformers/models/modernbert/configuration_modernbert.py @@ -27,7 +27,7 @@ class ModernBertConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`ModernBertModel`]. It is used to instantiate an ModernBert model according to the specified arguments, defining the model architecture. Instantiating a configuration with the - defaults will yield a similar configuration to that of the ModernBert-7B. + defaults will yield a similar configuration to that of the ModernBert-base. e.g. [answerdotai/modernbert-base](https://huggingface.co/answerdotai/modernbert-base) Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. @@ -84,9 +84,9 @@ class ModernBertConfig(PretrainedConfig): ```python >>> from transformers import ModernBertModel, ModernBertConfig - >>> # Initializing a ModernBert modernbert-7b style configuration + >>> # Initializing a ModernBert modernbert-base style configuration >>> configuration = ModernBertConfig() - >>> # Initializing a model from the modernbert-7b style configuration + >>> # Initializing a model from the modernbert-base style configuration >>> model = ModernBertModel(configuration) >>> # Accessing the model configuration >>> configuration = model.config @@ -182,4 +182,4 @@ def __init__( self.sparse_pred_ignore_index = sparse_pred_ignore_index if unpad_inputs is None: - self.unpad_inputs = self._attn_implementation == "flash_attention_2" + self.unpad_inputs = self._attn_implementation in {"flash_attention_2", "flex_attention"} diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index 038607dfd7ebbb..2aca5ba7b8c627 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -34,7 +34,7 @@ from ...modeling_outputs import BaseModelOutput, MaskedLMOutput, SequenceClassifierOutput, TokenClassifierOutput from ...modeling_utils import PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer, prune_qkv_linear_layer -from ...utils import is_flash_attn_2_available, logging +from ...utils import is_flash_attn_2_available, is_torch_greater_or_equal, logging from .configuration_modernbert import ModernBertConfig @@ -45,6 +45,9 @@ else: RotaryEmbedding = None +if is_torch_greater_or_equal("2.5"): + from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention + logger = logging.get_logger(__name__) @@ -476,31 +479,36 @@ def flash_attention_forward( return (attn.view(bs, dim),) -# def flex_attention_forward( -# config: ModernBertConfig, -# query: torch.Tensor, -# key: torch.Tensor, -# value: torch.Tensor, -# mask: Optional[torch.Tensor], -# output_attentions: bool = False, -# **_kwargs, -# ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - -# attn_output = flex_attention( -# query, -# key, -# value, -# enable_gqa=True, -# scale=config.scaling, -# return_lse=output_attentions, -# ) -# if not output_attentions: -# attn_weights = None -# else: -# attn_output, attn_weights = attn_output - -# attn_output = attn_output.transpose(1, 2).contiguous() -# return attn_output, attn_weights +def flex_attention_forward( + self: "ModernBertAttention", + qkv: torch.Tensor, + rotary_emb: ModernBertUnpaddedRotaryEmbedding, + cu_seqlens: torch.Tensor, + block_mask: BlockMask, + max_seqlen: int, + bs: int, + dim: int, + **_kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + qkv = rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + query, key, value = qkv.unbind(dim=1) + + query = query.transpose(0, 1).unsqueeze(0) + key = key.transpose(0, 1).unsqueeze(0) + value = value.transpose(0, 1).unsqueeze(0) + + attn_output = flex_attention( + query, + key, + value, + block_mask=block_mask, + enable_gqa=False, + scale=None, + return_lse=False, + ) + + attn_output = attn_output.squeeze(0).transpose(0, 1).contiguous() + return (attn_output.view(bs, dim),) def sdpa_attention_forward( @@ -535,7 +543,7 @@ def sdpa_attention_forward( MODERNBERT_ATTENTION_FUNCTION = { "flash_attention_2": flash_attention_forward, - # "flex_attention": flex_attention_forward, + "flex_attention": flex_attention_forward, "eager": eager_attention_forward, "sdpa": sdpa_attention_forward, } @@ -580,7 +588,7 @@ def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): rope_theta = config.local_rope_theta max_position_embeddings = config.local_attention - if config._attn_implementation == "flash_attention_2": + if config._attn_implementation in {"flash_attention_2", "flex_attention"}: self.rotary_emb = ModernBertUnpaddedRotaryEmbedding( dim=self.head_dim, max_seqlen=max_position_embeddings, base=rope_theta ) @@ -617,6 +625,7 @@ def forward( position_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None, + block_mask: Optional[BlockMask] = None, max_seqlen: Optional[int] = None, output_attentions: Optional[bool] = False, **kwargs, @@ -658,6 +667,20 @@ def forward( "bs": bs, } ) + elif self.config._attn_implementation == "flex_attention": + bs, dim = hidden_states.shape[:2] + qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) + + attn_kwargs.update( + { + "local_attention": self.local_attention, + "block_mask": block_mask, + "cu_seqlens": cu_seqlens, + "max_seqlen": max_seqlen, + "bs": bs, + } + ) + else: bs, seqlen = hidden_states.shape[:2] qkv = qkv.view(bs, seqlen, 3, self.num_heads, self.head_dim) @@ -723,6 +746,7 @@ def forward( position_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None, + block_mask: Optional[BlockMask] = None, max_seqlen: Optional[int] = None, output_attentions: Optional[bool] = False, ) -> torch.Tensor: @@ -739,6 +763,7 @@ def forward( self.attn_norm(hidden_states), position_ids=position_ids, cu_seqlens=cu_seqlens, + block_mask=block_mask, max_seqlen=max_seqlen, attention_mask=attention_mask, output_attentions=output_attentions, @@ -906,6 +931,7 @@ class ModernBertPreTrainedModel(PreTrainedModel): _no_split_modules = ["ModernBertEmbeddings", "ModernBertEncoderLayer"] _supports_flash_attn_2 = True _supports_sdpa = True + _supports_flex_attn = True def _init_weights( self, @@ -970,6 +996,40 @@ def _pad_outputs( inputs=inputs, indices=indices, batch=batch_size, seqlen=seqlen, labels=labels, ignore_index=ignore_index ) + @classmethod + def offsets_to_sequence_ids_tensor(cls, offsets): + device = offsets.device + counts = offsets[1:] - offsets[:-1] + return torch.repeat_interleave(torch.arange(len(counts), device=device, dtype=torch.int32), counts) + + @torch.compile(dynamic=False) + def create_attention_mask(self, sequence_ids, cu_seqlens, window_size): + """ + Creates a block mask combining sequence masking and local/or global attention masking. + """ + + def sliding_window_seq_mask_mod(b, h, q_idx, kv_idx): + # only allow attention within the same sequence + same_seq = sequence_ids[q_idx] == sequence_ids[kv_idx] + + # get position within the sequence + q_pos = q_idx - cu_seqlens[sequence_ids[q_idx]] + kv_pos = kv_idx - cu_seqlens[sequence_ids[kv_idx]] + + # sliding window within each sequence + in_window = (q_pos - kv_pos).abs() <= window_size + + return same_seq & in_window + + block_mask = create_block_mask( + sliding_window_seq_mask_mod, + B=None, + H=None, + Q_LEN=cu_seqlens[-1], + KV_LEN=cu_seqlens[-1], + ) + return block_mask + class ModernBertModel(ModernBertPreTrainedModel): def __init__(self, config: ModernBertConfig): @@ -1057,6 +1117,19 @@ def forward( # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) + # create block mask + if self.config._attn_implementation == "flex_attention": + sequence_ids = self.offsets_to_sequence_ids_tensor(cu_seqlens) + + if self.config.local_attention != (-1, -1): + window_size = self.config.local_attention // 2 + else: + window_size = max_seqlen + + block_mask = self.create_attention_mask(sequence_ids, cu_seqlens, window_size) + else: + block_mask = None + for encoder_layer in self.layers: if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -1077,6 +1150,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, cu_seqlens=cu_seqlens, + block_mask=block_mask, max_seqlen=max_seqlen, output_attentions=output_attentions, ) diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index 66d05a76e0bc75..60694f32e0b4c9 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -51,8 +51,7 @@ RotaryEmbedding = None if is_torch_greater_or_equal("2.5"): - pass - + from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention _CHECKPOINT_FOR_DOC = "answerdotai/modernbert-base" @@ -63,7 +62,7 @@ class ModernBertConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`ModernBertModel`]. It is used to instantiate an ModernBert model according to the specified arguments, defining the model architecture. Instantiating a configuration with the - defaults will yield a similar configuration to that of the ModernBert-7B. + defaults will yield a similar configuration to that of the ModernBert-base. e.g. [answerdotai/modernbert-base](https://huggingface.co/answerdotai/modernbert-base) Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. @@ -120,9 +119,9 @@ class ModernBertConfig(PretrainedConfig): ```python >>> from transformers import ModernBertModel, ModernBertConfig - >>> # Initializing a ModernBert modernbert-7b style configuration + >>> # Initializing a ModernBert modernbert-base style configuration >>> configuration = ModernBertConfig() - >>> # Initializing a model from the modernbert-7b style configuration + >>> # Initializing a model from the modernbert-base style configuration >>> model = ModernBertModel(configuration) >>> # Accessing the model configuration >>> configuration = model.config @@ -218,7 +217,7 @@ def __init__( self.sparse_pred_ignore_index = sparse_pred_ignore_index if unpad_inputs is None: - self.unpad_inputs = self._attn_implementation == "flash_attention_2" + self.unpad_inputs = self._attn_implementation in {"flash_attention_2", "flex_attention"} class ModernBertModuleType(str, Enum): @@ -681,31 +680,36 @@ def flash_attention_forward( return (attn.view(bs, dim),) -# def flex_attention_forward( -# config: ModernBertConfig, -# query: torch.Tensor, -# key: torch.Tensor, -# value: torch.Tensor, -# mask: Optional[torch.Tensor], -# output_attentions: bool = False, -# **_kwargs, -# ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - -# attn_output = flex_attention( -# query, -# key, -# value, -# enable_gqa=True, -# scale=config.scaling, -# return_lse=output_attentions, -# ) -# if not output_attentions: -# attn_weights = None -# else: -# attn_output, attn_weights = attn_output - -# attn_output = attn_output.transpose(1, 2).contiguous() -# return attn_output, attn_weights +def flex_attention_forward( + self: "ModernBertAttention", + qkv: torch.Tensor, + rotary_emb: ModernBertUnpaddedRotaryEmbedding, + cu_seqlens: torch.Tensor, + block_mask: BlockMask, + max_seqlen: int, + bs: int, + dim: int, + **_kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + qkv = rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + query, key, value = qkv.unbind(dim=1) + + query = query.transpose(0, 1).unsqueeze(0) + key = key.transpose(0, 1).unsqueeze(0) + value = value.transpose(0, 1).unsqueeze(0) + + attn_output = flex_attention( + query, + key, + value, + block_mask=block_mask, + enable_gqa=False, + scale=None, + return_lse=False, + ) + + attn_output = attn_output.squeeze(0).transpose(0, 1).contiguous() + return (attn_output.view(bs, dim),) def sdpa_attention_forward( @@ -740,7 +744,7 @@ def sdpa_attention_forward( MODERNBERT_ATTENTION_FUNCTION = { "flash_attention_2": flash_attention_forward, - # "flex_attention": flex_attention_forward, + "flex_attention": flex_attention_forward, "eager": eager_attention_forward, "sdpa": sdpa_attention_forward, } @@ -785,7 +789,7 @@ def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): rope_theta = config.local_rope_theta max_position_embeddings = config.local_attention - if config._attn_implementation == "flash_attention_2": + if config._attn_implementation in {"flash_attention_2", "flex_attention"}: self.rotary_emb = ModernBertUnpaddedRotaryEmbedding( dim=self.head_dim, max_seqlen=max_position_embeddings, base=rope_theta ) @@ -822,6 +826,7 @@ def forward( position_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None, + block_mask: Optional[BlockMask] = None, max_seqlen: Optional[int] = None, output_attentions: Optional[bool] = False, **kwargs, @@ -863,6 +868,20 @@ def forward( "bs": bs, } ) + elif self.config._attn_implementation == "flex_attention": + bs, dim = hidden_states.shape[:2] + qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) + + attn_kwargs.update( + { + "local_attention": self.local_attention, + "block_mask": block_mask, + "cu_seqlens": cu_seqlens, + "max_seqlen": max_seqlen, + "bs": bs, + } + ) + else: bs, seqlen = hidden_states.shape[:2] qkv = qkv.view(bs, seqlen, 3, self.num_heads, self.head_dim) @@ -928,6 +947,7 @@ def forward( position_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None, + block_mask: Optional[BlockMask] = None, max_seqlen: Optional[int] = None, output_attentions: Optional[bool] = False, ) -> torch.Tensor: @@ -944,6 +964,7 @@ def forward( self.attn_norm(hidden_states), position_ids=position_ids, cu_seqlens=cu_seqlens, + block_mask=block_mask, max_seqlen=max_seqlen, attention_mask=attention_mask, output_attentions=output_attentions, @@ -1021,6 +1042,7 @@ class ModernBertPreTrainedModel(PreTrainedModel): _no_split_modules = ["ModernBertEmbeddings", "ModernBertEncoderLayer"] _supports_flash_attn_2 = True _supports_sdpa = True + _supports_flex_attn = True def _init_weights( self, @@ -1085,6 +1107,40 @@ def _pad_outputs( inputs=inputs, indices=indices, batch=batch_size, seqlen=seqlen, labels=labels, ignore_index=ignore_index ) + @classmethod + def offsets_to_sequence_ids_tensor(cls, offsets): + device = offsets.device + counts = offsets[1:] - offsets[:-1] + return torch.repeat_interleave(torch.arange(len(counts), device=device, dtype=torch.int32), counts) + + @torch.compile(dynamic=False) + def create_attention_mask(self, sequence_ids, cu_seqlens, window_size): + """ + Creates a block mask combining sequence masking and local/or global attention masking. + """ + + def sliding_window_seq_mask_mod(b, h, q_idx, kv_idx): + # only allow attention within the same sequence + same_seq = sequence_ids[q_idx] == sequence_ids[kv_idx] + + # get position within the sequence + q_pos = q_idx - cu_seqlens[sequence_ids[q_idx]] + kv_pos = kv_idx - cu_seqlens[sequence_ids[kv_idx]] + + # sliding window within each sequence + in_window = (q_pos - kv_pos).abs() <= window_size + + return same_seq & in_window + + block_mask = create_block_mask( + sliding_window_seq_mask_mod, + B=None, + H=None, + Q_LEN=cu_seqlens[-1], + KV_LEN=cu_seqlens[-1], + ) + return block_mask + class ModernBertModel(ModernBertPreTrainedModel): def __init__(self, config: ModernBertConfig): @@ -1172,6 +1228,19 @@ def forward( # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) + # create block mask + if self.config._attn_implementation == "flex_attention": + sequence_ids = self.offsets_to_sequence_ids_tensor(cu_seqlens) + + if self.config.local_attention != (-1, -1): + window_size = self.config.local_attention // 2 + else: + window_size = max_seqlen + + block_mask = self.create_attention_mask(sequence_ids, cu_seqlens, window_size) + else: + block_mask = None + for encoder_layer in self.layers: if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -1192,6 +1261,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, cu_seqlens=cu_seqlens, + block_mask=block_mask, max_seqlen=max_seqlen, output_attentions=output_attentions, ) From a356708e9be7ad4da9d9c6e042e14fa5028eee36 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Mon, 16 Dec 2024 12:46:39 +0100 Subject: [PATCH 29/88] Reformat init_weights based on review --- .../models/modernbert/modeling_modernbert.py | 183 +++++------------- .../models/modernbert/modular_modernbert.py | 179 +++++------------ .../modernbert/test_modeling_modernbert.py | 21 +- tests/test_modeling_common.py | 2 + 4 files changed, 114 insertions(+), 271 deletions(-) diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index 2aca5ba7b8c627..4bee4e6c240b25 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -51,13 +51,6 @@ logger = logging.get_logger(__name__) -class ModernBertModuleType(str, Enum): - in_module = "in" - out_module = "out" - embedding = "emb" - final_out = "final_out" - - class ModernBertPoolingType(str, Enum): cls = "cls" mean = "mean" @@ -246,52 +239,6 @@ def extra_repr(self) -> str: return f"dim={self.dim}, base={self.base}, scale_base={self.scale_base}" -# Copyright 2023 OLMo Authors -# License: Apache-2.0 - - -def _init_modernbert_weights( - config: ModernBertConfig, - module: Union[nn.Linear, nn.Embedding], - module_type: ModernBertModuleType, -) -> None: - """ - Initialize weights of a linear or embedding module. - - :param config: The model config. - :param module: The linear or embedding submodule to initialize. - """ - if module_type is None: - raise RuntimeError("When using the full megatron init, every module must have a type.") - - cutoff_factor = config.initializer_cutoff_factor - if cutoff_factor is None: - cutoff_factor = 3 - - if module_type == ModernBertModuleType.in_module: - std = config.initializer_range # for att_proj (same as QKV), ff_proj - elif module_type == ModernBertModuleType.out_module: - std = config.initializer_range / math.sqrt(2.0 * config.num_hidden_layers) # for attn_out, ff_out - elif module_type == ModernBertModuleType.embedding: - std = config.initializer_range # token embeddings (wte) - elif module_type == ModernBertModuleType.final_out: - std = config.hidden_size**-0.5 # final output (ff_out) - else: - raise RuntimeError(f"Unknown module type '{module_type}'") - - nn.init.trunc_normal_( - module.weight, - mean=0.0, - std=std, - a=-cutoff_factor * std, - b=cutoff_factor * std, - ) - - if isinstance(module, nn.Linear): - if module.bias is not None: - nn.init.zeros_(module.bias) - - class ModernBertEmbeddings(nn.Module): """ Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. @@ -304,11 +251,6 @@ def __init__(self, config: ModernBertConfig): self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) self.drop = nn.Dropout(config.embedding_dropout) if config.embedding_dropout > 0.0 else nn.Identity() - def _init_weights(self, reset_params: bool = False): - _init_modernbert_weights(self.config, self.tok_embeddings, module_type=ModernBertModuleType.embedding) - if reset_params: - self.norm.reset_parameters() - @torch.compile(dynamic=True) def forward(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None) -> torch.Tensor: return self.drop(self.norm(self.tok_embeddings(input_ids))) @@ -329,10 +271,6 @@ def __init__(self, config: ModernBertConfig): self.drop = nn.Dropout(config.mlp_dropout) if config.mlp_dropout > 0.0 else nn.Identity() self.Wo = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_bias) - def _init_weights(self, reset_params: bool = False): - _init_modernbert_weights(self.config, self.Wi, module_type=ModernBertModuleType.in_module) - _init_modernbert_weights(self.config, self.Wo, module_type=ModernBertModuleType.out_module) - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: input, gate = self.Wi(hidden_states).chunk(2, dim=-1) return self.Wo(self.drop(self.act(input) * gate)) @@ -615,10 +553,6 @@ def prune_heads(self, heads): self.all_head_size = self.head_dim * self.num_heads self.pruned_heads = self.pruned_heads.union(heads) - def _init_weights(self, reset_params: bool = False): - _init_modernbert_weights(self.config, self.Wqkv, module_type=ModernBertModuleType.in_module) - _init_modernbert_weights(self.config, self.Wo, module_type=ModernBertModuleType.out_module) - def forward( self, hidden_states: torch.Tensor, @@ -731,11 +665,6 @@ def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): self.mlp_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) self.mlp = ModernBertMLP(config) - def _init_weights(self, reset_params: bool = False): - if reset_params: - self.attn_norm.reset_parameters() - self.mlp_norm.reset_parameters() - @torch.compile(dynamic=True) def compiled_mlp(self, hidden_states: torch.Tensor) -> torch.Tensor: return self.mlp(self.mlp_norm(hidden_states)) @@ -786,14 +715,6 @@ def __init__(self, config: ModernBertConfig): else nn.Identity() ) - def _init_weights(self, reset_params: bool = False): - if reset_params: - self.norm.reset_parameters() - _init_modernbert_weights(self.config, self.dense, module_type=ModernBertModuleType.in_module) - - def reset_parameters(self): - self._init_weights(reset_params=True) - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return self.norm(self.act(self.dense(hidden_states))) @@ -825,13 +746,9 @@ def forward(self, hidden_states: torch.Tensor, pool: Optional[bool] = True) -> t return self.drop(self.norm(self.act(self.dense(output)))) - def _init_weights(self, reset_params: bool = False): - _init_modernbert_weights(self.config, self.dense, module_type=ModernBertModuleType.out_module) - if reset_params and hasattr(self.norm, "reset_parameters"): - self.norm.reset_parameters() - def reset_parameters(self): - self._init_weights(reset_params=True) +# Copyright 2023 OLMo Authors +# License: Apache-2.0 def _unpad_modernbert_input( @@ -933,13 +850,47 @@ class ModernBertPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - def _init_weights( - self, - module: Union[ModernBertEncoderLayer, ModernBertAttention, ModernBertMLP, ModernBertEmbeddings], - reset_params: bool = False, - ): - if isinstance(module, (ModernBertEncoderLayer, ModernBertAttention, ModernBertMLP, ModernBertEmbeddings)): - module._init_weights(reset_params=reset_params) + def _init_weights(self, module: nn.Module): + cutoff_factor = self.config.initializer_cutoff_factor + if cutoff_factor is None: + cutoff_factor = 3 + + def init_weight(module: nn.Module, std: float): + nn.init.trunc_normal_( + module.weight, + mean=0.0, + std=std, + a=-cutoff_factor * std, + b=cutoff_factor * std, + ) + + if isinstance(module, nn.Linear): + if module.bias is not None: + nn.init.zeros_(module.bias) + + stds = { + "in": self.config.initializer_range, + "out": self.config.initializer_range / math.sqrt(2.0 * self.config.num_hidden_layers), + "embedding": self.config.initializer_range, + "final_out": self.config.hidden_size**-0.5, + } + + if isinstance(module, ModernBertEmbeddings): + init_weight(module.tok_embeddings, stds["embedding"]) + elif isinstance(module, ModernBertMLP): + init_weight(module.Wi, stds["in"]) + init_weight(module.Wo, stds["out"]) + elif isinstance(module, ModernBertAttention): + init_weight(module.Wqkv, stds["in"]) + init_weight(module.Wo, stds["out"]) + elif isinstance(module, ModernBertPredictionHead): + init_weight(module.dense, stds["in"]) # TODO: Should this be "out"/"final_out"? + elif isinstance(module, ModernBertPoolingHead): + init_weight(module.dense, stds["out"]) + elif isinstance(module, ModernBertForMaskedLM): + init_weight(module.decoder, stds["out"]) + elif isinstance(module, (ModernBertForSequenceClassification, ModernBertForTokenClassification)): + init_weight(module.classifier, stds["final_out"]) @torch.no_grad() def _unpad_inputs_no_grad( @@ -1057,15 +1008,6 @@ class PreTrainedModel for layer, heads in heads_to_prune.items(): self.layers[layer].attn.prune_heads(heads) - def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None): - if module and hasattr(module, "_init_weights"): - super()._init_weights(module, reset_params=reset_params) - elif isinstance(reset_params, bool): - self.embeddings._init_weights(reset_params=reset_params) - - if reset_params: - self.final_norm.reset_parameters() - def forward( self, input_ids: torch.LongTensor = None, @@ -1193,21 +1135,9 @@ def __init__(self, config: ModernBertConfig): self.sparse_prediction = self.config.sparse_prediction self.sparse_pred_ignore_index = self.config.sparse_pred_ignore_index - # Initialize weights and apply final processing - self._init_weights(reset_params=False) - def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None): - assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified" - if module: - super()._init_weights(module, reset_params=reset_params) - else: - assert isinstance(reset_params, bool) - self.model._init_weights(reset_params=reset_params) - self.head._init_weights(reset_params=reset_params) - - # Output weights. - if not self.config.tie_word_embeddings: - _init_modernbert_weights(self.config, self.decoder, module_type=ModernBertModuleType.out_module) + # Initialize weights and apply final processing + self.post_init() def get_output_embeddings(self): return self.decoder @@ -1308,17 +1238,7 @@ def __init__(self, config: ModernBertConfig): self.classifier = nn.Linear(config.hidden_size, config.num_labels) # Initialize weights and apply final processing - self._init_weights(reset_params=False) - - def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None): - assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified" - if module: - super()._init_weights(module, reset_params=reset_params) - else: - assert isinstance(reset_params, bool) - self.model._init_weights(reset_params=reset_params) - self.head._init_weights(reset_params=reset_params) - _init_modernbert_weights(self.config, self.classifier, module_type=ModernBertModuleType.final_out) + self.post_init() def forward( self, @@ -1407,16 +1327,7 @@ def __init__(self, config: ModernBertConfig): self.classifier = nn.Linear(config.hidden_size, config.num_labels) # Initialize weights and apply final processing - self._init_weights(reset_params=False) - - def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None): - assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified" - if module: - super()._init_weights(module, reset_params=reset_params) - else: - assert isinstance(reset_params, bool) - self.model._init_weights(reset_params=reset_params) - _init_modernbert_weights(self.config, self.classifier, module_type=ModernBertModuleType.final_out) + self.post_init() def forward( self, diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index 60694f32e0b4c9..eb0c2f3e4a5cdc 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -220,13 +220,6 @@ def __init__( self.unpad_inputs = self._attn_implementation in {"flash_attention_2", "flex_attention"} -class ModernBertModuleType(str, Enum): - in_module = "in" - out_module = "out" - embedding = "emb" - final_out = "final_out" - - class ModernBertPoolingType(str, Enum): cls = "cls" mean = "mean" @@ -237,48 +230,6 @@ class ModernBertPoolingType(str, Enum): # License: Apache-2.0 -def _init_modernbert_weights( - config: ModernBertConfig, - module: Union[nn.Linear, nn.Embedding], - module_type: ModernBertModuleType, -) -> None: - """ - Initialize weights of a linear or embedding module. - - :param config: The model config. - :param module: The linear or embedding submodule to initialize. - """ - if module_type is None: - raise RuntimeError("When using the full megatron init, every module must have a type.") - - cutoff_factor = config.initializer_cutoff_factor - if cutoff_factor is None: - cutoff_factor = 3 - - if module_type == ModernBertModuleType.in_module: - std = config.initializer_range # for att_proj (same as QKV), ff_proj - elif module_type == ModernBertModuleType.out_module: - std = config.initializer_range / math.sqrt(2.0 * config.num_hidden_layers) # for attn_out, ff_out - elif module_type == ModernBertModuleType.embedding: - std = config.initializer_range # token embeddings (wte) - elif module_type == ModernBertModuleType.final_out: - std = config.hidden_size**-0.5 # final output (ff_out) - else: - raise RuntimeError(f"Unknown module type '{module_type}'") - - nn.init.trunc_normal_( - module.weight, - mean=0.0, - std=std, - a=-cutoff_factor * std, - b=cutoff_factor * std, - ) - - if isinstance(module, nn.Linear): - if module.bias is not None: - nn.init.zeros_(module.bias) - - def _unpad_modernbert_input( inputs: torch.Tensor, attention_mask: torch.Tensor, @@ -563,11 +514,6 @@ def __init__(self, config: ModernBertConfig): self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) self.drop = nn.Dropout(config.embedding_dropout) if config.embedding_dropout > 0.0 else nn.Identity() - def _init_weights(self, reset_params: bool = False): - _init_modernbert_weights(self.config, self.tok_embeddings, module_type=ModernBertModuleType.embedding) - if reset_params: - self.norm.reset_parameters() - @torch.compile(dynamic=True) def forward(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None) -> torch.Tensor: return self.drop(self.norm(self.tok_embeddings(input_ids))) @@ -588,10 +534,6 @@ def __init__(self, config: ModernBertConfig): self.drop = nn.Dropout(config.mlp_dropout) if config.mlp_dropout > 0.0 else nn.Identity() self.Wo = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_bias) - def _init_weights(self, reset_params: bool = False): - _init_modernbert_weights(self.config, self.Wi, module_type=ModernBertModuleType.in_module) - _init_modernbert_weights(self.config, self.Wo, module_type=ModernBertModuleType.out_module) - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: input, gate = self.Wi(hidden_states).chunk(2, dim=-1) return self.Wo(self.drop(self.act(input) * gate)) @@ -816,10 +758,6 @@ def prune_heads(self, heads): self.all_head_size = self.head_dim * self.num_heads self.pruned_heads = self.pruned_heads.union(heads) - def _init_weights(self, reset_params: bool = False): - _init_modernbert_weights(self.config, self.Wqkv, module_type=ModernBertModuleType.in_module) - _init_modernbert_weights(self.config, self.Wo, module_type=ModernBertModuleType.out_module) - def forward( self, hidden_states: torch.Tensor, @@ -932,11 +870,6 @@ def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): self.mlp_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) self.mlp = ModernBertMLP(config) - def _init_weights(self, reset_params: bool = False): - if reset_params: - self.attn_norm.reset_parameters() - self.mlp_norm.reset_parameters() - @torch.compile(dynamic=True) def compiled_mlp(self, hidden_states: torch.Tensor) -> torch.Tensor: return self.mlp(self.mlp_norm(hidden_states)) @@ -987,14 +920,6 @@ def __init__(self, config: ModernBertConfig): else nn.Identity() ) - def _init_weights(self, reset_params: bool = False): - if reset_params: - self.norm.reset_parameters() - _init_modernbert_weights(self.config, self.dense, module_type=ModernBertModuleType.in_module) - - def reset_parameters(self): - self._init_weights(reset_params=True) - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return self.norm(self.act(self.dense(hidden_states))) @@ -1026,14 +951,6 @@ def forward(self, hidden_states: torch.Tensor, pool: Optional[bool] = True) -> t return self.drop(self.norm(self.act(self.dense(output)))) - def _init_weights(self, reset_params: bool = False): - _init_modernbert_weights(self.config, self.dense, module_type=ModernBertModuleType.out_module) - if reset_params and hasattr(self.norm, "reset_parameters"): - self.norm.reset_parameters() - - def reset_parameters(self): - self._init_weights(reset_params=True) - class ModernBertPreTrainedModel(PreTrainedModel): config_class = ModernBertConfig @@ -1044,13 +961,47 @@ class ModernBertPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - def _init_weights( - self, - module: Union[ModernBertEncoderLayer, ModernBertAttention, ModernBertMLP, ModernBertEmbeddings], - reset_params: bool = False, - ): - if isinstance(module, (ModernBertEncoderLayer, ModernBertAttention, ModernBertMLP, ModernBertEmbeddings)): - module._init_weights(reset_params=reset_params) + def _init_weights(self, module: nn.Module): + cutoff_factor = self.config.initializer_cutoff_factor + if cutoff_factor is None: + cutoff_factor = 3 + + def init_weight(module: nn.Module, std: float): + nn.init.trunc_normal_( + module.weight, + mean=0.0, + std=std, + a=-cutoff_factor * std, + b=cutoff_factor * std, + ) + + if isinstance(module, nn.Linear): + if module.bias is not None: + nn.init.zeros_(module.bias) + + stds = { + "in": self.config.initializer_range, + "out": self.config.initializer_range / math.sqrt(2.0 * self.config.num_hidden_layers), + "embedding": self.config.initializer_range, + "final_out": self.config.hidden_size**-0.5, + } + + if isinstance(module, ModernBertEmbeddings): + init_weight(module.tok_embeddings, stds["embedding"]) + elif isinstance(module, ModernBertMLP): + init_weight(module.Wi, stds["in"]) + init_weight(module.Wo, stds["out"]) + elif isinstance(module, ModernBertAttention): + init_weight(module.Wqkv, stds["in"]) + init_weight(module.Wo, stds["out"]) + elif isinstance(module, ModernBertPredictionHead): + init_weight(module.dense, stds["in"]) # TODO: Should this be "out"/"final_out"? + elif isinstance(module, ModernBertPoolingHead): + init_weight(module.dense, stds["out"]) + elif isinstance(module, ModernBertForMaskedLM): + init_weight(module.decoder, stds["out"]) + elif isinstance(module, (ModernBertForSequenceClassification, ModernBertForTokenClassification)): + init_weight(module.classifier, stds["final_out"]) @torch.no_grad() def _unpad_inputs_no_grad( @@ -1168,15 +1119,6 @@ class PreTrainedModel for layer, heads in heads_to_prune.items(): self.layers[layer].attn.prune_heads(heads) - def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None): - if module and hasattr(module, "_init_weights"): - super()._init_weights(module, reset_params=reset_params) - elif isinstance(reset_params, bool): - self.embeddings._init_weights(reset_params=reset_params) - - if reset_params: - self.final_norm.reset_parameters() - def forward( self, input_ids: torch.LongTensor = None, @@ -1304,21 +1246,9 @@ def __init__(self, config: ModernBertConfig): self.sparse_prediction = self.config.sparse_prediction self.sparse_pred_ignore_index = self.config.sparse_pred_ignore_index - # Initialize weights and apply final processing - self._init_weights(reset_params=False) - def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None): - assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified" - if module: - super()._init_weights(module, reset_params=reset_params) - else: - assert isinstance(reset_params, bool) - self.model._init_weights(reset_params=reset_params) - self.head._init_weights(reset_params=reset_params) - - # Output weights. - if not self.config.tie_word_embeddings: - _init_modernbert_weights(self.config, self.decoder, module_type=ModernBertModuleType.out_module) + # Initialize weights and apply final processing + self.post_init() def get_output_embeddings(self): return self.decoder @@ -1419,17 +1349,7 @@ def __init__(self, config: ModernBertConfig): self.classifier = nn.Linear(config.hidden_size, config.num_labels) # Initialize weights and apply final processing - self._init_weights(reset_params=False) - - def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None): - assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified" - if module: - super()._init_weights(module, reset_params=reset_params) - else: - assert isinstance(reset_params, bool) - self.model._init_weights(reset_params=reset_params) - self.head._init_weights(reset_params=reset_params) - _init_modernbert_weights(self.config, self.classifier, module_type=ModernBertModuleType.final_out) + self.post_init() def forward( self, @@ -1518,16 +1438,7 @@ def __init__(self, config: ModernBertConfig): self.classifier = nn.Linear(config.hidden_size, config.num_labels) # Initialize weights and apply final processing - self._init_weights(reset_params=False) - - def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None): - assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified" - if module: - super()._init_weights(module, reset_params=reset_params) - else: - assert isinstance(reset_params, bool) - self.model._init_weights(reset_params=reset_params) - _init_modernbert_weights(self.config, self.classifier, module_type=ModernBertModuleType.final_out) + self.post_init() def forward( self, diff --git a/tests/models/modernbert/test_modeling_modernbert.py b/tests/models/modernbert/test_modeling_modernbert.py index db97595ae2fdfc..bd485a2078a987 100644 --- a/tests/models/modernbert/test_modeling_modernbert.py +++ b/tests/models/modernbert/test_modeling_modernbert.py @@ -25,7 +25,7 @@ from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester -from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask +from ...test_modeling_common import ModelTesterMixin, _config_zero_init, ids_tensor, random_attention_mask from ...test_pipeline_mixin import PipelineTesterMixin @@ -258,6 +258,25 @@ def test_model_3d_mask_shapes(self): config_and_inputs[3] = random_attention_mask([batch_size, seq_length, seq_length]) self.model_tester.create_and_check_model(*config_and_inputs) + def test_initialization(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + configs_no_init = _config_zero_init(config) + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + for name, param in model.named_parameters(): + # The classifier.weight from ModernBertForSequenceClassification and ModernBertForTokenClassification + # are initialized without `initializer_range`, so they're not set to ~0 via the _config_zero_init + if param.requires_grad and not ( + name == "classifier.weight" + and model_class in [ModernBertForSequenceClassification, ModernBertForTokenClassification] + ): + self.assertIn( + ((param.data.mean() * 1e9).round() / 1e9).item(), + [0.0, 1.0], + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + @unittest.skip("ModernBert doesn't use `inputs_embeds` as input.") def test_inputs_embeds(self): pass diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 99d0a8058c67f8..318bf5a0a1cd22 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -3443,6 +3443,8 @@ def test_mismatched_shapes_have_properly_initialized_weights(self): "Data2VecAudioForSequenceClassification", "UniSpeechForSequenceClassification", "PvtForImageClassification", + "ModernBertForSequenceClassification", + "ModernBertForTokenClassification", ] special_param_names = [ r"^bit\.", From f83fdc0cb09c72e44182d8a03761402f96ad3b4d Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Mon, 16 Dec 2024 12:51:32 +0100 Subject: [PATCH 30/88] self -> module in attention forwards --- .../models/modernbert/modeling_modernbert.py | 28 +++++++++---------- .../models/modernbert/modular_modernbert.py | 28 +++++++++---------- 2 files changed, 28 insertions(+), 28 deletions(-) diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index 4bee4e6c240b25..58edc751fdf1a5 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -339,7 +339,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): def eager_attention_forward( - self: "ModernBertAttention", + module: "ModernBertAttention", qkv: torch.Tensor, position_ids: Optional[torch.LongTensor], attention_mask: torch.Tensor, @@ -350,12 +350,12 @@ def eager_attention_forward( **_kwargs, ) -> Tuple[torch.Tensor, torch.Tensor] | Tuple[torch.Tensor]: # qkv: [batch_size, seqlen, 3, nheads, headdim] - cos, sin = self.rotary_emb(qkv, position_ids=position_ids) + cos, sin = module.rotary_emb(qkv, position_ids=position_ids) query, key, value = qkv.transpose(3, 1).unbind(dim=2) # query, key, value: [batch_size, heads, seq_len, head_dim] query, key = apply_rotary_pos_emb(query, key, cos, sin) - scale = self.head_dim**-0.5 + scale = module.head_dim**-0.5 attn_weights = torch.matmul(query, key.transpose(2, 3)) * scale if attention_mask is not None: # no matter the length, we just slice it @@ -364,8 +364,8 @@ def eager_attention_forward( # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - if self.training: - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout) + if module.training: + attn_weights = nn.functional.dropout(attn_weights, p=module.attention_dropout) attn_output = torch.matmul(attn_weights, value) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bs, seqlen, dim) @@ -375,7 +375,7 @@ def eager_attention_forward( def flash_attention_forward( - self: "ModernBertAttention", + module: "ModernBertAttention", qkv: torch.Tensor, rotary_emb: ModernBertUnpaddedRotaryEmbedding, cu_seqlens: torch.Tensor, @@ -400,8 +400,8 @@ def flash_attention_forward( qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, - dropout_p=self.attention_dropout if self.training else 0.0, - deterministic=self.deterministic_flash_attn, + dropout_p=module.attention_dropout if module.training else 0.0, + deterministic=module.deterministic_flash_attn, window_size=local_attention, ) attn = attn.to(orig_dtype) # type: ignore @@ -410,15 +410,15 @@ def flash_attention_forward( qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, - dropout_p=self.attention_dropout if self.training else 0.0, - deterministic=self.deterministic_flash_attn, + dropout_p=module.attention_dropout if module.training else 0.0, + deterministic=module.deterministic_flash_attn, window_size=local_attention, ) return (attn.view(bs, dim),) def flex_attention_forward( - self: "ModernBertAttention", + module: "ModernBertAttention", qkv: torch.Tensor, rotary_emb: ModernBertUnpaddedRotaryEmbedding, cu_seqlens: torch.Tensor, @@ -450,7 +450,7 @@ def flex_attention_forward( def sdpa_attention_forward( - self: "ModernBertAttention", + module: "ModernBertAttention", qkv: torch.Tensor, position_ids: Optional[torch.LongTensor], attention_mask: torch.Tensor, @@ -460,7 +460,7 @@ def sdpa_attention_forward( **_kwargs, ) -> Tuple[torch.Tensor]: # qkv: [batch_size, seqlen, 3, nheads, headdim] - cos, sin = self.rotary_emb(qkv, position_ids=position_ids) + cos, sin = module.rotary_emb(qkv, position_ids=position_ids) query, key, value = qkv.transpose(3, 1).unbind(dim=2) # query, key, value: [batch_size, heads, seq_len, head_dim] query, key = apply_rotary_pos_emb(query, key, cos, sin) @@ -472,7 +472,7 @@ def sdpa_attention_forward( query, key, value, - dropout_p=self.attention_dropout if self.training else 0.0, + dropout_p=module.attention_dropout if module.training else 0.0, attn_mask=attention_mask, ).transpose(1, 2) attn_output = attn_output.view(bs, seqlen, dim) diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index eb0c2f3e4a5cdc..9b7bf49f0070be 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -544,7 +544,7 @@ class ModernBertRotaryEmbedding(GemmaRotaryEmbedding): def eager_attention_forward( - self: "ModernBertAttention", + module: "ModernBertAttention", qkv: torch.Tensor, position_ids: Optional[torch.LongTensor], attention_mask: torch.Tensor, @@ -555,12 +555,12 @@ def eager_attention_forward( **_kwargs, ) -> Tuple[torch.Tensor, torch.Tensor] | Tuple[torch.Tensor]: # qkv: [batch_size, seqlen, 3, nheads, headdim] - cos, sin = self.rotary_emb(qkv, position_ids=position_ids) + cos, sin = module.rotary_emb(qkv, position_ids=position_ids) query, key, value = qkv.transpose(3, 1).unbind(dim=2) # query, key, value: [batch_size, heads, seq_len, head_dim] query, key = apply_rotary_pos_emb(query, key, cos, sin) - scale = self.head_dim**-0.5 + scale = module.head_dim**-0.5 attn_weights = torch.matmul(query, key.transpose(2, 3)) * scale if attention_mask is not None: # no matter the length, we just slice it @@ -569,8 +569,8 @@ def eager_attention_forward( # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - if self.training: - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout) + if module.training: + attn_weights = nn.functional.dropout(attn_weights, p=module.attention_dropout) attn_output = torch.matmul(attn_weights, value) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bs, seqlen, dim) @@ -580,7 +580,7 @@ def eager_attention_forward( def flash_attention_forward( - self: "ModernBertAttention", + module: "ModernBertAttention", qkv: torch.Tensor, rotary_emb: ModernBertUnpaddedRotaryEmbedding, cu_seqlens: torch.Tensor, @@ -605,8 +605,8 @@ def flash_attention_forward( qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, - dropout_p=self.attention_dropout if self.training else 0.0, - deterministic=self.deterministic_flash_attn, + dropout_p=module.attention_dropout if module.training else 0.0, + deterministic=module.deterministic_flash_attn, window_size=local_attention, ) attn = attn.to(orig_dtype) # type: ignore @@ -615,15 +615,15 @@ def flash_attention_forward( qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, - dropout_p=self.attention_dropout if self.training else 0.0, - deterministic=self.deterministic_flash_attn, + dropout_p=module.attention_dropout if module.training else 0.0, + deterministic=module.deterministic_flash_attn, window_size=local_attention, ) return (attn.view(bs, dim),) def flex_attention_forward( - self: "ModernBertAttention", + module: "ModernBertAttention", qkv: torch.Tensor, rotary_emb: ModernBertUnpaddedRotaryEmbedding, cu_seqlens: torch.Tensor, @@ -655,7 +655,7 @@ def flex_attention_forward( def sdpa_attention_forward( - self: "ModernBertAttention", + module: "ModernBertAttention", qkv: torch.Tensor, position_ids: Optional[torch.LongTensor], attention_mask: torch.Tensor, @@ -665,7 +665,7 @@ def sdpa_attention_forward( **_kwargs, ) -> Tuple[torch.Tensor]: # qkv: [batch_size, seqlen, 3, nheads, headdim] - cos, sin = self.rotary_emb(qkv, position_ids=position_ids) + cos, sin = module.rotary_emb(qkv, position_ids=position_ids) query, key, value = qkv.transpose(3, 1).unbind(dim=2) # query, key, value: [batch_size, heads, seq_len, head_dim] query, key = apply_rotary_pos_emb(query, key, cos, sin) @@ -677,7 +677,7 @@ def sdpa_attention_forward( query, key, value, - dropout_p=self.attention_dropout if self.training else 0.0, + dropout_p=module.attention_dropout if module.training else 0.0, attn_mask=attention_mask, ).transpose(1, 2) attn_output = attn_output.view(bs, seqlen, dim) From b444c15e3936036a7a1100b44fedb61f0c923261 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Mon, 16 Dec 2024 13:37:07 +0100 Subject: [PATCH 31/88] Remove if config.tie_word_embeddings --- src/transformers/models/modernbert/modeling_modernbert.py | 7 ++----- src/transformers/models/modernbert/modular_modernbert.py | 7 ++----- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index 58edc751fdf1a5..762db6cbc77617 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -1118,7 +1118,7 @@ def forward( class ModernBertForMaskedLM(ModernBertPreTrainedModel): - _tied_weights_keys = ["decoder.weight"] + _tied_weights_keys = ["model.embeddings.tok_embeddings.weight"] def __init__(self, config: ModernBertConfig): super().__init__(config) @@ -1126,10 +1126,7 @@ def __init__(self, config: ModernBertConfig): self.model = ModernBertModel(config) self.head = ModernBertPredictionHead(config) - if config.tie_word_embeddings: - decoder_weights = self.model.embeddings.tok_embeddings.weight - else: - decoder_weights = nn.Linear(config.hidden_size, config.vocab_size, bias=False).weight + decoder_weights = nn.Linear(config.hidden_size, config.vocab_size, bias=False).weight self.decoder = nn.Linear(decoder_weights.size(1), decoder_weights.size(0), bias=config.decoder_bias) self.decoder.weight = decoder_weights diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index 9b7bf49f0070be..1818ae8d65d326 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -1229,7 +1229,7 @@ def forward( class ModernBertForMaskedLM(ModernBertPreTrainedModel): - _tied_weights_keys = ["decoder.weight"] + _tied_weights_keys = ["model.embeddings.tok_embeddings.weight"] def __init__(self, config: ModernBertConfig): super().__init__(config) @@ -1237,10 +1237,7 @@ def __init__(self, config: ModernBertConfig): self.model = ModernBertModel(config) self.head = ModernBertPredictionHead(config) - if config.tie_word_embeddings: - decoder_weights = self.model.embeddings.tok_embeddings.weight - else: - decoder_weights = nn.Linear(config.hidden_size, config.vocab_size, bias=False).weight + decoder_weights = nn.Linear(config.hidden_size, config.vocab_size, bias=False).weight self.decoder = nn.Linear(decoder_weights.size(1), decoder_weights.size(0), bias=config.decoder_bias) self.decoder.weight = decoder_weights From 5aaf273e5ea9eadff4b2b19efb4435874eabe270 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Mon, 16 Dec 2024 13:40:52 +0100 Subject: [PATCH 32/88] Reformat output projection on a different line --- src/transformers/models/modernbert/modeling_modernbert.py | 3 ++- src/transformers/models/modernbert/modular_modernbert.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index 762db6cbc77617..d6fe2a131dc79a 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -649,8 +649,9 @@ def forward( **attn_kwargs, ) hidden_states = attn_outputs[0] + hidden_states = self.out_drop(self.Wo(hidden_states)) - return (self.out_drop(self.Wo(hidden_states)),) + attn_outputs[1:] # add attentions if outputted + return (hidden_states,) + attn_outputs[1:] # add attentions if outputted class ModernBertEncoderLayer(nn.Module): diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index 1818ae8d65d326..13911deb187b06 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -854,8 +854,9 @@ def forward( **attn_kwargs, ) hidden_states = attn_outputs[0] + hidden_states = self.out_drop(self.Wo(hidden_states)) - return (self.out_drop(self.Wo(hidden_states)),) + attn_outputs[1:] # add attentions if outputted + return (hidden_states,) + attn_outputs[1:] # add attentions if outputted class ModernBertEncoderLayer(nn.Module): From 0a8d04458c7c8b110ae10d3500d7dc73cfc4bcb7 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Mon, 16 Dec 2024 13:45:24 +0100 Subject: [PATCH 33/88] Remove pruning --- .../models/modernbert/modeling_modernbert.py | 23 ------------------- .../models/modernbert/modular_modernbert.py | 23 ------------------- .../modernbert/test_modeling_modernbert.py | 1 + 3 files changed, 1 insertion(+), 46 deletions(-) diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index d6fe2a131dc79a..d2d9066c53ce3c 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -33,7 +33,6 @@ from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_outputs import BaseModelOutput, MaskedLMOutput, SequenceClassifierOutput, TokenClassifierOutput from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer, prune_qkv_linear_layer from ...utils import is_flash_attn_2_available, is_torch_greater_or_equal, logging from .configuration_modernbert import ModernBertConfig @@ -539,20 +538,6 @@ def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity() self.pruned_heads = set() - def prune_heads(self, heads): - if len(heads) == 0: - return - heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads) - - # Prune linear layers - self.Wqkv = prune_qkv_linear_layer(self.Wqkv, index) - self.Wo = prune_linear_layer(self.Wo, index, dim=1) - - # Update hyper params and store pruned heads - self.num_heads = self.num_heads - len(heads) - self.all_head_size = self.head_dim * self.num_heads - self.pruned_heads = self.pruned_heads.union(heads) - def forward( self, hidden_states: torch.Tensor, @@ -1001,14 +986,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.embeddings.tok_embeddings = value - def _prune_heads(self, heads_to_prune): - """ - Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base - class PreTrainedModel - """ - for layer, heads in heads_to_prune.items(): - self.layers[layer].attn.prune_heads(heads) - def forward( self, input_ids: torch.LongTensor = None, diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index 13911deb187b06..6f9b31df1a718e 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -34,7 +34,6 @@ TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer, prune_qkv_linear_layer from ...utils import ( is_flash_attn_2_available, is_torch_greater_or_equal, @@ -744,20 +743,6 @@ def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity() self.pruned_heads = set() - def prune_heads(self, heads): - if len(heads) == 0: - return - heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads) - - # Prune linear layers - self.Wqkv = prune_qkv_linear_layer(self.Wqkv, index) - self.Wo = prune_linear_layer(self.Wo, index, dim=1) - - # Update hyper params and store pruned heads - self.num_heads = self.num_heads - len(heads) - self.all_head_size = self.head_dim * self.num_heads - self.pruned_heads = self.pruned_heads.union(heads) - def forward( self, hidden_states: torch.Tensor, @@ -1112,14 +1097,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.embeddings.tok_embeddings = value - def _prune_heads(self, heads_to_prune): - """ - Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base - class PreTrainedModel - """ - for layer, heads in heads_to_prune.items(): - self.layers[layer].attn.prune_heads(heads) - def forward( self, input_ids: torch.LongTensor = None, diff --git a/tests/models/modernbert/test_modeling_modernbert.py b/tests/models/modernbert/test_modeling_modernbert.py index bd485a2078a987..2a2067e337a0f9 100644 --- a/tests/models/modernbert/test_modeling_modernbert.py +++ b/tests/models/modernbert/test_modeling_modernbert.py @@ -214,6 +214,7 @@ class ModernBertModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste ) fx_compatible = False test_head_masking = False + test_pruning = False model_split_percents = [0.5, 0.8, 0.9] # special case for ForPreTraining model From 382e481dbb3df8ece138d81f2a6c8f91065210de Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Mon, 16 Dec 2024 14:05:09 +0100 Subject: [PATCH 34/88] Remove assert --- src/transformers/models/modernbert/modeling_modernbert.py | 1 - src/transformers/models/modernbert/modular_modernbert.py | 1 - 2 files changed, 2 deletions(-) diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index d2d9066c53ce3c..1765cd8593a390 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -68,7 +68,6 @@ def forward( ): # (total_nnz, 3, nheads, headdim) total_nnz, three, nheads, headdim = qkv.shape - assert three == 3 if qkv.is_contiguous(): # Call 1 kernel instead of 2 kernels # We need qkv to be contiguous so that when we reshape to combine (3, nheads) diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index 6f9b31df1a718e..b8c6f43f66f5f7 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -331,7 +331,6 @@ def forward( ): # (total_nnz, 3, nheads, headdim) total_nnz, three, nheads, headdim = qkv.shape - assert three == 3 if qkv.is_contiguous(): # Call 1 kernel instead of 2 kernels # We need qkv to be contiguous so that when we reshape to combine (3, nheads) From 5d05e8ebe6e9e54856440850769a89cc665acf8d Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Mon, 16 Dec 2024 14:06:49 +0100 Subject: [PATCH 35/88] Call contiguous() to simplify paths --- .../models/modernbert/modeling_modernbert.py | 111 +++++------------- .../models/modernbert/modular_modernbert.py | 111 +++++------------- 2 files changed, 64 insertions(+), 158 deletions(-) diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index 1765cd8593a390..50d2f581141996 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -67,45 +67,22 @@ def forward( max_seqlen: Optional[int] = None, ): # (total_nnz, 3, nheads, headdim) - total_nnz, three, nheads, headdim = qkv.shape - if qkv.is_contiguous(): - # Call 1 kernel instead of 2 kernels - # We need qkv to be contiguous so that when we reshape to combine (3, nheads) - # dimensions, we get the same tensor - # qk = rearrange(qkv[:, :2], "b_s t h d -> b_s (t h) d") - qk = qkv[:, :2].view(total_nnz, -1, headdim) - apply_rotary( - qk, - cos, - sin, - seqlen_offsets=0, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - interleaved=False, - inplace=True, - ) - else: - q, k = qkv[:, 0, :, :], qkv[:, 1, :, :] - apply_rotary( - q, - cos, - sin, - seqlen_offsets=0, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - interleaved=False, - inplace=True, - ) - apply_rotary( - k, - cos, - sin, - seqlen_offsets=0, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - interleaved=False, - inplace=True, - ) + qkv = qkv.contiguous() + total_nnz, _three, _nheads, headdim = qkv.shape + # We need qkv to be contiguous so that when we reshape to combine (3, nheads) dimensions, + # we get the same tensor + # qk = rearrange(qkv[:, :2], "b_s t h d -> b_s (t h) d") + qk = qkv[:, :2].view(total_nnz, -1, headdim) + apply_rotary( + qk, + cos, + sin, + seqlen_offsets=0, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + interleaved=False, + inplace=True, + ) ctx.save_for_backward(cos, sin, cu_seqlens) ctx.max_seqlen = max_seqlen @@ -114,46 +91,22 @@ def forward( @staticmethod def backward(ctx, do): cos, sin, cu_seqlens = ctx.saved_tensors - if do.is_contiguous(): - total_nnz, three, nheads, headdim = do.shape - # Call 1 kernel instead of 2 kernels - # We need dqkv to be contiguous so that when we reshape to combine (3, nheads) dimensions, we get the same tensor - dqk = do[:, :2].view(total_nnz, -1, headdim) - apply_rotary( - dqk, - cos, - sin, - seqlen_offsets=0, - cu_seqlens=cu_seqlens, - max_seqlen=ctx.max_seqlen, - interleaved=False, - inplace=True, - conjugate=True, - ) - else: - dq, dk = do[:, 0, :, :], do[:, 1, :, :] - apply_rotary( - dq, - cos, - sin, - seqlen_offsets=0, - cu_seqlens=cu_seqlens, - max_seqlen=ctx.max_seqlen, - interleaved=False, - inplace=True, - conjugate=True, - ) - apply_rotary( - dk, - cos, - sin, - seqlen_offsets=0, - cu_seqlens=cu_seqlens, - max_seqlen=ctx.max_seqlen, - interleaved=False, - inplace=True, - conjugate=True, - ) + do = do.contiguous() + total_nnz, _three, _nheads, headdim = do.shape + # We need dqkv to be contiguous so that when we reshape to combine (3, nheads) dimensions, + # we get the same tensor + dqk = do[:, :2].view(total_nnz, -1, headdim) + apply_rotary( + dqk, + cos, + sin, + seqlen_offsets=0, + cu_seqlens=cu_seqlens, + max_seqlen=ctx.max_seqlen, + interleaved=False, + inplace=True, + conjugate=True, + ) return do, None, None, None, None, None, None diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index b8c6f43f66f5f7..99dfd06d97976b 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -330,45 +330,22 @@ def forward( max_seqlen: Optional[int] = None, ): # (total_nnz, 3, nheads, headdim) - total_nnz, three, nheads, headdim = qkv.shape - if qkv.is_contiguous(): - # Call 1 kernel instead of 2 kernels - # We need qkv to be contiguous so that when we reshape to combine (3, nheads) - # dimensions, we get the same tensor - # qk = rearrange(qkv[:, :2], "b_s t h d -> b_s (t h) d") - qk = qkv[:, :2].view(total_nnz, -1, headdim) - apply_rotary( - qk, - cos, - sin, - seqlen_offsets=0, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - interleaved=False, - inplace=True, - ) - else: - q, k = qkv[:, 0, :, :], qkv[:, 1, :, :] - apply_rotary( - q, - cos, - sin, - seqlen_offsets=0, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - interleaved=False, - inplace=True, - ) - apply_rotary( - k, - cos, - sin, - seqlen_offsets=0, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - interleaved=False, - inplace=True, - ) + qkv = qkv.contiguous() + total_nnz, _three, _nheads, headdim = qkv.shape + # We need qkv to be contiguous so that when we reshape to combine (3, nheads) dimensions, + # we get the same tensor + # qk = rearrange(qkv[:, :2], "b_s t h d -> b_s (t h) d") + qk = qkv[:, :2].view(total_nnz, -1, headdim) + apply_rotary( + qk, + cos, + sin, + seqlen_offsets=0, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + interleaved=False, + inplace=True, + ) ctx.save_for_backward(cos, sin, cu_seqlens) ctx.max_seqlen = max_seqlen @@ -377,46 +354,22 @@ def forward( @staticmethod def backward(ctx, do): cos, sin, cu_seqlens = ctx.saved_tensors - if do.is_contiguous(): - total_nnz, three, nheads, headdim = do.shape - # Call 1 kernel instead of 2 kernels - # We need dqkv to be contiguous so that when we reshape to combine (3, nheads) dimensions, we get the same tensor - dqk = do[:, :2].view(total_nnz, -1, headdim) - apply_rotary( - dqk, - cos, - sin, - seqlen_offsets=0, - cu_seqlens=cu_seqlens, - max_seqlen=ctx.max_seqlen, - interleaved=False, - inplace=True, - conjugate=True, - ) - else: - dq, dk = do[:, 0, :, :], do[:, 1, :, :] - apply_rotary( - dq, - cos, - sin, - seqlen_offsets=0, - cu_seqlens=cu_seqlens, - max_seqlen=ctx.max_seqlen, - interleaved=False, - inplace=True, - conjugate=True, - ) - apply_rotary( - dk, - cos, - sin, - seqlen_offsets=0, - cu_seqlens=cu_seqlens, - max_seqlen=ctx.max_seqlen, - interleaved=False, - inplace=True, - conjugate=True, - ) + do = do.contiguous() + total_nnz, _three, _nheads, headdim = do.shape + # We need dqkv to be contiguous so that when we reshape to combine (3, nheads) dimensions, + # we get the same tensor + dqk = do[:, :2].view(total_nnz, -1, headdim) + apply_rotary( + dqk, + cos, + sin, + seqlen_offsets=0, + cu_seqlens=cu_seqlens, + max_seqlen=ctx.max_seqlen, + interleaved=False, + inplace=True, + conjugate=True, + ) return do, None, None, None, None, None, None From 98508c7afdaa626f60273bc6678f4855bc16c5f6 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Mon, 16 Dec 2024 14:07:07 +0100 Subject: [PATCH 36/88] Remove prune_qkv_linear_layer --- src/transformers/pytorch_utils.py | 42 ------------------------------- 1 file changed, 42 deletions(-) diff --git a/src/transformers/pytorch_utils.py b/src/transformers/pytorch_utils.py index b2461fc1e73dba..5bdf8a355ddfaa 100644 --- a/src/transformers/pytorch_utils.py +++ b/src/transformers/pytorch_utils.py @@ -92,48 +92,6 @@ def prune_linear_layer(layer: nn.Linear, index: torch.LongTensor, dim: int = 0) return new_layer -def prune_qkv_linear_layer(layer: nn.Linear, index: torch.LongTensor, dim: int = 0) -> nn.Linear: - """ - Prune a QKV linear layer to keep only entries in index. - - Used to remove heads. - - Args: - layer (`torch.nn.Linear`): The layer to prune. - index (`torch.LongTensor`): The indices to keep in the layer. - dim (`int`, *optional*, defaults to 0): The dimension on which to keep the indices. - - Returns: - `torch.nn.Linear`: The pruned QKV layer as a new layer with `requires_grad=True`. - """ - assert layer.out_features % 3 == 0, "The output features of the linear layer should be divisible by 3" - index = torch.cat( - [ - index, - index + layer.out_features // 3, - index + 2 * layer.out_features // 3, - ] - ) - index = index.to(layer.weight.device) - W = layer.weight.index_select(dim, index).clone().detach() - if layer.bias is not None: - if dim == 1: - b = layer.bias.clone().detach() - else: - b = layer.bias[index].clone().detach() - new_size = list(layer.weight.size()) - new_size[dim] = len(index) - new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device) - new_layer.weight.requires_grad = False - new_layer.weight.copy_(W.contiguous()) - new_layer.weight.requires_grad = True - if layer.bias is not None: - new_layer.bias.requires_grad = False - new_layer.bias.copy_(b.contiguous()) - new_layer.bias.requires_grad = True - return new_layer - - class Conv1D(nn.Module): """ 1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2). From 2c076c8ea4b3637b690f301845c47260106bf364 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Mon, 16 Dec 2024 14:08:01 +0100 Subject: [PATCH 37/88] Format code --- src/transformers/models/modernbert/modular_modernbert.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index 99dfd06d97976b..9e43323b9d78ec 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -791,7 +791,7 @@ def forward( **attn_kwargs, ) hidden_states = attn_outputs[0] - hidden_states = self.out_drop(self.Wo(hidden_states)) + hidden_states = self.out_drop(self.Wo(hidden_states)) return (hidden_states,) + attn_outputs[1:] # add attentions if outputted From 986c6feb9990b4e49c3ed810005e0855dd1c8b4f Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Mon, 16 Dec 2024 16:28:24 +0100 Subject: [PATCH 38/88] Keep as kwargs, only use if needed --- .../models/modernbert/modeling_modernbert.py | 60 ++++--------------- .../models/modernbert/modular_modernbert.py | 60 ++++--------------- 2 files changed, 20 insertions(+), 100 deletions(-) diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index 50d2f581141996..d823d491b0df51 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -295,7 +295,6 @@ def eager_attention_forward( position_ids: Optional[torch.LongTensor], attention_mask: torch.Tensor, bs: int, - seqlen: int, dim: int, output_attentions: Optional[bool] = False, **_kwargs, @@ -319,7 +318,7 @@ def eager_attention_forward( attn_weights = nn.functional.dropout(attn_weights, p=module.attention_dropout) attn_output = torch.matmul(attn_weights, value) attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bs, seqlen, dim) + attn_output = attn_output.view(bs, -1, dim) if output_attentions: return (attn_output, attn_weights) return (attn_output,) @@ -406,7 +405,6 @@ def sdpa_attention_forward( position_ids: Optional[torch.LongTensor], attention_mask: torch.Tensor, bs: int, - seqlen: int, dim: int, **_kwargs, ) -> Tuple[torch.Tensor]: @@ -426,7 +424,7 @@ def sdpa_attention_forward( dropout_p=module.attention_dropout if module.training else 0.0, attn_mask=attention_mask, ).transpose(1, 2) - attn_output = attn_output.view(bs, seqlen, dim) + attn_output = attn_output.view(bs, -1, dim) return (attn_output,) @@ -493,11 +491,6 @@ def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): def forward( self, hidden_states: torch.Tensor, - position_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - cu_seqlens: Optional[torch.Tensor] = None, - block_mask: Optional[BlockMask] = None, - max_seqlen: Optional[int] = None, output_attentions: Optional[bool] = False, **kwargs, ) -> torch.Tensor: @@ -522,48 +515,11 @@ def forward( """ qkv = self.Wqkv(hidden_states) - attn_kwargs = { - "output_attentions": output_attentions, - "dim": self.all_head_size, - } - if self.config._attn_implementation == "flash_attention_2": - bs = hidden_states.shape[0] - qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) - - attn_kwargs.update( - { - "local_attention": self.local_attention, - "cu_seqlens": cu_seqlens, - "max_seqlen": max_seqlen, - "bs": bs, - } - ) - elif self.config._attn_implementation == "flex_attention": - bs, dim = hidden_states.shape[:2] + bs = hidden_states.shape[0] + if self.config._attn_implementation in ("flash_attention_2", "flex_attention"): qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) - - attn_kwargs.update( - { - "local_attention": self.local_attention, - "block_mask": block_mask, - "cu_seqlens": cu_seqlens, - "max_seqlen": max_seqlen, - "bs": bs, - } - ) - else: - bs, seqlen = hidden_states.shape[:2] - qkv = qkv.view(bs, seqlen, 3, self.num_heads, self.head_dim) - - attn_kwargs.update( - { - "position_ids": position_ids, - "attention_mask": attention_mask, - "bs": bs, - "seqlen": seqlen, - } - ) + qkv = qkv.view(bs, -1, 3, self.num_heads, self.head_dim) if output_attentions: if self.config._attn_implementation == "sdpa": @@ -583,7 +539,11 @@ def forward( self, qkv=qkv, rotary_emb=self.rotary_emb, - **attn_kwargs, + local_attention=self.local_attention, + bs=bs, + dim=self.all_head_size, + output_attentions=output_attentions, + **kwargs, ) hidden_states = attn_outputs[0] hidden_states = self.out_drop(self.Wo(hidden_states)) diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index 9e43323b9d78ec..9c5b14956e035e 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -500,7 +500,6 @@ def eager_attention_forward( position_ids: Optional[torch.LongTensor], attention_mask: torch.Tensor, bs: int, - seqlen: int, dim: int, output_attentions: Optional[bool] = False, **_kwargs, @@ -524,7 +523,7 @@ def eager_attention_forward( attn_weights = nn.functional.dropout(attn_weights, p=module.attention_dropout) attn_output = torch.matmul(attn_weights, value) attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bs, seqlen, dim) + attn_output = attn_output.view(bs, -1, dim) if output_attentions: return (attn_output, attn_weights) return (attn_output,) @@ -611,7 +610,6 @@ def sdpa_attention_forward( position_ids: Optional[torch.LongTensor], attention_mask: torch.Tensor, bs: int, - seqlen: int, dim: int, **_kwargs, ) -> Tuple[torch.Tensor]: @@ -631,7 +629,7 @@ def sdpa_attention_forward( dropout_p=module.attention_dropout if module.training else 0.0, attn_mask=attention_mask, ).transpose(1, 2) - attn_output = attn_output.view(bs, seqlen, dim) + attn_output = attn_output.view(bs, -1, dim) return (attn_output,) @@ -698,11 +696,6 @@ def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): def forward( self, hidden_states: torch.Tensor, - position_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - cu_seqlens: Optional[torch.Tensor] = None, - block_mask: Optional[BlockMask] = None, - max_seqlen: Optional[int] = None, output_attentions: Optional[bool] = False, **kwargs, ) -> torch.Tensor: @@ -727,48 +720,11 @@ def forward( """ qkv = self.Wqkv(hidden_states) - attn_kwargs = { - "output_attentions": output_attentions, - "dim": self.all_head_size, - } - if self.config._attn_implementation == "flash_attention_2": - bs = hidden_states.shape[0] + bs = hidden_states.shape[0] + if self.config._attn_implementation in ("flash_attention_2", "flex_attention"): qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) - - attn_kwargs.update( - { - "local_attention": self.local_attention, - "cu_seqlens": cu_seqlens, - "max_seqlen": max_seqlen, - "bs": bs, - } - ) - elif self.config._attn_implementation == "flex_attention": - bs, dim = hidden_states.shape[:2] - qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) - - attn_kwargs.update( - { - "local_attention": self.local_attention, - "block_mask": block_mask, - "cu_seqlens": cu_seqlens, - "max_seqlen": max_seqlen, - "bs": bs, - } - ) - else: - bs, seqlen = hidden_states.shape[:2] - qkv = qkv.view(bs, seqlen, 3, self.num_heads, self.head_dim) - - attn_kwargs.update( - { - "position_ids": position_ids, - "attention_mask": attention_mask, - "bs": bs, - "seqlen": seqlen, - } - ) + qkv = qkv.view(bs, -1, 3, self.num_heads, self.head_dim) if output_attentions: if self.config._attn_implementation == "sdpa": @@ -788,7 +744,11 @@ def forward( self, qkv=qkv, rotary_emb=self.rotary_emb, - **attn_kwargs, + local_attention=self.local_attention, + bs=bs, + dim=self.all_head_size, + output_attentions=output_attentions, + **kwargs, ) hidden_states = attn_outputs[0] hidden_states = self.out_drop(self.Wo(hidden_states)) From 5cd39ad940cdaa4875ca76f73cdb60b0c418ed4b Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Mon, 16 Dec 2024 16:51:39 +0100 Subject: [PATCH 39/88] Remove unused codepaths & related config options --- .../modernbert/configuration_modernbert.py | 6 ----- .../models/modernbert/modeling_modernbert.py | 18 ++++---------- .../models/modernbert/modular_modernbert.py | 24 ++++--------------- 3 files changed, 10 insertions(+), 38 deletions(-) diff --git a/src/transformers/models/modernbert/configuration_modernbert.py b/src/transformers/models/modernbert/configuration_modernbert.py index 9243967e08d299..094c8180f928d8 100644 --- a/src/transformers/models/modernbert/configuration_modernbert.py +++ b/src/transformers/models/modernbert/configuration_modernbert.py @@ -120,8 +120,6 @@ def __init__( global_attn_every_n_layers=3, local_attention=128, local_rope_theta=10000.0, - skip_first_prenorm=True, - embedding_norm=True, embedding_dropout=0.0, mlp_bias=False, mlp_dropout=0.0, @@ -130,7 +128,6 @@ def __init__( decoder_bias=True, classifier_dropout=0.0, classifier_pooling="mean", - classifier_norm=True, classifier_bias=False, classifier_activation="gelu", deterministic_flash_attn=False, @@ -164,8 +161,6 @@ def __init__( self.global_attn_every_n_layers = global_attn_every_n_layers self.local_attention = local_attention self.local_rope_theta = local_rope_theta - self.skip_first_prenorm = skip_first_prenorm - self.embedding_norm = embedding_norm self.embedding_dropout = embedding_dropout self.mlp_bias = mlp_bias self.mlp_dropout = mlp_dropout @@ -175,7 +170,6 @@ def __init__( self.classifier_dropout = classifier_dropout self.classifier_pooling = classifier_pooling self.classifier_bias = classifier_bias - self.classifier_norm = classifier_norm self.classifier_activation = classifier_activation self.deterministic_flash_attn = deterministic_flash_attn self.sparse_prediction = sparse_prediction diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index d823d491b0df51..71bf535e74933e 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -555,7 +555,7 @@ class ModernBertEncoderLayer(nn.Module): def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): super().__init__() self.config = config - if config.skip_first_prenorm and config.embedding_norm and layer_id == 0: + if layer_id == 0: self.attn_norm = nn.Identity() else: self.attn_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) @@ -606,12 +606,8 @@ def __init__(self, config: ModernBertConfig): super().__init__() self.config = config self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.classifier_bias) - self.act = ACT2FN[config.classifier_activation] if config.classifier_activation else nn.Identity() - self.norm = ( - nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) - if config.classifier_norm - else nn.Identity() - ) + self.act = ACT2FN[config.classifier_activation] + self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return self.norm(self.act(self.dense(hidden_states))) @@ -622,12 +618,8 @@ def __init__(self, config: ModernBertConfig): super().__init__() self.config = config self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.classifier_bias) - self.act = ACT2FN[config.classifier_activation] if config.classifier_activation else nn.Identity() - self.norm = ( - nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) - if config.classifier_norm - else nn.Identity() - ) + self.act = ACT2FN[config.classifier_activation] + self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) self.drop = torch.nn.Dropout(config.classifier_dropout) if config.classifier_dropout > 0 else nn.Identity() self.pooling_type = ModernBertPoolingType(config.classifier_pooling) diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index 9c5b14956e035e..0840e33917c5d9 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -154,8 +154,6 @@ def __init__( global_attn_every_n_layers=3, local_attention=128, local_rope_theta=10000.0, - skip_first_prenorm=True, - embedding_norm=True, embedding_dropout=0.0, mlp_bias=False, mlp_dropout=0.0, @@ -164,7 +162,6 @@ def __init__( decoder_bias=True, classifier_dropout=0.0, classifier_pooling="mean", - classifier_norm=True, classifier_bias=False, classifier_activation="gelu", deterministic_flash_attn=False, @@ -198,8 +195,6 @@ def __init__( self.global_attn_every_n_layers = global_attn_every_n_layers self.local_attention = local_attention self.local_rope_theta = local_rope_theta - self.skip_first_prenorm = skip_first_prenorm - self.embedding_norm = embedding_norm self.embedding_dropout = embedding_dropout self.mlp_bias = mlp_bias self.mlp_dropout = mlp_dropout @@ -209,7 +204,6 @@ def __init__( self.classifier_dropout = classifier_dropout self.classifier_pooling = classifier_pooling self.classifier_bias = classifier_bias - self.classifier_norm = classifier_norm self.classifier_activation = classifier_activation self.deterministic_flash_attn = deterministic_flash_attn self.sparse_prediction = sparse_prediction @@ -760,7 +754,7 @@ class ModernBertEncoderLayer(nn.Module): def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): super().__init__() self.config = config - if config.skip_first_prenorm and config.embedding_norm and layer_id == 0: + if layer_id == 0: self.attn_norm = nn.Identity() else: self.attn_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) @@ -811,12 +805,8 @@ def __init__(self, config: ModernBertConfig): super().__init__() self.config = config self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.classifier_bias) - self.act = ACT2FN[config.classifier_activation] if config.classifier_activation else nn.Identity() - self.norm = ( - nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) - if config.classifier_norm - else nn.Identity() - ) + self.act = ACT2FN[config.classifier_activation] + self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return self.norm(self.act(self.dense(hidden_states))) @@ -827,12 +817,8 @@ def __init__(self, config: ModernBertConfig): super().__init__() self.config = config self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.classifier_bias) - self.act = ACT2FN[config.classifier_activation] if config.classifier_activation else nn.Identity() - self.norm = ( - nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) - if config.classifier_norm - else nn.Identity() - ) + self.act = ACT2FN[config.classifier_activation] + self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) self.drop = torch.nn.Dropout(config.classifier_dropout) if config.classifier_dropout > 0 else nn.Identity() self.pooling_type = ModernBertPoolingType(config.classifier_pooling) From 2d606b9f49e022c373e739a934ec899814d4598f Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Mon, 16 Dec 2024 17:34:38 +0100 Subject: [PATCH 40/88] Remove 3d attn_mask test; fix token classification tuple output --- src/transformers/models/modernbert/modeling_modernbert.py | 2 +- src/transformers/models/modernbert/modular_modernbert.py | 2 +- tests/models/modernbert/test_modeling_modernbert.py | 8 -------- 3 files changed, 2 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index 71bf535e74933e..9be56e23191bcb 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -1253,7 +1253,7 @@ def forward( loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) if not return_dict: - output = (logits,) + outputs[2:] + output = (logits,) + outputs[1:] return ((loss,) + output) if loss is not None else output return TokenClassifierOutput( diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index 0840e33917c5d9..ec177b93cc8fa6 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -1358,7 +1358,7 @@ def forward( loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) if not return_dict: - output = (logits,) + outputs[2:] + output = (logits,) + outputs[1:] return ((loss,) + output) if loss is not None else output return TokenClassifierOutput( diff --git a/tests/models/modernbert/test_modeling_modernbert.py b/tests/models/modernbert/test_modeling_modernbert.py index 2a2067e337a0f9..7d44bca09721b3 100644 --- a/tests/models/modernbert/test_modeling_modernbert.py +++ b/tests/models/modernbert/test_modeling_modernbert.py @@ -251,14 +251,6 @@ def test_model_various_embeddings(self): config_and_inputs[0].position_embedding_type = type self.model_tester.create_and_check_model(*config_and_inputs) - def test_model_3d_mask_shapes(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - # manipulate input_mask - config_and_inputs = list(config_and_inputs) - batch_size, seq_length = config_and_inputs[3].shape - config_and_inputs[3] = random_attention_mask([batch_size, seq_length, seq_length]) - self.model_tester.create_and_check_model(*config_and_inputs) - def test_initialization(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() From 8eb87e88780538988736dee19be2a44b5eb30a10 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Mon, 16 Dec 2024 18:03:51 +0100 Subject: [PATCH 41/88] Reorder: attention_mask above position_ids, fixes gradient checkpointing --- src/transformers/models/modernbert/modeling_modernbert.py | 8 ++++---- src/transformers/models/modernbert/modular_modernbert.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index 9be56e23191bcb..ffe31ccc295b3a 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -292,8 +292,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): def eager_attention_forward( module: "ModernBertAttention", qkv: torch.Tensor, - position_ids: Optional[torch.LongTensor], attention_mask: torch.Tensor, + position_ids: Optional[torch.LongTensor], bs: int, dim: int, output_attentions: Optional[bool] = False, @@ -402,8 +402,8 @@ def flex_attention_forward( def sdpa_attention_forward( module: "ModernBertAttention", qkv: torch.Tensor, - position_ids: Optional[torch.LongTensor], attention_mask: torch.Tensor, + position_ids: Optional[torch.LongTensor], bs: int, dim: int, **_kwargs, @@ -570,8 +570,8 @@ def compiled_mlp(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward( self, hidden_states: torch.Tensor, - position_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, cu_seqlens: Optional[torch.Tensor] = None, block_mask: Optional[BlockMask] = None, max_seqlen: Optional[int] = None, @@ -581,8 +581,8 @@ def forward( Args: hidden_states: (total_nnz, dim) - position_ids: (total_nnz,) attention_mask: (batch, max_seqlen) + position_ids: (total_nnz,) cu_seqlens: (batch + 1,) max_seqlen: int """ diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index ec177b93cc8fa6..67708673700802 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -491,8 +491,8 @@ class ModernBertRotaryEmbedding(GemmaRotaryEmbedding): def eager_attention_forward( module: "ModernBertAttention", qkv: torch.Tensor, - position_ids: Optional[torch.LongTensor], attention_mask: torch.Tensor, + position_ids: Optional[torch.LongTensor], bs: int, dim: int, output_attentions: Optional[bool] = False, @@ -601,8 +601,8 @@ def flex_attention_forward( def sdpa_attention_forward( module: "ModernBertAttention", qkv: torch.Tensor, - position_ids: Optional[torch.LongTensor], attention_mask: torch.Tensor, + position_ids: Optional[torch.LongTensor], bs: int, dim: int, **_kwargs, @@ -769,8 +769,8 @@ def compiled_mlp(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward( self, hidden_states: torch.Tensor, - position_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, cu_seqlens: Optional[torch.Tensor] = None, block_mask: Optional[BlockMask] = None, max_seqlen: Optional[int] = None, @@ -780,8 +780,8 @@ def forward( Args: hidden_states: (total_nnz, dim) - position_ids: (total_nnz,) attention_mask: (batch, max_seqlen) + position_ids: (total_nnz,) cu_seqlens: (batch + 1,) max_seqlen: int """ From 3a24af438d0c4ae6bc8a4443a93f72f5428144c0 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Mon, 16 Dec 2024 18:30:45 +0100 Subject: [PATCH 42/88] Fix usage if no FA2 or torch v2.5+ --- src/transformers/models/modernbert/modeling_modernbert.py | 6 +++--- src/transformers/models/modernbert/modular_modernbert.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index ffe31ccc295b3a..d0e8fb4f6c59dd 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -42,7 +42,7 @@ from flash_attn.layers.rotary import RotaryEmbedding from flash_attn.ops.triton.rotary import apply_rotary else: - RotaryEmbedding = None + RotaryEmbedding = object if is_torch_greater_or_equal("2.5"): from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention @@ -372,7 +372,7 @@ def flex_attention_forward( qkv: torch.Tensor, rotary_emb: ModernBertUnpaddedRotaryEmbedding, cu_seqlens: torch.Tensor, - block_mask: BlockMask, + block_mask: "BlockMask", max_seqlen: int, bs: int, dim: int, @@ -573,7 +573,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, cu_seqlens: Optional[torch.Tensor] = None, - block_mask: Optional[BlockMask] = None, + block_mask: Optional["BlockMask"] = None, max_seqlen: Optional[int] = None, output_attentions: Optional[bool] = False, ) -> torch.Tensor: diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index 67708673700802..47ebf8d6c56012 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -47,7 +47,7 @@ from flash_attn.layers.rotary import RotaryEmbedding from flash_attn.ops.triton.rotary import apply_rotary else: - RotaryEmbedding = None + RotaryEmbedding = object if is_torch_greater_or_equal("2.5"): from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention @@ -571,7 +571,7 @@ def flex_attention_forward( qkv: torch.Tensor, rotary_emb: ModernBertUnpaddedRotaryEmbedding, cu_seqlens: torch.Tensor, - block_mask: BlockMask, + block_mask: "BlockMask", max_seqlen: int, bs: int, dim: int, @@ -772,7 +772,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, cu_seqlens: Optional[torch.Tensor] = None, - block_mask: Optional[BlockMask] = None, + block_mask: Optional["BlockMask"] = None, max_seqlen: Optional[int] = None, output_attentions: Optional[bool] = False, ) -> torch.Tensor: From 37a60302048150d12fbf4cdc3691317c3a40ce21 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Tue, 17 Dec 2024 09:56:36 +0100 Subject: [PATCH 43/88] Make torch.compile/triton optional Should we rename 'compile'? It's a bit vague --- .../modernbert/configuration_modernbert.py | 6 +++++ .../models/modernbert/modeling_modernbert.py | 21 ++++++++++++--- .../models/modernbert/modular_modernbert.py | 27 ++++++++++++++++--- src/transformers/utils/import_utils.py | 6 ++++- 4 files changed, 53 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/modernbert/configuration_modernbert.py b/src/transformers/models/modernbert/configuration_modernbert.py index 094c8180f928d8..9e6ec29baffc4c 100644 --- a/src/transformers/models/modernbert/configuration_modernbert.py +++ b/src/transformers/models/modernbert/configuration_modernbert.py @@ -21,6 +21,7 @@ # limitations under the License. from ...configuration_utils import PretrainedConfig +from ...utils.import_utils import is_triton_available class ModernBertConfig(PretrainedConfig): @@ -133,6 +134,7 @@ def __init__( deterministic_flash_attn=False, sparse_prediction=False, sparse_pred_ignore_index=-100, + compile=None, **kwargs, ): super().__init__( @@ -174,6 +176,10 @@ def __init__( self.deterministic_flash_attn = deterministic_flash_attn self.sparse_prediction = sparse_prediction self.sparse_pred_ignore_index = sparse_pred_ignore_index + self.compile = compile + + if self.compile is None: + self.compile = is_triton_available() if unpad_inputs is None: self.unpad_inputs = self._attn_implementation in {"flash_attention_2", "flex_attention"} diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index d0e8fb4f6c59dd..10070d921cb747 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -203,9 +203,17 @@ def __init__(self, config: ModernBertConfig): self.drop = nn.Dropout(config.embedding_dropout) if config.embedding_dropout > 0.0 else nn.Identity() @torch.compile(dynamic=True) - def forward(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None) -> torch.Tensor: + def compiled_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor: return self.drop(self.norm(self.tok_embeddings(input_ids))) + def forward(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None) -> torch.Tensor: + hidden_states = ( + self.compiled_embeddings(input_ids) + if self.config.compile + else self.drop(self.norm(self.tok_embeddings(input_ids))) + ) + return hidden_states + class ModernBertMLP(nn.Module): """Applies the GLU at the end of each ModernBERT layer. @@ -596,7 +604,10 @@ def forward( output_attentions=output_attentions, ) hidden_states = hidden_states + attn_outputs[0] - hidden_states = hidden_states + self.compiled_mlp(hidden_states) + mlp_output = ( + self.compiled_mlp(hidden_states) if self.config.compile else self.mlp(self.mlp_norm(hidden_states)) + ) + hidden_states = hidden_states + mlp_output return (hidden_states,) + attn_outputs[1:] # add attentions if outputted @@ -1083,7 +1094,11 @@ def forward( last_hidden_state = last_hidden_state[mask_tokens] labels = labels[mask_tokens] - logits = self.compiled_head(last_hidden_state) + logits = ( + self.compiled_head(last_hidden_state) + if self.config.compile + else self.decoder(self.head(last_hidden_state)) + ) loss = None if labels is not None: diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index 47ebf8d6c56012..30ad0f558dab62 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -39,6 +39,7 @@ is_torch_greater_or_equal, logging, ) +from ...utils.import_utils import is_triton_available from ..gemma.modeling_gemma import GemmaRotaryEmbedding, apply_rotary_pos_emb @@ -167,6 +168,7 @@ def __init__( deterministic_flash_attn=False, sparse_prediction=False, sparse_pred_ignore_index=-100, + compile=None, **kwargs, ): super().__init__( @@ -208,6 +210,10 @@ def __init__( self.deterministic_flash_attn = deterministic_flash_attn self.sparse_prediction = sparse_prediction self.sparse_pred_ignore_index = sparse_pred_ignore_index + self.compile = compile + + if self.compile is None: + self.compile = is_triton_available() if unpad_inputs is None: self.unpad_inputs = self._attn_implementation in {"flash_attention_2", "flex_attention"} @@ -460,9 +466,17 @@ def __init__(self, config: ModernBertConfig): self.drop = nn.Dropout(config.embedding_dropout) if config.embedding_dropout > 0.0 else nn.Identity() @torch.compile(dynamic=True) - def forward(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None) -> torch.Tensor: + def compiled_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor: return self.drop(self.norm(self.tok_embeddings(input_ids))) + def forward(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None) -> torch.Tensor: + hidden_states = ( + self.compiled_embeddings(input_ids) + if self.config.compile + else self.drop(self.norm(self.tok_embeddings(input_ids))) + ) + return hidden_states + class ModernBertMLP(nn.Module): """Applies the GLU at the end of each ModernBERT layer. @@ -795,7 +809,10 @@ def forward( output_attentions=output_attentions, ) hidden_states = hidden_states + attn_outputs[0] - hidden_states = hidden_states + self.compiled_mlp(hidden_states) + mlp_output = ( + self.compiled_mlp(hidden_states) if self.config.compile else self.mlp(self.mlp_norm(hidden_states)) + ) + hidden_states = hidden_states + mlp_output return (hidden_states,) + attn_outputs[1:] # add attentions if outputted @@ -1188,7 +1205,11 @@ def forward( last_hidden_state = last_hidden_state[mask_tokens] labels = labels[mask_tokens] - logits = self.compiled_head(last_hidden_state) + logits = ( + self.compiled_head(last_hidden_state) + if self.config.compile + else self.decoder(self.head(last_hidden_state)) + ) loss = None if labels is not None: diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 32a647594741dd..92823a4ee016c3 100755 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -192,7 +192,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ _tiktoken_available = _is_package_available("tiktoken") _blobfile_available = _is_package_available("blobfile") _liger_kernel_available = _is_package_available("liger_kernel") - +_triton_available = _is_package_available("triton") _torch_version = "N/A" _torch_available = False @@ -1243,6 +1243,10 @@ def is_liger_kernel_available(): return version.parse(importlib.metadata.version("liger_kernel")) >= version.parse("0.3.0") +def is_triton_available(): + return _triton_available + + # docstyle-ignore AV_IMPORT_ERROR = """ {0} requires the PyAv library but it was not found in your environment. You can install it with: From b3b4028e826d14b623bc6c35f3541ab51a67b234 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Tue, 17 Dec 2024 10:54:44 +0100 Subject: [PATCH 44/88] Separate pooling options into separate functions (cls, mean) - cls as default --- .../modernbert/configuration_modernbert.py | 9 +++- .../models/modernbert/modeling_modernbert.py | 44 ++++++++------- .../models/modernbert/modular_modernbert.py | 53 ++++++++++--------- 3 files changed, 57 insertions(+), 49 deletions(-) diff --git a/src/transformers/models/modernbert/configuration_modernbert.py b/src/transformers/models/modernbert/configuration_modernbert.py index 9e6ec29baffc4c..4c0f9dbbdb9de5 100644 --- a/src/transformers/models/modernbert/configuration_modernbert.py +++ b/src/transformers/models/modernbert/configuration_modernbert.py @@ -20,6 +20,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Literal + from ...configuration_utils import PretrainedConfig from ...utils.import_utils import is_triton_available @@ -128,7 +130,7 @@ def __init__( unpad_no_grad=True, decoder_bias=True, classifier_dropout=0.0, - classifier_pooling="mean", + classifier_pooling: Literal["cls", "mean"] = "cls", classifier_bias=False, classifier_activation="gelu", deterministic_flash_attn=False, @@ -178,6 +180,11 @@ def __init__( self.sparse_pred_ignore_index = sparse_pred_ignore_index self.compile = compile + if self.classifier_pooling not in ["cls", "mean"]: + raise ValueError( + f'Invalid value for `classifier_pooling`, should be either "cls" or "mean", but is {self.classifier_pooling}.' + ) + if self.compile is None: self.compile = is_triton_available() diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index 10070d921cb747..c0c24a7389a642 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -21,7 +21,6 @@ # limitations under the License. import math -from enum import Enum from typing import Optional, Tuple, Union import torch @@ -50,12 +49,6 @@ logger = logging.get_logger(__name__) -class ModernBertPoolingType(str, Enum): - cls = "cls" - mean = "mean" - max = "max" - - class ApplyRotaryEmbUnpad(torch.autograd.Function): @staticmethod def forward( @@ -624,6 +617,20 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return self.norm(self.act(self.dense(hidden_states))) +def cls_pooling(hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: + return hidden_states[:, 0] + + +def mean_pooling(hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: + return (hidden_states * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(dim=1, keepdim=True) + + +MODERNBERT_POOLING_FUNCTION = { + "cls": cls_pooling, + "mean": mean_pooling, +} + + class ModernBertPoolingHead(nn.Module): def __init__(self, config: ModernBertConfig): super().__init__() @@ -632,24 +639,15 @@ def __init__(self, config: ModernBertConfig): self.act = ACT2FN[config.classifier_activation] self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) self.drop = torch.nn.Dropout(config.classifier_dropout) if config.classifier_dropout > 0 else nn.Identity() - self.pooling_type = ModernBertPoolingType(config.classifier_pooling) + self.pooling = MODERNBERT_POOLING_FUNCTION[config.classifier_pooling] - def forward(self, hidden_states: torch.Tensor, pool: Optional[bool] = True) -> torch.Tensor: + def forward( + self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, pool: Optional[bool] = True + ) -> torch.Tensor: if pool: - if self.pooling_type == ModernBertPoolingType.cls: - output = hidden_states[:, 0] - elif self.pooling_type == ModernBertPoolingType.mean: - output = hidden_states.mean(dim=1) - elif self.pooling_type == ModernBertPoolingType.max: - output = hidden_states.max(dim=1)[0] - else: - output = hidden_states - - return self.drop(self.norm(self.act(self.dense(output)))) - + hidden_states = self.pooling(hidden_states, attention_mask) -# Copyright 2023 OLMo Authors -# License: Apache-2.0 + return self.drop(self.norm(self.act(self.dense(hidden_states)))) def _unpad_modernbert_input( @@ -1173,7 +1171,7 @@ def forward( ) last_hidden_state = outputs[0] - pooled_output = self.head(last_hidden_state) + pooled_output = self.head(last_hidden_state, attention_mask) logits = self.classifier(pooled_output) loss = None diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index 30ad0f558dab62..5506ba26f123d3 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -15,8 +15,7 @@ # limitations under the License. import math -from enum import Enum -from typing import Optional, Tuple, Union +from typing import Literal, Optional, Tuple, Union import torch import torch.nn.functional as F @@ -162,7 +161,7 @@ def __init__( unpad_no_grad=True, decoder_bias=True, classifier_dropout=0.0, - classifier_pooling="mean", + classifier_pooling: Literal["cls", "mean"] = "cls", classifier_bias=False, classifier_activation="gelu", deterministic_flash_attn=False, @@ -212,6 +211,11 @@ def __init__( self.sparse_pred_ignore_index = sparse_pred_ignore_index self.compile = compile + if self.classifier_pooling not in ["cls", "mean"]: + raise ValueError( + f'Invalid value for `classifier_pooling`, should be either "cls" or "mean", but is {self.classifier_pooling}.' + ) + if self.compile is None: self.compile = is_triton_available() @@ -219,16 +223,6 @@ def __init__( self.unpad_inputs = self._attn_implementation in {"flash_attention_2", "flex_attention"} -class ModernBertPoolingType(str, Enum): - cls = "cls" - mean = "mean" - max = "max" - - -# Copyright 2023 OLMo Authors -# License: Apache-2.0 - - def _unpad_modernbert_input( inputs: torch.Tensor, attention_mask: torch.Tensor, @@ -829,6 +823,20 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return self.norm(self.act(self.dense(hidden_states))) +def cls_pooling(hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: + return hidden_states[:, 0] + + +def mean_pooling(hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: + return (hidden_states * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(dim=1, keepdim=True) + + +MODERNBERT_POOLING_FUNCTION = { + "cls": cls_pooling, + "mean": mean_pooling, +} + + class ModernBertPoolingHead(nn.Module): def __init__(self, config: ModernBertConfig): super().__init__() @@ -837,20 +845,15 @@ def __init__(self, config: ModernBertConfig): self.act = ACT2FN[config.classifier_activation] self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) self.drop = torch.nn.Dropout(config.classifier_dropout) if config.classifier_dropout > 0 else nn.Identity() - self.pooling_type = ModernBertPoolingType(config.classifier_pooling) + self.pooling = MODERNBERT_POOLING_FUNCTION[config.classifier_pooling] - def forward(self, hidden_states: torch.Tensor, pool: Optional[bool] = True) -> torch.Tensor: + def forward( + self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, pool: Optional[bool] = True + ) -> torch.Tensor: if pool: - if self.pooling_type == ModernBertPoolingType.cls: - output = hidden_states[:, 0] - elif self.pooling_type == ModernBertPoolingType.mean: - output = hidden_states.mean(dim=1) - elif self.pooling_type == ModernBertPoolingType.max: - output = hidden_states.max(dim=1)[0] - else: - output = hidden_states + hidden_states = self.pooling(hidden_states, attention_mask) - return self.drop(self.norm(self.act(self.dense(output)))) + return self.drop(self.norm(self.act(self.dense(hidden_states)))) class ModernBertPreTrainedModel(PreTrainedModel): @@ -1284,7 +1287,7 @@ def forward( ) last_hidden_state = outputs[0] - pooled_output = self.head(last_hidden_state) + pooled_output = self.head(last_hidden_state, attention_mask) logits = self.classifier(pooled_output) loss = None From b241a7e97a1413be7efcbf0e1435dd1581a2dc0f Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Tue, 17 Dec 2024 11:07:13 +0100 Subject: [PATCH 45/88] Simplify _pad_modernbert_output, remove unused labels path --- .../models/modernbert/modeling_modernbert.py | 38 +++---------------- .../models/modernbert/modular_modernbert.py | 38 +++---------------- 2 files changed, 12 insertions(+), 64 deletions(-) diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index c0c24a7389a642..781d11d7a78eea 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -696,9 +696,7 @@ def _pad_modernbert_output( indices: torch.Tensor, batch: int, seqlen: int, - labels: Optional[torch.Tensor] = None, - ignore_index: int = -100, -) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: +) -> torch.Tensor: """ Add padding to sequences. @@ -707,12 +705,9 @@ def _pad_modernbert_output( indices: (total_nnz) batch: int, batch size seqlen: int, max sequence length - position_ids: (total_nnz) or None - labels: (total_nnz) or None Returns: padded_inputs: (batch, seqlen, ...) or (batch, seqlen) - padded_labels: (batch, seqlen) or None """ if inputs.dim() == 1: output = torch.zeros(batch * seqlen, dtype=inputs.dtype, device=inputs.device) @@ -724,20 +719,7 @@ def _pad_modernbert_output( output[indices] = inputs padded_inputs = output.view(batch, seqlen, *rest) - padded_labels = None - if labels is not None: - padded_labels = torch.full( - (batch * seqlen,), fill_value=ignore_index, dtype=labels.dtype, device=labels.device - ) - padded_labels[indices] = labels - padded_labels = padded_labels.view(batch, seqlen) - - return padded_inputs, padded_labels - - # Copyright (c) 2023, Tri Dao. - # License: Apache-2.0 - - # if is_flash_attn_2_available(): + return padded_inputs class ModernBertPreTrainedModel(PreTrainedModel): @@ -821,16 +803,12 @@ def _pad_outputs_no_grad( indices: torch.Tensor, batch_size: int, seqlen: int, - labels: Optional[torch.Tensor] = None, - ignore_index: int = -100, ): return self._pad_outputs( inputs=inputs, indices=indices, batch_size=batch_size, seqlen=seqlen, - labels=labels, - ignore_index=ignore_index, ) def _pad_outputs( @@ -839,12 +817,8 @@ def _pad_outputs( indices: torch.Tensor, batch_size: int, seqlen: int, - labels: Optional[torch.Tensor] = None, - ignore_index: int = -100, ): - return _pad_modernbert_output( - inputs=inputs, indices=indices, batch=batch_size, seqlen=seqlen, labels=labels, ignore_index=ignore_index - ) + return _pad_modernbert_output(inputs=inputs, indices=indices, batch=batch_size, seqlen=seqlen) @classmethod def offsets_to_sequence_ids_tensor(cls, offsets): @@ -997,7 +971,7 @@ def forward( hidden_states = self.final_norm(hidden_states) if repad: - hidden_states, _ = self._pad_outputs(hidden_states, indices, batch_size, seq_len) + hidden_states = self._pad_outputs(hidden_states, indices, batch_size, seq_len) if not return_dict: return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) @@ -1104,9 +1078,9 @@ def forward( if self.config.unpad_inputs: if self.config.unpad_no_grad: - logits, _ = self._pad_outputs_no_grad(logits, indices, batch_size, seq_len) + logits = self._pad_outputs_no_grad(logits, indices, batch_size, seq_len) else: - logits, _ = self._pad_outputs(logits, indices, batch_size, seq_len) + logits = self._pad_outputs(logits, indices, batch_size, seq_len) if not return_dict: output = (logits,) return ((loss,) + output) if loss is not None else output diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index 5506ba26f123d3..51cd3b2a4d858e 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -269,9 +269,7 @@ def _pad_modernbert_output( indices: torch.Tensor, batch: int, seqlen: int, - labels: Optional[torch.Tensor] = None, - ignore_index: int = -100, -) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: +) -> torch.Tensor: """ Add padding to sequences. @@ -280,12 +278,9 @@ def _pad_modernbert_output( indices: (total_nnz) batch: int, batch size seqlen: int, max sequence length - position_ids: (total_nnz) or None - labels: (total_nnz) or None Returns: padded_inputs: (batch, seqlen, ...) or (batch, seqlen) - padded_labels: (batch, seqlen) or None """ if inputs.dim() == 1: output = torch.zeros(batch * seqlen, dtype=inputs.dtype, device=inputs.device) @@ -297,20 +292,7 @@ def _pad_modernbert_output( output[indices] = inputs padded_inputs = output.view(batch, seqlen, *rest) - padded_labels = None - if labels is not None: - padded_labels = torch.full( - (batch * seqlen,), fill_value=ignore_index, dtype=labels.dtype, device=labels.device - ) - padded_labels[indices] = labels - padded_labels = padded_labels.view(batch, seqlen) - - return padded_inputs, padded_labels - - # Copyright (c) 2023, Tri Dao. - # License: Apache-2.0 - - # if is_flash_attn_2_available(): + return padded_inputs class ApplyRotaryEmbUnpad(torch.autograd.Function): @@ -937,16 +919,12 @@ def _pad_outputs_no_grad( indices: torch.Tensor, batch_size: int, seqlen: int, - labels: Optional[torch.Tensor] = None, - ignore_index: int = -100, ): return self._pad_outputs( inputs=inputs, indices=indices, batch_size=batch_size, seqlen=seqlen, - labels=labels, - ignore_index=ignore_index, ) def _pad_outputs( @@ -955,12 +933,8 @@ def _pad_outputs( indices: torch.Tensor, batch_size: int, seqlen: int, - labels: Optional[torch.Tensor] = None, - ignore_index: int = -100, ): - return _pad_modernbert_output( - inputs=inputs, indices=indices, batch=batch_size, seqlen=seqlen, labels=labels, ignore_index=ignore_index - ) + return _pad_modernbert_output(inputs=inputs, indices=indices, batch=batch_size, seqlen=seqlen) @classmethod def offsets_to_sequence_ids_tensor(cls, offsets): @@ -1113,7 +1087,7 @@ def forward( hidden_states = self.final_norm(hidden_states) if repad: - hidden_states, _ = self._pad_outputs(hidden_states, indices, batch_size, seq_len) + hidden_states = self._pad_outputs(hidden_states, indices, batch_size, seq_len) if not return_dict: return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) @@ -1220,9 +1194,9 @@ def forward( if self.config.unpad_inputs: if self.config.unpad_no_grad: - logits, _ = self._pad_outputs_no_grad(logits, indices, batch_size, seq_len) + logits = self._pad_outputs_no_grad(logits, indices, batch_size, seq_len) else: - logits, _ = self._pad_outputs(logits, indices, batch_size, seq_len) + logits = self._pad_outputs(logits, indices, batch_size, seq_len) if not return_dict: output = (logits,) return ((loss,) + output) if loss is not None else output From 66f460340646bc55faf8092370115347b00e7617 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Tue, 17 Dec 2024 12:12:22 +0100 Subject: [PATCH 46/88] Update tied weights to remove decoder.weight, simplify decoder loading --- src/transformers/models/modernbert/modeling_modernbert.py | 7 ++----- src/transformers/models/modernbert/modular_modernbert.py | 7 ++----- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index 781d11d7a78eea..771e26a8d3d390 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -983,17 +983,14 @@ def forward( class ModernBertForMaskedLM(ModernBertPreTrainedModel): - _tied_weights_keys = ["model.embeddings.tok_embeddings.weight"] + _tied_weights_keys = ["decoder.weight"] def __init__(self, config: ModernBertConfig): super().__init__(config) self.config = config self.model = ModernBertModel(config) self.head = ModernBertPredictionHead(config) - - decoder_weights = nn.Linear(config.hidden_size, config.vocab_size, bias=False).weight - self.decoder = nn.Linear(decoder_weights.size(1), decoder_weights.size(0), bias=config.decoder_bias) - self.decoder.weight = decoder_weights + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=config.decoder_bias) self.sparse_prediction = self.config.sparse_prediction self.sparse_pred_ignore_index = self.config.sparse_pred_ignore_index diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index 51cd3b2a4d858e..d0d5d00474e66b 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -1099,17 +1099,14 @@ def forward( class ModernBertForMaskedLM(ModernBertPreTrainedModel): - _tied_weights_keys = ["model.embeddings.tok_embeddings.weight"] + _tied_weights_keys = ["decoder.weight"] def __init__(self, config: ModernBertConfig): super().__init__(config) self.config = config self.model = ModernBertModel(config) self.head = ModernBertPredictionHead(config) - - decoder_weights = nn.Linear(config.hidden_size, config.vocab_size, bias=False).weight - self.decoder = nn.Linear(decoder_weights.size(1), decoder_weights.size(0), bias=config.decoder_bias) - self.decoder.weight = decoder_weights + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=config.decoder_bias) self.sparse_prediction = self.config.sparse_prediction self.sparse_pred_ignore_index = self.config.sparse_pred_ignore_index From 3eb786b60f5f5752d65dbd3be77a389d0bd8d1ac Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Tue, 17 Dec 2024 13:59:34 +0100 Subject: [PATCH 47/88] Adaptively set config.compile based on hf_device_map/device/resize, etc. --- .../modernbert/configuration_modernbert.py | 4 -- .../models/modernbert/modeling_modernbert.py | 40 ++++++++++++++++++ .../models/modernbert/modular_modernbert.py | 42 +++++++++++++++++-- .../modernbert/test_modeling_modernbert.py | 12 +++++- 4 files changed, 90 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/modernbert/configuration_modernbert.py b/src/transformers/models/modernbert/configuration_modernbert.py index 4c0f9dbbdb9de5..7279dc67dff745 100644 --- a/src/transformers/models/modernbert/configuration_modernbert.py +++ b/src/transformers/models/modernbert/configuration_modernbert.py @@ -23,7 +23,6 @@ from typing import Literal from ...configuration_utils import PretrainedConfig -from ...utils.import_utils import is_triton_available class ModernBertConfig(PretrainedConfig): @@ -185,8 +184,5 @@ def __init__( f'Invalid value for `classifier_pooling`, should be either "cls" or "mean", but is {self.classifier_pooling}.' ) - if self.compile is None: - self.compile = is_triton_available() - if unpad_inputs is None: self.unpad_inputs = self._attn_implementation in {"flash_attention_2", "flex_attention"} diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index 771e26a8d3d390..93abd0b7ee1d7c 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -33,6 +33,7 @@ from ...modeling_outputs import BaseModelOutput, MaskedLMOutput, SequenceClassifierOutput, TokenClassifierOutput from ...modeling_utils import PreTrainedModel from ...utils import is_flash_attn_2_available, is_torch_greater_or_equal, logging +from ...utils.import_utils import is_triton_available from .configuration_modernbert import ModernBertConfig @@ -773,6 +774,41 @@ def init_weight(module: nn.Module, std: float): elif isinstance(module, (ModernBertForSequenceClassification, ModernBertForTokenClassification)): init_weight(module.classifier, stds["final_out"]) + def _maybe_set_compile(self): + if self.config.compile is False: + return + + if hasattr(self, "hf_device_map"): + if self.config.compile: + logger.warning_once( + "If `accelerate` split the model across devices, `torch.compile` will not work. " + "Falling back to non-compiled mode." + ) + self.config.compile = False + + if self.device.type == "mps": + if self.config.compile: + logger.warning_once( + "Compiling the model with `torch.compile` and using a `torch.mps` device is not supported. " + "Falling back to non-compiled mode." + ) + self.config.compile = False + + if self.config.compile is None: + self.config.compile = is_triton_available() + + def resize_token_embeddings(self, *args, **kwargs): + model_embeds = super().resize_token_embeddings(*args, **kwargs) + + if self.config.compile in {True, None}: + if self.config.compile: + logger.warning_once( + "Resizing token embeddings with `torch.compile` is not supported. Falling back to non-compiled mode." + ) + self.config.compile = False + + return model_embeds + @torch.no_grad() def _unpad_inputs_no_grad( self, @@ -896,6 +932,7 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None + self._maybe_set_compile() self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) if batch_size is None and seq_len is None: @@ -1025,6 +1062,7 @@ def forward( **kwargs, ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict + self._maybe_set_compile() if self.config.unpad_inputs: if indices is None and cu_seqlens is None and max_seqlen is None: @@ -1126,6 +1164,7 @@ def forward( `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict + self._maybe_set_compile() outputs = self.model( input_ids, @@ -1212,6 +1251,7 @@ def forward( Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict + self._maybe_set_compile() outputs = self.model( input_ids, diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index d0d5d00474e66b..57baafb4a03010 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -216,9 +216,6 @@ def __init__( f'Invalid value for `classifier_pooling`, should be either "cls" or "mean", but is {self.classifier_pooling}.' ) - if self.compile is None: - self.compile = is_triton_available() - if unpad_inputs is None: self.unpad_inputs = self._attn_implementation in {"flash_attention_2", "flex_attention"} @@ -889,6 +886,41 @@ def init_weight(module: nn.Module, std: float): elif isinstance(module, (ModernBertForSequenceClassification, ModernBertForTokenClassification)): init_weight(module.classifier, stds["final_out"]) + def _maybe_set_compile(self): + if self.config.compile is False: + return + + if hasattr(self, "hf_device_map"): + if self.config.compile: + logger.warning_once( + "If `accelerate` split the model across devices, `torch.compile` will not work. " + "Falling back to non-compiled mode." + ) + self.config.compile = False + + if self.device.type == "mps": + if self.config.compile: + logger.warning_once( + "Compiling the model with `torch.compile` and using a `torch.mps` device is not supported. " + "Falling back to non-compiled mode." + ) + self.config.compile = False + + if self.config.compile is None: + self.config.compile = is_triton_available() + + def resize_token_embeddings(self, *args, **kwargs): + model_embeds = super().resize_token_embeddings(*args, **kwargs) + + if self.config.compile in {True, None}: + if self.config.compile: + logger.warning_once( + "Resizing token embeddings with `torch.compile` is not supported. Falling back to non-compiled mode." + ) + self.config.compile = False + + return model_embeds + @torch.no_grad() def _unpad_inputs_no_grad( self, @@ -1012,6 +1044,7 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None + self._maybe_set_compile() self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) if batch_size is None and seq_len is None: @@ -1141,6 +1174,7 @@ def forward( **kwargs, ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict + self._maybe_set_compile() if self.config.unpad_inputs: if indices is None and cu_seqlens is None and max_seqlen is None: @@ -1242,6 +1276,7 @@ def forward( `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict + self._maybe_set_compile() outputs = self.model( input_ids, @@ -1328,6 +1363,7 @@ def forward( Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict + self._maybe_set_compile() outputs = self.model( input_ids, diff --git a/tests/models/modernbert/test_modeling_modernbert.py b/tests/models/modernbert/test_modeling_modernbert.py index 7d44bca09721b3..0f9e3ade73f207 100644 --- a/tests/models/modernbert/test_modeling_modernbert.py +++ b/tests/models/modernbert/test_modeling_modernbert.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os import unittest from transformers import ModernBertConfig, is_torch_available @@ -118,7 +119,7 @@ def get_config(self): """ Returns a tiny configuration by default. """ - return ModernBertConfig( + config = ModernBertConfig( vocab_size=self.vocab_size, pad_token_id=self.pad_token_id, hidden_size=self.hidden_size, @@ -135,6 +136,15 @@ def get_config(self): is_decoder=False, initializer_range=self.initializer_range, ) + if test := os.environ.get("PYTEST_CURRENT_TEST", False): + test_name = test.split(":")[-1].split(" ")[0] + + # If we're testing `test_retain_grad_hidden_states_attentions`, we normally get an error + # that compilation doesn't work. Users can then set compile=False when loading the model, + # much like here. We're testing whether it works once they've done that. + if test_name == "test_retain_grad_hidden_states_attentions": + config.compile = False + return config def create_and_check_model(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels): model = ModernBertModel(config=config) From 28fc79e881f096e06c94a53fffeac30d2d066dda Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Tue, 17 Dec 2024 14:35:19 +0100 Subject: [PATCH 48/88] Update ModernBertConfig docstring --- .../modernbert/configuration_modernbert.py | 97 ++++++++++++------ .../models/modernbert/modular_modernbert.py | 99 +++++++++++++------ 2 files changed, 137 insertions(+), 59 deletions(-) diff --git a/src/transformers/models/modernbert/configuration_modernbert.py b/src/transformers/models/modernbert/configuration_modernbert.py index 7279dc67dff745..cb36eafed1af3a 100644 --- a/src/transformers/models/modernbert/configuration_modernbert.py +++ b/src/transformers/models/modernbert/configuration_modernbert.py @@ -29,32 +29,24 @@ class ModernBertConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`ModernBertModel`]. It is used to instantiate an ModernBert model according to the specified arguments, defining the model architecture. Instantiating a configuration with the - defaults will yield a similar configuration to that of the ModernBert-base. - e.g. [answerdotai/modernbert-base](https://huggingface.co/answerdotai/modernbert-base) + defaults will yield a similar configuration to that of the ModernBERT-base. + e.g. [answerdotai/ModernBERT-base](https://huggingface.co/answerdotai/ModernBERT-base) + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. + Args: - vocab_size (`int`, *optional*, defaults to 256000): + vocab_size (`int`, *optional*, defaults to 50368): Vocabulary size of the ModernBert model. Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling [`ModernBertModel`] - hidden_size (`int`, *optional*, defaults to 2304): + hidden_size (`int`, *optional*, defaults to 768): Dimension of the hidden representations. - intermediate_size (`int`, *optional*, defaults to 9216): + intermediate_size (`int`, *optional*, defaults to 1152): Dimension of the MLP representations. - num_hidden_layers (`int`, *optional*, defaults to 26): + num_hidden_layers (`int`, *optional*, defaults to 22): Number of hidden layers in the Transformer decoder. - num_attention_heads (`int`, *optional*, defaults to 8): + num_attention_heads (`int`, *optional*, defaults to 12): Number of attention heads for each attention layer in the Transformer decoder. - num_key_value_heads (`int`, *optional*, defaults to 4): - This is the number of key_value heads that should be used to implement Grouped Query Attention. If - `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if - `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When - converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed - by meanpooling all the original heads within that group. For more details checkout [this - paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to - `num_attention_heads`. - head_dim (`int`, *optional*, defaults to 256): - The attention head dimension. hidden_activation (`str` or `function`, *optional*, defaults to `"gelu"`): The non-linear activation function (function or string) in the decoder. Will default to `"gelu"` if not specified. @@ -62,34 +54,81 @@ class ModernBertConfig(PretrainedConfig): The maximum sequence length that this model might ever be used with. initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - rms_norm_eps (`float`, *optional*, defaults to 1e-06): + initializer_cutoff_factor (`float`, *optional*, defaults to 2.0): + The cutoff factor for the truncated_normal_initializer for initializing all weight matrices. + norm_eps (`float`, *optional*, defaults to 1e-05): The epsilon used by the rms normalization layers. - pad_token_id (`int`, *optional*, defaults to 0): + norm_bias (`bool`, *optional*, defaults to `False`): + Whether to use bias in the normalization layers. + pad_token_id (`int`, *optional*, defaults to 50283): Padding token id. - eos_token_id (`int`, *optional*, defaults to 1): + eos_token_id (`int`, *optional*, defaults to 50282): End of stream token id. - bos_token_id (`int`, *optional*, defaults to 2): + bos_token_id (`int`, *optional*, defaults to 50281): Beginning of stream token id. + cls_token_id (`int`, *optional*, defaults to 50281): + Classification token id. + sep_token_id (`int`, *optional*, defaults to 50282): + Separation token id. tie_word_embeddings (`bool`, *optional*, defaults to `True`): Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the model has a output word embedding layer. - rope_theta (`float`, *optional*, defaults to 10000.0): - The base period of the RoPE embeddings. - attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + global_rope_theta (`float`, *optional*, defaults to 160000.0): + The base period of the global RoPE embeddings. + attention_bias (`bool`, *optional*, defaults to `False`): Whether to use a bias in the query, key, value and output projection layers during self-attention. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. - query_pre_attn_scalar (`float`, *optional*, defaults to 256): scaling factor used on the attention scores - final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits. - attn_logit_softcapping (`float`, *optional*, defaults to 50.0): scaling factor when applying tanh softcapping on the attention scores. - cache_implementation (`str`, *optional*, defaults to `"hybrid"`): the cache type to be used with `generate`. + global_attn_every_n_layers (`int`, *optional*, defaults to 3): + The number of layers between global attention layers. + local_attention (`int`, *optional*, defaults to 128): + The window size for local attention. + local_rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the local RoPE embeddings. + embedding_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the embeddings. + mlp_bias (`bool`, *optional*, defaults to `False`): + Whether to use bias in the MLP layers. + mlp_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the MLP layers. + unpad_inputs (`bool`, *optional*): + Whether to unpad the inputs in the forward pass. If set to `None`, then it will be set to `True` if the + attention implementation is `flash_attention_2` or `flex_attention`. Otherwise it will be set to `False`. + Unpadded inputs can be used to speed up the forward pass for `flash_attention_2` and `flex_attention`. + unpad_no_grad (`bool`, *optional*, defaults to `True`): + Whether to use `no_grad` when unpadding the inputs. + decoder_bias (`bool`, *optional*, defaults to `True`): + Whether to use bias in the decoder layers. + classifier_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the classifier. + classifier_pooling (`str`, *optional*, defaults to `"cls"`): + The pooling method for the classifier. Should be either `"cls"` or `"mean"`. + classifier_bias (`bool`, *optional*, defaults to `False`): + Whether to use bias in the classifier. + classifier_activation (`str`, *optional*, defaults to `"gelu"`): + The activation function for the classifier. + deterministic_flash_attn (`bool`, *optional*, defaults to `False`): + Whether to use deterministic flash attention. If `False`, inference will be faster but not deterministic. + sparse_prediction (`bool`, *optional*, defaults to `False`): + Whether to use sparse prediction for the masked language model instead of returning the full dense logits. + sparse_pred_ignore_index (`int`, *optional*, defaults to -100): + The index to ignore for the sparse prediction. + compile (`bool`, *optional*): + Whether to compile the model. If `None`, then parts of the model will be compiled if 1) `triton` is + installed, 2) the model is not on MPS, 3) the model is not shared between devices, and 4) the model is not + resized after initialization. If `True`, then the model may be faster in some scenarios. + + Examples: ```python >>> from transformers import ModernBertModel, ModernBertConfig - >>> # Initializing a ModernBert modernbert-base style configuration + + >>> # Initializing a ModernBert style configuration >>> configuration = ModernBertConfig() + >>> # Initializing a model from the modernbert-base style configuration >>> model = ModernBertModel(configuration) + >>> # Accessing the model configuration >>> configuration = model.config ```""" diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index 57baafb4a03010..85e99e1fba538a 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -52,7 +52,7 @@ if is_torch_greater_or_equal("2.5"): from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention -_CHECKPOINT_FOR_DOC = "answerdotai/modernbert-base" +_CHECKPOINT_FOR_DOC = "answerdotai/ModernBERT-base" logger = logging.get_logger(__name__) @@ -61,32 +61,24 @@ class ModernBertConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`ModernBertModel`]. It is used to instantiate an ModernBert model according to the specified arguments, defining the model architecture. Instantiating a configuration with the - defaults will yield a similar configuration to that of the ModernBert-base. - e.g. [answerdotai/modernbert-base](https://huggingface.co/answerdotai/modernbert-base) + defaults will yield a similar configuration to that of the ModernBERT-base. + e.g. [answerdotai/ModernBERT-base](https://huggingface.co/answerdotai/ModernBERT-base) + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. + Args: - vocab_size (`int`, *optional*, defaults to 256000): + vocab_size (`int`, *optional*, defaults to 50368): Vocabulary size of the ModernBert model. Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling [`ModernBertModel`] - hidden_size (`int`, *optional*, defaults to 2304): + hidden_size (`int`, *optional*, defaults to 768): Dimension of the hidden representations. - intermediate_size (`int`, *optional*, defaults to 9216): + intermediate_size (`int`, *optional*, defaults to 1152): Dimension of the MLP representations. - num_hidden_layers (`int`, *optional*, defaults to 26): + num_hidden_layers (`int`, *optional*, defaults to 22): Number of hidden layers in the Transformer decoder. - num_attention_heads (`int`, *optional*, defaults to 8): + num_attention_heads (`int`, *optional*, defaults to 12): Number of attention heads for each attention layer in the Transformer decoder. - num_key_value_heads (`int`, *optional*, defaults to 4): - This is the number of key_value heads that should be used to implement Grouped Query Attention. If - `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if - `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When - converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed - by meanpooling all the original heads within that group. For more details checkout [this - paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to - `num_attention_heads`. - head_dim (`int`, *optional*, defaults to 256): - The attention head dimension. hidden_activation (`str` or `function`, *optional*, defaults to `"gelu"`): The non-linear activation function (function or string) in the decoder. Will default to `"gelu"` if not specified. @@ -94,34 +86,81 @@ class ModernBertConfig(PretrainedConfig): The maximum sequence length that this model might ever be used with. initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - rms_norm_eps (`float`, *optional*, defaults to 1e-06): + initializer_cutoff_factor (`float`, *optional*, defaults to 2.0): + The cutoff factor for the truncated_normal_initializer for initializing all weight matrices. + norm_eps (`float`, *optional*, defaults to 1e-05): The epsilon used by the rms normalization layers. - pad_token_id (`int`, *optional*, defaults to 0): + norm_bias (`bool`, *optional*, defaults to `False`): + Whether to use bias in the normalization layers. + pad_token_id (`int`, *optional*, defaults to 50283): Padding token id. - eos_token_id (`int`, *optional*, defaults to 1): + eos_token_id (`int`, *optional*, defaults to 50282): End of stream token id. - bos_token_id (`int`, *optional*, defaults to 2): + bos_token_id (`int`, *optional*, defaults to 50281): Beginning of stream token id. + cls_token_id (`int`, *optional*, defaults to 50281): + Classification token id. + sep_token_id (`int`, *optional*, defaults to 50282): + Separation token id. tie_word_embeddings (`bool`, *optional*, defaults to `True`): Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the model has a output word embedding layer. - rope_theta (`float`, *optional*, defaults to 10000.0): - The base period of the RoPE embeddings. - attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + global_rope_theta (`float`, *optional*, defaults to 160000.0): + The base period of the global RoPE embeddings. + attention_bias (`bool`, *optional*, defaults to `False`): Whether to use a bias in the query, key, value and output projection layers during self-attention. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. - query_pre_attn_scalar (`float`, *optional*, defaults to 256): scaling factor used on the attention scores - final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits. - attn_logit_softcapping (`float`, *optional*, defaults to 50.0): scaling factor when applying tanh softcapping on the attention scores. - cache_implementation (`str`, *optional*, defaults to `"hybrid"`): the cache type to be used with `generate`. + global_attn_every_n_layers (`int`, *optional*, defaults to 3): + The number of layers between global attention layers. + local_attention (`int`, *optional*, defaults to 128): + The window size for local attention. + local_rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the local RoPE embeddings. + embedding_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the embeddings. + mlp_bias (`bool`, *optional*, defaults to `False`): + Whether to use bias in the MLP layers. + mlp_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the MLP layers. + unpad_inputs (`bool`, *optional*): + Whether to unpad the inputs in the forward pass. If set to `None`, then it will be set to `True` if the + attention implementation is `flash_attention_2` or `flex_attention`. Otherwise it will be set to `False`. + Unpadded inputs can be used to speed up the forward pass for `flash_attention_2` and `flex_attention`. + unpad_no_grad (`bool`, *optional*, defaults to `True`): + Whether to use `no_grad` when unpadding the inputs. + decoder_bias (`bool`, *optional*, defaults to `True`): + Whether to use bias in the decoder layers. + classifier_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the classifier. + classifier_pooling (`str`, *optional*, defaults to `"cls"`): + The pooling method for the classifier. Should be either `"cls"` or `"mean"`. + classifier_bias (`bool`, *optional*, defaults to `False`): + Whether to use bias in the classifier. + classifier_activation (`str`, *optional*, defaults to `"gelu"`): + The activation function for the classifier. + deterministic_flash_attn (`bool`, *optional*, defaults to `False`): + Whether to use deterministic flash attention. If `False`, inference will be faster but not deterministic. + sparse_prediction (`bool`, *optional*, defaults to `False`): + Whether to use sparse prediction for the masked language model instead of returning the full dense logits. + sparse_pred_ignore_index (`int`, *optional*, defaults to -100): + The index to ignore for the sparse prediction. + compile (`bool`, *optional*): + Whether to compile the model. If `None`, then parts of the model will be compiled if 1) `triton` is + installed, 2) the model is not on MPS, 3) the model is not shared between devices, and 4) the model is not + resized after initialization. If `True`, then the model may be faster in some scenarios. + + Examples: ```python >>> from transformers import ModernBertModel, ModernBertConfig - >>> # Initializing a ModernBert modernbert-base style configuration + + >>> # Initializing a ModernBert style configuration >>> configuration = ModernBertConfig() + >>> # Initializing a model from the modernbert-base style configuration >>> model = ModernBertModel(configuration) + >>> # Accessing the model configuration >>> configuration = model.config ```""" From 612befa822f94122c6227caaea8e9c6a89d36985 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Tue, 17 Dec 2024 15:15:02 +0100 Subject: [PATCH 49/88] Satisfy some consistency checks, add unfinished docs --- docs/source/en/_toctree.yml | 2 + docs/source/en/model_doc/modernbert.md | 87 +++++++++++++++++++ docs/source/en/perf_infer_gpu_one.md | 2 + .../modernbert/configuration_modernbert.py | 5 -- .../models/modernbert/modular_modernbert.py | 5 -- 5 files changed, 91 insertions(+), 10 deletions(-) create mode 100644 docs/source/en/model_doc/modernbert.md diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index d87906159ce34f..9c947581745f8e 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -495,6 +495,8 @@ - local: model_doc/mobilebert title: MobileBERT - local: model_doc/mpnet + title: ModernBert + - local: model_doc/modernbert title: MPNet - local: model_doc/mpt title: MPT diff --git a/docs/source/en/model_doc/modernbert.md b/docs/source/en/model_doc/modernbert.md new file mode 100644 index 00000000000000..b1e94f62105e36 --- /dev/null +++ b/docs/source/en/model_doc/modernbert.md @@ -0,0 +1,87 @@ + + +# ModernBert + +
+ +Models + + +
+ +## Overview + +The ModernBert model was proposed in []() by ... + +It builds on BERT and modifies ... + +The abstract from the paper is the following: + +** + +The original code can be found [here](). + +## Usage tips + +- This implementation is similar to [`BertModel`] ... +- ModernBert doesn't have `token_type_ids`, so you don't need to indicate which token belongs to which segment. +- ModernBert is similar to BERT but with ... + +## Resources + +A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with RoBERTa. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource. + + + +... + + + +- [Masked language modeling task guide](../tasks/masked_language_modeling) + + +## ModernBertConfig + +[[autodoc]] ModernBertConfig + + + + +## ModernBertModel + +[[autodoc]] ModernBertModel + - forward + +## ModernBertForMaskedLM + +[[autodoc]] ModernBertForMaskedLM + - forward + +## ModernBertForSequenceClassification + +[[autodoc]] ModernBertForSequenceClassification + - forward + +## ModernBertForTokenClassification + +[[autodoc]] ModernBertForTokenClassification + - forward + + + diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index 4d7852a66307e2..0709a6f0556fc6 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -73,6 +73,7 @@ FlashAttention-2 is currently supported for the following architectures: * [MBart](https://huggingface.co/docs/transformers/model_doc/mbart#transformers.MBartModel) * [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel) * [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel) +* [ModernBert](https://huggingface.co/docs/transformers/model_doc/modernbert#transformers.ModernBert) * [Moshi](https://huggingface.co/docs/transformers/model_doc/moshi#transformers.MoshiModel) * [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel) * [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel) @@ -261,6 +262,7 @@ For now, Transformers supports SDPA inference and training for the following arc * [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel) * [Mllama](https://huggingface.co/docs/transformers/model_doc/mllama#transformers.MllamaForConditionalGeneration) * [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel) +* [ModernBert](https://huggingface.co/docs/transformers/model_doc/modernbert#transformers.ModernBert) * [Moshi](https://huggingface.co/docs/transformers/model_doc/moshi#transformers.MoshiModel) * [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel) * [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel) diff --git a/src/transformers/models/modernbert/configuration_modernbert.py b/src/transformers/models/modernbert/configuration_modernbert.py index cb36eafed1af3a..6a278e4f90ac7a 100644 --- a/src/transformers/models/modernbert/configuration_modernbert.py +++ b/src/transformers/models/modernbert/configuration_modernbert.py @@ -70,9 +70,6 @@ class ModernBertConfig(PretrainedConfig): Classification token id. sep_token_id (`int`, *optional*, defaults to 50282): Separation token id. - tie_word_embeddings (`bool`, *optional*, defaults to `True`): - Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the - model has a output word embedding layer. global_rope_theta (`float`, *optional*, defaults to 160000.0): The base period of the global RoPE embeddings. attention_bias (`bool`, *optional*, defaults to `False`): @@ -154,7 +151,6 @@ def __init__( bos_token_id=50281, cls_token_id=50281, sep_token_id=50282, - tie_word_embeddings=True, global_rope_theta=160000.0, attention_bias=False, attention_dropout=0.0, @@ -183,7 +179,6 @@ def __init__( eos_token_id=eos_token_id, cls_token_id=cls_token_id, sep_token_id=sep_token_id, - tie_word_embeddings=tie_word_embeddings, **kwargs, ) self.vocab_size = vocab_size diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index 85e99e1fba538a..d2707b8613e974 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -102,9 +102,6 @@ class ModernBertConfig(PretrainedConfig): Classification token id. sep_token_id (`int`, *optional*, defaults to 50282): Separation token id. - tie_word_embeddings (`bool`, *optional*, defaults to `True`): - Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the - model has a output word embedding layer. global_rope_theta (`float`, *optional*, defaults to 160000.0): The base period of the global RoPE embeddings. attention_bias (`bool`, *optional*, defaults to `False`): @@ -186,7 +183,6 @@ def __init__( bos_token_id=50281, cls_token_id=50281, sep_token_id=50282, - tie_word_embeddings=True, global_rope_theta=160000.0, attention_bias=False, attention_dropout=0.0, @@ -215,7 +211,6 @@ def __init__( eos_token_id=eos_token_id, cls_token_id=cls_token_id, sep_token_id=sep_token_id, - tie_word_embeddings=tie_word_embeddings, **kwargs, ) self.vocab_size = vocab_size From f4e280abd747b6b4a58e904d96ccfa6774d0299d Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Tue, 17 Dec 2024 16:00:57 +0100 Subject: [PATCH 50/88] Only set compile to False if there's more than 1 device --- src/transformers/models/modernbert/modeling_modernbert.py | 2 +- src/transformers/models/modernbert/modular_modernbert.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index 93abd0b7ee1d7c..61771064a3cefb 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -778,7 +778,7 @@ def _maybe_set_compile(self): if self.config.compile is False: return - if hasattr(self, "hf_device_map"): + if hasattr(self, "hf_device_map") and len(self.hf_device_map) > 1: if self.config.compile: logger.warning_once( "If `accelerate` split the model across devices, `torch.compile` will not work. " diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index d2707b8613e974..41928429063e44 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -924,7 +924,7 @@ def _maybe_set_compile(self): if self.config.compile is False: return - if hasattr(self, "hf_device_map"): + if hasattr(self, "hf_device_map") and len(self.hf_device_map) > 1: if self.config.compile: logger.warning_once( "If `accelerate` split the model across devices, `torch.compile` will not work. " From bc149676ab026c8476e20e7fd7eafa303473484c Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Tue, 17 Dec 2024 17:02:02 +0100 Subject: [PATCH 51/88] Add docstrings for public ModernBert classes --- .../models/modernbert/modeling_modernbert.py | 136 +++++++++++++++++- .../models/modernbert/modular_modernbert.py | 128 +++++++++++++++++ 2 files changed, 263 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index 61771064a3cefb..f1f490d461e10c 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -32,7 +32,15 @@ from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_outputs import BaseModelOutput, MaskedLMOutput, SequenceClassifierOutput, TokenClassifierOutput from ...modeling_utils import PreTrainedModel -from ...utils import is_flash_attn_2_available, is_torch_greater_or_equal, logging +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_torch_greater_or_equal, + logging, + replace_return_docstrings, +) from ...utils.import_utils import is_triton_available from .configuration_modernbert import ModernBertConfig @@ -49,6 +57,9 @@ logger = logging.get_logger(__name__) +_CHECKPOINT_FOR_DOC = "answerdotai/ModernBERT-base" +_CONFIG_FOR_DOC = "ModernBertConfig" + class ApplyRotaryEmbUnpad(torch.autograd.Function): @staticmethod @@ -723,6 +734,27 @@ def _pad_modernbert_output( return padded_inputs +MODERNBERT_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`ModernBertConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare ModernBert Model outputting raw hidden-states without any specific head on top.", + MODERNBERT_START_DOCSTRING, +) class ModernBertPreTrainedModel(PreTrainedModel): config_class = ModernBertConfig base_model_prefix = "model" @@ -891,6 +923,68 @@ def sliding_window_seq_mask_mod(b, h, q_idx, kv_idx): return block_mask +MODERNBERT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*): + Indices of the non-padding tokens in the input sequence. Used for unpadding the output. + cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*): + Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors. + max_seqlen (`int`, *optional*): + Maximum sequence length in the batch. Used to pad the output tensors. + batch_size (`int`, *optional*): + Batch size of the input sequences. Used to pad the output tensors. + seq_len (`int`, *optional*): + Sequence length of the input sequences. Used to pad the output tensors. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + + Returns: +""" + + +@add_start_docstrings( + "The bare ModernBert Model outputting raw hidden-states without any specific head on top.", + MODERNBERT_START_DOCSTRING, +) class ModernBertModel(ModernBertPreTrainedModel): def __init__(self, config: ModernBertConfig): super().__init__(config) @@ -909,6 +1003,13 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.embeddings.tok_embeddings = value + @add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutput, + config_class=_CONFIG_FOR_DOC, + ) def forward( self, input_ids: torch.LongTensor = None, @@ -1019,6 +1120,10 @@ def forward( ) +@add_start_docstrings( + "The ModernBert Model with a decoder head on top that is used for masked language modeling.", + MODERNBERT_START_DOCSTRING, +) class ModernBertForMaskedLM(ModernBertPreTrainedModel): _tied_weights_keys = ["decoder.weight"] @@ -1045,6 +1150,13 @@ def set_output_embeddings(self, new_embeddings: nn.Linear): def compiled_head(self, output: torch.Tensor) -> torch.Tensor: return self.decoder(self.head(output)) + @add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MaskedLMOutput, + config_class=_CONFIG_FOR_DOC, + ) def forward( self, input_ids: Optional[torch.Tensor], @@ -1128,6 +1240,10 @@ def forward( ) +@add_start_docstrings( + "The ModernBert Model with a sequence classification head on top that performs pooling.", + MODERNBERT_START_DOCSTRING, +) class ModernBertForSequenceClassification(ModernBertPreTrainedModel): def __init__(self, config: ModernBertConfig): super().__init__(config) @@ -1141,6 +1257,13 @@ def __init__(self, config: ModernBertConfig): # Initialize weights and apply final processing self.post_init() + @add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=SequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) def forward( self, input_ids: Optional[torch.Tensor], @@ -1219,6 +1342,10 @@ def forward( ) +@add_start_docstrings( + "The ModernBert Model with a token classification head on top, e.g. for Named Entity Recognition (NER) tasks.", + MODERNBERT_START_DOCSTRING, +) class ModernBertForTokenClassification(ModernBertPreTrainedModel): def __init__(self, config: ModernBertConfig): super().__init__(config) @@ -1231,6 +1358,13 @@ def __init__(self, config: ModernBertConfig): # Initialize weights and apply final processing self.post_init() + @add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TokenClassifierOutput, config_class=_CONFIG_FOR_DOC) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) def forward( self, input_ids: Optional[torch.Tensor], diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index 41928429063e44..c17eea22bd84f2 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -34,9 +34,13 @@ ) from ...modeling_utils import PreTrainedModel from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, is_flash_attn_2_available, is_torch_greater_or_equal, logging, + replace_return_docstrings, ) from ...utils.import_utils import is_triton_available from ..gemma.modeling_gemma import GemmaRotaryEmbedding, apply_rotary_pos_emb @@ -53,6 +57,7 @@ from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention _CHECKPOINT_FOR_DOC = "answerdotai/ModernBERT-base" +_CONFIG_FOR_DOC = "ModernBertConfig" logger = logging.get_logger(__name__) @@ -869,6 +874,27 @@ def forward( return self.drop(self.norm(self.act(self.dense(hidden_states)))) +MODERNBERT_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`ModernBertConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare ModernBert Model outputting raw hidden-states without any specific head on top.", + MODERNBERT_START_DOCSTRING, +) class ModernBertPreTrainedModel(PreTrainedModel): config_class = ModernBertConfig base_model_prefix = "model" @@ -1037,6 +1063,68 @@ def sliding_window_seq_mask_mod(b, h, q_idx, kv_idx): return block_mask +MODERNBERT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*): + Indices of the non-padding tokens in the input sequence. Used for unpadding the output. + cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*): + Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors. + max_seqlen (`int`, *optional*): + Maximum sequence length in the batch. Used to pad the output tensors. + batch_size (`int`, *optional*): + Batch size of the input sequences. Used to pad the output tensors. + seq_len (`int`, *optional*): + Sequence length of the input sequences. Used to pad the output tensors. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + + Returns: +""" + + +@add_start_docstrings( + "The bare ModernBert Model outputting raw hidden-states without any specific head on top.", + MODERNBERT_START_DOCSTRING, +) class ModernBertModel(ModernBertPreTrainedModel): def __init__(self, config: ModernBertConfig): super().__init__(config) @@ -1055,6 +1143,13 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.embeddings.tok_embeddings = value + @add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutput, + config_class=_CONFIG_FOR_DOC, + ) def forward( self, input_ids: torch.LongTensor = None, @@ -1165,6 +1260,10 @@ def forward( ) +@add_start_docstrings( + "The ModernBert Model with a decoder head on top that is used for masked language modeling.", + MODERNBERT_START_DOCSTRING, +) class ModernBertForMaskedLM(ModernBertPreTrainedModel): _tied_weights_keys = ["decoder.weight"] @@ -1191,6 +1290,13 @@ def set_output_embeddings(self, new_embeddings: nn.Linear): def compiled_head(self, output: torch.Tensor) -> torch.Tensor: return self.decoder(self.head(output)) + @add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MaskedLMOutput, + config_class=_CONFIG_FOR_DOC, + ) def forward( self, input_ids: Optional[torch.Tensor], @@ -1274,6 +1380,10 @@ def forward( ) +@add_start_docstrings( + "The ModernBert Model with a sequence classification head on top that performs pooling.", + MODERNBERT_START_DOCSTRING, +) class ModernBertForSequenceClassification(ModernBertPreTrainedModel): def __init__(self, config: ModernBertConfig): super().__init__(config) @@ -1287,6 +1397,13 @@ def __init__(self, config: ModernBertConfig): # Initialize weights and apply final processing self.post_init() + @add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=SequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) def forward( self, input_ids: Optional[torch.Tensor], @@ -1365,6 +1482,10 @@ def forward( ) +@add_start_docstrings( + "The ModernBert Model with a token classification head on top, e.g. for Named Entity Recognition (NER) tasks.", + MODERNBERT_START_DOCSTRING, +) class ModernBertForTokenClassification(ModernBertPreTrainedModel): def __init__(self, config: ModernBertConfig): super().__init__(config) @@ -1377,6 +1498,13 @@ def __init__(self, config: ModernBertConfig): # Initialize weights and apply final processing self.post_init() + @add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TokenClassifierOutput, config_class=_CONFIG_FOR_DOC) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) def forward( self, input_ids: Optional[torch.Tensor], From 0f17fb94bc9a2ca39c4801d60c23701924374c55 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Tue, 17 Dec 2024 17:38:12 +0100 Subject: [PATCH 52/88] Dont replace docstring returns - ends up being duplicate --- src/transformers/models/modernbert/modeling_modernbert.py | 7 ------- src/transformers/models/modernbert/modular_modernbert.py | 7 ------- 2 files changed, 14 deletions(-) diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index f1f490d461e10c..407ba5e80b6502 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -39,7 +39,6 @@ is_flash_attn_2_available, is_torch_greater_or_equal, logging, - replace_return_docstrings, ) from ...utils.import_utils import is_triton_available from .configuration_modernbert import ModernBertConfig @@ -976,8 +975,6 @@ def sliding_window_seq_mask_mod(b, h, q_idx, kv_idx): more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - - Returns: """ @@ -1004,7 +1001,6 @@ def set_input_embeddings(self, value): self.embeddings.tok_embeddings = value @add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, output_type=BaseModelOutput, @@ -1151,7 +1147,6 @@ def compiled_head(self, output: torch.Tensor) -> torch.Tensor: return self.decoder(self.head(output)) @add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, output_type=MaskedLMOutput, @@ -1258,7 +1253,6 @@ def __init__(self, config: ModernBertConfig): self.post_init() @add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, output_type=SequenceClassifierOutput, @@ -1359,7 +1353,6 @@ def __init__(self, config: ModernBertConfig): self.post_init() @add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TokenClassifierOutput, config_class=_CONFIG_FOR_DOC) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, output_type=TokenClassifierOutput, diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index c17eea22bd84f2..0b9ec85a045951 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -40,7 +40,6 @@ is_flash_attn_2_available, is_torch_greater_or_equal, logging, - replace_return_docstrings, ) from ...utils.import_utils import is_triton_available from ..gemma.modeling_gemma import GemmaRotaryEmbedding, apply_rotary_pos_emb @@ -1116,8 +1115,6 @@ def sliding_window_seq_mask_mod(b, h, q_idx, kv_idx): more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - - Returns: """ @@ -1144,7 +1141,6 @@ def set_input_embeddings(self, value): self.embeddings.tok_embeddings = value @add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, output_type=BaseModelOutput, @@ -1291,7 +1287,6 @@ def compiled_head(self, output: torch.Tensor) -> torch.Tensor: return self.decoder(self.head(output)) @add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, output_type=MaskedLMOutput, @@ -1398,7 +1393,6 @@ def __init__(self, config: ModernBertConfig): self.post_init() @add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, output_type=SequenceClassifierOutput, @@ -1499,7 +1493,6 @@ def __init__(self, config: ModernBertConfig): self.post_init() @add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TokenClassifierOutput, config_class=_CONFIG_FOR_DOC) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, output_type=TokenClassifierOutput, From 25b12b4a9398732377546ef8af232b4125aa33a4 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Tue, 17 Dec 2024 18:16:00 +0100 Subject: [PATCH 53/88] Fix mistake in toctree --- docs/source/en/_toctree.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 719bde7556292f..4c862ee2d68d41 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -497,9 +497,9 @@ - local: model_doc/mobilebert title: MobileBERT - local: model_doc/mpnet - title: ModernBert - - local: model_doc/modernbert title: MPNet + - local: model_doc/modernbert + title: ModernBert - local: model_doc/mpt title: MPT - local: model_doc/mra From f312eefd3f83f052fa5077a5950b3c0d16f24960 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Tue, 17 Dec 2024 18:18:31 +0100 Subject: [PATCH 54/88] Reformat toctree --- docs/source/en/_toctree.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 4c862ee2d68d41..d22ec0bbc5ccb8 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -496,10 +496,10 @@ title: mLUKE - local: model_doc/mobilebert title: MobileBERT - - local: model_doc/mpnet - title: MPNet - local: model_doc/modernbert title: ModernBert + - local: model_doc/mpnet + title: MPNet - local: model_doc/mpt title: MPT - local: model_doc/mra From 1e367df9b9e0a92189c562445d42d09b859c22d1 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Tue, 17 Dec 2024 20:05:18 +0100 Subject: [PATCH 55/88] Patched FlexAttention, SDPA, Eager with Local Attention --- .../models/modernbert/modeling_modernbert.py | 57 ++++++++++++------- .../models/modernbert/modular_modernbert.py | 57 ++++++++++++------- 2 files changed, 76 insertions(+), 38 deletions(-) diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index 407ba5e80b6502..a75b995f05b73c 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -306,6 +306,7 @@ def eager_attention_forward( qkv: torch.Tensor, attention_mask: torch.Tensor, position_ids: Optional[torch.LongTensor], + local_attention: Tuple[int, int], bs: int, dim: int, output_attentions: Optional[bool] = False, @@ -320,9 +321,21 @@ def eager_attention_forward( scale = module.head_dim**-0.5 attn_weights = torch.matmul(query, key.transpose(2, 3)) * scale - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key.shape[-2]] - attn_weights = attn_weights + causal_mask + if attention_mask is not None: + expanded_mask = _prepare_4d_attention_mask(attention_mask, attn_weights.dtype, tgt_len=key.shape[-2]) + + if local_attention != (-1, -1): + # Create position indices + rows = torch.arange(expanded_mask.shape[2]).unsqueeze(0) + # Calculate distance between positions + distance = torch.abs(rows - rows.T) + + # Create sliding window mask (1 for positions within window, 0 outside) + window_mask = (distance <= local_attention[0]).unsqueeze(0).unsqueeze(0).to(attention_mask.device) + # Combine with existing mask + expanded_mask.masked_fill_(window_mask.logical_not(), float("-inf")) + + attn_weights = attn_weights + expanded_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) @@ -385,6 +398,7 @@ def flex_attention_forward( rotary_emb: ModernBertUnpaddedRotaryEmbedding, cu_seqlens: torch.Tensor, block_mask: "BlockMask", + local_attention: Tuple[int, int], max_seqlen: int, bs: int, dim: int, @@ -401,7 +415,7 @@ def flex_attention_forward( query, key, value, - block_mask=block_mask, + block_mask=block_mask if local_attention != (-1, -1) else None, enable_gqa=False, scale=None, return_lse=False, @@ -416,6 +430,7 @@ def sdpa_attention_forward( qkv: torch.Tensor, attention_mask: torch.Tensor, position_ids: Optional[torch.LongTensor], + local_attention: Tuple[int, int], bs: int, dim: int, **_kwargs, @@ -427,7 +442,22 @@ def sdpa_attention_forward( query, key = apply_rotary_pos_emb(query, key, cos, sin) if attention_mask is not None: - attention_mask = attention_mask[:, :, :, : key.shape[-2]] + attention_mask = attention_mask[:, None, None, :].expand( + attention_mask.shape[0], 1, attention_mask.shape[1], attention_mask.shape[1] + ) + + if local_attention != (-1, -1): + # Create position indices + rows = torch.arange(attention_mask.shape[2]).unsqueeze(0) + # Calculate distance between positions + distance = torch.abs(rows - rows.T) + + # Create sliding window mask (1 for positions within window, 0 outside) + window_mask = (distance <= local_attention[0]).unsqueeze(0).unsqueeze(0).to(attention_mask.device) + # Combine with existing mask + attention_mask = torch.logical_and(attention_mask, window_mask) + + attention_mask = attention_mask.to(torch.bool) attn_output = F.scaled_dot_product_attention( query, @@ -893,7 +923,6 @@ def offsets_to_sequence_ids_tensor(cls, offsets): counts = offsets[1:] - offsets[:-1] return torch.repeat_interleave(torch.arange(len(counts), device=device, dtype=torch.int32), counts) - @torch.compile(dynamic=False) def create_attention_mask(self, sequence_ids, cu_seqlens, window_size): """ Creates a block mask combining sequence masking and local/or global attention masking. @@ -1053,23 +1082,12 @@ def forward( hidden_states = self.embeddings(input_ids) - # expand attention_mask - if self.config._attn_implementation != "flash_attention_2" and attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) - # create block mask + block_mask = None if self.config._attn_implementation == "flex_attention": sequence_ids = self.offsets_to_sequence_ids_tensor(cu_seqlens) - - if self.config.local_attention != (-1, -1): - window_size = self.config.local_attention // 2 - else: - window_size = max_seqlen - + window_size = self.config.local_attention // 2 block_mask = self.create_attention_mask(sequence_ids, cu_seqlens, window_size) - else: - block_mask = None for encoder_layer in self.layers: if output_hidden_states: @@ -1082,6 +1100,7 @@ def forward( attention_mask, position_ids, cu_seqlens, + block_mask, max_seqlen, output_attentions, ) diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index 0b9ec85a045951..d19419f6b9324b 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -518,6 +518,7 @@ def eager_attention_forward( qkv: torch.Tensor, attention_mask: torch.Tensor, position_ids: Optional[torch.LongTensor], + local_attention: Tuple[int, int], bs: int, dim: int, output_attentions: Optional[bool] = False, @@ -532,9 +533,21 @@ def eager_attention_forward( scale = module.head_dim**-0.5 attn_weights = torch.matmul(query, key.transpose(2, 3)) * scale - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key.shape[-2]] - attn_weights = attn_weights + causal_mask + if attention_mask is not None: + expanded_mask = _prepare_4d_attention_mask(attention_mask, attn_weights.dtype, tgt_len=key.shape[-2]) + + if local_attention != (-1, -1): + # Create position indices + rows = torch.arange(expanded_mask.shape[2]).unsqueeze(0) + # Calculate distance between positions + distance = torch.abs(rows - rows.T) + + # Create sliding window mask (1 for positions within window, 0 outside) + window_mask = (distance <= local_attention[0]).unsqueeze(0).unsqueeze(0).to(attention_mask.device) + # Combine with existing mask + expanded_mask.masked_fill_(window_mask.logical_not(), float("-inf")) + + attn_weights = attn_weights + expanded_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) @@ -597,6 +610,7 @@ def flex_attention_forward( rotary_emb: ModernBertUnpaddedRotaryEmbedding, cu_seqlens: torch.Tensor, block_mask: "BlockMask", + local_attention: Tuple[int, int], max_seqlen: int, bs: int, dim: int, @@ -613,7 +627,7 @@ def flex_attention_forward( query, key, value, - block_mask=block_mask, + block_mask=block_mask if local_attention != (-1, -1) else None, enable_gqa=False, scale=None, return_lse=False, @@ -628,6 +642,7 @@ def sdpa_attention_forward( qkv: torch.Tensor, attention_mask: torch.Tensor, position_ids: Optional[torch.LongTensor], + local_attention: Tuple[int, int], bs: int, dim: int, **_kwargs, @@ -639,7 +654,22 @@ def sdpa_attention_forward( query, key = apply_rotary_pos_emb(query, key, cos, sin) if attention_mask is not None: - attention_mask = attention_mask[:, :, :, : key.shape[-2]] + attention_mask = attention_mask[:, None, None, :].expand( + attention_mask.shape[0], 1, attention_mask.shape[1], attention_mask.shape[1] + ) + + if local_attention != (-1, -1): + # Create position indices + rows = torch.arange(attention_mask.shape[2]).unsqueeze(0) + # Calculate distance between positions + distance = torch.abs(rows - rows.T) + + # Create sliding window mask (1 for positions within window, 0 outside) + window_mask = (distance <= local_attention[0]).unsqueeze(0).unsqueeze(0).to(attention_mask.device) + # Combine with existing mask + attention_mask = torch.logical_and(attention_mask, window_mask) + + attention_mask = attention_mask.to(torch.bool) attn_output = F.scaled_dot_product_attention( query, @@ -1033,7 +1063,6 @@ def offsets_to_sequence_ids_tensor(cls, offsets): counts = offsets[1:] - offsets[:-1] return torch.repeat_interleave(torch.arange(len(counts), device=device, dtype=torch.int32), counts) - @torch.compile(dynamic=False) def create_attention_mask(self, sequence_ids, cu_seqlens, window_size): """ Creates a block mask combining sequence masking and local/or global attention masking. @@ -1193,23 +1222,12 @@ def forward( hidden_states = self.embeddings(input_ids) - # expand attention_mask - if self.config._attn_implementation != "flash_attention_2" and attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) - # create block mask + block_mask = None if self.config._attn_implementation == "flex_attention": sequence_ids = self.offsets_to_sequence_ids_tensor(cu_seqlens) - - if self.config.local_attention != (-1, -1): - window_size = self.config.local_attention // 2 - else: - window_size = max_seqlen - + window_size = self.config.local_attention // 2 block_mask = self.create_attention_mask(sequence_ids, cu_seqlens, window_size) - else: - block_mask = None for encoder_layer in self.layers: if output_hidden_states: @@ -1222,6 +1240,7 @@ def forward( attention_mask, position_ids, cu_seqlens, + block_mask, max_seqlen, output_attentions, ) From fb748ce728af891fbf56ea77fcea4f3aa2443ff3 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Tue, 17 Dec 2024 21:27:23 +0100 Subject: [PATCH 56/88] Implement FA2 -> SDPA -> Eager attn_impl defaulting, crucial both to match the original performance, and to get the highest inference speed without requiring users to manually pick FA2 --- .../modernbert/configuration_modernbert.py | 9 -- .../models/modernbert/modeling_modernbert.py | 119 +++++++++++++++- .../models/modernbert/modular_modernbert.py | 128 ++++++++++++++++-- .../modernbert/test_modeling_modernbert.py | 8 ++ tests/test_modeling_common.py | 4 +- 5 files changed, 241 insertions(+), 27 deletions(-) diff --git a/src/transformers/models/modernbert/configuration_modernbert.py b/src/transformers/models/modernbert/configuration_modernbert.py index 6a278e4f90ac7a..4906902c792fc4 100644 --- a/src/transformers/models/modernbert/configuration_modernbert.py +++ b/src/transformers/models/modernbert/configuration_modernbert.py @@ -88,10 +88,6 @@ class ModernBertConfig(PretrainedConfig): Whether to use bias in the MLP layers. mlp_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the MLP layers. - unpad_inputs (`bool`, *optional*): - Whether to unpad the inputs in the forward pass. If set to `None`, then it will be set to `True` if the - attention implementation is `flash_attention_2` or `flex_attention`. Otherwise it will be set to `False`. - Unpadded inputs can be used to speed up the forward pass for `flash_attention_2` and `flex_attention`. unpad_no_grad (`bool`, *optional*, defaults to `True`): Whether to use `no_grad` when unpadding the inputs. decoder_bias (`bool`, *optional*, defaults to `True`): @@ -160,7 +156,6 @@ def __init__( embedding_dropout=0.0, mlp_bias=False, mlp_dropout=0.0, - unpad_inputs=None, unpad_no_grad=True, decoder_bias=True, classifier_dropout=0.0, @@ -201,7 +196,6 @@ def __init__( self.embedding_dropout = embedding_dropout self.mlp_bias = mlp_bias self.mlp_dropout = mlp_dropout - self.unpad_inputs = unpad_inputs self.unpad_no_grad = unpad_no_grad self.decoder_bias = decoder_bias self.classifier_dropout = classifier_dropout @@ -217,6 +211,3 @@ def __init__( raise ValueError( f'Invalid value for `classifier_pooling`, should be either "cls" or "mean", but is {self.classifier_pooling}.' ) - - if unpad_inputs is None: - self.unpad_inputs = self._attn_implementation in {"flash_attention_2", "flex_attention"} diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index a75b995f05b73c..78d6178b537e40 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -21,7 +21,7 @@ # limitations under the License. import math -from typing import Optional, Tuple, Union +from typing import Dict, Optional, Tuple, Union import torch import torch.nn.functional as F @@ -835,6 +835,117 @@ def init_weight(module: nn.Module, std: float): elif isinstance(module, (ModernBertForSequenceClassification, ModernBertForTokenClassification)): init_weight(module.classifier, stds["final_out"]) + @classmethod + def _autoset_attn_implementation( + cls, + config, + use_flash_attention_2: bool = False, + torch_dtype: Optional[torch.dtype] = None, + device_map: Optional[Union[str, Dict[str, int]]] = None, + check_device_map: bool = True, + ): + """ + Automatically checks and dispatches to a default attention implementation. In order of priority: + 1. An implementation specified in `config._attn_implementation` (due for example to the argument attn_implementation="sdpa" in from_pretrained). + 2. DEPRECATED: if use_flash_attention_2 is set to `True` and `flash_attn` is available, flash attention. (`LlamaFlashAttention` for example) + 3. FA2, if available. + 4. SDPA, if available. + 5. Eager attention. + """ + # Here we use config._attn_implementation_internal to check whether the attention implementation was explicitely set by the user. + # The property `PretrainedConfig._attn_implementation` is never `None`, for backward compatibility (always fall back on "eager"). + # The `hasattr` here is used as some Transformers tests for some reason do not call PretrainedConfig __init__ (e.g. test_no_super_init_config_and_model) + requested_attn_implementation = None + if hasattr(config, "_attn_implementation_internal") and config._attn_implementation_internal is not None: + if config._attn_implementation != "flash_attention_2" and use_flash_attention_2: + raise ValueError( + f'Both attn_implementation="{config._attn_implementation}" and `use_flash_attention_2=True` were used when loading the model, which are not compatible.' + ' We recommend to just use `attn_implementation="flash_attention_2"` when loading the model.' + ) + + if not isinstance(config._attn_implementation, dict) and config._attn_implementation not in [ + "eager", + "sdpa", + "flash_attention_2", + "flex_attention", + ]: + message = f'Specified `attn_implementation="{config._attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)' + if cls._supports_flash_attn_2: + message += ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)' + if cls._supports_sdpa: + message += ', `"attn_implementation=sdpa"` (implementation using torch.nn.functional.scaled_dot_product_attention)' + if cls._supports_flex_attn: + message += ( + ', `"attn_implementation=flex_attention"` (implementation using torch\'s flex_attention)' + ) + raise ValueError(message + ".") + + # If a config is passed with a preset attn_implementation, we skip the automatic dispatch and use the user-provided config, with hard checks that the requested attention implementation is available. + requested_attn_implementation = config._attn_implementation_internal + + if use_flash_attention_2: + logger.warning_once( + 'The model was loaded with use_flash_attention_2=True, which is deprecated and may be removed in a future release. Please use `attn_implementation="flash_attention_2"` instead.' + ) + cls._check_and_enable_flash_attn_2( + config, + torch_dtype=torch_dtype, + device_map=device_map, + hard_check_only=False, + check_device_map=check_device_map, + ) + + # If FA2 is requested and it fails, we throw an error. + # If None is requested, we try to enable FA2, but if it fails, we fall back to the next implementation. + if requested_attn_implementation in [None, "flash_attention_2"]: + try: + config = cls._check_and_enable_flash_attn_2( + config, + torch_dtype=torch_dtype, + device_map=device_map, + hard_check_only=False, + check_device_map=check_device_map, + ) + except (ValueError, ImportError) as e: + if requested_attn_implementation == "flash_attention_2": + raise e + else: + config._attn_implementation_autoset = True + return config + + # If SDPA is requested and it fails, we throw an error. + # If None is requested, we try to enable SDPA after FA2 fails, but if it fails, we fall back to the next implementation. + if requested_attn_implementation in [None, "sdpa"]: + try: + config = cls._check_and_enable_sdpa(config, hard_check_only=False) + except (ValueError, ImportError) as e: + if requested_attn_implementation == "sdpa": + raise e + else: + if ( + torch.version.hip is not None + and config._attn_implementation == "sdpa" + and torch.cuda.device_count() > 1 + ): + logger.warning_once( + "Using the `SDPA` attention implementation on multi-gpu setup with ROCM may lead to performance issues due to the FA backend. Disabling it to use alternative backends." + ) + torch.backends.cuda.enable_flash_sdp(False) + config._attn_implementation_autoset = True + return config + + # If flex_attention is requested and it fails, we throw an error. + # This implementation is not used by default, so we only try to enable it if it is requested. + if requested_attn_implementation == "flex_attention": + config = cls._check_and_enable_flex_attn(config, hard_check_only=False) + config._attn_implementation_autoset = True + return config + + # If eager is requested, or if None is requested but FA2 and SDPA fail, we set it and return the config. + config._attn_implementation = "eager" + config._attn_implementation_autoset = True + return config + def _maybe_set_compile(self): if self.config.compile is False: return @@ -1068,7 +1179,7 @@ def forward( attention_mask = torch.ones((batch_size, seq_len), device=input_ids.device, dtype=torch.bool) repad = False - if self.config.unpad_inputs: + if self.config._attn_implementation in {"flash_attention_2", "flex_attention"}: if indices is None and cu_seqlens is None and max_seqlen is None: repad = True if self.config.unpad_no_grad: @@ -1190,7 +1301,7 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict self._maybe_set_compile() - if self.config.unpad_inputs: + if self.config._attn_implementation in {"flash_attention_2", "flex_attention"}: if indices is None and cu_seqlens is None and max_seqlen is None: batch_size, seq_len = input_ids.shape[:2] if self.config.unpad_no_grad: @@ -1237,7 +1348,7 @@ def forward( if labels is not None: loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size) - if self.config.unpad_inputs: + if self.config._attn_implementation in {"flash_attention_2", "flex_attention"}: if self.config.unpad_no_grad: logits = self._pad_outputs_no_grad(logits, indices, batch_size, seq_len) else: diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index d19419f6b9324b..f08708a3a448db 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -15,7 +15,7 @@ # limitations under the License. import math -from typing import Literal, Optional, Tuple, Union +from typing import Dict, Literal, Optional, Tuple, Union import torch import torch.nn.functional as F @@ -124,10 +124,6 @@ class ModernBertConfig(PretrainedConfig): Whether to use bias in the MLP layers. mlp_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the MLP layers. - unpad_inputs (`bool`, *optional*): - Whether to unpad the inputs in the forward pass. If set to `None`, then it will be set to `True` if the - attention implementation is `flash_attention_2` or `flex_attention`. Otherwise it will be set to `False`. - Unpadded inputs can be used to speed up the forward pass for `flash_attention_2` and `flex_attention`. unpad_no_grad (`bool`, *optional*, defaults to `True`): Whether to use `no_grad` when unpadding the inputs. decoder_bias (`bool`, *optional*, defaults to `True`): @@ -196,7 +192,6 @@ def __init__( embedding_dropout=0.0, mlp_bias=False, mlp_dropout=0.0, - unpad_inputs=None, unpad_no_grad=True, decoder_bias=True, classifier_dropout=0.0, @@ -237,7 +232,6 @@ def __init__( self.embedding_dropout = embedding_dropout self.mlp_bias = mlp_bias self.mlp_dropout = mlp_dropout - self.unpad_inputs = unpad_inputs self.unpad_no_grad = unpad_no_grad self.decoder_bias = decoder_bias self.classifier_dropout = classifier_dropout @@ -254,9 +248,6 @@ def __init__( f'Invalid value for `classifier_pooling`, should be either "cls" or "mean", but is {self.classifier_pooling}.' ) - if unpad_inputs is None: - self.unpad_inputs = self._attn_implementation in {"flash_attention_2", "flex_attention"} - def _unpad_modernbert_input( inputs: torch.Tensor, @@ -975,6 +966,117 @@ def init_weight(module: nn.Module, std: float): elif isinstance(module, (ModernBertForSequenceClassification, ModernBertForTokenClassification)): init_weight(module.classifier, stds["final_out"]) + @classmethod + def _autoset_attn_implementation( + cls, + config, + use_flash_attention_2: bool = False, + torch_dtype: Optional[torch.dtype] = None, + device_map: Optional[Union[str, Dict[str, int]]] = None, + check_device_map: bool = True, + ): + """ + Automatically checks and dispatches to a default attention implementation. In order of priority: + 1. An implementation specified in `config._attn_implementation` (due for example to the argument attn_implementation="sdpa" in from_pretrained). + 2. DEPRECATED: if use_flash_attention_2 is set to `True` and `flash_attn` is available, flash attention. (`LlamaFlashAttention` for example) + 3. FA2, if available. + 4. SDPA, if available. + 5. Eager attention. + """ + # Here we use config._attn_implementation_internal to check whether the attention implementation was explicitely set by the user. + # The property `PretrainedConfig._attn_implementation` is never `None`, for backward compatibility (always fall back on "eager"). + # The `hasattr` here is used as some Transformers tests for some reason do not call PretrainedConfig __init__ (e.g. test_no_super_init_config_and_model) + requested_attn_implementation = None + if hasattr(config, "_attn_implementation_internal") and config._attn_implementation_internal is not None: + if config._attn_implementation != "flash_attention_2" and use_flash_attention_2: + raise ValueError( + f'Both attn_implementation="{config._attn_implementation}" and `use_flash_attention_2=True` were used when loading the model, which are not compatible.' + ' We recommend to just use `attn_implementation="flash_attention_2"` when loading the model.' + ) + + if not isinstance(config._attn_implementation, dict) and config._attn_implementation not in [ + "eager", + "sdpa", + "flash_attention_2", + "flex_attention", + ]: + message = f'Specified `attn_implementation="{config._attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)' + if cls._supports_flash_attn_2: + message += ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)' + if cls._supports_sdpa: + message += ', `"attn_implementation=sdpa"` (implementation using torch.nn.functional.scaled_dot_product_attention)' + if cls._supports_flex_attn: + message += ( + ', `"attn_implementation=flex_attention"` (implementation using torch\'s flex_attention)' + ) + raise ValueError(message + ".") + + # If a config is passed with a preset attn_implementation, we skip the automatic dispatch and use the user-provided config, with hard checks that the requested attention implementation is available. + requested_attn_implementation = config._attn_implementation_internal + + if use_flash_attention_2: + logger.warning_once( + 'The model was loaded with use_flash_attention_2=True, which is deprecated and may be removed in a future release. Please use `attn_implementation="flash_attention_2"` instead.' + ) + cls._check_and_enable_flash_attn_2( + config, + torch_dtype=torch_dtype, + device_map=device_map, + hard_check_only=False, + check_device_map=check_device_map, + ) + + # If FA2 is requested and it fails, we throw an error. + # If None is requested, we try to enable FA2, but if it fails, we fall back to the next implementation. + if requested_attn_implementation in [None, "flash_attention_2"]: + try: + config = cls._check_and_enable_flash_attn_2( + config, + torch_dtype=torch_dtype, + device_map=device_map, + hard_check_only=False, + check_device_map=check_device_map, + ) + except (ValueError, ImportError) as e: + if requested_attn_implementation == "flash_attention_2": + raise e + else: + config._attn_implementation_autoset = True + return config + + # If SDPA is requested and it fails, we throw an error. + # If None is requested, we try to enable SDPA after FA2 fails, but if it fails, we fall back to the next implementation. + if requested_attn_implementation in [None, "sdpa"]: + try: + config = cls._check_and_enable_sdpa(config, hard_check_only=False) + except (ValueError, ImportError) as e: + if requested_attn_implementation == "sdpa": + raise e + else: + if ( + torch.version.hip is not None + and config._attn_implementation == "sdpa" + and torch.cuda.device_count() > 1 + ): + logger.warning_once( + "Using the `SDPA` attention implementation on multi-gpu setup with ROCM may lead to performance issues due to the FA backend. Disabling it to use alternative backends." + ) + torch.backends.cuda.enable_flash_sdp(False) + config._attn_implementation_autoset = True + return config + + # If flex_attention is requested and it fails, we throw an error. + # This implementation is not used by default, so we only try to enable it if it is requested. + if requested_attn_implementation == "flex_attention": + config = cls._check_and_enable_flex_attn(config, hard_check_only=False) + config._attn_implementation_autoset = True + return config + + # If eager is requested, or if None is requested but FA2 and SDPA fail, we set it and return the config. + config._attn_implementation = "eager" + config._attn_implementation_autoset = True + return config + def _maybe_set_compile(self): if self.config.compile is False: return @@ -1208,7 +1310,7 @@ def forward( attention_mask = torch.ones((batch_size, seq_len), device=input_ids.device, dtype=torch.bool) repad = False - if self.config.unpad_inputs: + if self.config._attn_implementation in {"flash_attention_2", "flex_attention"}: if indices is None and cu_seqlens is None and max_seqlen is None: repad = True if self.config.unpad_no_grad: @@ -1330,7 +1432,7 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict self._maybe_set_compile() - if self.config.unpad_inputs: + if self.config._attn_implementation in {"flash_attention_2", "flex_attention"}: if indices is None and cu_seqlens is None and max_seqlen is None: batch_size, seq_len = input_ids.shape[:2] if self.config.unpad_no_grad: @@ -1377,7 +1479,7 @@ def forward( if labels is not None: loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size) - if self.config.unpad_inputs: + if self.config._attn_implementation in {"flash_attention_2", "flex_attention"}: if self.config.unpad_no_grad: logits = self._pad_outputs_no_grad(logits, indices, batch_size, seq_len) else: diff --git a/tests/models/modernbert/test_modeling_modernbert.py b/tests/models/modernbert/test_modeling_modernbert.py index 0f9e3ade73f207..5094ca7963cb76 100644 --- a/tests/models/modernbert/test_modeling_modernbert.py +++ b/tests/models/modernbert/test_modeling_modernbert.py @@ -144,6 +144,14 @@ def get_config(self): # much like here. We're testing whether it works once they've done that. if test_name == "test_retain_grad_hidden_states_attentions": config.compile = False + # Some tests require attentions to be outputted, in that case we'll set the attention implementation to eager + # as the others don't support outputted attentions + if test_name in ( + "test_attention_outputs", + "test_hidden_states_output", + "test_retain_grad_hidden_states_attentions", + ): + config._attn_implementation = "eager" return config def create_and_check_model(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels): diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 9e3b9d741b44a9..f7cc90e86745ac 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4048,7 +4048,9 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) - model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype) + model_sdpa = model_class.from_pretrained( + tmpdirname, torch_dtype=torch_dtype, attn_implementation="sdpa" + ) model_sdpa = model_sdpa.eval().to(torch_device, dtype=torch_dtype) model_eager = model_class.from_pretrained( From 051233f9f76b74eedbe2bed0aca03daec2fe4083 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Tue, 17 Dec 2024 21:44:45 +0100 Subject: [PATCH 57/88] Patch test edge case with Idefics3 not working with 'attn_implementation="sdpa"' --- tests/test_modeling_common.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index f7cc90e86745ac..d08e25546ce10a 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4048,9 +4048,12 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) - model_sdpa = model_class.from_pretrained( - tmpdirname, torch_dtype=torch_dtype, attn_implementation="sdpa" - ) + try: + model_sdpa = model_class.from_pretrained( + tmpdirname, torch_dtype=torch_dtype, attn_implementation="sdpa" + ) + except ValueError: + model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype) model_sdpa = model_sdpa.eval().to(torch_device, dtype=torch_dtype) model_eager = model_class.from_pretrained( From 6c01711e06aeeba88b327e4befaeafc439565a5a Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Tue, 17 Dec 2024 21:50:46 +0100 Subject: [PATCH 58/88] Repad all_hidden_states as well --- src/transformers/models/modernbert/modeling_modernbert.py | 4 ++++ src/transformers/models/modernbert/modular_modernbert.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index 78d6178b537e40..e60f026b3c730f 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -1236,6 +1236,10 @@ def forward( if repad: hidden_states = self._pad_outputs(hidden_states, indices, batch_size, seq_len) + if all_hidden_states is not None: + all_hidden_states = tuple( + self._pad_outputs(hs, indices, batch_size, seq_len) for hs in all_hidden_states + ) if not return_dict: return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index f08708a3a448db..33ea7e2ed7fa43 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -1367,6 +1367,10 @@ def forward( if repad: hidden_states = self._pad_outputs(hidden_states, indices, batch_size, seq_len) + if all_hidden_states is not None: + all_hidden_states = tuple( + self._pad_outputs(hs, indices, batch_size, seq_len) for hs in all_hidden_states + ) if not return_dict: return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) From 5f7c5663ed4b58dbf21e2c3ed88d1963a953b88c Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Tue, 17 Dec 2024 23:52:47 -0600 Subject: [PATCH 59/88] rename config.compile to reference_compile --- .../modernbert/configuration_modernbert.py | 13 +++--- .../models/modernbert/modeling_modernbert.py | 28 +++++++------ .../models/modernbert/modular_modernbert.py | 41 ++++++++++--------- 3 files changed, 44 insertions(+), 38 deletions(-) diff --git a/src/transformers/models/modernbert/configuration_modernbert.py b/src/transformers/models/modernbert/configuration_modernbert.py index 4906902c792fc4..dfd759df4413b0 100644 --- a/src/transformers/models/modernbert/configuration_modernbert.py +++ b/src/transformers/models/modernbert/configuration_modernbert.py @@ -106,10 +106,11 @@ class ModernBertConfig(PretrainedConfig): Whether to use sparse prediction for the masked language model instead of returning the full dense logits. sparse_pred_ignore_index (`int`, *optional*, defaults to -100): The index to ignore for the sparse prediction. - compile (`bool`, *optional*): - Whether to compile the model. If `None`, then parts of the model will be compiled if 1) `triton` is - installed, 2) the model is not on MPS, 3) the model is not shared between devices, and 4) the model is not - resized after initialization. If `True`, then the model may be faster in some scenarios. + reference_compile (`bool`, *optional*): + Whether to compile the layers of the model which were compiled during pretraining. If `None`, then parts of + the model will be compiled if 1) `triton` is installed, 2) the model is not on MPS, 3) the model is not + shared between devices, and 4) the model is not resized after initialization. If `True`, then the model may + be faster in some scenarios. Examples: @@ -165,7 +166,7 @@ def __init__( deterministic_flash_attn=False, sparse_prediction=False, sparse_pred_ignore_index=-100, - compile=None, + reference_compile=None, **kwargs, ): super().__init__( @@ -205,7 +206,7 @@ def __init__( self.deterministic_flash_attn = deterministic_flash_attn self.sparse_prediction = sparse_prediction self.sparse_pred_ignore_index = sparse_pred_ignore_index - self.compile = compile + self.reference_compile = reference_compile if self.classifier_pooling not in ["cls", "mean"]: raise ValueError( diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index e60f026b3c730f..ed37f4aadfacff 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -213,7 +213,7 @@ def compiled_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor: def forward(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None) -> torch.Tensor: hidden_states = ( self.compiled_embeddings(input_ids) - if self.config.compile + if self.config.reference_compile else self.drop(self.norm(self.tok_embeddings(input_ids))) ) return hidden_states @@ -639,7 +639,9 @@ def forward( ) hidden_states = hidden_states + attn_outputs[0] mlp_output = ( - self.compiled_mlp(hidden_states) if self.config.compile else self.mlp(self.mlp_norm(hidden_states)) + self.compiled_mlp(hidden_states) + if self.config.reference_compile + else self.mlp(self.mlp_norm(hidden_states)) ) hidden_states = hidden_states + mlp_output @@ -947,37 +949,37 @@ def _autoset_attn_implementation( return config def _maybe_set_compile(self): - if self.config.compile is False: + if self.config.reference_compile is False: return if hasattr(self, "hf_device_map") and len(self.hf_device_map) > 1: - if self.config.compile: + if self.config.reference_compile: logger.warning_once( "If `accelerate` split the model across devices, `torch.compile` will not work. " "Falling back to non-compiled mode." ) - self.config.compile = False + self.config.reference_compile = False if self.device.type == "mps": - if self.config.compile: + if self.config.reference_compile: logger.warning_once( "Compiling the model with `torch.compile` and using a `torch.mps` device is not supported. " "Falling back to non-compiled mode." ) - self.config.compile = False + self.config.reference_compile = False - if self.config.compile is None: - self.config.compile = is_triton_available() + if self.config.reference_compile is None: + self.config.reference_compile = is_triton_available() def resize_token_embeddings(self, *args, **kwargs): model_embeds = super().resize_token_embeddings(*args, **kwargs) - if self.config.compile in {True, None}: - if self.config.compile: + if self.config.reference_compile in {True, None}: + if self.config.reference_compile: logger.warning_once( "Resizing token embeddings with `torch.compile` is not supported. Falling back to non-compiled mode." ) - self.config.compile = False + self.config.reference_compile = False return model_embeds @@ -1344,7 +1346,7 @@ def forward( logits = ( self.compiled_head(last_hidden_state) - if self.config.compile + if self.config.reference_compile else self.decoder(self.head(last_hidden_state)) ) diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index 33ea7e2ed7fa43..39d217885ea4bf 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -142,10 +142,11 @@ class ModernBertConfig(PretrainedConfig): Whether to use sparse prediction for the masked language model instead of returning the full dense logits. sparse_pred_ignore_index (`int`, *optional*, defaults to -100): The index to ignore for the sparse prediction. - compile (`bool`, *optional*): - Whether to compile the model. If `None`, then parts of the model will be compiled if 1) `triton` is - installed, 2) the model is not on MPS, 3) the model is not shared between devices, and 4) the model is not - resized after initialization. If `True`, then the model may be faster in some scenarios. + reference_compile (`bool`, *optional*): + Whether to compile the layers of the model which were compiled during pretraining. If `None`, then parts of + the model will be compiled if 1) `triton` is installed, 2) the model is not on MPS, 3) the model is not + shared between devices, and 4) the model is not resized after initialization. If `True`, then the model may + be faster in some scenarios. Examples: @@ -201,7 +202,7 @@ def __init__( deterministic_flash_attn=False, sparse_prediction=False, sparse_pred_ignore_index=-100, - compile=None, + reference_compile=None, **kwargs, ): super().__init__( @@ -241,7 +242,7 @@ def __init__( self.deterministic_flash_attn = deterministic_flash_attn self.sparse_prediction = sparse_prediction self.sparse_pred_ignore_index = sparse_pred_ignore_index - self.compile = compile + self.reference_compile = reference_compile if self.classifier_pooling not in ["cls", "mean"]: raise ValueError( @@ -474,7 +475,7 @@ def compiled_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor: def forward(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None) -> torch.Tensor: hidden_states = ( self.compiled_embeddings(input_ids) - if self.config.compile + if self.config.reference_compile else self.drop(self.norm(self.tok_embeddings(input_ids))) ) return hidden_states @@ -842,7 +843,9 @@ def forward( ) hidden_states = hidden_states + attn_outputs[0] mlp_output = ( - self.compiled_mlp(hidden_states) if self.config.compile else self.mlp(self.mlp_norm(hidden_states)) + self.compiled_mlp(hidden_states) + if self.config.reference_compile + else self.mlp(self.mlp_norm(hidden_states)) ) hidden_states = hidden_states + mlp_output @@ -1078,37 +1081,37 @@ def _autoset_attn_implementation( return config def _maybe_set_compile(self): - if self.config.compile is False: + if self.config.reference_compile is False: return if hasattr(self, "hf_device_map") and len(self.hf_device_map) > 1: - if self.config.compile: + if self.config.reference_compile: logger.warning_once( "If `accelerate` split the model across devices, `torch.compile` will not work. " "Falling back to non-compiled mode." ) - self.config.compile = False + self.config.reference_compile = False if self.device.type == "mps": - if self.config.compile: + if self.config.reference_compile: logger.warning_once( "Compiling the model with `torch.compile` and using a `torch.mps` device is not supported. " "Falling back to non-compiled mode." ) - self.config.compile = False + self.config.reference_compile = False - if self.config.compile is None: - self.config.compile = is_triton_available() + if self.config.reference_compile is None: + self.config.reference_compile = is_triton_available() def resize_token_embeddings(self, *args, **kwargs): model_embeds = super().resize_token_embeddings(*args, **kwargs) - if self.config.compile in {True, None}: - if self.config.compile: + if self.config.reference_compile in {True, None}: + if self.config.reference_compile: logger.warning_once( "Resizing token embeddings with `torch.compile` is not supported. Falling back to non-compiled mode." ) - self.config.compile = False + self.config.reference_compile = False return model_embeds @@ -1475,7 +1478,7 @@ def forward( logits = ( self.compiled_head(last_hidden_state) - if self.config.compile + if self.config.reference_compile else self.decoder(self.head(last_hidden_state)) ) From c8a80e7838c015c0c47597bf13a4f369c17907db Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Wed, 18 Dec 2024 01:28:08 -0600 Subject: [PATCH 60/88] disable flex_attention since it crashes --- .../models/modernbert/modeling_modernbert.py | 40 ++----------------- .../models/modernbert/modular_modernbert.py | 5 ++- 2 files changed, 7 insertions(+), 38 deletions(-) diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index ed37f4aadfacff..e08d12658c8a43 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -52,7 +52,7 @@ RotaryEmbedding = object if is_torch_greater_or_equal("2.5"): - from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention + from torch.nn.attention.flex_attention import BlockMask, create_block_mask logger = logging.get_logger(__name__) @@ -392,39 +392,6 @@ def flash_attention_forward( return (attn.view(bs, dim),) -def flex_attention_forward( - module: "ModernBertAttention", - qkv: torch.Tensor, - rotary_emb: ModernBertUnpaddedRotaryEmbedding, - cu_seqlens: torch.Tensor, - block_mask: "BlockMask", - local_attention: Tuple[int, int], - max_seqlen: int, - bs: int, - dim: int, - **_kwargs, -) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - qkv = rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) - query, key, value = qkv.unbind(dim=1) - - query = query.transpose(0, 1).unsqueeze(0) - key = key.transpose(0, 1).unsqueeze(0) - value = value.transpose(0, 1).unsqueeze(0) - - attn_output = flex_attention( - query, - key, - value, - block_mask=block_mask if local_attention != (-1, -1) else None, - enable_gqa=False, - scale=None, - return_lse=False, - ) - - attn_output = attn_output.squeeze(0).transpose(0, 1).contiguous() - return (attn_output.view(bs, dim),) - - def sdpa_attention_forward( module: "ModernBertAttention", qkv: torch.Tensor, @@ -472,7 +439,7 @@ def sdpa_attention_forward( MODERNBERT_ATTENTION_FUNCTION = { "flash_attention_2": flash_attention_forward, - "flex_attention": flex_attention_forward, + # "flex_attention": flex_attention_forward, "eager": eager_attention_forward, "sdpa": sdpa_attention_forward, } @@ -793,7 +760,7 @@ class ModernBertPreTrainedModel(PreTrainedModel): _no_split_modules = ["ModernBertEmbeddings", "ModernBertEncoderLayer"] _supports_flash_attn_2 = True _supports_sdpa = True - _supports_flex_attn = True + _supports_flex_attn = False def _init_weights(self, module: nn.Module): cutoff_factor = self.config.initializer_cutoff_factor @@ -939,6 +906,7 @@ def _autoset_attn_implementation( # If flex_attention is requested and it fails, we throw an error. # This implementation is not used by default, so we only try to enable it if it is requested. if requested_attn_implementation == "flex_attention": + raise NotImplementedError("Flex attention is not supported for ModernBert") config = cls._check_and_enable_flex_attn(config, hard_check_only=False) config._attn_implementation_autoset = True return config diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index 39d217885ea4bf..b8ad298a972e0d 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -676,7 +676,7 @@ def sdpa_attention_forward( MODERNBERT_ATTENTION_FUNCTION = { "flash_attention_2": flash_attention_forward, - "flex_attention": flex_attention_forward, + # "flex_attention": flex_attention_forward, "eager": eager_attention_forward, "sdpa": sdpa_attention_forward, } @@ -925,7 +925,7 @@ class ModernBertPreTrainedModel(PreTrainedModel): _no_split_modules = ["ModernBertEmbeddings", "ModernBertEncoderLayer"] _supports_flash_attn_2 = True _supports_sdpa = True - _supports_flex_attn = True + _supports_flex_attn = False def _init_weights(self, module: nn.Module): cutoff_factor = self.config.initializer_cutoff_factor @@ -1071,6 +1071,7 @@ def _autoset_attn_implementation( # If flex_attention is requested and it fails, we throw an error. # This implementation is not used by default, so we only try to enable it if it is requested. if requested_attn_implementation == "flex_attention": + raise NotImplementedError("Flex attention is not supported for ModernBert") config = cls._check_and_enable_flex_attn(config, hard_check_only=False) config._attn_implementation_autoset = True return config From 8962f052d5274ab62c5241cf01905bf31f764ed3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Benjamin=20Clavi=C3=A9?= Date: Wed, 18 Dec 2024 09:33:02 +0100 Subject: [PATCH 61/88] Update modernbert.md --- docs/source/en/model_doc/modernbert.md | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/docs/source/en/model_doc/modernbert.md b/docs/source/en/model_doc/modernbert.md index b1e94f62105e36..5286ec864065e5 100644 --- a/docs/source/en/model_doc/modernbert.md +++ b/docs/source/en/model_doc/modernbert.md @@ -27,15 +27,24 @@ rendered properly in your Markdown viewer. ## Overview -The ModernBert model was proposed in []() by ... +The ModernBert model was proposed in [Smarter, Better, Faster, Longer}: A Modern Bidirectional Encoder for Fast, Memory Efficient, and Long Context Finetuning and Inference](#) by Benjamin Warner, Antoine Chaffin, Benjamin Clavié, Orion Weller, Oskar Hallstrâm, Said Taghadouini, Alexis Galalgher, Raja Bisas, Faisal Ladhak, Tom Aarsen, Nathan Cooper, Grifin Adams, Jeremy Howard and Iacopo Poli. -It builds on BERT and modifies ... +It is a refresh of the traditional encoder architecture, as used in previous models such as [BERT](https://huggingface.co/docs/transformers/en/model_doc/bert) and [RoBERTa](https://huggingface.co/docs/transformers/en/model_doc/roberta). + +It builds on BERT and implements many modern architectural improvements which have been developed since its original release, such as: +- [Rotary Positional Embeddings](https://huggingface.co/blog/designing-positional-encoding) to support sequences of up to 8192 tokens. +- [Unpadding](https://arxiv.org/abs/2208.08124) to ensure no compute is wasted on padding tokens, speeding up processing time for batches with mixed-length sequences. +- [GeGLU](https://arxiv.org/abs/2002.05202) Replacing the original MLP layers with GeGLU layers, shown to improve performance. +- [Alternating Attention](https://arxiv.org/abs/2004.05150v2) where most attention layers employ a sliding window of 128 tokens, with Global Attention only used every 3 layers. +- [Flash Attention](https://github.com/Dao-AILab/flash-attention) to speed up processing. +- A model designed following recent [The Case for Co-Designing Model Architectures with Hardware](https://arxiv.org/abs/2401.14489), ensuring maximum efficiency across inference GPUs. +- Modern training data scales (2 trillion tokens) and mixtures (including code ande math data) The abstract from the paper is the following: -** +*Encoder-only transformer models such as BERT offer a great performance-size tradeoff for retrieval and classification tasks with respect to larger decoder-only models. Despite being the workhorse of numerous production pipelines, there have been limited Pareto improvements to BERT since its release. In this paper, we introduce ModernBERT, bringing modern model optimizations to encoder-only models and representing a major Pareto improvement over older encoders. Trained on 2 trillion tokens with a native 8192 sequence length, ModernBERT models exhibit state-of-the-art results on a large pool of evaluations encompassing diverse classification tasks and both single and multi-vector retrieval on different domains (including code). In addition to strong downstream performance, ModernBERT is also the most speed and memory efficient encoder and is designed for inference on common GPUs.* -The original code can be found [here](). +The original code can be found [here](https://github.com/answerdotai/modernbert). ## Usage tips From 7e89f4dd3b7428ab55f28c12058b7ef6c2aea08d Mon Sep 17 00:00:00 2001 From: Antoine Chaffin Date: Wed, 18 Dec 2024 09:41:35 +0000 Subject: [PATCH 62/88] Using dtype min to mask in eager --- src/transformers/models/modernbert/modeling_modernbert.py | 2 +- src/transformers/models/modernbert/modular_modernbert.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index e08d12658c8a43..62dff9aca0f534 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -333,7 +333,7 @@ def eager_attention_forward( # Create sliding window mask (1 for positions within window, 0 outside) window_mask = (distance <= local_attention[0]).unsqueeze(0).unsqueeze(0).to(attention_mask.device) # Combine with existing mask - expanded_mask.masked_fill_(window_mask.logical_not(), float("-inf")) + expanded_mask.masked_fill_(window_mask.logical_not(), torch.finfo(attn_weights.dtype).min) attn_weights = attn_weights + expanded_mask diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index b8ad298a972e0d..17dd6af56d35b2 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -537,7 +537,7 @@ def eager_attention_forward( # Create sliding window mask (1 for positions within window, 0 outside) window_mask = (distance <= local_attention[0]).unsqueeze(0).unsqueeze(0).to(attention_mask.device) # Combine with existing mask - expanded_mask.masked_fill_(window_mask.logical_not(), float("-inf")) + expanded_mask.masked_fill_(window_mask.logical_not(), torch.finfo(attn_weights.dtype).min) attn_weights = attn_weights + expanded_mask From 0742a1d05bcbf2edd849529004a5e76c85b67727 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Wed, 18 Dec 2024 10:49:16 +0100 Subject: [PATCH 63/88] Fully remove flex attention for now It's only compatible with the nightly torch 2.6, so we'll leave it be for now. It's also slower than eager/sdpa. Also, update compile -> reference_compile in one more case --- .../models/modernbert/modeling_modernbert.py | 59 ++---------- .../models/modernbert/modular_modernbert.py | 92 ++----------------- .../modernbert/test_modeling_modernbert.py | 2 +- 3 files changed, 13 insertions(+), 140 deletions(-) diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index 62dff9aca0f534..629ec68d1b6fd5 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -37,7 +37,6 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, - is_torch_greater_or_equal, logging, ) from ...utils.import_utils import is_triton_available @@ -51,9 +50,6 @@ else: RotaryEmbedding = object -if is_torch_greater_or_equal("2.5"): - from torch.nn.attention.flex_attention import BlockMask, create_block_mask - logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "answerdotai/ModernBERT-base" @@ -439,7 +435,6 @@ def sdpa_attention_forward( MODERNBERT_ATTENTION_FUNCTION = { "flash_attention_2": flash_attention_forward, - # "flex_attention": flex_attention_forward, "eager": eager_attention_forward, "sdpa": sdpa_attention_forward, } @@ -484,7 +479,7 @@ def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): rope_theta = config.local_rope_theta max_position_embeddings = config.local_attention - if config._attn_implementation in {"flash_attention_2", "flex_attention"}: + if config._attn_implementation == "flash_attention_2": self.rotary_emb = ModernBertUnpaddedRotaryEmbedding( dim=self.head_dim, max_seqlen=max_position_embeddings, base=rope_theta ) @@ -525,7 +520,7 @@ def forward( qkv = self.Wqkv(hidden_states) bs = hidden_states.shape[0] - if self.config._attn_implementation in ("flash_attention_2", "flex_attention"): + if self.config._attn_implementation == "flash_attention_2": qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) else: qkv = qkv.view(bs, -1, 3, self.num_heads, self.head_dim) @@ -582,7 +577,6 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, cu_seqlens: Optional[torch.Tensor] = None, - block_mask: Optional["BlockMask"] = None, max_seqlen: Optional[int] = None, output_attentions: Optional[bool] = False, ) -> torch.Tensor: @@ -599,7 +593,6 @@ def forward( self.attn_norm(hidden_states), position_ids=position_ids, cu_seqlens=cu_seqlens, - block_mask=block_mask, max_seqlen=max_seqlen, attention_mask=attention_mask, output_attentions=output_attentions, @@ -903,13 +896,9 @@ def _autoset_attn_implementation( config._attn_implementation_autoset = True return config - # If flex_attention is requested and it fails, we throw an error. - # This implementation is not used by default, so we only try to enable it if it is requested. + # Flex attention is not supported for ModernBert if requested_attn_implementation == "flex_attention": raise NotImplementedError("Flex attention is not supported for ModernBert") - config = cls._check_and_enable_flex_attn(config, hard_check_only=False) - config._attn_implementation_autoset = True - return config # If eager is requested, or if None is requested but FA2 and SDPA fail, we set it and return the config. config._attn_implementation = "eager" @@ -1004,33 +993,6 @@ def offsets_to_sequence_ids_tensor(cls, offsets): counts = offsets[1:] - offsets[:-1] return torch.repeat_interleave(torch.arange(len(counts), device=device, dtype=torch.int32), counts) - def create_attention_mask(self, sequence_ids, cu_seqlens, window_size): - """ - Creates a block mask combining sequence masking and local/or global attention masking. - """ - - def sliding_window_seq_mask_mod(b, h, q_idx, kv_idx): - # only allow attention within the same sequence - same_seq = sequence_ids[q_idx] == sequence_ids[kv_idx] - - # get position within the sequence - q_pos = q_idx - cu_seqlens[sequence_ids[q_idx]] - kv_pos = kv_idx - cu_seqlens[sequence_ids[kv_idx]] - - # sliding window within each sequence - in_window = (q_pos - kv_pos).abs() <= window_size - - return same_seq & in_window - - block_mask = create_block_mask( - sliding_window_seq_mask_mod, - B=None, - H=None, - Q_LEN=cu_seqlens[-1], - KV_LEN=cu_seqlens[-1], - ) - return block_mask - MODERNBERT_INPUTS_DOCSTRING = r""" Args: @@ -1149,7 +1111,7 @@ def forward( attention_mask = torch.ones((batch_size, seq_len), device=input_ids.device, dtype=torch.bool) repad = False - if self.config._attn_implementation in {"flash_attention_2", "flex_attention"}: + if self.config._attn_implementation == "flash_attention_2": if indices is None and cu_seqlens is None and max_seqlen is None: repad = True if self.config.unpad_no_grad: @@ -1163,13 +1125,6 @@ def forward( hidden_states = self.embeddings(input_ids) - # create block mask - block_mask = None - if self.config._attn_implementation == "flex_attention": - sequence_ids = self.offsets_to_sequence_ids_tensor(cu_seqlens) - window_size = self.config.local_attention // 2 - block_mask = self.create_attention_mask(sequence_ids, cu_seqlens, window_size) - for encoder_layer in self.layers: if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -1181,7 +1136,6 @@ def forward( attention_mask, position_ids, cu_seqlens, - block_mask, max_seqlen, output_attentions, ) @@ -1191,7 +1145,6 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, cu_seqlens=cu_seqlens, - block_mask=block_mask, max_seqlen=max_seqlen, output_attentions=output_attentions, ) @@ -1275,7 +1228,7 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict self._maybe_set_compile() - if self.config._attn_implementation in {"flash_attention_2", "flex_attention"}: + if self.config._attn_implementation == "flash_attention_2": if indices is None and cu_seqlens is None and max_seqlen is None: batch_size, seq_len = input_ids.shape[:2] if self.config.unpad_no_grad: @@ -1322,7 +1275,7 @@ def forward( if labels is not None: loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size) - if self.config._attn_implementation in {"flash_attention_2", "flex_attention"}: + if self.config._attn_implementation == "flash_attention_2": if self.config.unpad_no_grad: logits = self._pad_outputs_no_grad(logits, indices, batch_size, seq_len) else: diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index 17dd6af56d35b2..52121133639c4d 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -38,7 +38,6 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, - is_torch_greater_or_equal, logging, ) from ...utils.import_utils import is_triton_available @@ -52,9 +51,6 @@ else: RotaryEmbedding = object -if is_torch_greater_or_equal("2.5"): - from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention - _CHECKPOINT_FOR_DOC = "answerdotai/ModernBERT-base" _CONFIG_FOR_DOC = "ModernBertConfig" @@ -596,39 +592,6 @@ def flash_attention_forward( return (attn.view(bs, dim),) -def flex_attention_forward( - module: "ModernBertAttention", - qkv: torch.Tensor, - rotary_emb: ModernBertUnpaddedRotaryEmbedding, - cu_seqlens: torch.Tensor, - block_mask: "BlockMask", - local_attention: Tuple[int, int], - max_seqlen: int, - bs: int, - dim: int, - **_kwargs, -) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - qkv = rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) - query, key, value = qkv.unbind(dim=1) - - query = query.transpose(0, 1).unsqueeze(0) - key = key.transpose(0, 1).unsqueeze(0) - value = value.transpose(0, 1).unsqueeze(0) - - attn_output = flex_attention( - query, - key, - value, - block_mask=block_mask if local_attention != (-1, -1) else None, - enable_gqa=False, - scale=None, - return_lse=False, - ) - - attn_output = attn_output.squeeze(0).transpose(0, 1).contiguous() - return (attn_output.view(bs, dim),) - - def sdpa_attention_forward( module: "ModernBertAttention", qkv: torch.Tensor, @@ -676,7 +639,6 @@ def sdpa_attention_forward( MODERNBERT_ATTENTION_FUNCTION = { "flash_attention_2": flash_attention_forward, - # "flex_attention": flex_attention_forward, "eager": eager_attention_forward, "sdpa": sdpa_attention_forward, } @@ -721,7 +683,7 @@ def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): rope_theta = config.local_rope_theta max_position_embeddings = config.local_attention - if config._attn_implementation in {"flash_attention_2", "flex_attention"}: + if config._attn_implementation == "flash_attention_2": self.rotary_emb = ModernBertUnpaddedRotaryEmbedding( dim=self.head_dim, max_seqlen=max_position_embeddings, base=rope_theta ) @@ -762,7 +724,7 @@ def forward( qkv = self.Wqkv(hidden_states) bs = hidden_states.shape[0] - if self.config._attn_implementation in ("flash_attention_2", "flex_attention"): + if self.config._attn_implementation == "flash_attention_2": qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) else: qkv = qkv.view(bs, -1, 3, self.num_heads, self.head_dim) @@ -819,7 +781,6 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, cu_seqlens: Optional[torch.Tensor] = None, - block_mask: Optional["BlockMask"] = None, max_seqlen: Optional[int] = None, output_attentions: Optional[bool] = False, ) -> torch.Tensor: @@ -836,7 +797,6 @@ def forward( self.attn_norm(hidden_states), position_ids=position_ids, cu_seqlens=cu_seqlens, - block_mask=block_mask, max_seqlen=max_seqlen, attention_mask=attention_mask, output_attentions=output_attentions, @@ -1068,13 +1028,9 @@ def _autoset_attn_implementation( config._attn_implementation_autoset = True return config - # If flex_attention is requested and it fails, we throw an error. - # This implementation is not used by default, so we only try to enable it if it is requested. + # Flex attention is not supported for ModernBert if requested_attn_implementation == "flex_attention": raise NotImplementedError("Flex attention is not supported for ModernBert") - config = cls._check_and_enable_flex_attn(config, hard_check_only=False) - config._attn_implementation_autoset = True - return config # If eager is requested, or if None is requested but FA2 and SDPA fail, we set it and return the config. config._attn_implementation = "eager" @@ -1169,33 +1125,6 @@ def offsets_to_sequence_ids_tensor(cls, offsets): counts = offsets[1:] - offsets[:-1] return torch.repeat_interleave(torch.arange(len(counts), device=device, dtype=torch.int32), counts) - def create_attention_mask(self, sequence_ids, cu_seqlens, window_size): - """ - Creates a block mask combining sequence masking and local/or global attention masking. - """ - - def sliding_window_seq_mask_mod(b, h, q_idx, kv_idx): - # only allow attention within the same sequence - same_seq = sequence_ids[q_idx] == sequence_ids[kv_idx] - - # get position within the sequence - q_pos = q_idx - cu_seqlens[sequence_ids[q_idx]] - kv_pos = kv_idx - cu_seqlens[sequence_ids[kv_idx]] - - # sliding window within each sequence - in_window = (q_pos - kv_pos).abs() <= window_size - - return same_seq & in_window - - block_mask = create_block_mask( - sliding_window_seq_mask_mod, - B=None, - H=None, - Q_LEN=cu_seqlens[-1], - KV_LEN=cu_seqlens[-1], - ) - return block_mask - MODERNBERT_INPUTS_DOCSTRING = r""" Args: @@ -1314,7 +1243,7 @@ def forward( attention_mask = torch.ones((batch_size, seq_len), device=input_ids.device, dtype=torch.bool) repad = False - if self.config._attn_implementation in {"flash_attention_2", "flex_attention"}: + if self.config._attn_implementation == "flash_attention_2": if indices is None and cu_seqlens is None and max_seqlen is None: repad = True if self.config.unpad_no_grad: @@ -1328,13 +1257,6 @@ def forward( hidden_states = self.embeddings(input_ids) - # create block mask - block_mask = None - if self.config._attn_implementation == "flex_attention": - sequence_ids = self.offsets_to_sequence_ids_tensor(cu_seqlens) - window_size = self.config.local_attention // 2 - block_mask = self.create_attention_mask(sequence_ids, cu_seqlens, window_size) - for encoder_layer in self.layers: if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -1346,7 +1268,6 @@ def forward( attention_mask, position_ids, cu_seqlens, - block_mask, max_seqlen, output_attentions, ) @@ -1356,7 +1277,6 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, cu_seqlens=cu_seqlens, - block_mask=block_mask, max_seqlen=max_seqlen, output_attentions=output_attentions, ) @@ -1440,7 +1360,7 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict self._maybe_set_compile() - if self.config._attn_implementation in {"flash_attention_2", "flex_attention"}: + if self.config._attn_implementation == "flash_attention_2": if indices is None and cu_seqlens is None and max_seqlen is None: batch_size, seq_len = input_ids.shape[:2] if self.config.unpad_no_grad: @@ -1487,7 +1407,7 @@ def forward( if labels is not None: loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size) - if self.config._attn_implementation in {"flash_attention_2", "flex_attention"}: + if self.config._attn_implementation == "flash_attention_2": if self.config.unpad_no_grad: logits = self._pad_outputs_no_grad(logits, indices, batch_size, seq_len) else: diff --git a/tests/models/modernbert/test_modeling_modernbert.py b/tests/models/modernbert/test_modeling_modernbert.py index 5094ca7963cb76..809478f90d1c24 100644 --- a/tests/models/modernbert/test_modeling_modernbert.py +++ b/tests/models/modernbert/test_modeling_modernbert.py @@ -143,7 +143,7 @@ def get_config(self): # that compilation doesn't work. Users can then set compile=False when loading the model, # much like here. We're testing whether it works once they've done that. if test_name == "test_retain_grad_hidden_states_attentions": - config.compile = False + config.reference_compile = False # Some tests require attentions to be outputted, in that case we'll set the attention implementation to eager # as the others don't support outputted attentions if test_name in ( From 6c6cddb8d3a6bd3450e1742c4a17319e745a8103 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Wed, 18 Dec 2024 12:02:10 +0100 Subject: [PATCH 64/88] Call contiguous to allow for .view() --- .../models/modernbert/modeling_modernbert.py | 18 +++++++++++------- .../models/modernbert/modular_modernbert.py | 18 +++++++++++------- 2 files changed, 22 insertions(+), 14 deletions(-) diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index 629ec68d1b6fd5..66df07f5121571 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -422,13 +422,17 @@ def sdpa_attention_forward( attention_mask = attention_mask.to(torch.bool) - attn_output = F.scaled_dot_product_attention( - query, - key, - value, - dropout_p=module.attention_dropout if module.training else 0.0, - attn_mask=attention_mask, - ).transpose(1, 2) + attn_output = ( + F.scaled_dot_product_attention( + query, + key, + value, + dropout_p=module.attention_dropout if module.training else 0.0, + attn_mask=attention_mask, + ) + .transpose(1, 2) + .contiguous() + ) attn_output = attn_output.view(bs, -1, dim) return (attn_output,) diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index 52121133639c4d..aa08586094dabe 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -626,13 +626,17 @@ def sdpa_attention_forward( attention_mask = attention_mask.to(torch.bool) - attn_output = F.scaled_dot_product_attention( - query, - key, - value, - dropout_p=module.attention_dropout if module.training else 0.0, - attn_mask=attention_mask, - ).transpose(1, 2) + attn_output = ( + F.scaled_dot_product_attention( + query, + key, + value, + dropout_p=module.attention_dropout if module.training else 0.0, + attn_mask=attention_mask, + ) + .transpose(1, 2) + .contiguous() + ) attn_output = attn_output.view(bs, -1, dim) return (attn_output,) From e37e4ec5d30c5851ba45adf00581986b504fed20 Mon Sep 17 00:00:00 2001 From: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com> Date: Wed, 18 Dec 2024 14:52:01 +0100 Subject: [PATCH 65/88] Copyright 2020 -> 2024 Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- docs/source/en/model_doc/modernbert.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/model_doc/modernbert.md b/docs/source/en/model_doc/modernbert.md index 5286ec864065e5..a504fd85cb9c60 100644 --- a/docs/source/en/model_doc/modernbert.md +++ b/docs/source/en/model_doc/modernbert.md @@ -1,4 +1,4 @@ - + +Paper page + ## Overview -The ModernBert model was proposed in [Smarter, Better, Faster, Longer: A Modern Bidirectional Encoder for Fast, Memory Efficient, and Long Context Finetuning and Inference](#) by Benjamin Warner, Antoine Chaffin, Benjamin Clavié, Orion Weller, Oskar Hallstrâm, Said Taghadouini, Alexis Galalgher, Raja Bisas, Faisal Ladhak, Tom Aarsen, Nathan Cooper, Grifin Adams, Jeremy Howard and Iacopo Poli. +The ModernBert model was proposed in [Smarter, Better, Faster, Longer: A Modern Bidirectional Encoder for Fast, Memory Efficient, and Long Context Finetuning and Inference](https://arxiv.org/abs/2412.13663) by Benjamin Warner, Antoine Chaffin, Benjamin Clavié, Orion Weller, Oskar Hallstrâm, Said Taghadouini, Alexis Galalgher, Raja Bisas, Faisal Ladhak, Tom Aarsen, Nathan Cooper, Grifin Adams, Jeremy Howard and Iacopo Poli. It is a refresh of the traditional encoder architecture, as used in previous models such as [BERT](https://huggingface.co/docs/transformers/en/model_doc/bert) and [RoBERTa](https://huggingface.co/docs/transformers/en/model_doc/roberta). From 44275fd738741895362f6bd8fe5ade45f5c6f98d Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Thu, 19 Dec 2024 10:17:38 +0100 Subject: [PATCH 85/88] Remove unpad_no_grad, always pad/unpad without gradients --- .../modernbert/configuration_modernbert.py | 4 ---- .../models/modernbert/modeling_modernbert.py | 19 +++------------ .../models/modernbert/modular_modernbert.py | 23 +++---------------- 3 files changed, 6 insertions(+), 40 deletions(-) diff --git a/src/transformers/models/modernbert/configuration_modernbert.py b/src/transformers/models/modernbert/configuration_modernbert.py index a458850bed8845..93a7a67885ff9b 100644 --- a/src/transformers/models/modernbert/configuration_modernbert.py +++ b/src/transformers/models/modernbert/configuration_modernbert.py @@ -87,8 +87,6 @@ class ModernBertConfig(PretrainedConfig): Whether to use bias in the MLP layers. mlp_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the MLP layers. - unpad_no_grad (`bool`, *optional*, defaults to `True`): - Whether to use `no_grad` when unpadding the inputs. decoder_bias (`bool`, *optional*, defaults to `True`): Whether to use bias in the decoder layers. classifier_pooling (`str`, *optional*, defaults to `"cls"`): @@ -157,7 +155,6 @@ def __init__( embedding_dropout=0.0, mlp_bias=False, mlp_dropout=0.0, - unpad_no_grad=True, decoder_bias=True, classifier_pooling: Literal["cls", "mean"] = "cls", classifier_dropout=0.0, @@ -197,7 +194,6 @@ def __init__( self.embedding_dropout = embedding_dropout self.mlp_bias = mlp_bias self.mlp_dropout = mlp_dropout - self.unpad_no_grad = unpad_no_grad self.decoder_bias = decoder_bias self.classifier_pooling = classifier_pooling self.classifier_dropout = classifier_dropout diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index b106e7b169083d..77de321b5bcc6d 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -910,12 +910,7 @@ def forward( if self.config._attn_implementation == "flash_attention_2": if indices is None and cu_seqlens is None and max_seqlen is None: repad = True - if self.config.unpad_no_grad: - with torch.no_grad(): - input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input( - inputs=input_ids, attention_mask=attention_mask - ) - else: + with torch.no_grad(): input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input( inputs=input_ids, attention_mask=attention_mask ) @@ -1088,12 +1083,7 @@ def forward( batch_size, seq_len = input_ids.shape[:2] if attention_mask is None: attention_mask = torch.ones((batch_size, seq_len), device=input_ids.device, dtype=torch.bool) - if self.config.unpad_no_grad: - with torch.no_grad(): - input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input( - inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels - ) - else: + with torch.no_grad(): input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input( inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels ) @@ -1135,10 +1125,7 @@ def forward( loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size) if self.config._attn_implementation == "flash_attention_2": - if self.config.unpad_no_grad: - with torch.no_grad(): - logits = _pad_modernbert_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len) - else: + with torch.no_grad(): logits = _pad_modernbert_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len) if not return_dict: output = (logits,) diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index 77f97c2e101423..c336421fd394de 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -119,8 +119,6 @@ class ModernBertConfig(PretrainedConfig): Whether to use bias in the MLP layers. mlp_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the MLP layers. - unpad_no_grad (`bool`, *optional*, defaults to `True`): - Whether to use `no_grad` when unpadding the inputs. decoder_bias (`bool`, *optional*, defaults to `True`): Whether to use bias in the decoder layers. classifier_pooling (`str`, *optional*, defaults to `"cls"`): @@ -189,7 +187,6 @@ def __init__( embedding_dropout=0.0, mlp_bias=False, mlp_dropout=0.0, - unpad_no_grad=True, decoder_bias=True, classifier_pooling: Literal["cls", "mean"] = "cls", classifier_dropout=0.0, @@ -229,7 +226,6 @@ def __init__( self.embedding_dropout = embedding_dropout self.mlp_bias = mlp_bias self.mlp_dropout = mlp_dropout - self.unpad_no_grad = unpad_no_grad self.decoder_bias = decoder_bias self.classifier_pooling = classifier_pooling self.classifier_dropout = classifier_dropout @@ -1038,12 +1034,7 @@ def forward( if self.config._attn_implementation == "flash_attention_2": if indices is None and cu_seqlens is None and max_seqlen is None: repad = True - if self.config.unpad_no_grad: - with torch.no_grad(): - input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input( - inputs=input_ids, attention_mask=attention_mask - ) - else: + with torch.no_grad(): input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input( inputs=input_ids, attention_mask=attention_mask ) @@ -1216,12 +1207,7 @@ def forward( batch_size, seq_len = input_ids.shape[:2] if attention_mask is None: attention_mask = torch.ones((batch_size, seq_len), device=input_ids.device, dtype=torch.bool) - if self.config.unpad_no_grad: - with torch.no_grad(): - input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input( - inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels - ) - else: + with torch.no_grad(): input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input( inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels ) @@ -1263,10 +1249,7 @@ def forward( loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size) if self.config._attn_implementation == "flash_attention_2": - if self.config.unpad_no_grad: - with torch.no_grad(): - logits = _pad_modernbert_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len) - else: + with torch.no_grad(): logits = _pad_modernbert_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len) if not return_dict: output = (logits,) From d799d65bd9ebeab85e1d6a7f453841c5aa21b272 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Thu, 19 Dec 2024 10:28:12 +0100 Subject: [PATCH 86/88] local_attention_mask -> sliding_window_mask --- .../models/modernbert/modeling_modernbert.py | 68 ++++++------------- .../models/modernbert/modular_modernbert.py | 68 ++++++------------- 2 files changed, 38 insertions(+), 98 deletions(-) diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index 77de321b5bcc6d..cf32a411fd4dfc 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -300,7 +300,7 @@ def eager_attention_forward( module: "ModernBertAttention", qkv: torch.Tensor, attention_mask: torch.Tensor, - local_attention_mask: torch.Tensor, + sliding_window_mask: torch.Tensor, position_ids: Optional[torch.LongTensor], local_attention: Tuple[int, int], bs: int, @@ -318,7 +318,7 @@ def eager_attention_forward( attn_weights = torch.matmul(query, key.transpose(2, 3)) * scale if local_attention != (-1, -1): - attention_mask = local_attention_mask + attention_mask = sliding_window_mask attn_weights = attn_weights + attention_mask @@ -380,7 +380,7 @@ def sdpa_attention_forward( module: "ModernBertAttention", qkv: torch.Tensor, attention_mask: torch.Tensor, - local_attention_mask: torch.Tensor, + sliding_window_mask: torch.Tensor, position_ids: Optional[torch.LongTensor], local_attention: Tuple[int, int], bs: int, @@ -394,7 +394,7 @@ def sdpa_attention_forward( query, key = apply_rotary_pos_emb(query, key, cos, sin) if local_attention != (-1, -1): - attention_mask = local_attention_mask + attention_mask = sliding_window_mask attn_output = ( F.scaled_dot_product_attention( @@ -476,25 +476,6 @@ def forward( output_attentions: Optional[bool] = False, **kwargs, ) -> torch.Tensor: - """Perform self-attention. - - There are two attention implementations supported: PyTorch's SDPA attention and Flash Attention 2. - - The arguments are unpadded. The SDPA implementation of attention requires padded arguments while the - Flash Attention implementation does not. If using SDPA we first call `pad_input`. Once we compute - attention, we re-unpad our outputs for the other layers. The pad/unpad operations add overhead, but not - sending pad tokens through ffs saves compute. - - Args: - hidden_states: (total_nnz, dim) - cu_seqlens: (batch + 1,) - max_seqlen: int - indices: (total_nnz,) - attn_mask: (batch, max_seqlen) - - Returns: - attention: (total_nnz, dim) - """ qkv = self.Wqkv(hidden_states) bs = hidden_states.shape[0] @@ -539,25 +520,16 @@ def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, - local_attention_mask: Optional[torch.Tensor] = None, + sliding_window_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None, output_attentions: Optional[bool] = False, ) -> torch.Tensor: - """Forward pass for a ModernBert layer, including both attention and MLP. - - Args: - hidden_states: (total_nnz, dim) - attention_mask: (batch, max_seqlen) - position_ids: (total_nnz,) - cu_seqlens: (batch + 1,) - max_seqlen: int - """ attn_outputs = self.attn( self.attn_norm(hidden_states), attention_mask=attention_mask, - local_attention_mask=local_attention_mask, + sliding_window_mask=sliding_window_mask, position_ids=position_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, @@ -815,7 +787,7 @@ def _pad_modernbert_output( - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. - local_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers perform global attention, while the rest perform local attention. This mask is used to avoid attending to far-away tokens in the local attention layers. @@ -877,7 +849,7 @@ def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, - local_attention_mask: Optional[torch.Tensor] = None, + sliding_window_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, indices: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None, @@ -918,7 +890,7 @@ def forward( if position_ids is None: position_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0) - attention_mask, local_attention_mask = self._update_attention_mask( + attention_mask, sliding_window_mask = self._update_attention_mask( attention_mask, output_attentions=output_attentions ) @@ -933,7 +905,7 @@ def forward( encoder_layer.__call__, hidden_states, attention_mask, - local_attention_mask, + sliding_window_mask, position_ids, cu_seqlens, max_seqlen, @@ -943,7 +915,7 @@ def forward( layer_outputs = encoder_layer( hidden_states, attention_mask=attention_mask, - local_attention_mask=local_attention_mask, + sliding_window_mask=sliding_window_mask, position_ids=position_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, @@ -1003,11 +975,9 @@ def _update_attention_mask(self, attention_mask: torch.Tensor, output_attentions (distance <= self.config.local_attention // 2).unsqueeze(0).unsqueeze(0).to(attention_mask.device) ) # Combine with existing mask - local_attention_mask = global_attention_mask.masked_fill( - window_mask.logical_not(), torch.finfo(self.dtype).min - ) + sliding_window_mask = global_attention_mask.masked_fill(window_mask.logical_not(), torch.finfo(self.dtype).min) - return global_attention_mask, local_attention_mask + return global_attention_mask, sliding_window_mask class ModernBertPredictionHead(nn.Module): @@ -1062,7 +1032,7 @@ def forward( self, input_ids: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor] = None, - local_attention_mask: Optional[torch.Tensor] = None, + sliding_window_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None, @@ -1091,7 +1061,7 @@ def forward( outputs = self.model( input_ids, attention_mask=attention_mask, - local_attention_mask=local_attention_mask, + sliding_window_mask=sliding_window_mask, position_ids=position_ids, indices=indices, cu_seqlens=cu_seqlens, @@ -1199,7 +1169,7 @@ def forward( self, input_ids: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor] = None, - local_attention_mask: Optional[torch.Tensor] = None, + sliding_window_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None, @@ -1224,7 +1194,7 @@ def forward( outputs = self.model( input_ids, attention_mask=attention_mask, - local_attention_mask=local_attention_mask, + sliding_window_mask=sliding_window_mask, position_ids=position_ids, indices=indices, cu_seqlens=cu_seqlens, @@ -1301,7 +1271,7 @@ def forward( self, input_ids: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor] = None, - local_attention_mask: Optional[torch.Tensor] = None, + sliding_window_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None, @@ -1323,7 +1293,7 @@ def forward( outputs = self.model( input_ids, attention_mask=attention_mask, - local_attention_mask=local_attention_mask, + sliding_window_mask=sliding_window_mask, position_ids=position_ids, indices=indices, cu_seqlens=cu_seqlens, diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index c336421fd394de..c8f05188acb38b 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -496,7 +496,7 @@ def eager_attention_forward( module: "ModernBertAttention", qkv: torch.Tensor, attention_mask: torch.Tensor, - local_attention_mask: torch.Tensor, + sliding_window_mask: torch.Tensor, position_ids: Optional[torch.LongTensor], local_attention: Tuple[int, int], bs: int, @@ -514,7 +514,7 @@ def eager_attention_forward( attn_weights = torch.matmul(query, key.transpose(2, 3)) * scale if local_attention != (-1, -1): - attention_mask = local_attention_mask + attention_mask = sliding_window_mask attn_weights = attn_weights + attention_mask @@ -576,7 +576,7 @@ def sdpa_attention_forward( module: "ModernBertAttention", qkv: torch.Tensor, attention_mask: torch.Tensor, - local_attention_mask: torch.Tensor, + sliding_window_mask: torch.Tensor, position_ids: Optional[torch.LongTensor], local_attention: Tuple[int, int], bs: int, @@ -590,7 +590,7 @@ def sdpa_attention_forward( query, key = apply_rotary_pos_emb(query, key, cos, sin) if local_attention != (-1, -1): - attention_mask = local_attention_mask + attention_mask = sliding_window_mask attn_output = ( F.scaled_dot_product_attention( @@ -672,25 +672,6 @@ def forward( output_attentions: Optional[bool] = False, **kwargs, ) -> torch.Tensor: - """Perform self-attention. - - There are two attention implementations supported: PyTorch's SDPA attention and Flash Attention 2. - - The arguments are unpadded. The SDPA implementation of attention requires padded arguments while the - Flash Attention implementation does not. If using SDPA we first call `pad_input`. Once we compute - attention, we re-unpad our outputs for the other layers. The pad/unpad operations add overhead, but not - sending pad tokens through ffs saves compute. - - Args: - hidden_states: (total_nnz, dim) - cu_seqlens: (batch + 1,) - max_seqlen: int - indices: (total_nnz,) - attn_mask: (batch, max_seqlen) - - Returns: - attention: (total_nnz, dim) - """ qkv = self.Wqkv(hidden_states) bs = hidden_states.shape[0] @@ -735,25 +716,16 @@ def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, - local_attention_mask: Optional[torch.Tensor] = None, + sliding_window_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None, output_attentions: Optional[bool] = False, ) -> torch.Tensor: - """Forward pass for a ModernBert layer, including both attention and MLP. - - Args: - hidden_states: (total_nnz, dim) - attention_mask: (batch, max_seqlen) - position_ids: (total_nnz,) - cu_seqlens: (batch + 1,) - max_seqlen: int - """ attn_outputs = self.attn( self.attn_norm(hidden_states), attention_mask=attention_mask, - local_attention_mask=local_attention_mask, + sliding_window_mask=sliding_window_mask, position_ids=position_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, @@ -939,7 +911,7 @@ def resize_token_embeddings(self, *args, **kwargs): - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. - local_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers perform global attention, while the rest perform local attention. This mask is used to avoid attending to far-away tokens in the local attention layers. @@ -1001,7 +973,7 @@ def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, - local_attention_mask: Optional[torch.Tensor] = None, + sliding_window_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, indices: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None, @@ -1042,7 +1014,7 @@ def forward( if position_ids is None: position_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0) - attention_mask, local_attention_mask = self._update_attention_mask( + attention_mask, sliding_window_mask = self._update_attention_mask( attention_mask, output_attentions=output_attentions ) @@ -1057,7 +1029,7 @@ def forward( encoder_layer.__call__, hidden_states, attention_mask, - local_attention_mask, + sliding_window_mask, position_ids, cu_seqlens, max_seqlen, @@ -1067,7 +1039,7 @@ def forward( layer_outputs = encoder_layer( hidden_states, attention_mask=attention_mask, - local_attention_mask=local_attention_mask, + sliding_window_mask=sliding_window_mask, position_ids=position_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, @@ -1127,11 +1099,9 @@ def _update_attention_mask(self, attention_mask: torch.Tensor, output_attentions (distance <= self.config.local_attention // 2).unsqueeze(0).unsqueeze(0).to(attention_mask.device) ) # Combine with existing mask - local_attention_mask = global_attention_mask.masked_fill( - window_mask.logical_not(), torch.finfo(self.dtype).min - ) + sliding_window_mask = global_attention_mask.masked_fill(window_mask.logical_not(), torch.finfo(self.dtype).min) - return global_attention_mask, local_attention_mask + return global_attention_mask, sliding_window_mask class ModernBertPredictionHead(nn.Module): @@ -1186,7 +1156,7 @@ def forward( self, input_ids: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor] = None, - local_attention_mask: Optional[torch.Tensor] = None, + sliding_window_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None, @@ -1215,7 +1185,7 @@ def forward( outputs = self.model( input_ids, attention_mask=attention_mask, - local_attention_mask=local_attention_mask, + sliding_window_mask=sliding_window_mask, position_ids=position_ids, indices=indices, cu_seqlens=cu_seqlens, @@ -1323,7 +1293,7 @@ def forward( self, input_ids: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor] = None, - local_attention_mask: Optional[torch.Tensor] = None, + sliding_window_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None, @@ -1348,7 +1318,7 @@ def forward( outputs = self.model( input_ids, attention_mask=attention_mask, - local_attention_mask=local_attention_mask, + sliding_window_mask=sliding_window_mask, position_ids=position_ids, indices=indices, cu_seqlens=cu_seqlens, @@ -1425,7 +1395,7 @@ def forward( self, input_ids: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor] = None, - local_attention_mask: Optional[torch.Tensor] = None, + sliding_window_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None, @@ -1447,7 +1417,7 @@ def forward( outputs = self.model( input_ids, attention_mask=attention_mask, - local_attention_mask=local_attention_mask, + sliding_window_mask=sliding_window_mask, position_ids=position_ids, indices=indices, cu_seqlens=cu_seqlens, From ed77867db7384c018beb4b352d5d75f47b8d2b72 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Thu, 19 Dec 2024 10:37:39 +0100 Subject: [PATCH 87/88] Revert "Use the pooling head in TokenClassification" This reverts commit 99c38badd1dbce01d7aef41095fbf2f5cce87279. There was no real motivation, no info on whether having this bigger head does anything useful. --- src/transformers/models/modernbert/modeling_modernbert.py | 4 ++-- src/transformers/models/modernbert/modular_modernbert.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index cf32a411fd4dfc..a077362f158f9a 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -1255,7 +1255,7 @@ def __init__(self, config: ModernBertConfig): self.num_labels = config.num_labels self.model = ModernBertModel(config) - self.head = ModernBertPoolingHead(config) + self.drop = nn.Dropout(config.classifier_dropout) self.classifier = nn.Linear(config.hidden_size, config.num_labels) # Initialize weights and apply final processing @@ -1306,7 +1306,7 @@ def forward( ) last_hidden_state = outputs[0] - last_hidden_state = self.head(last_hidden_state, attention_mask, pool=False) + last_hidden_state = self.drop(last_hidden_state) logits = self.classifier(last_hidden_state) loss = None diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index c8f05188acb38b..a4b2888b8e3018 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -1379,7 +1379,7 @@ def __init__(self, config: ModernBertConfig): self.num_labels = config.num_labels self.model = ModernBertModel(config) - self.head = ModernBertPoolingHead(config) + self.drop = nn.Dropout(config.classifier_dropout) self.classifier = nn.Linear(config.hidden_size, config.num_labels) # Initialize weights and apply final processing @@ -1430,7 +1430,7 @@ def forward( ) last_hidden_state = outputs[0] - last_hidden_state = self.head(last_hidden_state, attention_mask, pool=False) + last_hidden_state = self.drop(last_hidden_state) logits = self.classifier(last_hidden_state) loss = None From 92e17c667aae8e9473237e1ec9a843f16eb58e70 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Thu, 19 Dec 2024 10:40:23 +0100 Subject: [PATCH 88/88] Simplify pooling, 2 options via if-else --- .../modernbert/configuration_modernbert.py | 5 +++ .../models/modernbert/modeling_modernbert.py | 27 ++++------------ .../models/modernbert/modular_modernbert.py | 32 +++++++------------ 3 files changed, 24 insertions(+), 40 deletions(-) diff --git a/src/transformers/models/modernbert/configuration_modernbert.py b/src/transformers/models/modernbert/configuration_modernbert.py index 93a7a67885ff9b..13e9edf067efc4 100644 --- a/src/transformers/models/modernbert/configuration_modernbert.py +++ b/src/transformers/models/modernbert/configuration_modernbert.py @@ -204,5 +204,10 @@ def __init__( self.sparse_pred_ignore_index = sparse_pred_ignore_index self.reference_compile = reference_compile + if self.classifier_pooling not in ["cls", "mean"]: + raise ValueError( + f'Invalid value for `classifier_pooling`, should be either "cls" or "mean", but is {self.classifier_pooling}.' + ) + __all__ = ["ModernBertConfig"] diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index a077362f158f9a..db8d98893f96fe 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -1109,20 +1109,6 @@ def forward( ) -def cls_pooling(hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: - return hidden_states[:, 0] - - -def mean_pooling(hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: - return (hidden_states * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(dim=1, keepdim=True) - - -MODERNBERT_POOLING_FUNCTION = { - "cls": cls_pooling, - "mean": mean_pooling, -} - - class ModernBertPoolingHead(nn.Module): def __init__(self, config: ModernBertConfig): super().__init__() @@ -1131,13 +1117,14 @@ def __init__(self, config: ModernBertConfig): self.act = ACT2FN[config.classifier_activation] self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) self.drop = torch.nn.Dropout(config.classifier_dropout) - self.pooling = MODERNBERT_POOLING_FUNCTION[config.classifier_pooling] - def forward( - self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, pool: Optional[bool] = True - ) -> torch.Tensor: - if pool: - hidden_states = self.pooling(hidden_states, attention_mask) + def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: + if self.config.classifier_pooling == "cls": + hidden_states = hidden_states[:, 0] + elif self.config.classifier_pooling == "mean": + hidden_states = (hidden_states * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum( + dim=1, keepdim=True + ) return self.drop(self.norm(self.act(self.dense(hidden_states)))) diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index a4b2888b8e3018..3c23f9178b1b51 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -236,6 +236,11 @@ def __init__( self.sparse_pred_ignore_index = sparse_pred_ignore_index self.reference_compile = reference_compile + if self.classifier_pooling not in ["cls", "mean"]: + raise ValueError( + f'Invalid value for `classifier_pooling`, should be either "cls" or "mean", but is {self.classifier_pooling}.' + ) + def _unpad_modernbert_input( inputs: torch.Tensor, @@ -1233,20 +1238,6 @@ def forward( ) -def cls_pooling(hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: - return hidden_states[:, 0] - - -def mean_pooling(hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: - return (hidden_states * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(dim=1, keepdim=True) - - -MODERNBERT_POOLING_FUNCTION = { - "cls": cls_pooling, - "mean": mean_pooling, -} - - class ModernBertPoolingHead(nn.Module): def __init__(self, config: ModernBertConfig): super().__init__() @@ -1255,13 +1246,14 @@ def __init__(self, config: ModernBertConfig): self.act = ACT2FN[config.classifier_activation] self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) self.drop = torch.nn.Dropout(config.classifier_dropout) - self.pooling = MODERNBERT_POOLING_FUNCTION[config.classifier_pooling] - def forward( - self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, pool: Optional[bool] = True - ) -> torch.Tensor: - if pool: - hidden_states = self.pooling(hidden_states, attention_mask) + def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: + if self.config.classifier_pooling == "cls": + hidden_states = hidden_states[:, 0] + elif self.config.classifier_pooling == "mean": + hidden_states = (hidden_states * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum( + dim=1, keepdim=True + ) return self.drop(self.norm(self.act(self.dense(hidden_states))))