Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ModernBERT to Transformers #35158

Merged
merged 91 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
91 commits
Select commit Hold shift + click to select a range
6b5a823
initial cut of modernbert for transformers
warner-benjamin Dec 9, 2024
dafb203
small bug fixes
warner-benjamin Dec 10, 2024
df13def
fixes
warner-benjamin Dec 11, 2024
d09eabf
Update import
tomaarsen Dec 11, 2024
8c3afea
Use compiled mlp->mlp_norm to match research implementation
tomaarsen Dec 11, 2024
a40aaa9
Propagate changes in modular to modeling
tomaarsen Dec 11, 2024
9f0b8ca
Replace duplicate attn_out_dropout in favor of attention_dropout
tomaarsen Dec 11, 2024
900d8ec
Update BOS to CLS and EOS to SEP
tomaarsen Dec 11, 2024
caf8901
Set default classifier bias to False, matching research repo
tomaarsen Dec 11, 2024
8276602
Update tie_word_embeddings description
tomaarsen Dec 11, 2024
79e4bbb
Fix _init_weights for ForMaskedLM
tomaarsen Dec 11, 2024
b59bad9
Match base_model_prefix
tomaarsen Dec 11, 2024
e7bef53
Add compiled_head to match research repo outputs
tomaarsen Dec 11, 2024
120578b
Fix imports for ModernBertForMaskedLM
tomaarsen Dec 11, 2024
142ff11
Just use "gelu" default outright for classifier
tomaarsen Dec 11, 2024
b44abdc
Fix config name typo: initalizer -> initializer
tomaarsen Dec 11, 2024
3de8ebf
Remove some unused parameters in docstring. Still lots to edit there!
tomaarsen Dec 11, 2024
7a05b3f
Compile the embeddings forward
tomaarsen Dec 12, 2024
88b0ecf
Add drafts for ForSequenceClassification/ForTokenClassification
tomaarsen Dec 12, 2024
5e3d61d
Add initial SDPA support (not exactly equivalent to FA2 yet!)
tomaarsen Dec 12, 2024
2a3d378
Only use attention dropout if training
tomaarsen Dec 12, 2024
a2051d6
Add initial eager attention support (also not equivalent to FA2 yet!)
tomaarsen Dec 12, 2024
124f1fd
Add initial tests, output_attentions, output_hidden_states, prune_heads
tomaarsen Dec 13, 2024
38f959b
Remove kwargs from ModernBertForMaskedLM
tomaarsen Dec 13, 2024
f716943
Remove/adjust/skip improper tests; warn if padding but no attn mask
tomaarsen Dec 13, 2024
f41adaa
Run formatting etc.
tomaarsen Dec 13, 2024
d06654a
Run python utils/custom_init_isort.py
tomaarsen Dec 14, 2024
f9301f4
FlexAttention with unpadded sequences(matches FA2 within bf16 numerics)
staghado Dec 15, 2024
a356708
Reformat init_weights based on review
tomaarsen Dec 16, 2024
f83fdc0
self -> module in attention forwards
tomaarsen Dec 16, 2024
b444c15
Remove if config.tie_word_embeddings
tomaarsen Dec 16, 2024
5aaf273
Reformat output projection on a different line
tomaarsen Dec 16, 2024
0a8d044
Remove pruning
tomaarsen Dec 16, 2024
382e481
Remove assert
tomaarsen Dec 16, 2024
5d05e8e
Call contiguous() to simplify paths
tomaarsen Dec 16, 2024
98508c7
Remove prune_qkv_linear_layer
tomaarsen Dec 16, 2024
2c076c8
Format code
tomaarsen Dec 16, 2024
986c6fe
Keep as kwargs, only use if needed
tomaarsen Dec 16, 2024
5cd39ad
Remove unused codepaths & related config options
tomaarsen Dec 16, 2024
2d606b9
Remove 3d attn_mask test; fix token classification tuple output
tomaarsen Dec 16, 2024
8eb87e8
Reorder: attention_mask above position_ids, fixes gradient checkpointing
tomaarsen Dec 16, 2024
5d83c56
Merge branch 'main' into pr-35158
tomaarsen Dec 16, 2024
3a24af4
Fix usage if no FA2 or torch v2.5+
tomaarsen Dec 16, 2024
37a6030
Make torch.compile/triton optional
tomaarsen Dec 17, 2024
b3b4028
Separate pooling options into separate functions (cls, mean) - cls as…
tomaarsen Dec 17, 2024
b241a7e
Simplify _pad_modernbert_output, remove unused labels path
tomaarsen Dec 17, 2024
66f4603
Update tied weights to remove decoder.weight, simplify decoder loading
tomaarsen Dec 17, 2024
3eb786b
Adaptively set config.compile based on hf_device_map/device/resize, etc.
tomaarsen Dec 17, 2024
093b601
Merge branch 'main' of https://github.com/huggingface/transformers in…
tomaarsen Dec 17, 2024
28fc79e
Update ModernBertConfig docstring
tomaarsen Dec 17, 2024
612befa
Satisfy some consistency checks, add unfinished docs
tomaarsen Dec 17, 2024
ae32e8b
Merge branch 'main' of https://github.com/huggingface/transformers in…
tomaarsen Dec 17, 2024
f4e280a
Only set compile to False if there's more than 1 device
tomaarsen Dec 17, 2024
bc14967
Add docstrings for public ModernBert classes
tomaarsen Dec 17, 2024
0f17fb9
Dont replace docstring returns - ends up being duplicate
tomaarsen Dec 17, 2024
25b12b4
Fix mistake in toctree
tomaarsen Dec 17, 2024
f312eef
Reformat toctree
tomaarsen Dec 17, 2024
1e367df
Patched FlexAttention, SDPA, Eager with Local Attention
tomaarsen Dec 17, 2024
fb748ce
Implement FA2 -> SDPA -> Eager attn_impl defaulting, crucial
tomaarsen Dec 17, 2024
051233f
Patch test edge case with Idefics3 not working with 'attn_implementat…
tomaarsen Dec 17, 2024
6c01711
Repad all_hidden_states as well
tomaarsen Dec 17, 2024
5f7c566
rename config.compile to reference_compile
warner-benjamin Dec 18, 2024
c8a80e7
disable flex_attention since it crashes
warner-benjamin Dec 18, 2024
8962f05
Update modernbert.md
bclavie Dec 18, 2024
7e89f4d
Using dtype min to mask in eager
NohTow Dec 18, 2024
0742a1d
Fully remove flex attention for now
tomaarsen Dec 18, 2024
6c6cddb
Call contiguous to allow for .view()
tomaarsen Dec 18, 2024
e37e4ec
Copyright 2020 -> 2024
tomaarsen Dec 18, 2024
9afc480
Update/simplify __init__ structure
tomaarsen Dec 18, 2024
aa1bdb4
Remove "... if dropout_prob > 0 else identity"
tomaarsen Dec 18, 2024
659807f
re-use existing pad/unpad functions instead of creating new ones
staghado Dec 18, 2024
7955e39
remove flexattention method
staghado Dec 18, 2024
4145119
Compute attention_mask and local_attention_mask once in modeling
tomaarsen Dec 18, 2024
0e572d5
Simplify sequence classification prediction heads, only CLS now
tomaarsen Dec 18, 2024
e5dca63
Simplify module.training in eager attn
tomaarsen Dec 18, 2024
bf11173
Also export ModernBertPreTrainedModel
tomaarsen Dec 18, 2024
54ed5db
Update the documentation with links to finetuning scripts
tomaarsen Dec 18, 2024
a1bfae8
Explain local_attention_mask parameter in docstring
tomaarsen Dec 18, 2024
df7658a
Simplify _autoset_attn_implementation, rely on super()
tomaarsen Dec 18, 2024
b3404ed
Keep "in" to initialize Prediction head
tomaarsen Dec 18, 2024
e057bc2
add back mean pooling
warner-benjamin Dec 18, 2024
99c38ba
Use the pooling head in TokenClassification
warner-benjamin Dec 18, 2024
5114ed7
update copyright
warner-benjamin Dec 18, 2024
175fb95
Reset config._attn_implementation_internal on failure
tomaarsen Dec 18, 2024
8cedfc5
Allow optional attention_mask in ForMaskedLM head
warner-benjamin Dec 18, 2024
2380729
fix failing run_slow tests
warner-benjamin Dec 18, 2024
7686134
Add links to the paper
tomaarsen Dec 19, 2024
44275fd
Remove unpad_no_grad, always pad/unpad without gradients
tomaarsen Dec 19, 2024
d799d65
local_attention_mask -> sliding_window_mask
tomaarsen Dec 19, 2024
ed77867
Revert "Use the pooling head in TokenClassification"
tomaarsen Dec 19, 2024
92e17c6
Simplify pooling, 2 options via if-else
tomaarsen Dec 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,7 @@
"models.fuyu": ["FuyuConfig"],
"models.gemma": ["GemmaConfig"],
"models.gemma2": ["Gemma2Config"],
"models.modernbert": ["ModernBertConfig"],
"models.git": [
"GitConfig",
"GitProcessor",
Expand Down Expand Up @@ -2299,6 +2300,15 @@
"Gemma2PreTrainedModel",
]
)
_import_structure["models.modernbert"].extend(
[
"ModernBertForCausalLM",
"ModernBertForSequenceClassification",
"ModernBertForTokenClassification",
"ModernBertModel",
"ModernBertPreTrainedModel",
]
)
_import_structure["models.git"].extend(
[
"GitForCausalLM",
Expand Down Expand Up @@ -5328,6 +5338,7 @@
from .models.fuyu import FuyuConfig
from .models.gemma import GemmaConfig
from .models.gemma2 import Gemma2Config
from .models.modernbert import ModernBertConfig
from .models.git import (
GitConfig,
GitProcessor,
Expand Down Expand Up @@ -7056,6 +7067,13 @@
Gemma2Model,
Gemma2PreTrainedModel,
)
from .models.modernbert import (
ModernBertForCausalLM,
ModernBertForSequenceClassification,
ModernBertForTokenClassification,
ModernBertModel,
ModernBertPreTrainedModel,
)
from .models.git import (
GitForCausalLM,
GitModel,
Expand Down
17 changes: 17 additions & 0 deletions src/transformers/loss/loss_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,22 @@ def ForCausalLMLoss(
return loss


def ForMaskedLMLoss(
logits, labels, vocab_size: int, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs
):
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()

# Flatten the tokens
logits = logits.view(-1, vocab_size)
labels = labels.view(-1)
# Enable model parallelism

labels = labels.to(logits.device)
loss = fixed_cross_entropy(logits, labels, num_items_in_batch, ignore_index, **kwargs)
return loss


def ForSequenceClassificationLoss(labels, pooled_logits, config, **kwargs):
num_labels = config.num_labels
if config.problem_type is None:
Expand Down Expand Up @@ -101,6 +117,7 @@ def ForTokenClassification(logits, labels, config, **kwargs):

LOSS_MAPPING = {
"ForCausalLM": ForCausalLMLoss,
"ForMaskedLM": ForMaskedLMLoss,
"ForQuestionAnswering": ForQuestionAnsweringLoss,
"ForSequenceClassification": ForSequenceClassificationLoss,
"ForTokenClassification": ForTokenClassification,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@
fuyu,
gemma,
gemma2,
modernbert,
git,
glm,
glpn,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@
("fuyu", "FuyuConfig"),
("gemma", "GemmaConfig"),
("gemma2", "Gemma2Config"),
("modernbert", "ModernBertConfig"),
("git", "GitConfig"),
("glm", "GlmConfig"),
("glpn", "GLPNConfig"),
Expand Down Expand Up @@ -417,6 +418,7 @@
("fuyu", "Fuyu"),
("gemma", "Gemma"),
("gemma2", "Gemma2"),
("modernbert", "ModernBERT"),
("git", "GIT"),
("glm", "GLM"),
("glpn", "GLPN"),
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@
("funnel", ("FunnelModel", "FunnelBaseModel")),
("gemma", "GemmaModel"),
("gemma2", "Gemma2Model"),
("modernbert", "ModernBertModel"),
("git", "GitModel"),
("glm", "GlmModel"),
("glpn", "GLPNModel"),
Expand Down Expand Up @@ -487,6 +488,7 @@
("fuyu", "FuyuForCausalLM"),
("gemma", "GemmaForCausalLM"),
("gemma2", "Gemma2ForCausalLM"),
("modernbert", "ModernBertForCausalLM"),
("git", "GitForCausalLM"),
("glm", "GlmForCausalLM"),
("gpt-sw3", "GPT2LMHeadModel"),
Expand Down Expand Up @@ -945,6 +947,7 @@
("funnel", "FunnelForSequenceClassification"),
("gemma", "GemmaForSequenceClassification"),
("gemma2", "Gemma2ForSequenceClassification"),
("modernbert", "ModernBertForSequenceClassification"),
tomaarsen marked this conversation as resolved.
Show resolved Hide resolved
("glm", "GlmForSequenceClassification"),
("gpt-sw3", "GPT2ForSequenceClassification"),
("gpt2", "GPT2ForSequenceClassification"),
Expand Down Expand Up @@ -1136,6 +1139,7 @@
("funnel", "FunnelForTokenClassification"),
("gemma", "GemmaForTokenClassification"),
("gemma2", "Gemma2ForTokenClassification"),
("modernbert", "ModernBertForTokenClassification"),
tomaarsen marked this conversation as resolved.
Show resolved Hide resolved
("glm", "GlmForTokenClassification"),
("gpt-sw3", "GPT2ForTokenClassification"),
("gpt2", "GPT2ForTokenClassification"),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@
"GemmaTokenizerFast" if is_tokenizers_available() else None,
),
),
("modernbert", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
tomaarsen marked this conversation as resolved.
Show resolved Hide resolved
("git", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
("glm", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
("gpt-sw3", ("GPTSw3Tokenizer" if is_sentencepiece_available() else None, None)),
Expand Down
61 changes: 61 additions & 0 deletions src/transformers/models/modernbert/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING

from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_torch_available,
)


_import_structure = {
"configuration_modernbert": ["ModernBertConfig"],
}

try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_modernbert"] = [
"ModernBertForCausalLM",
"ModernBertModel",
"ModernBertPreTrainedModel",
"ModernBertForSequenceClassification",
"ModernBertForTokenClassification",
]

if TYPE_CHECKING:
from .configuration_modernbert import ModernBertConfig

try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_modernbert import (
ModernBertForCausalLM,
ModernBertForSequenceClassification,
ModernBertForTokenClassification,
ModernBertModel,
ModernBertPreTrainedModel,
)

else:
import sys

sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
184 changes: 184 additions & 0 deletions src/transformers/models/modernbert/configuration_modernbert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/modernbert/modular_modernbert.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_modernbert.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from ...configuration_utils import PretrainedConfig


class ModernBertConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`ModernBertModel`]. It is used to instantiate an ModernBert
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the ModernBert-7B.
e.g. [answerdotai/modernbert-base](https://huggingface.co/answerdotai/modernbert-base)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 256000):
Vocabulary size of the ModernBert model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`ModernBertModel`]
hidden_size (`int`, *optional*, defaults to 2304):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 9216):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 26):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 8):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*, defaults to 4):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
head_dim (`int`, *optional*, defaults to 256):
The attention head dimension.
hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"`
if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function.
max_position_embeddings (`int`, *optional*, defaults to 8192):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*, defaults to 0):
Padding token id.
eos_token_id (`int`, *optional*, defaults to 1):
End of stream token id.
bos_token_id (`int`, *optional*, defaults to 2):
Beginning of stream token id.
tie_word_embeddings (`bool`, *optional*, defaults to `True`):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
query_pre_attn_scalar (`float`, *optional*, defaults to 256): scaling factor used on the attention scores
sliding_window (`int`, *optional*, defaults to 4096): in ModernBert, every other layer uses sliding window attention. This is the
size of the sliding window.
final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits.
attn_logit_softcapping (`float`, *optional*, defaults to 50.0): scaling factor when applying tanh softcapping on the attention scores.
cache_implementation (`str`, *optional*, defaults to `"hybrid"`): the cache type to be used with `generate`.

