Skip to content

Commit

Permalink
last todos
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurZucker committed Dec 3, 2023
1 parent ebec096 commit a2e5b22
Showing 1 changed file with 26 additions and 21 deletions.
47 changes: 26 additions & 21 deletions src/transformers/models/llava/modeling_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def _init_weights(self, module):
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
input_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
Expand Down Expand Up @@ -279,7 +279,7 @@ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_m
self.config.text_config.vocab_size = new_num_tokens
return model_embeds

def _merge_input_ids_with_image_features(self, image_features, input_embeds, input_ids, attention_mask, position_ids):
def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, position_ids):
# 1. Create a mask to know where image tokens are
image_token_mask = (input_ids == self.config.image_token_index)
num_image_tokens = torch.sum(image_token_mask, dim = -1)
Expand All @@ -290,21 +290,26 @@ def _merge_input_ids_with_image_features(self, image_features, input_embeds, inp

# 3. Create the full embedding, already padded to the maximum position
max_embed_dim = text_to_overwrite.max()
final_embedding = torch.zeros(input_ids.shape[0], max_embed_dim+1, input_embeds.shape[-1])
final_embedding = torch.zeros(input_ids.shape[0], max_embed_dim+1, inputs_embeds.shape[-1])

# 3. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
final_embedding.scatter_(-2, text_to_overwrite.unsqueeze(2).expand_as(input_embeds), input_embeds)
final_embedding.scatter_(-2, text_to_overwrite.unsqueeze(2).expand_as(inputs_embeds), inputs_embeds)

# equivalent to
# batch_indices = torch.arange(final_embedding.size(0)).view(-1, 1).expand_as(text_to_overwrite)
# final_embedding[batch_indices,text_to_overwrite] = input_embeds # we also right on the start image token
# final_embedding[batch_indices,text_to_overwrite] = inputs_embeds # we also right on the start image token

# 4. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling (apart from the padding)
# 4. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling
image_to_overwrite = torch.all(final_embedding == 0, dim=-1)

# Make sur to not write on the padding
image_to_overwrite &= image_to_overwrite.cumsum(-1) <= (num_image_tokens * nb_text_tokens_per_images)[:, None]
final_embedding[image_to_overwrite] = image_features.reshape(-1, 4096)
return input_embeds, attention_mask, position_ids

# TODO last thing is to update the positions ids, easy you can just offset the index to overwrite.
# TODO update the attention mask correctly
return final_embedding, attention_mask, position_ids

@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
Expand All @@ -315,7 +320,7 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
input_embeds: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
vision_feature_layer: Optional[int] = None,
vision_feature_select_strategy: Optional[str] = None,
labels: Optional[torch.LongTensor] = None,
Expand Down Expand Up @@ -364,9 +369,9 @@ def forward(
else self.config.vision_feature_select_strategy
)

if input_embeds is None:
if inputs_embeds is None:
# 1. Extra the input embeddings
input_embeds = self.get_input_embeddings()(input_ids)
inputs_embeds = self.get_input_embeddings()(input_ids)

# 2. Merge text and images
if pixel_values is not None and input_ids.shape[1] != 1:
Expand All @@ -384,15 +389,15 @@ def forward(
)

image_features = self.multi_modal_projector(selected_image_feature)
input_embeds, attention_mask, position_ids = self._merge_input_ids_with_image_features(image_features, input_embeds, input_ids, attention_mask, position_ids)
inputs_embeds, attention_mask, position_ids = self._merge_input_ids_with_image_features(image_features, inputs_embeds, input_ids, attention_mask, position_ids)
if labels is None:
labels = torch.full_like(input_ids, self.config.ignore_index)

outputs = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
input_embeds=input_embeds,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
Expand Down Expand Up @@ -429,29 +434,29 @@ def forward(
attentions=outputs.attentions,
)

def _optionally_pad_input_embeds(
self, input_embeds, attention_mask, position_ids, labels, max_seqlen, batch_size
def _optionally_pad_inputs_embeds(
self, inputs_embeds, attention_mask, position_ids, labels, max_seqlen, batch_size
):
r"""
Optionally pad the input embeddings by correctly setting the padding tokens
on the correct places inside the attention mask, position ids and labels.
"""
padded_input_embeds = []
padded_inputs_embeds = []
padded_labels = torch.full(
(batch_size, max_seqlen),
self.config.ignore_index,
dtype=labels[0].dtype,
device=labels[0].device,
)

for i, (current_embeds, cur_new_labels) in enumerate(zip(input_embeds, labels)):
for i, (current_embeds, cur_new_labels) in enumerate(zip(inputs_embeds, labels)):
# Get the current sequence length and padding side
# then optionally padd the input embeds
current_seq_len = current_embeds.shape[0]
padding_side = getattr(self.config, "tokenizer_padding_side", "right")
padded_embedding = pad_sequence(current_embeds, max_seqlen, padding_side)

padded_input_embeds.append(padded_embedding)
padded_inputs_embeds.append(padded_embedding)

if current_seq_len > 0:
start_index = -current_seq_len if padding_side == "left" else 0
Expand All @@ -463,11 +468,11 @@ def _optionally_pad_input_embeds(
0, current_seq_len, dtype=position_ids.dtype, device=position_ids.device
)

input_embeds = torch.stack(padded_input_embeds, dim=0)
return input_embeds, attention_mask, position_ids, padded_labels
inputs_embeds = torch.stack(padded_inputs_embeds, dim=0)
return inputs_embeds, attention_mask, position_ids, padded_labels

def prepare_inputs_for_generation(self, input_ids, past_key_values=None, input_embeds=None, pixel_values=None, **kwargs):
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, pixel_values=None, **kwargs):
# Call `prepare_inputs_for_generation` from the LM
model_input = self.language_model.prepare_inputs_for_generation(input_ids, past_key_values, input_embeds=input_embeds, **kwargs)
model_input = self.language_model.prepare_inputs_for_generation(input_ids, past_key_values, inputs_embeds=inputs_embeds, **kwargs)
model_input.update({"pixel_values": pixel_values})
return model_input

0 comments on commit a2e5b22

Please sign in to comment.