Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VLMs: major clean up 🧼 #34502

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1110,6 +1110,17 @@ def prepare_inputs_for_generation(
):
# Overwritten -- extra custom processing

if input_ids is not None:
img_token_not_enough = (input_ids == self.config.image_token_index).sum(
1
).max() < self.config.image_seq_length
video_token_not_enough = (input_ids == self.config.video_token_index).sum(
1
).max() < self.config.video_seq_length
legacy_processing = (img_token_not_enough and pixel_values is not None) or (
video_token_not_enough and pixel_values_videos is not None
)

model_inputs = self.language_model.prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
Expand All @@ -1122,7 +1133,7 @@ def prepare_inputs_for_generation(

# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model
if cache_position[0] == 0:
if legacy_processing or cache_position[0] == 0:
model_inputs["pixel_values"] = pixel_values
model_inputs["pixel_values_videos"] = pixel_values_videos
model_inputs["image_sizes"] = image_sizes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,17 @@ def prepare_inputs_for_generation(
):
# Overwritten -- extra custom processing

if input_ids is not None:
img_token_not_enough = (input_ids == self.config.image_token_index).sum(
1
).max() < self.config.image_seq_length
video_token_not_enough = (input_ids == self.config.video_token_index).sum(
1
).max() < self.config.video_seq_length
legacy_processing = (img_token_not_enough and pixel_values is not None) or (
video_token_not_enough and pixel_values_videos is not None
)

model_inputs = self.language_model.prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
Expand All @@ -635,7 +646,7 @@ def prepare_inputs_for_generation(

# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model
if cache_position[0] == 0:
if legacy_processing or cache_position[0] == 0:
model_inputs["pixel_values"] = pixel_values
model_inputs["pixel_values_videos"] = pixel_values_videos
model_inputs["image_sizes"] = image_sizes
Expand Down
13 changes: 12 additions & 1 deletion src/transformers/models/video_llava/modeling_video_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,17 @@ def prepare_inputs_for_generation(
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model

if input_ids is not None:
img_token_not_enough = (input_ids == self.config.image_token_index).sum(
1
).max() < self.config.image_seq_length
video_token_not_enough = (input_ids == self.config.video_token_index).sum(
1
).max() < self.config.video_seq_length
legacy_processing = (img_token_not_enough and pixel_values_images is not None) or (
video_token_not_enough and pixel_values_videos is not None
)

model_inputs = self.language_model.prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
Expand All @@ -730,7 +741,7 @@ def prepare_inputs_for_generation(
**kwargs,
)

if cache_position[0] == 0:
if legacy_processing or cache_position[0] == 0:
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model
model_inputs["pixel_values_images"] = pixel_values_images
Expand Down
Loading