Skip to content

Commit

Permalink
vectorize works for batch of images and text
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurZucker committed Dec 3, 2023
1 parent 33b7dc8 commit ebec096
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions src/transformers/models/llava/modeling_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,10 +302,8 @@ def _merge_input_ids_with_image_features(self, image_features, input_embeds, inp

# 4. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling (apart from the padding)
image_to_overwrite = torch.all(final_embedding == 0, dim=-1)
image_to_overwrite &= image_to_overwrite.cumsum(-1) <= (num_image_tokens * nb_text_tokens_per_images)[:, None]
final_embedding[image_to_overwrite] = image_features
# We can have multiple images in a single batch, hence we use different
# indexes for image and text.
image_to_overwrite &= image_to_overwrite.cumsum(-1) <= (num_image_tokens * nb_text_tokens_per_images)[:, None]
final_embedding[image_to_overwrite] = image_features.reshape(-1, 4096)
return input_embeds, attention_mask, position_ids

@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
Expand Down

0 comments on commit ebec096

Please sign in to comment.