Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ModernBERT to Transformers #35158

Merged
merged 91 commits into from
Dec 19, 2024
Merged
Changes from 1 commit
Commits
Show all changes
91 commits
Select commit Hold shift + click to select a range
6b5a823
initial cut of modernbert for transformers
warner-benjamin Dec 9, 2024
dafb203
small bug fixes
warner-benjamin Dec 10, 2024
df13def
fixes
warner-benjamin Dec 11, 2024
d09eabf
Update import
tomaarsen Dec 11, 2024
8c3afea
Use compiled mlp->mlp_norm to match research implementation
tomaarsen Dec 11, 2024
a40aaa9
Propagate changes in modular to modeling
tomaarsen Dec 11, 2024
9f0b8ca
Replace duplicate attn_out_dropout in favor of attention_dropout
tomaarsen Dec 11, 2024
900d8ec
Update BOS to CLS and EOS to SEP
tomaarsen Dec 11, 2024
caf8901
Set default classifier bias to False, matching research repo
tomaarsen Dec 11, 2024
8276602
Update tie_word_embeddings description
tomaarsen Dec 11, 2024
79e4bbb
Fix _init_weights for ForMaskedLM
tomaarsen Dec 11, 2024
b59bad9
Match base_model_prefix
tomaarsen Dec 11, 2024
e7bef53
Add compiled_head to match research repo outputs
tomaarsen Dec 11, 2024
120578b
Fix imports for ModernBertForMaskedLM
tomaarsen Dec 11, 2024
142ff11
Just use "gelu" default outright for classifier
tomaarsen Dec 11, 2024
b44abdc
Fix config name typo: initalizer -> initializer
tomaarsen Dec 11, 2024
3de8ebf
Remove some unused parameters in docstring. Still lots to edit there!
tomaarsen Dec 11, 2024
7a05b3f
Compile the embeddings forward
tomaarsen Dec 12, 2024
88b0ecf
Add drafts for ForSequenceClassification/ForTokenClassification
tomaarsen Dec 12, 2024
5e3d61d
Add initial SDPA support (not exactly equivalent to FA2 yet!)
tomaarsen Dec 12, 2024
2a3d378
Only use attention dropout if training
tomaarsen Dec 12, 2024
a2051d6
Add initial eager attention support (also not equivalent to FA2 yet!)
tomaarsen Dec 12, 2024
124f1fd
Add initial tests, output_attentions, output_hidden_states, prune_heads
tomaarsen Dec 13, 2024
38f959b
Remove kwargs from ModernBertForMaskedLM
tomaarsen Dec 13, 2024
f716943
Remove/adjust/skip improper tests; warn if padding but no attn mask
tomaarsen Dec 13, 2024
f41adaa
Run formatting etc.
tomaarsen Dec 13, 2024
d06654a
Run python utils/custom_init_isort.py
tomaarsen Dec 14, 2024
f9301f4
FlexAttention with unpadded sequences(matches FA2 within bf16 numerics)
staghado Dec 15, 2024
a356708
Reformat init_weights based on review
tomaarsen Dec 16, 2024
f83fdc0
self -> module in attention forwards
tomaarsen Dec 16, 2024
b444c15
Remove if config.tie_word_embeddings
tomaarsen Dec 16, 2024
5aaf273
Reformat output projection on a different line
tomaarsen Dec 16, 2024
0a8d044
Remove pruning
tomaarsen Dec 16, 2024
382e481
Remove assert
tomaarsen Dec 16, 2024
5d05e8e
Call contiguous() to simplify paths
tomaarsen Dec 16, 2024
98508c7
Remove prune_qkv_linear_layer
tomaarsen Dec 16, 2024
2c076c8
Format code
tomaarsen Dec 16, 2024
986c6fe
Keep as kwargs, only use if needed
tomaarsen Dec 16, 2024
5cd39ad
Remove unused codepaths & related config options
tomaarsen Dec 16, 2024
2d606b9
Remove 3d attn_mask test; fix token classification tuple output
tomaarsen Dec 16, 2024
8eb87e8
Reorder: attention_mask above position_ids, fixes gradient checkpointing
tomaarsen Dec 16, 2024
5d83c56
Merge branch 'main' into pr-35158
tomaarsen Dec 16, 2024
3a24af4
Fix usage if no FA2 or torch v2.5+
tomaarsen Dec 16, 2024
37a6030
Make torch.compile/triton optional
tomaarsen Dec 17, 2024
b3b4028
Separate pooling options into separate functions (cls, mean) - cls as…
tomaarsen Dec 17, 2024
b241a7e
Simplify _pad_modernbert_output, remove unused labels path
tomaarsen Dec 17, 2024
66f4603
Update tied weights to remove decoder.weight, simplify decoder loading
tomaarsen Dec 17, 2024
3eb786b
Adaptively set config.compile based on hf_device_map/device/resize, etc.
tomaarsen Dec 17, 2024
093b601
Merge branch 'main' of https://github.com/huggingface/transformers in…
tomaarsen Dec 17, 2024
28fc79e
Update ModernBertConfig docstring
tomaarsen Dec 17, 2024
612befa
Satisfy some consistency checks, add unfinished docs
tomaarsen Dec 17, 2024
ae32e8b
Merge branch 'main' of https://github.com/huggingface/transformers in…
tomaarsen Dec 17, 2024
f4e280a
Only set compile to False if there's more than 1 device
tomaarsen Dec 17, 2024
bc14967
Add docstrings for public ModernBert classes
tomaarsen Dec 17, 2024
0f17fb9
Dont replace docstring returns - ends up being duplicate
tomaarsen Dec 17, 2024
25b12b4
Fix mistake in toctree
tomaarsen Dec 17, 2024
f312eef
Reformat toctree
tomaarsen Dec 17, 2024
1e367df
Patched FlexAttention, SDPA, Eager with Local Attention
tomaarsen Dec 17, 2024
fb748ce
Implement FA2 -> SDPA -> Eager attn_impl defaulting, crucial
tomaarsen Dec 17, 2024
051233f
Patch test edge case with Idefics3 not working with 'attn_implementat…
tomaarsen Dec 17, 2024
6c01711
Repad all_hidden_states as well
tomaarsen Dec 17, 2024
5f7c566
rename config.compile to reference_compile
warner-benjamin Dec 18, 2024
c8a80e7
disable flex_attention since it crashes
warner-benjamin Dec 18, 2024
8962f05
Update modernbert.md
bclavie Dec 18, 2024
7e89f4d
Using dtype min to mask in eager
NohTow Dec 18, 2024
0742a1d
Fully remove flex attention for now
tomaarsen Dec 18, 2024
6c6cddb
Call contiguous to allow for .view()
tomaarsen Dec 18, 2024
e37e4ec
Copyright 2020 -> 2024
tomaarsen Dec 18, 2024
9afc480
Update/simplify __init__ structure
tomaarsen Dec 18, 2024
aa1bdb4
Remove "... if dropout_prob > 0 else identity"
tomaarsen Dec 18, 2024
659807f
re-use existing pad/unpad functions instead of creating new ones
staghado Dec 18, 2024
7955e39
remove flexattention method
staghado Dec 18, 2024
4145119
Compute attention_mask and local_attention_mask once in modeling
tomaarsen Dec 18, 2024
0e572d5
Simplify sequence classification prediction heads, only CLS now
tomaarsen Dec 18, 2024
e5dca63
Simplify module.training in eager attn
tomaarsen Dec 18, 2024
bf11173
Also export ModernBertPreTrainedModel
tomaarsen Dec 18, 2024
54ed5db
Update the documentation with links to finetuning scripts
tomaarsen Dec 18, 2024
a1bfae8
Explain local_attention_mask parameter in docstring
tomaarsen Dec 18, 2024
df7658a
Simplify _autoset_attn_implementation, rely on super()
tomaarsen Dec 18, 2024
b3404ed
Keep "in" to initialize Prediction head
tomaarsen Dec 18, 2024
e057bc2
add back mean pooling
warner-benjamin Dec 18, 2024
99c38ba
Use the pooling head in TokenClassification
warner-benjamin Dec 18, 2024
5114ed7
update copyright
warner-benjamin Dec 18, 2024
175fb95
Reset config._attn_implementation_internal on failure
tomaarsen Dec 18, 2024
8cedfc5
Allow optional attention_mask in ForMaskedLM head
warner-benjamin Dec 18, 2024
2380729
fix failing run_slow tests
warner-benjamin Dec 18, 2024
7686134
Add links to the paper
tomaarsen Dec 19, 2024
44275fd
Remove unpad_no_grad, always pad/unpad without gradients
tomaarsen Dec 19, 2024
d799d65
local_attention_mask -> sliding_window_mask
tomaarsen Dec 19, 2024
ed77867
Revert "Use the pooling head in TokenClassification"
tomaarsen Dec 19, 2024
92e17c6
Simplify pooling, 2 options via if-else
tomaarsen Dec 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Reset config._attn_implementation_internal on failure
tomaarsen committed Dec 18, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
commit 175fb95b8926fae4227364d7eb3c8d093980ef10
Original file line number Diff line number Diff line change
@@ -668,7 +668,7 @@ def _autoset_attn_implementation(
check_device_map=check_device_map,
)
except (ValueError, ImportError):
pass
config._attn_implementation_internal = None
return super()._autoset_attn_implementation(
config,
use_flash_attention_2=use_flash_attention_2,
2 changes: 1 addition & 1 deletion src/transformers/models/modernbert/modular_modernbert.py
Original file line number Diff line number Diff line change
@@ -868,7 +868,7 @@ def _autoset_attn_implementation(
check_device_map=check_device_map,
)
except (ValueError, ImportError):
pass
config._attn_implementation_internal = None
return super()._autoset_attn_implementation(
config,
use_flash_attention_2=use_flash_attention_2,