```python
>>> from transformers import ModernBertModel, ModernBertConfig
>>> # Initializing a ModernBert modernbert-7b style configuration
>>> configuration = ModernBertConfig()
>>> # Initializing a model from the modernbert-7b style configuration
>>> model = ModernBertModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""

model_type = "modernbert"
keys_to_ignore_at_inference = ["past_key_values"]

def __init__(
self,
vocab_size=50368,
hidden_size=768,
intermediate_size=1152,
num_hidden_layers=22,
num_attention_heads=12,
hidden_activation="gelu_python",
max_position_embeddings=8192,
initializer_range=0.02,
initalizer_cutoff_factor=2.0,
norm_eps=1e-5,
norm_bias=False,
pad_token_id=50283,
eos_token_id=50281,
bos_token_id=50282,
cls_token_id=50281,
sep_token_id=50282,
tie_word_embeddings=True,
global_rope_theta=160000.0,
attention_bias=False,
attention_dropout=0.0,
attn_out_dropout=0.1,
global_attn_every_n_layers=3,
local_attention=128,
local_rope_theta=10000.0,
skip_first_prenorm=True,
embedding_norm=True,
embedding_dropout=0.0,
mlp_bias=False,
mlp_dropout=0.0,
unpad_inputs=True,
unpad_no_grad=True,
decoder_bias=True,
classifier_dropout=0.0,
classifier_pooling="mean",
tomaarsen marked this conversation as resolved.
Show resolved Hide resolved
classifier_norm=True,
classifier_bias=True,
classifier_activation=None,
tomaarsen marked this conversation as resolved.
Show resolved Hide resolved
deterministic_flash_attn=False,
**kwargs,
):
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
cls_token_id=cls_token_id,
sep_token_id=sep_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.initializer_range = initializer_range
self.initalizer_cutoff_factor = initalizer_cutoff_factor
self.norm_eps = norm_eps
self.norm_bias = norm_bias
self.global_rope_theta = global_rope_theta
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.attn_out_dropout = attn_out_dropout
self.hidden_activation = hidden_activation
self.global_attn_every_n_layers = global_attn_every_n_layers
self.local_attention = local_attention
self.local_rope_theta = local_rope_theta
self.skip_first_prenorm = skip_first_prenorm
self.embedding_norm = embedding_norm
self.embedding_dropout = embedding_dropout
self.mlp_bias = mlp_bias
self.mlp_dropout = mlp_dropout
self.unpad_inputs = unpad_inputs
self.unpad_no_grad = unpad_no_grad
self.decoder_bias = decoder_bias
self.classifier_dropout = classifier_dropout
self.classifier_pooling = classifier_pooling
self.classifier_bias = classifier_bias
self.classifier_norm = classifier_norm
self.classifier_activation = classifier_activation if classifier_activation is not None else hidden_activation
self.deterministic_flash_attn = deterministic_flash_attn
Loading