Skip to content

Commit

Permalink
Add inputs_embeds param to ModernBertModel (#35373)
Browse files Browse the repository at this point in the history
* update modular_modernbert -- add inputs_embeds param to ModernBertModel

* Fix implementation issues; extend to other classes; docstring

First of all, the inputs_embeds shouldn't fully replace `self.embeddings(input_ids)`, because this call also does layer normalization and dropout. So, now both input_ids and inputs_embeds is passed to the ModernBertEmbeddings, much like how BertEmbeddings is implemented.

I also added `inputs_embeds` to the docstring, and propagated the changes to the other model classes.

I also introduced an error if input_ids and input_embeds are both or neither provided.

Lastly, I fixed an issue with device being based solely on input_ids with attention_mask.

* Propagate inputs_embeds to ModernBertForMaskedLM correctly

Also reintroduce inputs_embeds test

---------

Co-authored-by: Tom Aarsen <[email protected]>
  • Loading branch information
jxmorris12 and tomaarsen authored Jan 9, 2025
1 parent 1b2f942 commit 832c619
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 57 deletions.
94 changes: 68 additions & 26 deletions src/transformers/models/modernbert/modeling_modernbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,12 +205,17 @@ def __init__(self, config: ModernBertConfig):
def compiled_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor:
return self.drop(self.norm(self.tok_embeddings(input_ids)))

def forward(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None) -> torch.Tensor:
hidden_states = (
self.compiled_embeddings(input_ids)
if self.config.reference_compile
else self.drop(self.norm(self.tok_embeddings(input_ids)))
)
def forward(
self, input_ids: torch.LongTensor = None, inputs_embeds: Optional[torch.Tensor] = None
) -> torch.Tensor:
if inputs_embeds is not None:
hidden_states = self.drop(self.norm(inputs_embeds))
else:
hidden_states = (
self.compiled_embeddings(input_ids)
if self.config.reference_compile
else self.drop(self.norm(self.tok_embeddings(input_ids)))
)
return hidden_states


Expand Down Expand Up @@ -791,6 +796,10 @@ def _pad_modernbert_output(
config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
Expand Down Expand Up @@ -842,10 +851,11 @@ def set_input_embeddings(self, value):
)
def forward(
self,
input_ids: torch.LongTensor = None,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
sliding_window_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
indices: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
Expand All @@ -861,35 +871,49 @@ def forward(
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None

self._maybe_set_compile()
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)

if input_ids is not None:
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)

if batch_size is None and seq_len is None:
batch_size, seq_len = input_ids.shape[:2]
if inputs_embeds is not None:
batch_size, seq_len = inputs_embeds.shape[:2]
else:
batch_size, seq_len = input_ids.shape[:2]
device = input_ids.device if input_ids is not None else inputs_embeds.device

if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_len), device=input_ids.device, dtype=torch.bool)
attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)

repad = False
if self.config._attn_implementation == "flash_attention_2":
if indices is None and cu_seqlens is None and max_seqlen is None:
repad = True
with torch.no_grad():
input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
inputs=input_ids, attention_mask=attention_mask
if inputs_embeds is None:
with torch.no_grad():
input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
inputs=input_ids, attention_mask=attention_mask
)
else:
inputs_embeds, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
inputs=inputs_embeds, attention_mask=attention_mask
)
else:
if position_ids is None:
position_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
position_ids = torch.arange(seq_len, device=device).unsqueeze(0)

attention_mask, sliding_window_mask = self._update_attention_mask(
attention_mask, output_attentions=output_attentions
)

hidden_states = self.embeddings(input_ids)
hidden_states = self.embeddings(input_ids=input_ids, inputs_embeds=inputs_embeds)

for encoder_layer in self.layers:
if output_hidden_states:
Expand Down Expand Up @@ -1025,10 +1049,11 @@ def compiled_head(self, output: torch.Tensor) -> torch.Tensor:
)
def forward(
self,
input_ids: Optional[torch.Tensor],
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
sliding_window_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
indices: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
Expand All @@ -1045,19 +1070,32 @@ def forward(

if self.config._attn_implementation == "flash_attention_2":
if indices is None and cu_seqlens is None and max_seqlen is None:
batch_size, seq_len = input_ids.shape[:2]
if batch_size is None and seq_len is None:
if inputs_embeds is not None:
batch_size, seq_len = inputs_embeds.shape[:2]
else:
batch_size, seq_len = input_ids.shape[:2]
device = input_ids.device if input_ids is not None else inputs_embeds.device

if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_len), device=input_ids.device, dtype=torch.bool)
with torch.no_grad():
input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels
attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)

if inputs_embeds is None:
with torch.no_grad():
input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels
)
else:
inputs_embeds, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
inputs=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, labels=labels
)

outputs = self.model(
input_ids,
input_ids=input_ids,
attention_mask=attention_mask,
sliding_window_mask=sliding_window_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
indices=indices,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
Expand Down Expand Up @@ -1130,10 +1168,11 @@ def __init__(self, config: ModernBertConfig):
)
def forward(
self,
input_ids: Optional[torch.Tensor],
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
sliding_window_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
indices: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
Expand All @@ -1155,10 +1194,11 @@ def forward(
self._maybe_set_compile()

outputs = self.model(
input_ids,
input_ids=input_ids,
attention_mask=attention_mask,
sliding_window_mask=sliding_window_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
indices=indices,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
Expand Down Expand Up @@ -1241,10 +1281,11 @@ def __init__(self, config: ModernBertConfig):
)
def forward(
self,
input_ids: Optional[torch.Tensor],
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
sliding_window_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
indices: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
Expand All @@ -1263,10 +1304,11 @@ def forward(
self._maybe_set_compile()

outputs = self.model(
input_ids,
input_ids=input_ids,
attention_mask=attention_mask,
sliding_window_mask=sliding_window_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
indices=indices,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
Expand Down
Loading

0 comments on commit 832c619

Please sign in to comment.