From a2e5b22bc156ab873adf39d5239384c5a9082cc1 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Sun, 3 Dec 2023 16:58:13 +0100 Subject: [PATCH] last todos --- .../models/llava/modeling_llava.py | 47 ++++++++++--------- 1 file changed, 26 insertions(+), 21 deletions(-) diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index 791b64c8fe7c82..455e47c75cbb7b 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -217,7 +217,7 @@ def _init_weights(self, module): If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape `(batch_size, sequence_length)`. - input_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. @@ -279,7 +279,7 @@ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_m self.config.text_config.vocab_size = new_num_tokens return model_embeds - def _merge_input_ids_with_image_features(self, image_features, input_embeds, input_ids, attention_mask, position_ids): + def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, position_ids): # 1. Create a mask to know where image tokens are image_token_mask = (input_ids == self.config.image_token_index) num_image_tokens = torch.sum(image_token_mask, dim = -1) @@ -290,21 +290,26 @@ def _merge_input_ids_with_image_features(self, image_features, input_embeds, inp # 3. Create the full embedding, already padded to the maximum position max_embed_dim = text_to_overwrite.max() - final_embedding = torch.zeros(input_ids.shape[0], max_embed_dim+1, input_embeds.shape[-1]) + final_embedding = torch.zeros(input_ids.shape[0], max_embed_dim+1, inputs_embeds.shape[-1]) # 3. Fill the embeddings based on the mask. If we have ["hey" "", "how", "are"] # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features - final_embedding.scatter_(-2, text_to_overwrite.unsqueeze(2).expand_as(input_embeds), input_embeds) + final_embedding.scatter_(-2, text_to_overwrite.unsqueeze(2).expand_as(inputs_embeds), inputs_embeds) # equivalent to # batch_indices = torch.arange(final_embedding.size(0)).view(-1, 1).expand_as(text_to_overwrite) - # final_embedding[batch_indices,text_to_overwrite] = input_embeds # we also right on the start image token + # final_embedding[batch_indices,text_to_overwrite] = inputs_embeds # we also right on the start image token - # 4. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling (apart from the padding) + # 4. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling image_to_overwrite = torch.all(final_embedding == 0, dim=-1) + + # Make sur to not write on the padding image_to_overwrite &= image_to_overwrite.cumsum(-1) <= (num_image_tokens * nb_text_tokens_per_images)[:, None] final_embedding[image_to_overwrite] = image_features.reshape(-1, 4096) - return input_embeds, attention_mask, position_ids + + # TODO last thing is to update the positions ids, easy you can just offset the index to overwrite. + # TODO update the attention mask correctly + return final_embedding, attention_mask, position_ids @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) @@ -315,7 +320,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, - input_embeds: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, vision_feature_layer: Optional[int] = None, vision_feature_select_strategy: Optional[str] = None, labels: Optional[torch.LongTensor] = None, @@ -364,9 +369,9 @@ def forward( else self.config.vision_feature_select_strategy ) - if input_embeds is None: + if inputs_embeds is None: # 1. Extra the input embeddings - input_embeds = self.get_input_embeddings()(input_ids) + inputs_embeds = self.get_input_embeddings()(input_ids) # 2. Merge text and images if pixel_values is not None and input_ids.shape[1] != 1: @@ -384,7 +389,7 @@ def forward( ) image_features = self.multi_modal_projector(selected_image_feature) - input_embeds, attention_mask, position_ids = self._merge_input_ids_with_image_features(image_features, input_embeds, input_ids, attention_mask, position_ids) + inputs_embeds, attention_mask, position_ids = self._merge_input_ids_with_image_features(image_features, inputs_embeds, input_ids, attention_mask, position_ids) if labels is None: labels = torch.full_like(input_ids, self.config.ignore_index) @@ -392,7 +397,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, - input_embeds=input_embeds, + inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, @@ -429,14 +434,14 @@ def forward( attentions=outputs.attentions, ) - def _optionally_pad_input_embeds( - self, input_embeds, attention_mask, position_ids, labels, max_seqlen, batch_size + def _optionally_pad_inputs_embeds( + self, inputs_embeds, attention_mask, position_ids, labels, max_seqlen, batch_size ): r""" Optionally pad the input embeddings by correctly setting the padding tokens on the correct places inside the attention mask, position ids and labels. """ - padded_input_embeds = [] + padded_inputs_embeds = [] padded_labels = torch.full( (batch_size, max_seqlen), self.config.ignore_index, @@ -444,14 +449,14 @@ def _optionally_pad_input_embeds( device=labels[0].device, ) - for i, (current_embeds, cur_new_labels) in enumerate(zip(input_embeds, labels)): + for i, (current_embeds, cur_new_labels) in enumerate(zip(inputs_embeds, labels)): # Get the current sequence length and padding side # then optionally padd the input embeds current_seq_len = current_embeds.shape[0] padding_side = getattr(self.config, "tokenizer_padding_side", "right") padded_embedding = pad_sequence(current_embeds, max_seqlen, padding_side) - padded_input_embeds.append(padded_embedding) + padded_inputs_embeds.append(padded_embedding) if current_seq_len > 0: start_index = -current_seq_len if padding_side == "left" else 0 @@ -463,11 +468,11 @@ def _optionally_pad_input_embeds( 0, current_seq_len, dtype=position_ids.dtype, device=position_ids.device ) - input_embeds = torch.stack(padded_input_embeds, dim=0) - return input_embeds, attention_mask, position_ids, padded_labels + inputs_embeds = torch.stack(padded_inputs_embeds, dim=0) + return inputs_embeds, attention_mask, position_ids, padded_labels - def prepare_inputs_for_generation(self, input_ids, past_key_values=None, input_embeds=None, pixel_values=None, **kwargs): + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, pixel_values=None, **kwargs): # Call `prepare_inputs_for_generation` from the LM - model_input = self.language_model.prepare_inputs_for_generation(input_ids, past_key_values, input_embeds=input_embeds, **kwargs) + model_input = self.language_model.prepare_inputs_for_generation(input_ids, past_key_values, inputs_embeds=inputs_embeds, **kwargs) model_input.update({"pixel_values": pixel_values}) return model_input