From 33b7dc8c613dacd4e9efead103ad7d93679ed7fb Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Sun, 3 Dec 2023 12:47:40 +0100 Subject: [PATCH] make correct dispatch operations --- .../models/llava/modeling_llava.py | 27 +++++++++++++------ .../models/llava/processing_llava.py | 15 ++++------- 2 files changed, 24 insertions(+), 18 deletions(-) diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index 751f292f3ee8f5..6eb44b1858960a 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -283,16 +283,27 @@ def _merge_input_ids_with_image_features(self, image_features, input_embeds, inp # 1. Create a mask to know where image tokens are image_token_mask = (input_ids == self.config.image_token_index) num_image_tokens = torch.sum(image_token_mask, dim = -1) - max_embed_dim = num_image_tokens.max() - - # 2. Create the full embedding - final_embedding = torch.zeros(input_ids.shape[0], self.config.text_config.max_position_embeddings, input_embeds.shape[-1]) - nb_text_tokens_per_images = image_features.shape[1]-1 - + nb_text_tokens_per_images = image_features.shape[1] + + # 2. Compute the positions where text should be written + text_to_overwrite = torch.cumsum(image_token_mask*nb_text_tokens_per_images + 1, -1)-1 + + # 3. Create the full embedding, already padded to the maximum position + max_embed_dim = text_to_overwrite.max() + final_embedding = torch.zeros(input_ids.shape[0], max_embed_dim+1, input_embeds.shape[-1]) + # 3. Fill the embeddings based on the mask. If we have ["hey" "", "how", "are"] # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features - final_embedding[:, torch.cumsum(image_token_mask*nb_text_tokens_per_images + 1, -1)] = input_embeds - final_embedding[:, torch.range(0, image_features.shape[1], dtype=torch.long) * num_image_tokens] = image_features.split(num_image_tokens)[0] + final_embedding.scatter_(-2, text_to_overwrite.unsqueeze(2).expand_as(input_embeds), input_embeds) + + # equivalent to + # batch_indices = torch.arange(final_embedding.size(0)).view(-1, 1).expand_as(text_to_overwrite) + # final_embedding[batch_indices,text_to_overwrite] = input_embeds # we also right on the start image token + + # 4. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling (apart from the padding) + image_to_overwrite = torch.all(final_embedding == 0, dim=-1) + image_to_overwrite &= image_to_overwrite.cumsum(-1) <= (num_image_tokens * nb_text_tokens_per_images)[:, None] + final_embedding[image_to_overwrite] = image_features # We can have multiple images in a single batch, hence we use different # indexes for image and text. return input_embeds, attention_mask, position_ids diff --git a/src/transformers/models/llava/processing_llava.py b/src/transformers/models/llava/processing_llava.py index 7833e898abd5d0..ddf31ade81cb43 100644 --- a/src/transformers/models/llava/processing_llava.py +++ b/src/transformers/models/llava/processing_llava.py @@ -50,7 +50,7 @@ def __init__(self, image_processor, tokenizer=None): raise ValueError("You need to specify an `image_processor`.") if tokenizer is None: raise ValueError("You need to specify a `tokenizer`.") - tokenizer.add_tokens(AddedToken("", special = True, normalized = False)) + tokenizer.add_tokens(AddedToken("", special=True, normalized=False)) super().__init__(image_processor, tokenizer) self.current_processor = self.image_processor @@ -58,7 +58,10 @@ def __call__( self, text=None, images=None, + padding=None, + truncation=None, transform: Callable = None, + max_length=None, return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, ) -> BatchEncoding: """This method takes batched or non-batched text made of text and images and converts them into text that @@ -85,14 +88,6 @@ def __call__( A custom transform function that accepts a single image can be passed for training. For example, `torchvision.Compose` can be used to compose multiple functions. If `None` a preset inference-specific set of transforms will be applied to the images - add_eos_token (`bool`, *optional*, defaults to `False`): - Adds `eos_token` at the end of the final prompt if True` - add_end_of_utterance_token (`bool`, *optional*) - Whether to automatically add `` after each prompt's text input (unless followed by an - image). If `None` the tokenizer will be checked instead and if this token is found in - `additional_special_tokens` then the value will be `True`. - debug (`bool`, *optional*, defaults to `False`): - `True` value will help debug prompt generation by dumping useful information return_tensors (`str` or `TensorType`, *optional*, defaults to `TensorType.PYTORCH`): The type of tensors to return. Can be one of: - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. @@ -115,7 +110,7 @@ def __call__( pixel_values = None # Attention mask have to be created later on? Or not? - text_inputs = self.tokenizer(text, return_tensors=return_tensors) + text_inputs = self.tokenizer(text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length) return BatchFeature(data={**text_inputs,"pixel_values": pixel_values})