Skip to content

Commit

Permalink
Paligemma- fix devices and dtype assignments (#31008)
Browse files Browse the repository at this point in the history
* fix devices and dtype assignments

* [run-slow]paligemma
  • Loading branch information
molbap authored May 24, 2024
1 parent deba765 commit bdb9106
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions src/transformers/models/paligemma/modeling_paligemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,14 +301,15 @@ def _merge_input_ids_with_image_features(
pad_mask = input_ids == self.pad_token_id

# expand masks to match embedding dimension
text_mask_expanded = text_mask.unsqueeze(-1).expand(-1, -1, embed_dim)
pad_mask_expanded = pad_mask.unsqueeze(-1).expand(-1, -1, embed_dim)
text_mask_expanded = text_mask.unsqueeze(-1).expand(-1, -1, embed_dim).to(inputs_embeds.device)
pad_mask_expanded = pad_mask.unsqueeze(-1).expand(-1, -1, embed_dim).to(inputs_embeds.device)
# insert padding and text token embeddings
final_embedding = torch.where(text_mask_expanded, inputs_embeds, final_embedding)
final_embedding = torch.where(pad_mask_expanded, torch.zeros_like(final_embedding), final_embedding)
# insert image embeddings - the image mask is always less or equal to the sentence in length
final_embedding = final_embedding.masked_scatter(
image_mask.unsqueeze(-1).expand_as(final_embedding), scaled_image_features
image_mask.unsqueeze(-1).expand_as(final_embedding).to(device=final_embedding.device),
scaled_image_features.to(device=final_embedding.device, dtype=final_embedding.dtype),
)
final_embedding = torch.where(pad_mask_expanded, torch.zeros_like(final_embedding), final_embedding)
if attention_mask is not None:
Expand All @@ -329,10 +330,12 @@ def _merge_input_ids_with_image_features(
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
# unmask the prefill
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
token_type_ids[:, None, None, :] == 0, 0
token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
Expand Down Expand Up @@ -484,7 +487,7 @@ def forward(
# we use the input attention mask to shift the logits and labels, because it is 2D.
shift_attention_mask = input_attention_mask[..., 1:]
shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
shift_labels = shift_labels[shift_attention_mask.to(logits.device) != 0].contiguous()
shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
else:
shift_logits = shift_logits.contiguous()
shift_labels = shift_labels.contiguous()
Expand Down

0 comments on commit bdb9106

Please sign in to comment.