Skip to content

Commit

Permalink
Put the legacy processing code back
Browse files Browse the repository at this point in the history
  • Loading branch information
Rocketknight1 committed Dec 23, 2024
1 parent 0d29bc3 commit 99ea497
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 1 deletion.
4 changes: 4 additions & 0 deletions src/transformers/models/llava/configuration_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class LlavaConfig(PretrainedConfig):
Can be one of `"default"` or `"full"`.
vision_feature_layer (`int`, *optional*, defaults to -2):
The index of the layer to select the vision feature.
image_seq_length (`int`, *optional*, defaults to 576):
Sequence length of one image embedding.
multimodal_projector_bias (`bool`, *optional*, defaults to `True`):
Whether to use bias in the multimodal projector.
Expand Down Expand Up @@ -84,12 +86,14 @@ def __init__(
projector_hidden_act="gelu",
vision_feature_select_strategy="default",
vision_feature_layer=-2,
image_seq_length=576,
multimodal_projector_bias=True,
**kwargs,
):
self.ignore_index = ignore_index
self.image_token_index = image_token_index
self.projector_hidden_act = projector_hidden_act
self.image_seq_length = image_seq_length

if vision_feature_select_strategy not in ["default", "full"]:
raise ValueError(
Expand Down
56 changes: 55 additions & 1 deletion src/transformers/models/llava/modeling_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,9 +464,17 @@ def forward(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)

legacy_processing = False
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)

# if the number of image tokens is more than image embeddings seq length, then prob we expanded it in processing
# not very reliable, but we don't expect one to actually pass 500+ images for one prompt
# In case we're in decoding stage, legacy behavior is checked by presence of pixel values even if use_cache=True
legacy_processing = (
(input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
) or (input_ids.shape[-1] == 1 and pixel_values is not None)

image_features = None
if pixel_values is not None:
image_features = self.get_image_features(
Expand All @@ -475,7 +483,53 @@ def forward(
vision_feature_select_strategy=vision_feature_select_strategy,
)

if image_features is not None:
if legacy_processing:
logger.warning_once(
"Expanding inputs for image tokens in LLaVa should be done in processing. "
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
"with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
"Using processors without these attributes in the config is deprecated and will throw an error in v4.50."
)
# prefill stage vs decoding stage (legacy behavior copied)
if input_ids.shape[1] != 1:
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
image_features, inputs_embeds, input_ids, attention_mask, labels
)
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
else:
# Retrieve the first layer to inspect the logits and mask out the hidden states
# that are set to 0
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]

# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)

# Get the target length
target_length = input_ids.shape[1]
past_length = first_layer_past_key_value.shape[-1]

extended_attention_mask = torch.ones(
(attention_mask.shape[0], past_length),
dtype=attention_mask.dtype,
device=attention_mask.device,
)

# Filter out only the tokens that can be un-attended, this can happen
# if one uses Llava + Fused modules where the cache on the
# first iteration is already big enough, or if one passes custom cache
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
new_batch_index = batch_index[valid_indices]
new_non_attended_tokens = non_attended_tokens[valid_indices]

# Zero-out the places where we don't need to attend
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0

attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:]

# TODO: @raushan retain only the new behavior after v4.47
elif image_features is not None:
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
n_image_features = image_features.shape[0] * image_features.shape[1]

Expand Down

0 comments on commit 99ea497

Please sign in to comment.