diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 435b482df599cf..d22ec0bbc5ccb8 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -496,6 +496,8 @@ title: mLUKE - local: model_doc/mobilebert title: MobileBERT + - local: model_doc/modernbert + title: ModernBert - local: model_doc/mpnet title: MPNet - local: model_doc/mpt diff --git a/docs/source/en/index.md b/docs/source/en/index.md index 3bd1c286d43240..fa9d597a43d7f0 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -231,6 +231,7 @@ Flax), PyTorch, and/or TensorFlow. | [MobileNetV2](model_doc/mobilenet_v2) | ✅ | ❌ | ❌ | | [MobileViT](model_doc/mobilevit) | ✅ | ✅ | ❌ | | [MobileViTV2](model_doc/mobilevitv2) | ✅ | ❌ | ❌ | +| [ModernBERT](model_doc/modernbert) | ✅ | ❌ | ❌ | | [Moshi](model_doc/moshi) | ✅ | ❌ | ❌ | | [MPNet](model_doc/mpnet) | ✅ | ✅ | ❌ | | [MPT](model_doc/mpt) | ✅ | ❌ | ❌ | diff --git a/docs/source/en/model_doc/modernbert.md b/docs/source/en/model_doc/modernbert.md new file mode 100644 index 00000000000000..ab09f38ff12154 --- /dev/null +++ b/docs/source/en/model_doc/modernbert.md @@ -0,0 +1,91 @@ + + +# ModernBert + +
+ +Models + + +Paper page + +
+ +## Overview + +The ModernBert model was proposed in [Smarter, Better, Faster, Longer: A Modern Bidirectional Encoder for Fast, Memory Efficient, and Long Context Finetuning and Inference](https://arxiv.org/abs/2412.13663) by Benjamin Warner, Antoine Chaffin, Benjamin Clavié, Orion Weller, Oskar Hallström, Said Taghadouini, Alexis Galalgher, Raja Bisas, Faisal Ladhak, Tom Aarsen, Nathan Cooper, Grifin Adams, Jeremy Howard and Iacopo Poli. + +It is a refresh of the traditional encoder architecture, as used in previous models such as [BERT](https://huggingface.co/docs/transformers/en/model_doc/bert) and [RoBERTa](https://huggingface.co/docs/transformers/en/model_doc/roberta). + +It builds on BERT and implements many modern architectural improvements which have been developed since its original release, such as: +- [Rotary Positional Embeddings](https://huggingface.co/blog/designing-positional-encoding) to support sequences of up to 8192 tokens. +- [Unpadding](https://arxiv.org/abs/2208.08124) to ensure no compute is wasted on padding tokens, speeding up processing time for batches with mixed-length sequences. +- [GeGLU](https://arxiv.org/abs/2002.05202) Replacing the original MLP layers with GeGLU layers, shown to improve performance. +- [Alternating Attention](https://arxiv.org/abs/2004.05150v2) where most attention layers employ a sliding window of 128 tokens, with Global Attention only used every 3 layers. +- [Flash Attention](https://github.com/Dao-AILab/flash-attention) to speed up processing. +- A model designed following recent [The Case for Co-Designing Model Architectures with Hardware](https://arxiv.org/abs/2401.14489), ensuring maximum efficiency across inference GPUs. +- Modern training data scales (2 trillion tokens) and mixtures (including code ande math data) + +The abstract from the paper is the following: + +*Encoder-only transformer models such as BERT offer a great performance-size tradeoff for retrieval and classification tasks with respect to larger decoder-only models. Despite being the workhorse of numerous production pipelines, there have been limited Pareto improvements to BERT since its release. In this paper, we introduce ModernBERT, bringing modern model optimizations to encoder-only models and representing a major Pareto improvement over older encoders. Trained on 2 trillion tokens with a native 8192 sequence length, ModernBERT models exhibit state-of-the-art results on a large pool of evaluations encompassing diverse classification tasks and both single and multi-vector retrieval on different domains (including code). In addition to strong downstream performance, ModernBERT is also the most speed and memory efficient encoder and is designed for inference on common GPUs.* + +The original code can be found [here](https://github.com/answerdotai/modernbert). + +## Resources + +A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with ModernBert. + + + +- A script on how to [finetune for text similarity or information retrieval with Sentence Transformers](https://github.com/AnswerDotAI/ModernBERT/blob/main/examples/train_st.py). 🌎 +- A script on how to [finetune for information retrieval with PyLate](https://github.com/AnswerDotAI/ModernBERT/blob/main/examples/train_pylate.py). 🌎 + + + +- [Masked language modeling task guide](../tasks/masked_language_modeling) + + +## ModernBertConfig + +[[autodoc]] ModernBertConfig + + + + +## ModernBertModel + +[[autodoc]] ModernBertModel + - forward + +## ModernBertForMaskedLM + +[[autodoc]] ModernBertForMaskedLM + - forward + +## ModernBertForSequenceClassification + +[[autodoc]] ModernBertForSequenceClassification + - forward + +## ModernBertForTokenClassification + +[[autodoc]] ModernBertForTokenClassification + - forward + + + diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index cbb498070d69e5..d849c571cc6ad0 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -73,6 +73,7 @@ FlashAttention-2 is currently supported for the following architectures: * [MBart](https://huggingface.co/docs/transformers/model_doc/mbart#transformers.MBartModel) * [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel) * [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel) +* [ModernBert](https://huggingface.co/docs/transformers/model_doc/modernbert#transformers.ModernBert) * [Moshi](https://huggingface.co/docs/transformers/model_doc/moshi#transformers.MoshiModel) * [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel) * [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel) @@ -263,6 +264,7 @@ For now, Transformers supports SDPA inference and training for the following arc * [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel) * [Mllama](https://huggingface.co/docs/transformers/model_doc/mllama#transformers.MllamaForConditionalGeneration) * [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel) +* [ModernBert](https://huggingface.co/docs/transformers/model_doc/modernbert#transformers.ModernBert) * [Moshi](https://huggingface.co/docs/transformers/model_doc/moshi#transformers.MoshiModel) * [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel) * [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 920dc334dbb2a4..4ad57faf1cb98d 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -605,6 +605,7 @@ "models.mobilenet_v2": ["MobileNetV2Config"], "models.mobilevit": ["MobileViTConfig"], "models.mobilevitv2": ["MobileViTV2Config"], + "models.modernbert": ["ModernBertConfig"], "models.moshi": [ "MoshiConfig", "MoshiDepthConfig", @@ -2861,6 +2862,15 @@ "MobileViTV2PreTrainedModel", ] ) + _import_structure["models.modernbert"].extend( + [ + "ModernBertForMaskedLM", + "ModernBertForSequenceClassification", + "ModernBertForTokenClassification", + "ModernBertModel", + "ModernBertPreTrainedModel", + ] + ) _import_structure["models.moshi"].extend( [ "MoshiForCausalLM", @@ -5556,6 +5566,7 @@ from .models.mobilevitv2 import ( MobileViTV2Config, ) + from .models.modernbert import ModernBertConfig from .models.moshi import ( MoshiConfig, MoshiDepthConfig, @@ -7546,6 +7557,13 @@ MobileViTV2Model, MobileViTV2PreTrainedModel, ) + from .models.modernbert import ( + ModernBertForMaskedLM, + ModernBertForSequenceClassification, + ModernBertForTokenClassification, + ModernBertModel, + ModernBertPreTrainedModel, + ) from .models.moshi import ( MoshiForCausalLM, MoshiForConditionalGeneration, diff --git a/src/transformers/loss/loss_utils.py b/src/transformers/loss/loss_utils.py index efa23d24e360b4..7f6aaaa44264ca 100644 --- a/src/transformers/loss/loss_utils.py +++ b/src/transformers/loss/loss_utils.py @@ -47,6 +47,22 @@ def ForCausalLMLoss( return loss +def ForMaskedLMLoss( + logits, labels, vocab_size: int, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs +): + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + + # Flatten the tokens + logits = logits.view(-1, vocab_size) + labels = labels.view(-1) + # Enable model parallelism + + labels = labels.to(logits.device) + loss = fixed_cross_entropy(logits, labels, num_items_in_batch, ignore_index, **kwargs) + return loss + + def ForSequenceClassificationLoss(labels, pooled_logits, config, **kwargs): num_labels = config.num_labels if config.problem_type is None: @@ -101,6 +117,7 @@ def ForTokenClassification(logits, labels, config, **kwargs): LOSS_MAPPING = { "ForCausalLM": ForCausalLMLoss, + "ForMaskedLM": ForMaskedLMLoss, "ForQuestionAnswering": ForQuestionAnsweringLoss, "ForSequenceClassification": ForSequenceClassificationLoss, "ForTokenClassification": ForTokenClassification, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 5eb74fab5abe71..b37dacb37587ba 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -166,6 +166,7 @@ mobilenet_v2, mobilevit, mobilevitv2, + modernbert, moshi, mpnet, mpt, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index d7d8281c2e3f03..9dc6247511c9fe 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -186,6 +186,7 @@ ("mobilenet_v2", "MobileNetV2Config"), ("mobilevit", "MobileViTConfig"), ("mobilevitv2", "MobileViTV2Config"), + ("modernbert", "ModernBertConfig"), ("moshi", "MoshiConfig"), ("mpnet", "MPNetConfig"), ("mpt", "MptConfig"), @@ -508,6 +509,7 @@ ("mobilenet_v2", "MobileNetV2"), ("mobilevit", "MobileViT"), ("mobilevitv2", "MobileViTV2"), + ("modernbert", "ModernBERT"), ("moshi", "Moshi"), ("mpnet", "MPNet"), ("mpt", "MPT"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 5d41ad42beea7e..89ae437d908f34 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -175,6 +175,7 @@ ("mobilenet_v2", "MobileNetV2Model"), ("mobilevit", "MobileViTModel"), ("mobilevitv2", "MobileViTV2Model"), + ("modernbert", "ModernBertModel"), ("moshi", "MoshiModel"), ("mpnet", "MPNetModel"), ("mpt", "MptModel"), @@ -836,6 +837,7 @@ ("mega", "MegaForMaskedLM"), ("megatron-bert", "MegatronBertForMaskedLM"), ("mobilebert", "MobileBertForMaskedLM"), + ("modernbert", "ModernBertForMaskedLM"), ("mpnet", "MPNetForMaskedLM"), ("mra", "MraForMaskedLM"), ("mvp", "MvpForConditionalGeneration"), @@ -990,6 +992,7 @@ ("mistral", "MistralForSequenceClassification"), ("mixtral", "MixtralForSequenceClassification"), ("mobilebert", "MobileBertForSequenceClassification"), + ("modernbert", "ModernBertForSequenceClassification"), ("mpnet", "MPNetForSequenceClassification"), ("mpt", "MptForSequenceClassification"), ("mra", "MraForSequenceClassification"), @@ -1176,6 +1179,7 @@ ("mistral", "MistralForTokenClassification"), ("mixtral", "MixtralForTokenClassification"), ("mobilebert", "MobileBertForTokenClassification"), + ("modernbert", "ModernBertForTokenClassification"), ("mpnet", "MPNetForTokenClassification"), ("mpt", "MptForTokenClassification"), ("mra", "MraForTokenClassification"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 1cdebde8cd904f..350c230f142c15 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -313,6 +313,7 @@ ("mllama", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)), ("mluke", ("MLukeTokenizer" if is_sentencepiece_available() else None, None)), ("mobilebert", ("MobileBertTokenizer", "MobileBertTokenizerFast" if is_tokenizers_available() else None)), + ("modernbert", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), ("moshi", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), ("mpnet", ("MPNetTokenizer", "MPNetTokenizerFast" if is_tokenizers_available() else None)), ("mpt", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)), diff --git a/src/transformers/models/modernbert/__init__.py b/src/transformers/models/modernbert/__init__.py new file mode 100644 index 00000000000000..18317742981909 --- /dev/null +++ b/src/transformers/models/modernbert/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_modernbert import * + from .modeling_modernbert import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/modernbert/configuration_modernbert.py b/src/transformers/models/modernbert/configuration_modernbert.py new file mode 100644 index 00000000000000..13e9edf067efc4 --- /dev/null +++ b/src/transformers/models/modernbert/configuration_modernbert.py @@ -0,0 +1,213 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/modernbert/modular_modernbert.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_modernbert.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2024 Answer.AI, LightOn, and contributors, and the HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Literal + +from ...configuration_utils import PretrainedConfig + + +class ModernBertConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`ModernBertModel`]. It is used to instantiate an ModernBert + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the ModernBERT-base. + e.g. [answerdotai/ModernBERT-base](https://huggingface.co/answerdotai/ModernBERT-base) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 50368): + Vocabulary size of the ModernBert model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`ModernBertModel`] + hidden_size (`int`, *optional*, defaults to 768): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 1152): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 22): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer decoder. + hidden_activation (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the decoder. Will default to `"gelu"` + if not specified. + max_position_embeddings (`int`, *optional*, defaults to 8192): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + initializer_cutoff_factor (`float`, *optional*, defaults to 2.0): + The cutoff factor for the truncated_normal_initializer for initializing all weight matrices. + norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + norm_bias (`bool`, *optional*, defaults to `False`): + Whether to use bias in the normalization layers. + pad_token_id (`int`, *optional*, defaults to 50283): + Padding token id. + eos_token_id (`int`, *optional*, defaults to 50282): + End of stream token id. + bos_token_id (`int`, *optional*, defaults to 50281): + Beginning of stream token id. + cls_token_id (`int`, *optional*, defaults to 50281): + Classification token id. + sep_token_id (`int`, *optional*, defaults to 50282): + Separation token id. + global_rope_theta (`float`, *optional*, defaults to 160000.0): + The base period of the global RoPE embeddings. + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + global_attn_every_n_layers (`int`, *optional*, defaults to 3): + The number of layers between global attention layers. + local_attention (`int`, *optional*, defaults to 128): + The window size for local attention. + local_rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the local RoPE embeddings. + embedding_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the embeddings. + mlp_bias (`bool`, *optional*, defaults to `False`): + Whether to use bias in the MLP layers. + mlp_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the MLP layers. + decoder_bias (`bool`, *optional*, defaults to `True`): + Whether to use bias in the decoder layers. + classifier_pooling (`str`, *optional*, defaults to `"cls"`): + The pooling method for the classifier. Should be either `"cls"` or `"mean"`. In local attention layers, the + CLS token doesn't attend to all tokens on long sequences. + classifier_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the classifier. + classifier_bias (`bool`, *optional*, defaults to `False`): + Whether to use bias in the classifier. + classifier_activation (`str`, *optional*, defaults to `"gelu"`): + The activation function for the classifier. + deterministic_flash_attn (`bool`, *optional*, defaults to `False`): + Whether to use deterministic flash attention. If `False`, inference will be faster but not deterministic. + sparse_prediction (`bool`, *optional*, defaults to `False`): + Whether to use sparse prediction for the masked language model instead of returning the full dense logits. + sparse_pred_ignore_index (`int`, *optional*, defaults to -100): + The index to ignore for the sparse prediction. + reference_compile (`bool`, *optional*): + Whether to compile the layers of the model which were compiled during pretraining. If `None`, then parts of + the model will be compiled if 1) `triton` is installed, 2) the model is not on MPS, 3) the model is not + shared between devices, and 4) the model is not resized after initialization. If `True`, then the model may + be faster in some scenarios. + + Examples: + + ```python + >>> from transformers import ModernBertModel, ModernBertConfig + + >>> # Initializing a ModernBert style configuration + >>> configuration = ModernBertConfig() + + >>> # Initializing a model from the modernbert-base style configuration + >>> model = ModernBertModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "modernbert" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=50368, + hidden_size=768, + intermediate_size=1152, + num_hidden_layers=22, + num_attention_heads=12, + hidden_activation="gelu", + max_position_embeddings=8192, + initializer_range=0.02, + initializer_cutoff_factor=2.0, + norm_eps=1e-5, + norm_bias=False, + pad_token_id=50283, + eos_token_id=50282, + bos_token_id=50281, + cls_token_id=50281, + sep_token_id=50282, + global_rope_theta=160000.0, + attention_bias=False, + attention_dropout=0.0, + global_attn_every_n_layers=3, + local_attention=128, + local_rope_theta=10000.0, + embedding_dropout=0.0, + mlp_bias=False, + mlp_dropout=0.0, + decoder_bias=True, + classifier_pooling: Literal["cls", "mean"] = "cls", + classifier_dropout=0.0, + classifier_bias=False, + classifier_activation="gelu", + deterministic_flash_attn=False, + sparse_prediction=False, + sparse_pred_ignore_index=-100, + reference_compile=None, + **kwargs, + ): + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + cls_token_id=cls_token_id, + sep_token_id=sep_token_id, + **kwargs, + ) + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.initializer_range = initializer_range + self.initializer_cutoff_factor = initializer_cutoff_factor + self.norm_eps = norm_eps + self.norm_bias = norm_bias + self.global_rope_theta = global_rope_theta + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.hidden_activation = hidden_activation + self.global_attn_every_n_layers = global_attn_every_n_layers + self.local_attention = local_attention + self.local_rope_theta = local_rope_theta + self.embedding_dropout = embedding_dropout + self.mlp_bias = mlp_bias + self.mlp_dropout = mlp_dropout + self.decoder_bias = decoder_bias + self.classifier_pooling = classifier_pooling + self.classifier_dropout = classifier_dropout + self.classifier_bias = classifier_bias + self.classifier_activation = classifier_activation + self.deterministic_flash_attn = deterministic_flash_attn + self.sparse_prediction = sparse_prediction + self.sparse_pred_ignore_index = sparse_pred_ignore_index + self.reference_compile = reference_compile + + if self.classifier_pooling not in ["cls", "mean"]: + raise ValueError( + f'Invalid value for `classifier_pooling`, should be either "cls" or "mean", but is {self.classifier_pooling}.' + ) + + +__all__ = ["ModernBertConfig"] diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py new file mode 100644 index 00000000000000..db8d98893f96fe --- /dev/null +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -0,0 +1,1322 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/modernbert/modular_modernbert.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_modernbert.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2024 Answer.AI, LightOn, and contributors, and the HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Dict, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...modeling_outputs import BaseModelOutput, MaskedLMOutput, SequenceClassifierOutput, TokenClassifierOutput +from ...modeling_utils import PreTrainedModel +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + logging, +) +from ...utils.import_utils import is_triton_available +from .configuration_modernbert import ModernBertConfig + + +if is_flash_attn_2_available(): + from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func + from flash_attn.layers.rotary import RotaryEmbedding + from flash_attn.ops.triton.rotary import apply_rotary +else: + RotaryEmbedding = object + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "answerdotai/ModernBERT-base" +_CONFIG_FOR_DOC = "ModernBertConfig" + + +class ApplyRotaryEmbUnpad(torch.autograd.Function): + @staticmethod + def forward( + ctx, + qkv, + cos, + sin, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + ): + # (total_nnz, 3, nheads, headdim) + qkv = qkv.contiguous() + total_nnz, _three, _nheads, headdim = qkv.shape + # We need qkv to be contiguous so that when we reshape to combine (3, nheads) dimensions, + # we get the same tensor + # qk = rearrange(qkv[:, :2], "b_s t h d -> b_s (t h) d") + qk = qkv[:, :2].view(total_nnz, -1, headdim) + apply_rotary( + qk, + cos, + sin, + seqlen_offsets=0, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + interleaved=False, + inplace=True, + ) + + ctx.save_for_backward(cos, sin, cu_seqlens) + ctx.max_seqlen = max_seqlen + return qkv + + @staticmethod + def backward(ctx, do): + cos, sin, cu_seqlens = ctx.saved_tensors + do = do.contiguous() + total_nnz, _three, _nheads, headdim = do.shape + # We need dqkv to be contiguous so that when we reshape to combine (3, nheads) dimensions, + # we get the same tensor + dqk = do[:, :2].view(total_nnz, -1, headdim) + apply_rotary( + dqk, + cos, + sin, + seqlen_offsets=0, + cu_seqlens=cu_seqlens, + max_seqlen=ctx.max_seqlen, + interleaved=False, + inplace=True, + conjugate=True, + ) + + return do, None, None, None, None, None, None + + +def apply_rotary_unpadded( + qkv, + cos, + sin, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, +): + """ + Arguments: + qkv: (total_nnz, 3, nheads, headdim) - input tensor for packed QKV. + cos, sin: (seqlen_rotary, rotary_dim / 2) + interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead + of 1st half and 2nd half (GPT-NeoX style). + inplace: if True, apply rotary embedding in-place. + seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount. + Most commonly used in inference when we have KV cache. + cu_seqlens: (batch + 1,) or None + max_seqlen: int + Return: + out: (total_nnz, dim) + rotary_dim must be <= headdim + Apply rotary embedding to the first rotary_dim of x. + """ + return ApplyRotaryEmbUnpad.apply(qkv, cos, sin, cu_seqlens, max_seqlen) + + +class ModernBertUnpaddedRotaryEmbedding(RotaryEmbedding): + """ + The rotary position embeddings applied directly to unpadded sequences. + """ + + def __init__( + self, + dim: int, + base: float = 10000.0, + max_seqlen: Optional[int] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + """ + max_seqlen: if max_seqlen, device, and dtype are provided, we precompute the cos_sin_cache + up to max_seqlen. If the max_seqlen, device, or dtype during training/inference differ, + the cos_sin_cache wll be recomputed during the forward pass. + """ + super().__init__(dim=dim, base=base, pos_idx_in_fp32=True, device=device, interleaved=False) + self.max_seqlen = max_seqlen + + if max_seqlen is not None and device is not None and dtype is not None: + self._update_cos_sin_cache(max_seqlen, device=device, dtype=dtype) + + def forward( + self, + qkv: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: Optional[int] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Apply rotary embedding *inplace* to qkv. + qkv: (total_nnz, 3, nheads, headdim) + cu_seqlens: (batch + 1,) cumulative sequence lengths + max_seqlen: int max seq length in the batch + """ + if max_seqlen is not None: + self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype) + + qkv = apply_rotary_unpadded( + qkv, + self._cos_cached, + self._sin_cached, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + + return qkv + + def extra_repr(self) -> str: + return f"dim={self.dim}, base={self.base}, scale_base={self.scale_base}" + + +class ModernBertEmbeddings(nn.Module): + """ + Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. + """ + + def __init__(self, config: ModernBertConfig): + super().__init__() + self.config = config + self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) + self.drop = nn.Dropout(config.embedding_dropout) + + @torch.compile(dynamic=True) + def compiled_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor: + return self.drop(self.norm(self.tok_embeddings(input_ids))) + + def forward(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None) -> torch.Tensor: + hidden_states = ( + self.compiled_embeddings(input_ids) + if self.config.reference_compile + else self.drop(self.norm(self.tok_embeddings(input_ids))) + ) + return hidden_states + + +class ModernBertMLP(nn.Module): + """Applies the GLU at the end of each ModernBERT layer. + + Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate` + and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality. + """ + + def __init__(self, config: ModernBertConfig): + super().__init__() + self.config = config + self.Wi = nn.Linear(config.hidden_size, int(config.intermediate_size) * 2, bias=config.mlp_bias) + self.act = ACT2FN[config.hidden_activation] + self.drop = nn.Dropout(config.mlp_dropout) + self.Wo = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_bias) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input, gate = self.Wi(hidden_states).chunk(2, dim=-1) + return self.Wo(self.drop(self.act(input) * gate)) + + +class ModernBertRotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) + self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) + + @torch.no_grad() + def forward(self, x, position_ids, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + self.inv_freq.to(x.device) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def eager_attention_forward( + module: "ModernBertAttention", + qkv: torch.Tensor, + attention_mask: torch.Tensor, + sliding_window_mask: torch.Tensor, + position_ids: Optional[torch.LongTensor], + local_attention: Tuple[int, int], + bs: int, + dim: int, + output_attentions: Optional[bool] = False, + **_kwargs, +) -> Tuple[torch.Tensor, torch.Tensor] | Tuple[torch.Tensor]: + # qkv: [batch_size, seqlen, 3, nheads, headdim] + cos, sin = module.rotary_emb(qkv, position_ids=position_ids) + query, key, value = qkv.transpose(3, 1).unbind(dim=2) + # query, key, value: [batch_size, heads, seq_len, head_dim] + query, key = apply_rotary_pos_emb(query, key, cos, sin) + + scale = module.head_dim**-0.5 + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scale + + if local_attention != (-1, -1): + attention_mask = sliding_window_mask + + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=module.attention_dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bs, -1, dim) + if output_attentions: + return (attn_output, attn_weights) + return (attn_output,) + + +def flash_attention_forward( + module: "ModernBertAttention", + qkv: torch.Tensor, + rotary_emb: ModernBertUnpaddedRotaryEmbedding, + cu_seqlens: torch.Tensor, + max_seqlen: int, + local_attention: Tuple[int, int], + bs: int, + dim: int, + target_dtype: torch.dtype = torch.bfloat16, + **_kwargs, +) -> Tuple[torch.Tensor]: + # (total_seqlen, 3, nheads, headdim) + qkv = rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + + convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16) + if convert_dtype: + # FA2 implementation only supports fp16 and bf16. If FA2 is supported, + # bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported) + orig_dtype = qkv.dtype + qkv = qkv.to(target_dtype) + + attn = flash_attn_varlen_qkvpacked_func( + qkv, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + dropout_p=module.attention_dropout if module.training else 0.0, + deterministic=module.deterministic_flash_attn, + window_size=local_attention, + ) + attn = attn.to(orig_dtype) # type: ignore + else: + attn = flash_attn_varlen_qkvpacked_func( + qkv, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + dropout_p=module.attention_dropout if module.training else 0.0, + deterministic=module.deterministic_flash_attn, + window_size=local_attention, + ) + return (attn.view(bs, dim),) + + +def sdpa_attention_forward( + module: "ModernBertAttention", + qkv: torch.Tensor, + attention_mask: torch.Tensor, + sliding_window_mask: torch.Tensor, + position_ids: Optional[torch.LongTensor], + local_attention: Tuple[int, int], + bs: int, + dim: int, + **_kwargs, +) -> Tuple[torch.Tensor]: + # qkv: [batch_size, seqlen, 3, nheads, headdim] + cos, sin = module.rotary_emb(qkv, position_ids=position_ids) + query, key, value = qkv.transpose(3, 1).unbind(dim=2) + # query, key, value: [batch_size, heads, seq_len, head_dim] + query, key = apply_rotary_pos_emb(query, key, cos, sin) + + if local_attention != (-1, -1): + attention_mask = sliding_window_mask + + attn_output = ( + F.scaled_dot_product_attention( + query, + key, + value, + dropout_p=module.attention_dropout if module.training else 0.0, + attn_mask=attention_mask, + ) + .transpose(1, 2) + .contiguous() + ) + attn_output = attn_output.view(bs, -1, dim) + return (attn_output,) + + +MODERNBERT_ATTENTION_FUNCTION = { + "flash_attention_2": flash_attention_forward, + "eager": eager_attention_forward, + "sdpa": sdpa_attention_forward, +} + + +class ModernBertAttention(nn.Module): + """Performs multi-headed self attention on a batch of unpadded sequences. + + If Flash Attention 2 is installed, this module uses Flash Attention to improve throughput. + If Flash Attention 2 is not installed, the implementation will use PyTorch's SDPA kernel, + which requires padding and unpadding inputs, adding some overhead. + + See `forward` method for additional details. + """ + + def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): + super().__init__() + self.config = config + self.layer_id = layer_id + + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention heads ({config.num_attention_heads})" + ) + + self.attention_dropout = config.attention_dropout + self.deterministic_flash_attn = config.deterministic_flash_attn + self.num_heads = config.num_attention_heads + self.head_dim = config.hidden_size // config.num_attention_heads + self.all_head_size = self.head_dim * self.num_heads + self.Wqkv = nn.Linear(config.hidden_size, 3 * self.all_head_size, bias=config.attention_bias) + + if layer_id % config.global_attn_every_n_layers != 0: + self.local_attention = (config.local_attention // 2, config.local_attention // 2) + else: + self.local_attention = (-1, -1) + + rope_theta = config.global_rope_theta + max_position_embeddings = config.max_position_embeddings + if self.local_attention != (-1, -1): + if config.local_rope_theta is not None: + rope_theta = config.local_rope_theta + max_position_embeddings = config.local_attention + + if config._attn_implementation == "flash_attention_2": + self.rotary_emb = ModernBertUnpaddedRotaryEmbedding( + dim=self.head_dim, max_seqlen=max_position_embeddings, base=rope_theta + ) + else: + self.rotary_emb = ModernBertRotaryEmbedding( + dim=self.head_dim, max_position_embeddings=max_position_embeddings, base=rope_theta + ) + + self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias) + self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity() + self.pruned_heads = set() + + def forward( + self, + hidden_states: torch.Tensor, + output_attentions: Optional[bool] = False, + **kwargs, + ) -> torch.Tensor: + qkv = self.Wqkv(hidden_states) + + bs = hidden_states.shape[0] + if self.config._attn_implementation == "flash_attention_2": + qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) + else: + qkv = qkv.view(bs, -1, 3, self.num_heads, self.head_dim) + + attn_outputs = MODERNBERT_ATTENTION_FUNCTION[self.config._attn_implementation]( + self, + qkv=qkv, + rotary_emb=self.rotary_emb, + local_attention=self.local_attention, + bs=bs, + dim=self.all_head_size, + output_attentions=output_attentions, + **kwargs, + ) + hidden_states = attn_outputs[0] + hidden_states = self.out_drop(self.Wo(hidden_states)) + + return (hidden_states,) + attn_outputs[1:] # add attentions if outputted + + +class ModernBertEncoderLayer(nn.Module): + def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): + super().__init__() + self.config = config + if layer_id == 0: + self.attn_norm = nn.Identity() + else: + self.attn_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) + self.attn = ModernBertAttention(config=config, layer_id=layer_id) + self.mlp_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) + self.mlp = ModernBertMLP(config) + + @torch.compile(dynamic=True) + def compiled_mlp(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.mlp(self.mlp_norm(hidden_states)) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + sliding_window_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + output_attentions: Optional[bool] = False, + ) -> torch.Tensor: + attn_outputs = self.attn( + self.attn_norm(hidden_states), + attention_mask=attention_mask, + sliding_window_mask=sliding_window_mask, + position_ids=position_ids, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + output_attentions=output_attentions, + ) + hidden_states = hidden_states + attn_outputs[0] + mlp_output = ( + self.compiled_mlp(hidden_states) + if self.config.reference_compile + else self.mlp(self.mlp_norm(hidden_states)) + ) + hidden_states = hidden_states + mlp_output + + return (hidden_states,) + attn_outputs[1:] # add attentions if outputted + + +MODERNBERT_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`ModernBertConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare ModernBert Model outputting raw hidden-states without any specific head on top.", + MODERNBERT_START_DOCSTRING, +) +class ModernBertPreTrainedModel(PreTrainedModel): + config_class = ModernBertConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["ModernBertEmbeddings", "ModernBertEncoderLayer"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = False + + def _init_weights(self, module: nn.Module): + cutoff_factor = self.config.initializer_cutoff_factor + if cutoff_factor is None: + cutoff_factor = 3 + + def init_weight(module: nn.Module, std: float): + nn.init.trunc_normal_( + module.weight, + mean=0.0, + std=std, + a=-cutoff_factor * std, + b=cutoff_factor * std, + ) + + if isinstance(module, nn.Linear): + if module.bias is not None: + nn.init.zeros_(module.bias) + + stds = { + "in": self.config.initializer_range, + "out": self.config.initializer_range / math.sqrt(2.0 * self.config.num_hidden_layers), + "embedding": self.config.initializer_range, + "final_out": self.config.hidden_size**-0.5, + } + + if isinstance(module, ModernBertEmbeddings): + init_weight(module.tok_embeddings, stds["embedding"]) + elif isinstance(module, ModernBertMLP): + init_weight(module.Wi, stds["in"]) + init_weight(module.Wo, stds["out"]) + elif isinstance(module, ModernBertAttention): + init_weight(module.Wqkv, stds["in"]) + init_weight(module.Wo, stds["out"]) + elif isinstance(module, ModernBertPredictionHead): + init_weight(module.dense, stds["in"]) + elif isinstance(module, ModernBertPoolingHead): + init_weight(module.dense, stds["out"]) + elif isinstance(module, ModernBertForMaskedLM): + init_weight(module.decoder, stds["out"]) + elif isinstance(module, (ModernBertForSequenceClassification, ModernBertForTokenClassification)): + init_weight(module.classifier, stds["final_out"]) + + @classmethod + def _autoset_attn_implementation( + cls, + config, + use_flash_attention_2: bool = False, + torch_dtype: Optional[torch.dtype] = None, + device_map: Optional[Union[str, Dict[str, int]]] = None, + check_device_map: bool = True, + ): + # If the user didn't specify anything, try to use flash_attention_2 if available. + # Otherwise we fall back to the default SDPA -> Eager from the super() method. + if config._attn_implementation_internal is None: + config._attn_implementation_internal = "flash_attention_2" + try: + return cls._check_and_enable_flash_attn_2( + config, + torch_dtype=torch_dtype, + device_map=device_map, + hard_check_only=False, + check_device_map=check_device_map, + ) + except (ValueError, ImportError): + config._attn_implementation_internal = None + return super()._autoset_attn_implementation( + config, + use_flash_attention_2=use_flash_attention_2, + torch_dtype=torch_dtype, + device_map=device_map, + check_device_map=check_device_map, + ) + + def _maybe_set_compile(self): + if self.config.reference_compile is False: + return + + if hasattr(self, "hf_device_map") and len(self.hf_device_map) > 1: + if self.config.reference_compile: + logger.warning_once( + "If `accelerate` split the model across devices, `torch.compile` will not work. " + "Falling back to non-compiled mode." + ) + self.config.reference_compile = False + + if self.device.type == "mps": + if self.config.reference_compile: + logger.warning_once( + "Compiling the model with `torch.compile` and using a `torch.mps` device is not supported. " + "Falling back to non-compiled mode." + ) + self.config.reference_compile = False + + if self.config.reference_compile is None: + self.config.reference_compile = is_triton_available() + + def resize_token_embeddings(self, *args, **kwargs): + model_embeds = super().resize_token_embeddings(*args, **kwargs) + + if self.config.reference_compile in {True, None}: + if self.config.reference_compile: + logger.warning_once( + "Resizing token embeddings with `torch.compile` is not supported. Falling back to non-compiled mode." + ) + self.config.reference_compile = False + + return model_embeds + + +def _unpad_modernbert_input( + inputs: torch.Tensor, + attention_mask: torch.Tensor, + position_ids: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Remove padding from input sequences. + + Args: + inputs: (batch, seqlen, ...) or (batch, seqlen) + attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. + position_ids: (batch, seqlen), int, position ids + labels: (batch, seqlen), int, labels + + Returns: + unpadded_inputs: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask. + indices: (total_nnz) + cu_seqlens: (batch + 1), the cumulative sequence lengths + max_seqlen_in_batch: int + unpadded_position_ids: (total_nnz) or None + unpadded_labels: (total_nnz) or None + """ + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = int(seqlens_in_batch.max().item()) + cu_seqlens = torch.nn.functional.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + + if inputs.dim() == 2: + unpadded_inputs = inputs.flatten()[indices] + else: + batch, seqlen, *rest = inputs.shape + shape = batch * seqlen + unpadded_inputs = inputs.view(shape, *rest)[indices] + + unpadded_position_ids = position_ids.flatten()[indices] if position_ids is not None else None + unpadded_labels = labels.flatten()[indices] if labels is not None else None + + return unpadded_inputs, indices, cu_seqlens, max_seqlen_in_batch, unpadded_position_ids, unpadded_labels + + +def _pad_modernbert_output( + inputs: torch.Tensor, + indices: torch.Tensor, + batch: int, + seqlen: int, +) -> torch.Tensor: + """ + Add padding to sequences. + + Args: + inputs: (total_nnz, ...) or (total_nnz,), where total_nnz = number of tokens selected in attention_mask. + indices: (total_nnz) + batch: int, batch size + seqlen: int, max sequence length + + Returns: + padded_inputs: (batch, seqlen, ...) or (batch, seqlen) + """ + if inputs.dim() == 1: + output = torch.zeros(batch * seqlen, dtype=inputs.dtype, device=inputs.device) + output[indices] = inputs + padded_inputs = output.view(batch, seqlen) + else: + _, *rest = inputs.shape + output = torch.zeros(batch * seqlen, *rest, dtype=inputs.dtype, device=inputs.device) + output[indices] = inputs + padded_inputs = output.view(batch, seqlen, *rest) + + return padded_inputs + + +MODERNBERT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers + perform global attention, while the rest perform local attention. This mask is used to avoid attending to + far-away tokens in the local attention layers. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*): + Indices of the non-padding tokens in the input sequence. Used for unpadding the output. + cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*): + Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors. + max_seqlen (`int`, *optional*): + Maximum sequence length in the batch. Used to pad the output tensors. + batch_size (`int`, *optional*): + Batch size of the input sequences. Used to pad the output tensors. + seq_len (`int`, *optional*): + Sequence length of the input sequences. Used to pad the output tensors. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare ModernBert Model outputting raw hidden-states without any specific head on top.", + MODERNBERT_START_DOCSTRING, +) +class ModernBertModel(ModernBertPreTrainedModel): + def __init__(self, config: ModernBertConfig): + super().__init__(config) + self.config = config + self.embeddings = ModernBertEmbeddings(config) + self.layers = nn.ModuleList( + [ModernBertEncoderLayer(config, layer_id) for layer_id in range(config.num_hidden_layers)] + ) + self.final_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) + self.gradient_checkpointing = False + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.tok_embeddings + + def set_input_embeddings(self, value): + self.embeddings.tok_embeddings = value + + @add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + sliding_window_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + indices: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + batch_size: Optional[int] = None, + seq_len: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + self._maybe_set_compile() + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + + if batch_size is None and seq_len is None: + batch_size, seq_len = input_ids.shape[:2] + + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_len), device=input_ids.device, dtype=torch.bool) + + repad = False + if self.config._attn_implementation == "flash_attention_2": + if indices is None and cu_seqlens is None and max_seqlen is None: + repad = True + with torch.no_grad(): + input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input( + inputs=input_ids, attention_mask=attention_mask + ) + else: + if position_ids is None: + position_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0) + + attention_mask, sliding_window_mask = self._update_attention_mask( + attention_mask, output_attentions=output_attentions + ) + + hidden_states = self.embeddings(input_ids) + + for encoder_layer in self.layers: + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + attention_mask, + sliding_window_mask, + position_ids, + cu_seqlens, + max_seqlen, + output_attentions, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask=attention_mask, + sliding_window_mask=sliding_window_mask, + position_ids=position_ids, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + output_attentions=output_attentions, + ) + hidden_states = layer_outputs[0] + if output_attentions and len(layer_outputs) > 1: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + hidden_states = self.final_norm(hidden_states) + + if repad: + hidden_states = _pad_modernbert_output( + inputs=hidden_states, indices=indices, batch=batch_size, seqlen=seq_len + ) + if all_hidden_states is not None: + all_hidden_states = tuple( + _pad_modernbert_output(inputs=hs, indices=indices, batch=batch_size, seqlen=seq_len) + for hs in all_hidden_states + ) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + def _update_attention_mask(self, attention_mask: torch.Tensor, output_attentions: bool) -> torch.Tensor: + if output_attentions: + if self.config._attn_implementation == "sdpa": + logger.warning_once( + "Outputting attentions is only supported with the 'eager' attention implementation, " + 'not with "sdpa". Falling back to `attn_implementation="eager"`.' + ) + self.config._attn_implementation = "eager" + elif self.config._attn_implementation != "eager": + logger.warning_once( + "Outputting attentions is only supported with the eager attention implementation, " + f'not with {self.config._attn_implementation}. Consider setting `attn_implementation="eager"`.' + " Setting `output_attentions=False`." + ) + + global_attention_mask = _prepare_4d_attention_mask(attention_mask, self.dtype) + + # Create position indices + rows = torch.arange(global_attention_mask.shape[2]).unsqueeze(0) + # Calculate distance between positions + distance = torch.abs(rows - rows.T) + + # Create sliding window mask (1 for positions within window, 0 outside) + window_mask = ( + (distance <= self.config.local_attention // 2).unsqueeze(0).unsqueeze(0).to(attention_mask.device) + ) + # Combine with existing mask + sliding_window_mask = global_attention_mask.masked_fill(window_mask.logical_not(), torch.finfo(self.dtype).min) + + return global_attention_mask, sliding_window_mask + + +class ModernBertPredictionHead(nn.Module): + def __init__(self, config: ModernBertConfig): + super().__init__() + self.config = config + self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.classifier_bias) + self.act = ACT2FN[config.classifier_activation] + self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.norm(self.act(self.dense(hidden_states))) + + +@add_start_docstrings( + "The ModernBert Model with a decoder head on top that is used for masked language modeling.", + MODERNBERT_START_DOCSTRING, +) +class ModernBertForMaskedLM(ModernBertPreTrainedModel): + _tied_weights_keys = ["decoder.weight"] + + def __init__(self, config: ModernBertConfig): + super().__init__(config) + self.config = config + self.model = ModernBertModel(config) + self.head = ModernBertPredictionHead(config) + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=config.decoder_bias) + + self.sparse_prediction = self.config.sparse_prediction + self.sparse_pred_ignore_index = self.config.sparse_pred_ignore_index + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.decoder + + def set_output_embeddings(self, new_embeddings: nn.Linear): + self.decoder = new_embeddings + + @torch.compile(dynamic=True) + def compiled_head(self, output: torch.Tensor) -> torch.Tensor: + return self.decoder(self.head(output)) + + @add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MaskedLMOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + sliding_window_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + indices: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + batch_size: Optional[int] = None, + seq_len: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + self._maybe_set_compile() + + if self.config._attn_implementation == "flash_attention_2": + if indices is None and cu_seqlens is None and max_seqlen is None: + batch_size, seq_len = input_ids.shape[:2] + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_len), device=input_ids.device, dtype=torch.bool) + with torch.no_grad(): + input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input( + inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels + ) + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + sliding_window_mask=sliding_window_mask, + position_ids=position_ids, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + batch_size=batch_size, + seq_len=seq_len, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + last_hidden_state = outputs[0] + + if self.sparse_prediction and labels is not None: + # flatten labels and output first + labels = labels.view(-1) + last_hidden_state = last_hidden_state.view(labels.shape[0], -1) + + # then filter out the non-masked tokens + mask_tokens = labels != self.sparse_pred_ignore_index + last_hidden_state = last_hidden_state[mask_tokens] + labels = labels[mask_tokens] + + logits = ( + self.compiled_head(last_hidden_state) + if self.config.reference_compile + else self.decoder(self.head(last_hidden_state)) + ) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size) + + if self.config._attn_implementation == "flash_attention_2": + with torch.no_grad(): + logits = _pad_modernbert_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len) + if not return_dict: + output = (logits,) + return ((loss,) + output) if loss is not None else output + + return MaskedLMOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class ModernBertPoolingHead(nn.Module): + def __init__(self, config: ModernBertConfig): + super().__init__() + self.config = config + self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.classifier_bias) + self.act = ACT2FN[config.classifier_activation] + self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) + self.drop = torch.nn.Dropout(config.classifier_dropout) + + def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: + if self.config.classifier_pooling == "cls": + hidden_states = hidden_states[:, 0] + elif self.config.classifier_pooling == "mean": + hidden_states = (hidden_states * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum( + dim=1, keepdim=True + ) + + return self.drop(self.norm(self.act(self.dense(hidden_states)))) + + +@add_start_docstrings( + "The ModernBert Model with a sequence classification head on top that performs pooling.", + MODERNBERT_START_DOCSTRING, +) +class ModernBertForSequenceClassification(ModernBertPreTrainedModel): + def __init__(self, config: ModernBertConfig): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + + self.model = ModernBertModel(config) + self.head = ModernBertPoolingHead(config) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=SequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + sliding_window_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + indices: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + batch_size: Optional[int] = None, + seq_len: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + self._maybe_set_compile() + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + sliding_window_mask=sliding_window_mask, + position_ids=position_ids, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + batch_size=batch_size, + seq_len=seq_len, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + last_hidden_state = outputs[0] + + pooled_output = self.head(last_hidden_state, attention_mask) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + "The ModernBert Model with a token classification head on top, e.g. for Named Entity Recognition (NER) tasks.", + MODERNBERT_START_DOCSTRING, +) +class ModernBertForTokenClassification(ModernBertPreTrainedModel): + def __init__(self, config: ModernBertConfig): + super().__init__(config) + self.num_labels = config.num_labels + + self.model = ModernBertModel(config) + self.drop = nn.Dropout(config.classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + sliding_window_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + indices: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + batch_size: Optional[int] = None, + seq_len: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + self._maybe_set_compile() + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + sliding_window_mask=sliding_window_mask, + position_ids=position_ids, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + batch_size=batch_size, + seq_len=seq_len, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + last_hidden_state = outputs[0] + + last_hidden_state = self.drop(last_hidden_state) + logits = self.classifier(last_hidden_state) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "ModernBertModel", + "ModernBertPreTrainedModel", + "ModernBertForMaskedLM", + "ModernBertForSequenceClassification", + "ModernBertForTokenClassification", +] diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py new file mode 100644 index 00000000000000..3c23f9178b1b51 --- /dev/null +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -0,0 +1,1452 @@ +# Copyright 2024 Answer.AI, LightOn, and contributors, and the HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Dict, Literal, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...configuration_utils import PretrainedConfig +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...modeling_outputs import ( + BaseModelOutput, + MaskedLMOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from ...modeling_utils import PreTrainedModel +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + logging, +) +from ...utils.import_utils import is_triton_available +from ..gemma.modeling_gemma import GemmaRotaryEmbedding, apply_rotary_pos_emb + + +if is_flash_attn_2_available(): + from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func + from flash_attn.layers.rotary import RotaryEmbedding + from flash_attn.ops.triton.rotary import apply_rotary +else: + RotaryEmbedding = object + +_CHECKPOINT_FOR_DOC = "answerdotai/ModernBERT-base" +_CONFIG_FOR_DOC = "ModernBertConfig" + +logger = logging.get_logger(__name__) + + +class ModernBertConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`ModernBertModel`]. It is used to instantiate an ModernBert + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the ModernBERT-base. + e.g. [answerdotai/ModernBERT-base](https://huggingface.co/answerdotai/ModernBERT-base) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 50368): + Vocabulary size of the ModernBert model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`ModernBertModel`] + hidden_size (`int`, *optional*, defaults to 768): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 1152): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 22): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer decoder. + hidden_activation (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the decoder. Will default to `"gelu"` + if not specified. + max_position_embeddings (`int`, *optional*, defaults to 8192): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + initializer_cutoff_factor (`float`, *optional*, defaults to 2.0): + The cutoff factor for the truncated_normal_initializer for initializing all weight matrices. + norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + norm_bias (`bool`, *optional*, defaults to `False`): + Whether to use bias in the normalization layers. + pad_token_id (`int`, *optional*, defaults to 50283): + Padding token id. + eos_token_id (`int`, *optional*, defaults to 50282): + End of stream token id. + bos_token_id (`int`, *optional*, defaults to 50281): + Beginning of stream token id. + cls_token_id (`int`, *optional*, defaults to 50281): + Classification token id. + sep_token_id (`int`, *optional*, defaults to 50282): + Separation token id. + global_rope_theta (`float`, *optional*, defaults to 160000.0): + The base period of the global RoPE embeddings. + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + global_attn_every_n_layers (`int`, *optional*, defaults to 3): + The number of layers between global attention layers. + local_attention (`int`, *optional*, defaults to 128): + The window size for local attention. + local_rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the local RoPE embeddings. + embedding_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the embeddings. + mlp_bias (`bool`, *optional*, defaults to `False`): + Whether to use bias in the MLP layers. + mlp_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the MLP layers. + decoder_bias (`bool`, *optional*, defaults to `True`): + Whether to use bias in the decoder layers. + classifier_pooling (`str`, *optional*, defaults to `"cls"`): + The pooling method for the classifier. Should be either `"cls"` or `"mean"`. In local attention layers, the + CLS token doesn't attend to all tokens on long sequences. + classifier_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the classifier. + classifier_bias (`bool`, *optional*, defaults to `False`): + Whether to use bias in the classifier. + classifier_activation (`str`, *optional*, defaults to `"gelu"`): + The activation function for the classifier. + deterministic_flash_attn (`bool`, *optional*, defaults to `False`): + Whether to use deterministic flash attention. If `False`, inference will be faster but not deterministic. + sparse_prediction (`bool`, *optional*, defaults to `False`): + Whether to use sparse prediction for the masked language model instead of returning the full dense logits. + sparse_pred_ignore_index (`int`, *optional*, defaults to -100): + The index to ignore for the sparse prediction. + reference_compile (`bool`, *optional*): + Whether to compile the layers of the model which were compiled during pretraining. If `None`, then parts of + the model will be compiled if 1) `triton` is installed, 2) the model is not on MPS, 3) the model is not + shared between devices, and 4) the model is not resized after initialization. If `True`, then the model may + be faster in some scenarios. + + Examples: + + ```python + >>> from transformers import ModernBertModel, ModernBertConfig + + >>> # Initializing a ModernBert style configuration + >>> configuration = ModernBertConfig() + + >>> # Initializing a model from the modernbert-base style configuration + >>> model = ModernBertModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "modernbert" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=50368, + hidden_size=768, + intermediate_size=1152, + num_hidden_layers=22, + num_attention_heads=12, + hidden_activation="gelu", + max_position_embeddings=8192, + initializer_range=0.02, + initializer_cutoff_factor=2.0, + norm_eps=1e-5, + norm_bias=False, + pad_token_id=50283, + eos_token_id=50282, + bos_token_id=50281, + cls_token_id=50281, + sep_token_id=50282, + global_rope_theta=160000.0, + attention_bias=False, + attention_dropout=0.0, + global_attn_every_n_layers=3, + local_attention=128, + local_rope_theta=10000.0, + embedding_dropout=0.0, + mlp_bias=False, + mlp_dropout=0.0, + decoder_bias=True, + classifier_pooling: Literal["cls", "mean"] = "cls", + classifier_dropout=0.0, + classifier_bias=False, + classifier_activation="gelu", + deterministic_flash_attn=False, + sparse_prediction=False, + sparse_pred_ignore_index=-100, + reference_compile=None, + **kwargs, + ): + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + cls_token_id=cls_token_id, + sep_token_id=sep_token_id, + **kwargs, + ) + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.initializer_range = initializer_range + self.initializer_cutoff_factor = initializer_cutoff_factor + self.norm_eps = norm_eps + self.norm_bias = norm_bias + self.global_rope_theta = global_rope_theta + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.hidden_activation = hidden_activation + self.global_attn_every_n_layers = global_attn_every_n_layers + self.local_attention = local_attention + self.local_rope_theta = local_rope_theta + self.embedding_dropout = embedding_dropout + self.mlp_bias = mlp_bias + self.mlp_dropout = mlp_dropout + self.decoder_bias = decoder_bias + self.classifier_pooling = classifier_pooling + self.classifier_dropout = classifier_dropout + self.classifier_bias = classifier_bias + self.classifier_activation = classifier_activation + self.deterministic_flash_attn = deterministic_flash_attn + self.sparse_prediction = sparse_prediction + self.sparse_pred_ignore_index = sparse_pred_ignore_index + self.reference_compile = reference_compile + + if self.classifier_pooling not in ["cls", "mean"]: + raise ValueError( + f'Invalid value for `classifier_pooling`, should be either "cls" or "mean", but is {self.classifier_pooling}.' + ) + + +def _unpad_modernbert_input( + inputs: torch.Tensor, + attention_mask: torch.Tensor, + position_ids: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Remove padding from input sequences. + + Args: + inputs: (batch, seqlen, ...) or (batch, seqlen) + attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. + position_ids: (batch, seqlen), int, position ids + labels: (batch, seqlen), int, labels + + Returns: + unpadded_inputs: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask. + indices: (total_nnz) + cu_seqlens: (batch + 1), the cumulative sequence lengths + max_seqlen_in_batch: int + unpadded_position_ids: (total_nnz) or None + unpadded_labels: (total_nnz) or None + """ + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = int(seqlens_in_batch.max().item()) + cu_seqlens = torch.nn.functional.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + + if inputs.dim() == 2: + unpadded_inputs = inputs.flatten()[indices] + else: + batch, seqlen, *rest = inputs.shape + shape = batch * seqlen + unpadded_inputs = inputs.view(shape, *rest)[indices] + + unpadded_position_ids = position_ids.flatten()[indices] if position_ids is not None else None + unpadded_labels = labels.flatten()[indices] if labels is not None else None + + return unpadded_inputs, indices, cu_seqlens, max_seqlen_in_batch, unpadded_position_ids, unpadded_labels + + +def _pad_modernbert_output( + inputs: torch.Tensor, + indices: torch.Tensor, + batch: int, + seqlen: int, +) -> torch.Tensor: + """ + Add padding to sequences. + + Args: + inputs: (total_nnz, ...) or (total_nnz,), where total_nnz = number of tokens selected in attention_mask. + indices: (total_nnz) + batch: int, batch size + seqlen: int, max sequence length + + Returns: + padded_inputs: (batch, seqlen, ...) or (batch, seqlen) + """ + if inputs.dim() == 1: + output = torch.zeros(batch * seqlen, dtype=inputs.dtype, device=inputs.device) + output[indices] = inputs + padded_inputs = output.view(batch, seqlen) + else: + _, *rest = inputs.shape + output = torch.zeros(batch * seqlen, *rest, dtype=inputs.dtype, device=inputs.device) + output[indices] = inputs + padded_inputs = output.view(batch, seqlen, *rest) + + return padded_inputs + + +class ApplyRotaryEmbUnpad(torch.autograd.Function): + @staticmethod + def forward( + ctx, + qkv, + cos, + sin, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + ): + # (total_nnz, 3, nheads, headdim) + qkv = qkv.contiguous() + total_nnz, _three, _nheads, headdim = qkv.shape + # We need qkv to be contiguous so that when we reshape to combine (3, nheads) dimensions, + # we get the same tensor + # qk = rearrange(qkv[:, :2], "b_s t h d -> b_s (t h) d") + qk = qkv[:, :2].view(total_nnz, -1, headdim) + apply_rotary( + qk, + cos, + sin, + seqlen_offsets=0, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + interleaved=False, + inplace=True, + ) + + ctx.save_for_backward(cos, sin, cu_seqlens) + ctx.max_seqlen = max_seqlen + return qkv + + @staticmethod + def backward(ctx, do): + cos, sin, cu_seqlens = ctx.saved_tensors + do = do.contiguous() + total_nnz, _three, _nheads, headdim = do.shape + # We need dqkv to be contiguous so that when we reshape to combine (3, nheads) dimensions, + # we get the same tensor + dqk = do[:, :2].view(total_nnz, -1, headdim) + apply_rotary( + dqk, + cos, + sin, + seqlen_offsets=0, + cu_seqlens=cu_seqlens, + max_seqlen=ctx.max_seqlen, + interleaved=False, + inplace=True, + conjugate=True, + ) + + return do, None, None, None, None, None, None + + +def apply_rotary_unpadded( + qkv, + cos, + sin, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, +): + """ + Arguments: + qkv: (total_nnz, 3, nheads, headdim) - input tensor for packed QKV. + cos, sin: (seqlen_rotary, rotary_dim / 2) + interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead + of 1st half and 2nd half (GPT-NeoX style). + inplace: if True, apply rotary embedding in-place. + seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount. + Most commonly used in inference when we have KV cache. + cu_seqlens: (batch + 1,) or None + max_seqlen: int + Return: + out: (total_nnz, dim) + rotary_dim must be <= headdim + Apply rotary embedding to the first rotary_dim of x. + """ + return ApplyRotaryEmbUnpad.apply(qkv, cos, sin, cu_seqlens, max_seqlen) + + +class ModernBertUnpaddedRotaryEmbedding(RotaryEmbedding): + """ + The rotary position embeddings applied directly to unpadded sequences. + """ + + def __init__( + self, + dim: int, + base: float = 10000.0, + max_seqlen: Optional[int] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + """ + max_seqlen: if max_seqlen, device, and dtype are provided, we precompute the cos_sin_cache + up to max_seqlen. If the max_seqlen, device, or dtype during training/inference differ, + the cos_sin_cache wll be recomputed during the forward pass. + """ + super().__init__(dim=dim, base=base, pos_idx_in_fp32=True, device=device, interleaved=False) + self.max_seqlen = max_seqlen + + if max_seqlen is not None and device is not None and dtype is not None: + self._update_cos_sin_cache(max_seqlen, device=device, dtype=dtype) + + def forward( + self, + qkv: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: Optional[int] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Apply rotary embedding *inplace* to qkv. + qkv: (total_nnz, 3, nheads, headdim) + cu_seqlens: (batch + 1,) cumulative sequence lengths + max_seqlen: int max seq length in the batch + """ + if max_seqlen is not None: + self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype) + + qkv = apply_rotary_unpadded( + qkv, + self._cos_cached, + self._sin_cached, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + + return qkv + + def extra_repr(self) -> str: + return f"dim={self.dim}, base={self.base}, scale_base={self.scale_base}" + + +class ModernBertEmbeddings(nn.Module): + """ + Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. + """ + + def __init__(self, config: ModernBertConfig): + super().__init__() + self.config = config + self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) + self.drop = nn.Dropout(config.embedding_dropout) + + @torch.compile(dynamic=True) + def compiled_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor: + return self.drop(self.norm(self.tok_embeddings(input_ids))) + + def forward(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None) -> torch.Tensor: + hidden_states = ( + self.compiled_embeddings(input_ids) + if self.config.reference_compile + else self.drop(self.norm(self.tok_embeddings(input_ids))) + ) + return hidden_states + + +class ModernBertMLP(nn.Module): + """Applies the GLU at the end of each ModernBERT layer. + + Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate` + and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality. + """ + + def __init__(self, config: ModernBertConfig): + super().__init__() + self.config = config + self.Wi = nn.Linear(config.hidden_size, int(config.intermediate_size) * 2, bias=config.mlp_bias) + self.act = ACT2FN[config.hidden_activation] + self.drop = nn.Dropout(config.mlp_dropout) + self.Wo = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_bias) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input, gate = self.Wi(hidden_states).chunk(2, dim=-1) + return self.Wo(self.drop(self.act(input) * gate)) + + +class ModernBertRotaryEmbedding(GemmaRotaryEmbedding): + pass + + +def eager_attention_forward( + module: "ModernBertAttention", + qkv: torch.Tensor, + attention_mask: torch.Tensor, + sliding_window_mask: torch.Tensor, + position_ids: Optional[torch.LongTensor], + local_attention: Tuple[int, int], + bs: int, + dim: int, + output_attentions: Optional[bool] = False, + **_kwargs, +) -> Tuple[torch.Tensor, torch.Tensor] | Tuple[torch.Tensor]: + # qkv: [batch_size, seqlen, 3, nheads, headdim] + cos, sin = module.rotary_emb(qkv, position_ids=position_ids) + query, key, value = qkv.transpose(3, 1).unbind(dim=2) + # query, key, value: [batch_size, heads, seq_len, head_dim] + query, key = apply_rotary_pos_emb(query, key, cos, sin) + + scale = module.head_dim**-0.5 + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scale + + if local_attention != (-1, -1): + attention_mask = sliding_window_mask + + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=module.attention_dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bs, -1, dim) + if output_attentions: + return (attn_output, attn_weights) + return (attn_output,) + + +def flash_attention_forward( + module: "ModernBertAttention", + qkv: torch.Tensor, + rotary_emb: ModernBertUnpaddedRotaryEmbedding, + cu_seqlens: torch.Tensor, + max_seqlen: int, + local_attention: Tuple[int, int], + bs: int, + dim: int, + target_dtype: torch.dtype = torch.bfloat16, + **_kwargs, +) -> Tuple[torch.Tensor]: + # (total_seqlen, 3, nheads, headdim) + qkv = rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + + convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16) + if convert_dtype: + # FA2 implementation only supports fp16 and bf16. If FA2 is supported, + # bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported) + orig_dtype = qkv.dtype + qkv = qkv.to(target_dtype) + + attn = flash_attn_varlen_qkvpacked_func( + qkv, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + dropout_p=module.attention_dropout if module.training else 0.0, + deterministic=module.deterministic_flash_attn, + window_size=local_attention, + ) + attn = attn.to(orig_dtype) # type: ignore + else: + attn = flash_attn_varlen_qkvpacked_func( + qkv, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + dropout_p=module.attention_dropout if module.training else 0.0, + deterministic=module.deterministic_flash_attn, + window_size=local_attention, + ) + return (attn.view(bs, dim),) + + +def sdpa_attention_forward( + module: "ModernBertAttention", + qkv: torch.Tensor, + attention_mask: torch.Tensor, + sliding_window_mask: torch.Tensor, + position_ids: Optional[torch.LongTensor], + local_attention: Tuple[int, int], + bs: int, + dim: int, + **_kwargs, +) -> Tuple[torch.Tensor]: + # qkv: [batch_size, seqlen, 3, nheads, headdim] + cos, sin = module.rotary_emb(qkv, position_ids=position_ids) + query, key, value = qkv.transpose(3, 1).unbind(dim=2) + # query, key, value: [batch_size, heads, seq_len, head_dim] + query, key = apply_rotary_pos_emb(query, key, cos, sin) + + if local_attention != (-1, -1): + attention_mask = sliding_window_mask + + attn_output = ( + F.scaled_dot_product_attention( + query, + key, + value, + dropout_p=module.attention_dropout if module.training else 0.0, + attn_mask=attention_mask, + ) + .transpose(1, 2) + .contiguous() + ) + attn_output = attn_output.view(bs, -1, dim) + return (attn_output,) + + +MODERNBERT_ATTENTION_FUNCTION = { + "flash_attention_2": flash_attention_forward, + "eager": eager_attention_forward, + "sdpa": sdpa_attention_forward, +} + + +class ModernBertAttention(nn.Module): + """Performs multi-headed self attention on a batch of unpadded sequences. + + If Flash Attention 2 is installed, this module uses Flash Attention to improve throughput. + If Flash Attention 2 is not installed, the implementation will use PyTorch's SDPA kernel, + which requires padding and unpadding inputs, adding some overhead. + + See `forward` method for additional details. + """ + + def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): + super().__init__() + self.config = config + self.layer_id = layer_id + + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention heads ({config.num_attention_heads})" + ) + + self.attention_dropout = config.attention_dropout + self.deterministic_flash_attn = config.deterministic_flash_attn + self.num_heads = config.num_attention_heads + self.head_dim = config.hidden_size // config.num_attention_heads + self.all_head_size = self.head_dim * self.num_heads + self.Wqkv = nn.Linear(config.hidden_size, 3 * self.all_head_size, bias=config.attention_bias) + + if layer_id % config.global_attn_every_n_layers != 0: + self.local_attention = (config.local_attention // 2, config.local_attention // 2) + else: + self.local_attention = (-1, -1) + + rope_theta = config.global_rope_theta + max_position_embeddings = config.max_position_embeddings + if self.local_attention != (-1, -1): + if config.local_rope_theta is not None: + rope_theta = config.local_rope_theta + max_position_embeddings = config.local_attention + + if config._attn_implementation == "flash_attention_2": + self.rotary_emb = ModernBertUnpaddedRotaryEmbedding( + dim=self.head_dim, max_seqlen=max_position_embeddings, base=rope_theta + ) + else: + self.rotary_emb = ModernBertRotaryEmbedding( + dim=self.head_dim, max_position_embeddings=max_position_embeddings, base=rope_theta + ) + + self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias) + self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity() + self.pruned_heads = set() + + def forward( + self, + hidden_states: torch.Tensor, + output_attentions: Optional[bool] = False, + **kwargs, + ) -> torch.Tensor: + qkv = self.Wqkv(hidden_states) + + bs = hidden_states.shape[0] + if self.config._attn_implementation == "flash_attention_2": + qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) + else: + qkv = qkv.view(bs, -1, 3, self.num_heads, self.head_dim) + + attn_outputs = MODERNBERT_ATTENTION_FUNCTION[self.config._attn_implementation]( + self, + qkv=qkv, + rotary_emb=self.rotary_emb, + local_attention=self.local_attention, + bs=bs, + dim=self.all_head_size, + output_attentions=output_attentions, + **kwargs, + ) + hidden_states = attn_outputs[0] + hidden_states = self.out_drop(self.Wo(hidden_states)) + + return (hidden_states,) + attn_outputs[1:] # add attentions if outputted + + +class ModernBertEncoderLayer(nn.Module): + def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): + super().__init__() + self.config = config + if layer_id == 0: + self.attn_norm = nn.Identity() + else: + self.attn_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) + self.attn = ModernBertAttention(config=config, layer_id=layer_id) + self.mlp_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) + self.mlp = ModernBertMLP(config) + + @torch.compile(dynamic=True) + def compiled_mlp(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.mlp(self.mlp_norm(hidden_states)) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + sliding_window_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + output_attentions: Optional[bool] = False, + ) -> torch.Tensor: + attn_outputs = self.attn( + self.attn_norm(hidden_states), + attention_mask=attention_mask, + sliding_window_mask=sliding_window_mask, + position_ids=position_ids, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + output_attentions=output_attentions, + ) + hidden_states = hidden_states + attn_outputs[0] + mlp_output = ( + self.compiled_mlp(hidden_states) + if self.config.reference_compile + else self.mlp(self.mlp_norm(hidden_states)) + ) + hidden_states = hidden_states + mlp_output + + return (hidden_states,) + attn_outputs[1:] # add attentions if outputted + + +MODERNBERT_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`ModernBertConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare ModernBert Model outputting raw hidden-states without any specific head on top.", + MODERNBERT_START_DOCSTRING, +) +class ModernBertPreTrainedModel(PreTrainedModel): + config_class = ModernBertConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["ModernBertEmbeddings", "ModernBertEncoderLayer"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = False + + def _init_weights(self, module: nn.Module): + cutoff_factor = self.config.initializer_cutoff_factor + if cutoff_factor is None: + cutoff_factor = 3 + + def init_weight(module: nn.Module, std: float): + nn.init.trunc_normal_( + module.weight, + mean=0.0, + std=std, + a=-cutoff_factor * std, + b=cutoff_factor * std, + ) + + if isinstance(module, nn.Linear): + if module.bias is not None: + nn.init.zeros_(module.bias) + + stds = { + "in": self.config.initializer_range, + "out": self.config.initializer_range / math.sqrt(2.0 * self.config.num_hidden_layers), + "embedding": self.config.initializer_range, + "final_out": self.config.hidden_size**-0.5, + } + + if isinstance(module, ModernBertEmbeddings): + init_weight(module.tok_embeddings, stds["embedding"]) + elif isinstance(module, ModernBertMLP): + init_weight(module.Wi, stds["in"]) + init_weight(module.Wo, stds["out"]) + elif isinstance(module, ModernBertAttention): + init_weight(module.Wqkv, stds["in"]) + init_weight(module.Wo, stds["out"]) + elif isinstance(module, ModernBertPredictionHead): + init_weight(module.dense, stds["in"]) + elif isinstance(module, ModernBertPoolingHead): + init_weight(module.dense, stds["out"]) + elif isinstance(module, ModernBertForMaskedLM): + init_weight(module.decoder, stds["out"]) + elif isinstance(module, (ModernBertForSequenceClassification, ModernBertForTokenClassification)): + init_weight(module.classifier, stds["final_out"]) + + @classmethod + def _autoset_attn_implementation( + cls, + config, + use_flash_attention_2: bool = False, + torch_dtype: Optional[torch.dtype] = None, + device_map: Optional[Union[str, Dict[str, int]]] = None, + check_device_map: bool = True, + ): + # If the user didn't specify anything, try to use flash_attention_2 if available. + # Otherwise we fall back to the default SDPA -> Eager from the super() method. + if config._attn_implementation_internal is None: + config._attn_implementation_internal = "flash_attention_2" + try: + return cls._check_and_enable_flash_attn_2( + config, + torch_dtype=torch_dtype, + device_map=device_map, + hard_check_only=False, + check_device_map=check_device_map, + ) + except (ValueError, ImportError): + config._attn_implementation_internal = None + return super()._autoset_attn_implementation( + config, + use_flash_attention_2=use_flash_attention_2, + torch_dtype=torch_dtype, + device_map=device_map, + check_device_map=check_device_map, + ) + + def _maybe_set_compile(self): + if self.config.reference_compile is False: + return + + if hasattr(self, "hf_device_map") and len(self.hf_device_map) > 1: + if self.config.reference_compile: + logger.warning_once( + "If `accelerate` split the model across devices, `torch.compile` will not work. " + "Falling back to non-compiled mode." + ) + self.config.reference_compile = False + + if self.device.type == "mps": + if self.config.reference_compile: + logger.warning_once( + "Compiling the model with `torch.compile` and using a `torch.mps` device is not supported. " + "Falling back to non-compiled mode." + ) + self.config.reference_compile = False + + if self.config.reference_compile is None: + self.config.reference_compile = is_triton_available() + + def resize_token_embeddings(self, *args, **kwargs): + model_embeds = super().resize_token_embeddings(*args, **kwargs) + + if self.config.reference_compile in {True, None}: + if self.config.reference_compile: + logger.warning_once( + "Resizing token embeddings with `torch.compile` is not supported. Falling back to non-compiled mode." + ) + self.config.reference_compile = False + + return model_embeds + + +MODERNBERT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers + perform global attention, while the rest perform local attention. This mask is used to avoid attending to + far-away tokens in the local attention layers. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*): + Indices of the non-padding tokens in the input sequence. Used for unpadding the output. + cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*): + Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors. + max_seqlen (`int`, *optional*): + Maximum sequence length in the batch. Used to pad the output tensors. + batch_size (`int`, *optional*): + Batch size of the input sequences. Used to pad the output tensors. + seq_len (`int`, *optional*): + Sequence length of the input sequences. Used to pad the output tensors. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare ModernBert Model outputting raw hidden-states without any specific head on top.", + MODERNBERT_START_DOCSTRING, +) +class ModernBertModel(ModernBertPreTrainedModel): + def __init__(self, config: ModernBertConfig): + super().__init__(config) + self.config = config + self.embeddings = ModernBertEmbeddings(config) + self.layers = nn.ModuleList( + [ModernBertEncoderLayer(config, layer_id) for layer_id in range(config.num_hidden_layers)] + ) + self.final_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) + self.gradient_checkpointing = False + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.tok_embeddings + + def set_input_embeddings(self, value): + self.embeddings.tok_embeddings = value + + @add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + sliding_window_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + indices: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + batch_size: Optional[int] = None, + seq_len: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + self._maybe_set_compile() + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + + if batch_size is None and seq_len is None: + batch_size, seq_len = input_ids.shape[:2] + + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_len), device=input_ids.device, dtype=torch.bool) + + repad = False + if self.config._attn_implementation == "flash_attention_2": + if indices is None and cu_seqlens is None and max_seqlen is None: + repad = True + with torch.no_grad(): + input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input( + inputs=input_ids, attention_mask=attention_mask + ) + else: + if position_ids is None: + position_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0) + + attention_mask, sliding_window_mask = self._update_attention_mask( + attention_mask, output_attentions=output_attentions + ) + + hidden_states = self.embeddings(input_ids) + + for encoder_layer in self.layers: + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + attention_mask, + sliding_window_mask, + position_ids, + cu_seqlens, + max_seqlen, + output_attentions, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask=attention_mask, + sliding_window_mask=sliding_window_mask, + position_ids=position_ids, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + output_attentions=output_attentions, + ) + hidden_states = layer_outputs[0] + if output_attentions and len(layer_outputs) > 1: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + hidden_states = self.final_norm(hidden_states) + + if repad: + hidden_states = _pad_modernbert_output( + inputs=hidden_states, indices=indices, batch=batch_size, seqlen=seq_len + ) + if all_hidden_states is not None: + all_hidden_states = tuple( + _pad_modernbert_output(inputs=hs, indices=indices, batch=batch_size, seqlen=seq_len) + for hs in all_hidden_states + ) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + def _update_attention_mask(self, attention_mask: torch.Tensor, output_attentions: bool) -> torch.Tensor: + if output_attentions: + if self.config._attn_implementation == "sdpa": + logger.warning_once( + "Outputting attentions is only supported with the 'eager' attention implementation, " + 'not with "sdpa". Falling back to `attn_implementation="eager"`.' + ) + self.config._attn_implementation = "eager" + elif self.config._attn_implementation != "eager": + logger.warning_once( + "Outputting attentions is only supported with the eager attention implementation, " + f'not with {self.config._attn_implementation}. Consider setting `attn_implementation="eager"`.' + " Setting `output_attentions=False`." + ) + + global_attention_mask = _prepare_4d_attention_mask(attention_mask, self.dtype) + + # Create position indices + rows = torch.arange(global_attention_mask.shape[2]).unsqueeze(0) + # Calculate distance between positions + distance = torch.abs(rows - rows.T) + + # Create sliding window mask (1 for positions within window, 0 outside) + window_mask = ( + (distance <= self.config.local_attention // 2).unsqueeze(0).unsqueeze(0).to(attention_mask.device) + ) + # Combine with existing mask + sliding_window_mask = global_attention_mask.masked_fill(window_mask.logical_not(), torch.finfo(self.dtype).min) + + return global_attention_mask, sliding_window_mask + + +class ModernBertPredictionHead(nn.Module): + def __init__(self, config: ModernBertConfig): + super().__init__() + self.config = config + self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.classifier_bias) + self.act = ACT2FN[config.classifier_activation] + self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.norm(self.act(self.dense(hidden_states))) + + +@add_start_docstrings( + "The ModernBert Model with a decoder head on top that is used for masked language modeling.", + MODERNBERT_START_DOCSTRING, +) +class ModernBertForMaskedLM(ModernBertPreTrainedModel): + _tied_weights_keys = ["decoder.weight"] + + def __init__(self, config: ModernBertConfig): + super().__init__(config) + self.config = config + self.model = ModernBertModel(config) + self.head = ModernBertPredictionHead(config) + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=config.decoder_bias) + + self.sparse_prediction = self.config.sparse_prediction + self.sparse_pred_ignore_index = self.config.sparse_pred_ignore_index + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.decoder + + def set_output_embeddings(self, new_embeddings: nn.Linear): + self.decoder = new_embeddings + + @torch.compile(dynamic=True) + def compiled_head(self, output: torch.Tensor) -> torch.Tensor: + return self.decoder(self.head(output)) + + @add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MaskedLMOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + sliding_window_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + indices: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + batch_size: Optional[int] = None, + seq_len: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + self._maybe_set_compile() + + if self.config._attn_implementation == "flash_attention_2": + if indices is None and cu_seqlens is None and max_seqlen is None: + batch_size, seq_len = input_ids.shape[:2] + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_len), device=input_ids.device, dtype=torch.bool) + with torch.no_grad(): + input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input( + inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels + ) + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + sliding_window_mask=sliding_window_mask, + position_ids=position_ids, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + batch_size=batch_size, + seq_len=seq_len, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + last_hidden_state = outputs[0] + + if self.sparse_prediction and labels is not None: + # flatten labels and output first + labels = labels.view(-1) + last_hidden_state = last_hidden_state.view(labels.shape[0], -1) + + # then filter out the non-masked tokens + mask_tokens = labels != self.sparse_pred_ignore_index + last_hidden_state = last_hidden_state[mask_tokens] + labels = labels[mask_tokens] + + logits = ( + self.compiled_head(last_hidden_state) + if self.config.reference_compile + else self.decoder(self.head(last_hidden_state)) + ) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size) + + if self.config._attn_implementation == "flash_attention_2": + with torch.no_grad(): + logits = _pad_modernbert_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len) + if not return_dict: + output = (logits,) + return ((loss,) + output) if loss is not None else output + + return MaskedLMOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class ModernBertPoolingHead(nn.Module): + def __init__(self, config: ModernBertConfig): + super().__init__() + self.config = config + self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.classifier_bias) + self.act = ACT2FN[config.classifier_activation] + self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) + self.drop = torch.nn.Dropout(config.classifier_dropout) + + def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: + if self.config.classifier_pooling == "cls": + hidden_states = hidden_states[:, 0] + elif self.config.classifier_pooling == "mean": + hidden_states = (hidden_states * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum( + dim=1, keepdim=True + ) + + return self.drop(self.norm(self.act(self.dense(hidden_states)))) + + +@add_start_docstrings( + "The ModernBert Model with a sequence classification head on top that performs pooling.", + MODERNBERT_START_DOCSTRING, +) +class ModernBertForSequenceClassification(ModernBertPreTrainedModel): + def __init__(self, config: ModernBertConfig): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + + self.model = ModernBertModel(config) + self.head = ModernBertPoolingHead(config) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=SequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + sliding_window_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + indices: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + batch_size: Optional[int] = None, + seq_len: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + self._maybe_set_compile() + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + sliding_window_mask=sliding_window_mask, + position_ids=position_ids, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + batch_size=batch_size, + seq_len=seq_len, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + last_hidden_state = outputs[0] + + pooled_output = self.head(last_hidden_state, attention_mask) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + "The ModernBert Model with a token classification head on top, e.g. for Named Entity Recognition (NER) tasks.", + MODERNBERT_START_DOCSTRING, +) +class ModernBertForTokenClassification(ModernBertPreTrainedModel): + def __init__(self, config: ModernBertConfig): + super().__init__(config) + self.num_labels = config.num_labels + + self.model = ModernBertModel(config) + self.drop = nn.Dropout(config.classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + sliding_window_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + indices: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + batch_size: Optional[int] = None, + seq_len: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + self._maybe_set_compile() + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + sliding_window_mask=sliding_window_mask, + position_ids=position_ids, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + batch_size=batch_size, + seq_len=seq_len, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + last_hidden_state = outputs[0] + + last_hidden_state = self.drop(last_hidden_state) + logits = self.classifier(last_hidden_state) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "ModernBertConfig", + "ModernBertModel", + "ModernBertPreTrainedModel", + "ModernBertForMaskedLM", + "ModernBertForSequenceClassification", + "ModernBertForTokenClassification", +] diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 823c51a290713d..682e89e62fc317 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -6397,6 +6397,41 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class ModernBertForMaskedLM(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class ModernBertForSequenceClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class ModernBertForTokenClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class ModernBertModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class ModernBertPreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class MoshiForCausalLM(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 32a647594741dd..92823a4ee016c3 100755 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -192,7 +192,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ _tiktoken_available = _is_package_available("tiktoken") _blobfile_available = _is_package_available("blobfile") _liger_kernel_available = _is_package_available("liger_kernel") - +_triton_available = _is_package_available("triton") _torch_version = "N/A" _torch_available = False @@ -1243,6 +1243,10 @@ def is_liger_kernel_available(): return version.parse(importlib.metadata.version("liger_kernel")) >= version.parse("0.3.0") +def is_triton_available(): + return _triton_available + + # docstyle-ignore AV_IMPORT_ERROR = """ {0} requires the PyAv library but it was not found in your environment. You can install it with: diff --git a/tests/models/modernbert/__init__.py b/tests/models/modernbert/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/tests/models/modernbert/test_modeling_modernbert.py b/tests/models/modernbert/test_modeling_modernbert.py new file mode 100644 index 00000000000000..4fce0cd86352f0 --- /dev/null +++ b/tests/models/modernbert/test_modeling_modernbert.py @@ -0,0 +1,367 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import unittest + +import pytest + +from transformers import ModernBertConfig, is_torch_available +from transformers.models.auto import get_values +from transformers.testing_utils import ( + CaptureLogger, + require_flash_attn, + require_torch, + require_torch_gpu, + slow, + torch_device, +) + +from ...generation.test_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, _config_zero_init, ids_tensor, random_attention_mask +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_torch_available(): + import torch + + from transformers import ( + MODEL_FOR_PRETRAINING_MAPPING, + ModernBertForMaskedLM, + ModernBertForSequenceClassification, + ModernBertForTokenClassification, + ModernBertModel, + logging, + ) + + +class ModernBertModelTester: + def __init__( + self, + parent, + batch_size=13, + seq_length=7, + is_training=True, + use_input_mask=True, + use_labels=True, + vocab_size=99, + pad_token_id=0, + hidden_size=32, + num_hidden_layers=2, + num_attention_heads=4, + intermediate_size=37, + hidden_activation="gelu", + mlp_dropout=0.0, + attention_dropout=0.0, + embedding_dropout=0.0, + classifier_dropout=0.0, + max_position_embeddings=512, + type_vocab_size=16, + type_sequence_label_size=2, + initializer_range=0.02, + num_labels=3, + num_choices=4, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_input_mask = use_input_mask + self.use_labels = use_labels + self.vocab_size = vocab_size + self.pad_token_id = pad_token_id + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_activation = hidden_activation + self.mlp_dropout = mlp_dropout + self.attention_dropout = attention_dropout + self.embedding_dropout = embedding_dropout + self.classifier_dropout = classifier_dropout + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + self.num_labels = num_labels + self.num_choices = num_choices + self.scope = scope + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + input_mask = None + if self.use_input_mask: + input_mask = random_attention_mask([self.batch_size, self.seq_length]) + + sequence_labels = None + token_labels = None + choice_labels = None + if self.use_labels: + sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) + choice_labels = ids_tensor([self.batch_size], self.num_choices) + + config = self.get_config() + + return config, input_ids, input_mask, sequence_labels, token_labels, choice_labels + + def get_config(self): + """ + Returns a tiny configuration by default. + """ + config = ModernBertConfig( + vocab_size=self.vocab_size, + pad_token_id=self.pad_token_id, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + hidden_activation=self.hidden_activation, + mlp_dropout=self.mlp_dropout, + attention_dropout=self.attention_dropout, + embedding_dropout=self.embedding_dropout, + classifier_dropout=self.classifier_dropout, + max_position_embeddings=self.max_position_embeddings, + type_vocab_size=self.type_vocab_size, + is_decoder=False, + initializer_range=self.initializer_range, + ) + if test := os.environ.get("PYTEST_CURRENT_TEST", False): + test_name = test.split(":")[-1].split(" ")[0] + + # If we're testing `test_retain_grad_hidden_states_attentions`, we normally get an error + # that compilation doesn't work. Users can then set compile=False when loading the model, + # much like here. We're testing whether it works once they've done that. + if test_name == "test_retain_grad_hidden_states_attentions": + config.reference_compile = False + # Some tests require attentions to be outputted, in that case we'll set the attention implementation to eager + # as the others don't support outputted attentions + if test_name in ( + "test_attention_outputs", + "test_hidden_states_output", + "test_retain_grad_hidden_states_attentions", + ): + config._attn_implementation = "eager" + return config + + def create_and_check_model(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels): + model = ModernBertModel(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask) + result = model(input_ids) + result = model(input_ids) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + + def create_and_check_for_masked_lm( + self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = ModernBertForMaskedLM(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask, labels=token_labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + + def create_and_check_for_sequence_classification( + self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + config.num_labels = self.num_labels + model = ModernBertForSequenceClassification(config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask, labels=sequence_labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels)) + + def create_and_check_for_token_classification( + self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + config.num_labels = self.num_labels + model = ModernBertForTokenClassification(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask, labels=token_labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = config_and_inputs + inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} + return config, inputs_dict + + +@require_torch +class ModernBertModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): + test_torchscript = False + + all_model_classes = ( + ( + ModernBertModel, + ModernBertForMaskedLM, + ModernBertForSequenceClassification, + ModernBertForTokenClassification, + ) + if is_torch_available() + else () + ) + all_generative_model_classes = () + pipeline_model_mapping = ( + { + "feature-extraction": ModernBertModel, + "fill-mask": ModernBertForMaskedLM, + "text-classification": ModernBertForSequenceClassification, + "token-classification": ModernBertForTokenClassification, + "zero-shot": ModernBertForSequenceClassification, + } + if is_torch_available() + else {} + ) + fx_compatible = False + test_head_masking = False + test_pruning = False + model_split_percents = [0.5, 0.8, 0.9] + + # special case for ForPreTraining model + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) + + if inputs_dict.get("output_attentions", False): + inputs_dict["output_attentions"] = True + + if return_labels: + if model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING): + inputs_dict["labels"] = torch.zeros( + (self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device + ) + inputs_dict["next_sentence_label"] = torch.zeros( + self.model_tester.batch_size, dtype=torch.long, device=torch_device + ) + return inputs_dict + + def setUp(self): + self.model_tester = ModernBertModelTester(self) + self.config_tester = ConfigTester(self, config_class=ModernBertConfig, hidden_size=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_model_various_embeddings(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + for type in ["absolute", "relative_key", "relative_key_query"]: + config_and_inputs[0].position_embedding_type = type + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_initialization(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + configs_no_init = _config_zero_init(config) + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + for name, param in model.named_parameters(): + # The classifier.weight from ModernBertForSequenceClassification and ModernBertForTokenClassification + # are initialized without `initializer_range`, so they're not set to ~0 via the _config_zero_init + if param.requires_grad and not ( + name == "classifier.weight" + and model_class in [ModernBertForSequenceClassification, ModernBertForTokenClassification] + ): + self.assertIn( + ((param.data.mean() * 1e9).round() / 1e9).item(), + [0.0, 1.0], + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + + @unittest.skip("ModernBert doesn't use `inputs_embeds` as input.") + def test_inputs_embeds(self): + pass + + def test_for_masked_lm(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_masked_lm(*config_and_inputs) + + def test_for_sequence_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs) + + def test_for_token_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_token_classification(*config_and_inputs) + + def test_for_warning_if_padding_and_no_attention_mask(self): + ( + config, + input_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = self.model_tester.prepare_config_and_inputs() + + # Set pad tokens in the input_ids + input_ids[0, 0] = config.pad_token_id + + # Check for warnings if the attention_mask is missing. + logger = logging.get_logger("transformers.modeling_utils") + # clear cache so we can test the warning is emitted (from `warning_once`). + logger.warning_once.cache_clear() + + with CaptureLogger(logger) as cl: + model = ModernBertModel(config=config) + model.to(torch_device) + model.eval() + model(input_ids, attention_mask=None) + self.assertIn("We strongly recommend passing in an `attention_mask`", cl.out) + + @unittest.skip("ModernBert doesn't use separate classes for SDPA, but a function instead.") + def test_sdpa_can_dispatch_non_composite_models(self): + pass + + @slow + def test_model_from_pretrained(self): + model_name = "google-bert/bert-base-uncased" + model = ModernBertModel.from_pretrained(model_name) + self.assertIsNotNone(model) + + @require_flash_attn + @require_torch_gpu + @pytest.mark.flash_attn_test + @slow + def test_flash_attn_2_inference_equivalence_right_padding(self): + self.skipTest(reason="ModernBert flash attention does not support right padding") + + @require_flash_attn + @require_torch_gpu + @pytest.mark.flash_attn_test + @slow + def test_flash_attn_2_conversion(self): + self.skipTest(reason="ModernBert doesn't use the ModernBertFlashAttention2 class method.") + + +@require_torch +class ModernBertModelIntegrationTest(unittest.TestCase): + """ + These still need to be written, once public models are available. + """ diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 3aaf18c945451f..d08e25546ce10a 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -3443,6 +3443,8 @@ def test_mismatched_shapes_have_properly_initialized_weights(self): "Data2VecAudioForSequenceClassification", "UniSpeechForSequenceClassification", "PvtForImageClassification", + "ModernBertForSequenceClassification", + "ModernBertForTokenClassification", "TimmWrapperForImageClassification", ] special_param_names = [ @@ -4046,7 +4048,12 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) - model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype) + try: + model_sdpa = model_class.from_pretrained( + tmpdirname, torch_dtype=torch_dtype, attn_implementation="sdpa" + ) + except ValueError: + model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype) model_sdpa = model_sdpa.eval().to(torch_device, dtype=torch_dtype) model_eager = model_class.from_pretrained(