From 78d4ea79a939ae73e7e2c993c13d25a56008bd76 Mon Sep 17 00:00:00 2001 From: Jack Morris Date: Fri, 20 Dec 2024 10:54:05 -0800 Subject: [PATCH 1/3] update modular_modernbert -- add inputs_embeds param to ModernBertModel --- .../models/modernbert/modular_modernbert.py | 28 ++++++++++++++----- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index dac356146f3015..399f5569ac7b26 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -1010,6 +1010,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + inputs_embeds: torch.Tensor = None, ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutput]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1024,7 +1025,12 @@ def forward( self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) if batch_size is None and seq_len is None: - batch_size, seq_len = input_ids.shape[:2] + if inputs_embeds is not None: + device = inputs_embeds.device + batch_size, seq_len = inputs_embeds.shape[:2] + else: + device = input_ids.device + batch_size, seq_len = input_ids.shape[:2] if attention_mask is None: attention_mask = torch.ones((batch_size, seq_len), device=input_ids.device, dtype=torch.bool) @@ -1033,19 +1039,27 @@ def forward( if self.config._attn_implementation == "flash_attention_2": if indices is None and cu_seqlens is None and max_seqlen is None: repad = True - with torch.no_grad(): - input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input( - inputs=input_ids, attention_mask=attention_mask + if inputs_embeds is None: + with torch.no_grad(): + input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input( + inputs=input_ids, attention_mask=attention_mask + ) + else: + inputs_embeds, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input( + inputs=inputs_embeds, attention_mask=attention_mask ) else: if position_ids is None: - position_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0) + position_ids = torch.arange(seq_len, device=device).unsqueeze(0) attention_mask, sliding_window_mask = self._update_attention_mask( attention_mask, output_attentions=output_attentions ) - - hidden_states = self.embeddings(input_ids) + + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.embeddings(input_ids) for encoder_layer in self.layers: if output_hidden_states: From 10c4995c87778ba849a8523fd566eddb72d82b63 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Mon, 30 Dec 2024 13:51:08 +0100 Subject: [PATCH 2/3] Fix implementation issues; extend to other classes; docstring First of all, the inputs_embeds shouldn't fully replace `self.embeddings(input_ids)`, because this call also does layer normalization and dropout. So, now both input_ids and inputs_embeds is passed to the ModernBertEmbeddings, much like how BertEmbeddings is implemented. I also added `inputs_embeds` to the docstring, and propagated the changes to the other model classes. I also introduced an error if input_ids and input_embeds are both or neither provided. Lastly, I fixed an issue with device being based solely on input_ids with attention_mask. --- .../models/modernbert/modeling_modernbert.py | 83 +++++++++++++------ .../models/modernbert/modular_modernbert.py | 69 +++++++++------ .../modernbert/test_modeling_modernbert.py | 6 +- 3 files changed, 108 insertions(+), 50 deletions(-) diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index 237fba6f645fa5..1c6542d7b7a73a 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -205,12 +205,17 @@ def __init__(self, config: ModernBertConfig): def compiled_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor: return self.drop(self.norm(self.tok_embeddings(input_ids))) - def forward(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None) -> torch.Tensor: - hidden_states = ( - self.compiled_embeddings(input_ids) - if self.config.reference_compile - else self.drop(self.norm(self.tok_embeddings(input_ids))) - ) + def forward( + self, input_ids: torch.LongTensor = None, inputs_embeds: Optional[torch.Tensor] = None + ) -> torch.Tensor: + if inputs_embeds is not None: + hidden_states = self.drop(self.norm(inputs_embeds)) + else: + hidden_states = ( + self.compiled_embeddings(input_ids) + if self.config.reference_compile + else self.drop(self.norm(self.tok_embeddings(input_ids))) + ) return hidden_states @@ -794,6 +799,10 @@ def _pad_modernbert_output( config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*): Indices of the non-padding tokens in the input sequence. Used for unpadding the output. cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*): @@ -845,10 +854,11 @@ def set_input_embeddings(self, value): ) def forward( self, - input_ids: torch.LongTensor = None, + input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, sliding_window_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None, @@ -864,35 +874,49 @@ def forward( ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None self._maybe_set_compile() - self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + + if input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) if batch_size is None and seq_len is None: - batch_size, seq_len = input_ids.shape[:2] + if inputs_embeds is not None: + batch_size, seq_len = inputs_embeds.shape[:2] + else: + batch_size, seq_len = input_ids.shape[:2] + device = input_ids.device if input_ids is not None else inputs_embeds.device if attention_mask is None: - attention_mask = torch.ones((batch_size, seq_len), device=input_ids.device, dtype=torch.bool) + attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool) repad = False if self.config._attn_implementation == "flash_attention_2": if indices is None and cu_seqlens is None and max_seqlen is None: repad = True - with torch.no_grad(): - input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input( - inputs=input_ids, attention_mask=attention_mask + if inputs_embeds is None: + with torch.no_grad(): + input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input( + inputs=input_ids, attention_mask=attention_mask + ) + else: + inputs_embeds, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input( + inputs=inputs_embeds, attention_mask=attention_mask ) else: if position_ids is None: - position_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0) + position_ids = torch.arange(seq_len, device=device).unsqueeze(0) attention_mask, sliding_window_mask = self._update_attention_mask( attention_mask, output_attentions=output_attentions ) - hidden_states = self.embeddings(input_ids) + hidden_states = self.embeddings(input_ids=input_ids, inputs_embeds=inputs_embeds) for encoder_layer in self.layers: if output_hidden_states: @@ -1028,10 +1052,11 @@ def compiled_head(self, output: torch.Tensor) -> torch.Tensor: ) def forward( self, - input_ids: Optional[torch.Tensor], + input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, sliding_window_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None, @@ -1051,16 +1076,22 @@ def forward( batch_size, seq_len = input_ids.shape[:2] if attention_mask is None: attention_mask = torch.ones((batch_size, seq_len), device=input_ids.device, dtype=torch.bool) - with torch.no_grad(): - input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input( - inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels + if inputs_embeds is None: + with torch.no_grad(): + input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input( + inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels + ) + else: + inputs_embeds, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input( + inputs=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, labels=labels ) outputs = self.model( - input_ids, + input_ids=input_ids, attention_mask=attention_mask, sliding_window_mask=sliding_window_mask, position_ids=position_ids, + inputs_embeds=inputs_embeds, indices=indices, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, @@ -1133,10 +1164,11 @@ def __init__(self, config: ModernBertConfig): ) def forward( self, - input_ids: Optional[torch.Tensor], + input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, sliding_window_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None, @@ -1158,10 +1190,11 @@ def forward( self._maybe_set_compile() outputs = self.model( - input_ids, + input_ids=input_ids, attention_mask=attention_mask, sliding_window_mask=sliding_window_mask, position_ids=position_ids, + inputs_embeds=inputs_embeds, indices=indices, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, @@ -1244,10 +1277,11 @@ def __init__(self, config: ModernBertConfig): ) def forward( self, - input_ids: Optional[torch.Tensor], + input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, sliding_window_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None, @@ -1266,10 +1300,11 @@ def forward( self._maybe_set_compile() outputs = self.model( - input_ids, + input_ids=input_ids, attention_mask=attention_mask, sliding_window_mask=sliding_window_mask, position_ids=position_ids, + inputs_embeds=inputs_embeds, indices=indices, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index 399f5569ac7b26..5e3c44386fb759 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -464,12 +464,15 @@ def __init__(self, config: ModernBertConfig): def compiled_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor: return self.drop(self.norm(self.tok_embeddings(input_ids))) - def forward(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None) -> torch.Tensor: - hidden_states = ( - self.compiled_embeddings(input_ids) - if self.config.reference_compile - else self.drop(self.norm(self.tok_embeddings(input_ids))) - ) + def forward(self, input_ids: torch.LongTensor = None, inputs_embeds: Optional[torch.Tensor] = None) -> torch.Tensor: + if inputs_embeds is not None: + hidden_states = self.drop(self.norm(inputs_embeds)) + else: + hidden_states = ( + self.compiled_embeddings(input_ids) + if self.config.reference_compile + else self.drop(self.norm(self.tok_embeddings(input_ids))) + ) return hidden_states @@ -947,6 +950,10 @@ def resize_token_embeddings(self, *args, **kwargs): config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*): Indices of the non-padding tokens in the input sequence. Used for unpadding the output. cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*): @@ -998,10 +1005,11 @@ def set_input_embeddings(self, value): ) def forward( self, - input_ids: torch.LongTensor = None, + input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, sliding_window_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None, @@ -1010,7 +1018,6 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - inputs_embeds: torch.Tensor = None, ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutput]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1018,22 +1025,26 @@ def forward( ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None self._maybe_set_compile() - self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + + if input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) if batch_size is None and seq_len is None: if inputs_embeds is not None: - device = inputs_embeds.device batch_size, seq_len = inputs_embeds.shape[:2] else: - device = input_ids.device batch_size, seq_len = input_ids.shape[:2] + device = input_ids.device if input_ids is not None else inputs_embeds.device if attention_mask is None: - attention_mask = torch.ones((batch_size, seq_len), device=input_ids.device, dtype=torch.bool) + attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool) repad = False if self.config._attn_implementation == "flash_attention_2": @@ -1056,10 +1067,7 @@ def forward( attention_mask, output_attentions=output_attentions ) - if inputs_embeds is not None: - hidden_states = inputs_embeds - else: - hidden_states = self.embeddings(input_ids) + hidden_states = self.embeddings(input_ids=input_ids, inputs_embeds=inputs_embeds) for encoder_layer in self.layers: if output_hidden_states: @@ -1195,10 +1203,11 @@ def compiled_head(self, output: torch.Tensor) -> torch.Tensor: ) def forward( self, - input_ids: Optional[torch.Tensor], + input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, sliding_window_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None, @@ -1218,16 +1227,22 @@ def forward( batch_size, seq_len = input_ids.shape[:2] if attention_mask is None: attention_mask = torch.ones((batch_size, seq_len), device=input_ids.device, dtype=torch.bool) - with torch.no_grad(): - input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input( - inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels + if inputs_embeds is None: + with torch.no_grad(): + input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input( + inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels + ) + else: + inputs_embeds, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input( + inputs=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, labels=labels ) outputs = self.model( - input_ids, + input_ids=input_ids, attention_mask=attention_mask, sliding_window_mask=sliding_window_mask, position_ids=position_ids, + inputs_embeds=inputs_embeds, indices=indices, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, @@ -1300,10 +1315,11 @@ def __init__(self, config: ModernBertConfig): ) def forward( self, - input_ids: Optional[torch.Tensor], + input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, sliding_window_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None, @@ -1325,10 +1341,11 @@ def forward( self._maybe_set_compile() outputs = self.model( - input_ids, + input_ids=input_ids, attention_mask=attention_mask, sliding_window_mask=sliding_window_mask, position_ids=position_ids, + inputs_embeds=inputs_embeds, indices=indices, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, @@ -1411,10 +1428,11 @@ def __init__(self, config: ModernBertConfig): ) def forward( self, - input_ids: Optional[torch.Tensor], + input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, sliding_window_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None, @@ -1433,10 +1451,11 @@ def forward( self._maybe_set_compile() outputs = self.model( - input_ids, + input_ids=input_ids, attention_mask=attention_mask, sliding_window_mask=sliding_window_mask, position_ids=position_ids, + inputs_embeds=inputs_embeds, indices=indices, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, diff --git a/tests/models/modernbert/test_modeling_modernbert.py b/tests/models/modernbert/test_modeling_modernbert.py index 4fce0cd86352f0..febeefe8ac17d4 100644 --- a/tests/models/modernbert/test_modeling_modernbert.py +++ b/tests/models/modernbert/test_modeling_modernbert.py @@ -146,7 +146,11 @@ def get_config(self): # If we're testing `test_retain_grad_hidden_states_attentions`, we normally get an error # that compilation doesn't work. Users can then set compile=False when loading the model, # much like here. We're testing whether it works once they've done that. - if test_name == "test_retain_grad_hidden_states_attentions": + + # If we're testing `test_inputs_embeds_matches_input_ids`, then we'd like to test with `reference_compile` + # set to False, otherwise the input_ids with compiled input embeddings will not match the inputs_embeds + # with atol=1e-8 and rtol=1e-5 + if test_name in ("test_retain_grad_hidden_states_attentions", "test_inputs_embeds_matches_input_ids"): config.reference_compile = False # Some tests require attentions to be outputted, in that case we'll set the attention implementation to eager # as the others don't support outputted attentions From 1aa3f4f450846037cc609a8c34328132543bb455 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Mon, 30 Dec 2024 14:17:37 +0100 Subject: [PATCH 3/3] Propagate inputs_embeds to ModernBertForMaskedLM correctly Also reintroduce inputs_embeds test --- .../models/modernbert/modeling_modernbert.py | 11 +++++++++-- .../models/modernbert/modular_modernbert.py | 19 ++++++++++++++----- .../modernbert/test_modeling_modernbert.py | 6 +----- 3 files changed, 24 insertions(+), 12 deletions(-) diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index 1c6542d7b7a73a..d03128541e2334 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -1073,9 +1073,16 @@ def forward( if self.config._attn_implementation == "flash_attention_2": if indices is None and cu_seqlens is None and max_seqlen is None: - batch_size, seq_len = input_ids.shape[:2] + if batch_size is None and seq_len is None: + if inputs_embeds is not None: + batch_size, seq_len = inputs_embeds.shape[:2] + else: + batch_size, seq_len = input_ids.shape[:2] + device = input_ids.device if input_ids is not None else inputs_embeds.device + if attention_mask is None: - attention_mask = torch.ones((batch_size, seq_len), device=input_ids.device, dtype=torch.bool) + attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool) + if inputs_embeds is None: with torch.no_grad(): input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input( diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index 5e3c44386fb759..992ae3475afa29 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -464,7 +464,9 @@ def __init__(self, config: ModernBertConfig): def compiled_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor: return self.drop(self.norm(self.tok_embeddings(input_ids))) - def forward(self, input_ids: torch.LongTensor = None, inputs_embeds: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward( + self, input_ids: torch.LongTensor = None, inputs_embeds: Optional[torch.Tensor] = None + ) -> torch.Tensor: if inputs_embeds is not None: hidden_states = self.drop(self.norm(inputs_embeds)) else: @@ -1037,7 +1039,7 @@ def forward( self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) if batch_size is None and seq_len is None: - if inputs_embeds is not None: + if inputs_embeds is not None: batch_size, seq_len = inputs_embeds.shape[:2] else: batch_size, seq_len = input_ids.shape[:2] @@ -1066,7 +1068,7 @@ def forward( attention_mask, sliding_window_mask = self._update_attention_mask( attention_mask, output_attentions=output_attentions ) - + hidden_states = self.embeddings(input_ids=input_ids, inputs_embeds=inputs_embeds) for encoder_layer in self.layers: @@ -1224,9 +1226,16 @@ def forward( if self.config._attn_implementation == "flash_attention_2": if indices is None and cu_seqlens is None and max_seqlen is None: - batch_size, seq_len = input_ids.shape[:2] + if batch_size is None and seq_len is None: + if inputs_embeds is not None: + batch_size, seq_len = inputs_embeds.shape[:2] + else: + batch_size, seq_len = input_ids.shape[:2] + device = input_ids.device if input_ids is not None else inputs_embeds.device + if attention_mask is None: - attention_mask = torch.ones((batch_size, seq_len), device=input_ids.device, dtype=torch.bool) + attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool) + if inputs_embeds is None: with torch.no_grad(): input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input( diff --git a/tests/models/modernbert/test_modeling_modernbert.py b/tests/models/modernbert/test_modeling_modernbert.py index febeefe8ac17d4..08e77505b5b787 100644 --- a/tests/models/modernbert/test_modeling_modernbert.py +++ b/tests/models/modernbert/test_modeling_modernbert.py @@ -146,7 +146,7 @@ def get_config(self): # If we're testing `test_retain_grad_hidden_states_attentions`, we normally get an error # that compilation doesn't work. Users can then set compile=False when loading the model, # much like here. We're testing whether it works once they've done that. - + # If we're testing `test_inputs_embeds_matches_input_ids`, then we'd like to test with `reference_compile` # set to False, otherwise the input_ids with compiled input embeddings will not match the inputs_embeds # with atol=1e-8 and rtol=1e-5 @@ -298,10 +298,6 @@ def test_initialization(self): msg=f"Parameter {name} of model {model_class} seems not properly initialized", ) - @unittest.skip("ModernBert doesn't use `inputs_embeds` as input.") - def test_inputs_embeds(self): - pass - def test_for_masked_lm(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)