Skip to content

Commit

Permalink
make correct dispatch operations
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurZucker committed Dec 3, 2023
1 parent 587b8e6 commit 33b7dc8
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 18 deletions.
27 changes: 19 additions & 8 deletions src/transformers/models/llava/modeling_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,16 +283,27 @@ def _merge_input_ids_with_image_features(self, image_features, input_embeds, inp
# 1. Create a mask to know where image tokens are
image_token_mask = (input_ids == self.config.image_token_index)
num_image_tokens = torch.sum(image_token_mask, dim = -1)
max_embed_dim = num_image_tokens.max()

# 2. Create the full embedding
final_embedding = torch.zeros(input_ids.shape[0], self.config.text_config.max_position_embeddings, input_embeds.shape[-1])
nb_text_tokens_per_images = image_features.shape[1]-1

nb_text_tokens_per_images = image_features.shape[1]

# 2. Compute the positions where text should be written
text_to_overwrite = torch.cumsum(image_token_mask*nb_text_tokens_per_images + 1, -1)-1

# 3. Create the full embedding, already padded to the maximum position
max_embed_dim = text_to_overwrite.max()
final_embedding = torch.zeros(input_ids.shape[0], max_embed_dim+1, input_embeds.shape[-1])

# 3. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
final_embedding[:, torch.cumsum(image_token_mask*nb_text_tokens_per_images + 1, -1)] = input_embeds
final_embedding[:, torch.range(0, image_features.shape[1], dtype=torch.long) * num_image_tokens] = image_features.split(num_image_tokens)[0]
final_embedding.scatter_(-2, text_to_overwrite.unsqueeze(2).expand_as(input_embeds), input_embeds)

# equivalent to
# batch_indices = torch.arange(final_embedding.size(0)).view(-1, 1).expand_as(text_to_overwrite)
# final_embedding[batch_indices,text_to_overwrite] = input_embeds # we also right on the start image token

# 4. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling (apart from the padding)
image_to_overwrite = torch.all(final_embedding == 0, dim=-1)
image_to_overwrite &= image_to_overwrite.cumsum(-1) <= (num_image_tokens * nb_text_tokens_per_images)[:, None]
final_embedding[image_to_overwrite] = image_features
# We can have multiple images in a single batch, hence we use different
# indexes for image and text.
return input_embeds, attention_mask, position_ids
Expand Down
15 changes: 5 additions & 10 deletions src/transformers/models/llava/processing_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,18 @@ def __init__(self, image_processor, tokenizer=None):
raise ValueError("You need to specify an `image_processor`.")
if tokenizer is None:
raise ValueError("You need to specify a `tokenizer`.")
tokenizer.add_tokens(AddedToken("<image>", special = True, normalized = False))
tokenizer.add_tokens(AddedToken("<image>", special=True, normalized=False))
super().__init__(image_processor, tokenizer)
self.current_processor = self.image_processor

def __call__(
self,
text=None,
images=None,
padding=None,
truncation=None,
transform: Callable = None,
max_length=None,
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
) -> BatchEncoding:
"""This method takes batched or non-batched text made of text and images and converts them into text that
Expand All @@ -85,14 +88,6 @@ def __call__(
A custom transform function that accepts a single image can be passed for training. For example,
`torchvision.Compose` can be used to compose multiple functions. If `None` a preset inference-specific
set of transforms will be applied to the images
add_eos_token (`bool`, *optional*, defaults to `False`):
Adds `eos_token` at the end of the final prompt if True`
add_end_of_utterance_token (`bool`, *optional*)
Whether to automatically add `<end_of_utterance>` after each prompt's text input (unless followed by an
image). If `None` the tokenizer will be checked instead and if this token is found in
`additional_special_tokens` then the value will be `True`.
debug (`bool`, *optional*, defaults to `False`):
`True` value will help debug prompt generation by dumping useful information
return_tensors (`str` or `TensorType`, *optional*, defaults to `TensorType.PYTORCH`):
The type of tensors to return. Can be one of:
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
Expand All @@ -115,7 +110,7 @@ def __call__(
pixel_values = None

# Attention mask have to be created later on? Or not?
text_inputs = self.tokenizer(text, return_tensors=return_tensors)
text_inputs = self.tokenizer(text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length)

return BatchFeature(data={**text_inputs,"pixel_values": pixel_values})

Expand Down

0 comments on commit 33b7dc8

Please sign in to comment.