diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index 9e4677c1fdd099..c3dcb26ad883dd 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -205,12 +205,17 @@ def __init__(self, config: ModernBertConfig): def compiled_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor: return self.drop(self.norm(self.tok_embeddings(input_ids))) - def forward(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None) -> torch.Tensor: - hidden_states = ( - self.compiled_embeddings(input_ids) - if self.config.reference_compile - else self.drop(self.norm(self.tok_embeddings(input_ids))) - ) + def forward( + self, input_ids: torch.LongTensor = None, inputs_embeds: Optional[torch.Tensor] = None + ) -> torch.Tensor: + if inputs_embeds is not None: + hidden_states = self.drop(self.norm(inputs_embeds)) + else: + hidden_states = ( + self.compiled_embeddings(input_ids) + if self.config.reference_compile + else self.drop(self.norm(self.tok_embeddings(input_ids))) + ) return hidden_states @@ -791,6 +796,10 @@ def _pad_modernbert_output( config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*): Indices of the non-padding tokens in the input sequence. Used for unpadding the output. cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*): @@ -842,10 +851,11 @@ def set_input_embeddings(self, value): ) def forward( self, - input_ids: torch.LongTensor = None, + input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, sliding_window_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None, @@ -861,35 +871,49 @@ def forward( ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None self._maybe_set_compile() - self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + + if input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) if batch_size is None and seq_len is None: - batch_size, seq_len = input_ids.shape[:2] + if inputs_embeds is not None: + batch_size, seq_len = inputs_embeds.shape[:2] + else: + batch_size, seq_len = input_ids.shape[:2] + device = input_ids.device if input_ids is not None else inputs_embeds.device if attention_mask is None: - attention_mask = torch.ones((batch_size, seq_len), device=input_ids.device, dtype=torch.bool) + attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool) repad = False if self.config._attn_implementation == "flash_attention_2": if indices is None and cu_seqlens is None and max_seqlen is None: repad = True - with torch.no_grad(): - input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input( - inputs=input_ids, attention_mask=attention_mask + if inputs_embeds is None: + with torch.no_grad(): + input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input( + inputs=input_ids, attention_mask=attention_mask + ) + else: + inputs_embeds, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input( + inputs=inputs_embeds, attention_mask=attention_mask ) else: if position_ids is None: - position_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0) + position_ids = torch.arange(seq_len, device=device).unsqueeze(0) attention_mask, sliding_window_mask = self._update_attention_mask( attention_mask, output_attentions=output_attentions ) - hidden_states = self.embeddings(input_ids) + hidden_states = self.embeddings(input_ids=input_ids, inputs_embeds=inputs_embeds) for encoder_layer in self.layers: if output_hidden_states: @@ -1025,10 +1049,11 @@ def compiled_head(self, output: torch.Tensor) -> torch.Tensor: ) def forward( self, - input_ids: Optional[torch.Tensor], + input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, sliding_window_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None, @@ -1045,19 +1070,32 @@ def forward( if self.config._attn_implementation == "flash_attention_2": if indices is None and cu_seqlens is None and max_seqlen is None: - batch_size, seq_len = input_ids.shape[:2] + if batch_size is None and seq_len is None: + if inputs_embeds is not None: + batch_size, seq_len = inputs_embeds.shape[:2] + else: + batch_size, seq_len = input_ids.shape[:2] + device = input_ids.device if input_ids is not None else inputs_embeds.device + if attention_mask is None: - attention_mask = torch.ones((batch_size, seq_len), device=input_ids.device, dtype=torch.bool) - with torch.no_grad(): - input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input( - inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels + attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool) + + if inputs_embeds is None: + with torch.no_grad(): + input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input( + inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels + ) + else: + inputs_embeds, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input( + inputs=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, labels=labels ) outputs = self.model( - input_ids, + input_ids=input_ids, attention_mask=attention_mask, sliding_window_mask=sliding_window_mask, position_ids=position_ids, + inputs_embeds=inputs_embeds, indices=indices, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, @@ -1130,10 +1168,11 @@ def __init__(self, config: ModernBertConfig): ) def forward( self, - input_ids: Optional[torch.Tensor], + input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, sliding_window_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None, @@ -1155,10 +1194,11 @@ def forward( self._maybe_set_compile() outputs = self.model( - input_ids, + input_ids=input_ids, attention_mask=attention_mask, sliding_window_mask=sliding_window_mask, position_ids=position_ids, + inputs_embeds=inputs_embeds, indices=indices, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, @@ -1241,10 +1281,11 @@ def __init__(self, config: ModernBertConfig): ) def forward( self, - input_ids: Optional[torch.Tensor], + input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, sliding_window_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None, @@ -1263,10 +1304,11 @@ def forward( self._maybe_set_compile() outputs = self.model( - input_ids, + input_ids=input_ids, attention_mask=attention_mask, sliding_window_mask=sliding_window_mask, position_ids=position_ids, + inputs_embeds=inputs_embeds, indices=indices, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index b747940f3b0623..aaad42a7e67c93 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -464,12 +464,17 @@ def __init__(self, config: ModernBertConfig): def compiled_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor: return self.drop(self.norm(self.tok_embeddings(input_ids))) - def forward(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None) -> torch.Tensor: - hidden_states = ( - self.compiled_embeddings(input_ids) - if self.config.reference_compile - else self.drop(self.norm(self.tok_embeddings(input_ids))) - ) + def forward( + self, input_ids: torch.LongTensor = None, inputs_embeds: Optional[torch.Tensor] = None + ) -> torch.Tensor: + if inputs_embeds is not None: + hidden_states = self.drop(self.norm(inputs_embeds)) + else: + hidden_states = ( + self.compiled_embeddings(input_ids) + if self.config.reference_compile + else self.drop(self.norm(self.tok_embeddings(input_ids))) + ) return hidden_states @@ -944,6 +949,10 @@ def resize_token_embeddings(self, *args, **kwargs): config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*): Indices of the non-padding tokens in the input sequence. Used for unpadding the output. cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*): @@ -995,10 +1004,11 @@ def set_input_embeddings(self, value): ) def forward( self, - input_ids: torch.LongTensor = None, + input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, sliding_window_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None, @@ -1014,35 +1024,49 @@ def forward( ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None self._maybe_set_compile() - self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + + if input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) if batch_size is None and seq_len is None: - batch_size, seq_len = input_ids.shape[:2] + if inputs_embeds is not None: + batch_size, seq_len = inputs_embeds.shape[:2] + else: + batch_size, seq_len = input_ids.shape[:2] + device = input_ids.device if input_ids is not None else inputs_embeds.device if attention_mask is None: - attention_mask = torch.ones((batch_size, seq_len), device=input_ids.device, dtype=torch.bool) + attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool) repad = False if self.config._attn_implementation == "flash_attention_2": if indices is None and cu_seqlens is None and max_seqlen is None: repad = True - with torch.no_grad(): - input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input( - inputs=input_ids, attention_mask=attention_mask + if inputs_embeds is None: + with torch.no_grad(): + input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input( + inputs=input_ids, attention_mask=attention_mask + ) + else: + inputs_embeds, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input( + inputs=inputs_embeds, attention_mask=attention_mask ) else: if position_ids is None: - position_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0) + position_ids = torch.arange(seq_len, device=device).unsqueeze(0) attention_mask, sliding_window_mask = self._update_attention_mask( attention_mask, output_attentions=output_attentions ) - hidden_states = self.embeddings(input_ids) + hidden_states = self.embeddings(input_ids=input_ids, inputs_embeds=inputs_embeds) for encoder_layer in self.layers: if output_hidden_states: @@ -1178,10 +1202,11 @@ def compiled_head(self, output: torch.Tensor) -> torch.Tensor: ) def forward( self, - input_ids: Optional[torch.Tensor], + input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, sliding_window_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None, @@ -1198,19 +1223,32 @@ def forward( if self.config._attn_implementation == "flash_attention_2": if indices is None and cu_seqlens is None and max_seqlen is None: - batch_size, seq_len = input_ids.shape[:2] + if batch_size is None and seq_len is None: + if inputs_embeds is not None: + batch_size, seq_len = inputs_embeds.shape[:2] + else: + batch_size, seq_len = input_ids.shape[:2] + device = input_ids.device if input_ids is not None else inputs_embeds.device + if attention_mask is None: - attention_mask = torch.ones((batch_size, seq_len), device=input_ids.device, dtype=torch.bool) - with torch.no_grad(): - input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input( - inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels + attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool) + + if inputs_embeds is None: + with torch.no_grad(): + input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input( + inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels + ) + else: + inputs_embeds, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input( + inputs=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, labels=labels ) outputs = self.model( - input_ids, + input_ids=input_ids, attention_mask=attention_mask, sliding_window_mask=sliding_window_mask, position_ids=position_ids, + inputs_embeds=inputs_embeds, indices=indices, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, @@ -1283,10 +1321,11 @@ def __init__(self, config: ModernBertConfig): ) def forward( self, - input_ids: Optional[torch.Tensor], + input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, sliding_window_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None, @@ -1308,10 +1347,11 @@ def forward( self._maybe_set_compile() outputs = self.model( - input_ids, + input_ids=input_ids, attention_mask=attention_mask, sliding_window_mask=sliding_window_mask, position_ids=position_ids, + inputs_embeds=inputs_embeds, indices=indices, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, @@ -1394,10 +1434,11 @@ def __init__(self, config: ModernBertConfig): ) def forward( self, - input_ids: Optional[torch.Tensor], + input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, sliding_window_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None, @@ -1416,10 +1457,11 @@ def forward( self._maybe_set_compile() outputs = self.model( - input_ids, + input_ids=input_ids, attention_mask=attention_mask, sliding_window_mask=sliding_window_mask, position_ids=position_ids, + inputs_embeds=inputs_embeds, indices=indices, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, diff --git a/tests/models/modernbert/test_modeling_modernbert.py b/tests/models/modernbert/test_modeling_modernbert.py index 4fce0cd86352f0..08e77505b5b787 100644 --- a/tests/models/modernbert/test_modeling_modernbert.py +++ b/tests/models/modernbert/test_modeling_modernbert.py @@ -146,7 +146,11 @@ def get_config(self): # If we're testing `test_retain_grad_hidden_states_attentions`, we normally get an error # that compilation doesn't work. Users can then set compile=False when loading the model, # much like here. We're testing whether it works once they've done that. - if test_name == "test_retain_grad_hidden_states_attentions": + + # If we're testing `test_inputs_embeds_matches_input_ids`, then we'd like to test with `reference_compile` + # set to False, otherwise the input_ids with compiled input embeddings will not match the inputs_embeds + # with atol=1e-8 and rtol=1e-5 + if test_name in ("test_retain_grad_hidden_states_attentions", "test_inputs_embeds_matches_input_ids"): config.reference_compile = False # Some tests require attentions to be outputted, in that case we'll set the attention implementation to eager # as the others don't support outputted attentions @@ -294,10 +298,6 @@ def test_initialization(self): msg=f"Parameter {name} of model {model_class} seems not properly initialized", ) - @unittest.skip("ModernBert doesn't use `inputs_embeds` as input.") - def test_inputs_embeds(self): - pass - def test_for_masked_lm(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)