Skip to content

Commit

Permalink
Paligemma causal attention mask (#30967)
Browse files Browse the repository at this point in the history
* PaliGemma working causal attention

* Formatting

* Style

* Docstrings + remove commented code

* Update docstring for PaliGemma Config

* PaliGemma - add separator ind to model/labels

* Refactor + docstring paligemma processor method

* Style

* return token type ids when tokenizing labels

* use token type ids when building causal mask

* add token type ids to tester

* remove separator from config

* fix style

* don't ignore separator

* add processor documentation

* simplify tokenization

* fix causal mask

* style

* fix label propagation, revert suffix naming

* fix style

* fix labels tokenization

* [run-slow]paligemma

* add eos if suffixes are present

* [run-slow]paligemma

* [run-slow]paligemma

* add misssing tokens to fast version

* Apply suggestions from code review

Co-authored-by: Arthur <[email protected]>

* fix style

* [run-slow]paligemma

---------

Co-authored-by: Peter Robicheaux <[email protected]>
Co-authored-by: Arthur <[email protected]>
  • Loading branch information
3 people committed May 22, 2024
1 parent e5b788a commit 8282db5
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 47 deletions.
53 changes: 41 additions & 12 deletions src/transformers/models/paligemma/modeling_paligemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,9 +282,14 @@ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_m
self.vocab_size = model_embeds.num_embeddings
return model_embeds

def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
def _merge_input_ids_with_image_features(
self, image_features, inputs_embeds, input_ids, attention_mask, labels, token_type_ids, cache_position
):
_, _, embed_dim = image_features.shape
batch_size, sequence_length = input_ids.shape
dtype, device = inputs_embeds.dtype, inputs_embeds.device
min_dtype = torch.finfo(dtype).min

scaled_image_features = image_features / (self.config.hidden_size**0.5)
final_embedding = torch.zeros(
batch_size, sequence_length, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
Expand All @@ -305,24 +310,43 @@ def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, in
image_mask.unsqueeze(-1).expand_as(final_embedding), scaled_image_features
)
final_embedding = torch.where(pad_mask_expanded, torch.zeros_like(final_embedding), final_embedding)
if attention_mask is not None:
position_ids = (attention_mask.cumsum(-1)).masked_fill_((attention_mask == 0), 1)
else:
position_ids = None

final_attention_mask_4d = attention_mask.unsqueeze(1).unsqueeze(2) * attention_mask.unsqueeze(1).unsqueeze(-1)
final_attention_mask_4d = final_attention_mask_4d.float().expand(
-1, self.config.text_config.num_key_value_heads, -1, -1
)

# position_ids = torch.arange(0, sequence_length, device=input_ids.device).expand(batch_size, -1)
# position_ids = torch.where(input_ids == self.pad_token_id, torch.ones_like(position_ids), position_ids)
position_ids = (attention_mask.cumsum(-1)).masked_fill_((attention_mask == 0), 1)
if token_type_ids is not None and labels is not None:
# we are training thus we need to create a full mask on the image + prefix but causal on suffix
target_length = cache_position[-1] + 1
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(inputs_embeds.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
# unmask the prefill
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
token_type_ids[:, None, None, :] == 0, 0
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)

if labels is not None:
final_labels = torch.full(
(batch_size, sequence_length), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
)
final_labels = torch.where(input_ids != self.pad_token_id, labels, final_labels)
else:
causal_mask = attention_mask.unsqueeze(1).unsqueeze(2) * attention_mask.unsqueeze(1).unsqueeze(-1)
causal_mask = causal_mask.to(dtype).expand(-1, self.config.text_config.num_key_value_heads, -1, -1)
final_labels = None
return final_embedding, final_attention_mask_4d, final_labels, position_ids
return final_embedding, causal_mask, final_labels, position_ids

@add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=PaliGemmaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
Expand All @@ -333,6 +357,7 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
token_type_ids: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
Expand Down Expand Up @@ -396,8 +421,10 @@ def forward(
selected_image_feature = image_outputs.last_hidden_state
image_features = self.multi_modal_projector(selected_image_feature)

if cache_position is None:
cache_position = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device)
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
image_features, inputs_embeds, input_ids, attention_mask, labels
image_features, inputs_embeds, input_ids, attention_mask, labels, token_type_ids, cache_position
)

else:
Expand Down Expand Up @@ -486,6 +513,7 @@ def prepare_inputs_for_generation(
cache_position=None,
pixel_values=None,
attention_mask=None,
token_type_ids=None,
**kwargs,
):
past_length = 0
Expand Down Expand Up @@ -544,6 +572,7 @@ def prepare_inputs_for_generation(
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"pixel_values": pixel_values,
"token_type_ids": token_type_ids,
}
)
return model_inputs
Expand Down
Loading

0 comments on commit 8282db5

Please sign in to comment.