Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add video modality for InstrucBLIP #30182

Merged
merged 19 commits into from
Jun 25, 2024
Merged

Conversation

zucchini-nlp
Copy link
Member

What does this PR do?

I made these changes a month ago and forgot contributing. This PR adds video processing capabilities for InstructBLIP models. The paper states InstructBLIP was trained and evaluated on video, along with images and the original repo has some code on how video inference works.

Seems like this feature has some interest from the community (see here), so I believe we can add it.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.


model_input_names = ["pixel_values"]

def __init__(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can perhaps also be Copied from

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, let me try. Then I have to add "copy ignore" on the preprocess probably

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No that's only if you add Copied from to the class. In that case you can add "Ignore copy" above methods that you don't want to copy

Copy link
Contributor

@NielsRogge NielsRogge left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow that's awesome, thanks for working on that!

I'm just concerned about 2 things:

  • making sure that we have a robust API for multimodal processors that is consistent
  • the current InstructBLIP models on the hub all use BlipImageProcessor. This PR would introduce a new image processor, I guess we would then need to update the auto mapping to make sure AutoImageProcessor still works as expected.

@@ -57,6 +57,7 @@ def __init__(self, image_processor, tokenizer, qformer_tokenizer):
def __call__(
self,
images: ImageInput = None,
videos: ImageInput = None,
Copy link
Contributor

@NielsRogge NielsRogge Apr 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @molbap since we'd like to standardize multimodal processors, this one isn't making it easier 😅

at some point we will have a VideoTextToTextPipeline, and we'll need to make sure they all have the same API.

See also the ImageTextToTextPipeline which is worked on at #29572. Although technically it could work if we just expect the following to work for any video-text-to-text model:

inputs = processor(videos=..., text=..., return_tensors="pt")

generated_ids = model.generate(**inputs, max_new_tokens=20)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure! Re to normalizing processors, for models taking video inputs, vivit, videomae, tvlt have videos: ImageInput but tvp has videos: Union[ImageInput, List[ImageInput], List[List[ImageInput]]],. However x_clip reuses videomae's processor.

Overall ImageInput is defined as a type union of Images or list of images. Looks like in the future we might prefer supporting at least list of list of images so a VideoInput defined as such could make sense, or an union of types as done in x_clip.

@zucchini-nlp
Copy link
Member Author

@NielsRogge

  1. making sure that we have a robust API for multimodal processors that is consistent
    yeah, that needs to be reworked. Right now we have only Blip as first model that supports videos and there willl be VideoLlava.

Unfortunately video llava processing is going a different way to easily be able to interleave modalities. I guess that is part of what is being discussed internally in slack

  1. the current InstructBLIP models on the hub all use BlipImageProcessor. This PR would introduce a new image processor, I guess we would then need to update the [auto mapping (https://github.com/huggingface/transformers/blob/e516d1b19d035469b4852e34ba0356587e6f8ade/src/transformers/models/auto/image_processing_auto.py#L75) to make sure AutoImageProcessor still works as expected.

Yep, seems like now it works only if calling specifically "InstructBlipImageProcessor"

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on adding this capability!

Two general comments:

  • We have to be careful here with the mappings wrt backwards compatibility and expected behaviour. As a user, I should be able to do:
image_processor = InstructBlipImageProcessor()
images = image_processor(images, return_tensors="pt")

and get exactly the same output as I was getting before with the blip image processor

  • We should avoid adding lots of if-statements in existing modeling code and instead add a new model

@@ -1368,11 +1368,46 @@ def forward(
>>> generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()
>>> print(generated_text)
The unusual aspect of this image is that a man is ironing clothes on the back of a yellow SUV, which is parked in the middle of a busy city street. This is an unconventional approach to ironing clothes, as it requires the man to balance himself and his ironing equipment on top of the vehicle while navigating through traffic. Additionally, the presence of taxis and other vehicles in the scene further emphasizes the unusual nature of this situation.

# To generate from video input
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The addition of all the if statements here indicate this should really be split into a different model. For things like torch.compile we really want to avoid models being able to have inputs with a varying number of dimensions e.g. adding InstructBLIPForVideoQuestionAnswering instead.

Comment on lines 130 to 131
if images is not None or videos is not None:
image_encoding = self.image_processor(images=images, videos=videos, return_tensors=return_tensors)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is going to break things. Either:

  • We add the new image processor to auto map for existing checkpoints. This might lead to differences in output as the image processor used is different. AFAICT processing steps are the same but output shape isn't.
  • We load the old image processors, which will break or emit warnings with the videos input

size = get_size_dict(size, default_to_square=False)

if (images is None) ^ (videos is not None):
raise ValueError("InstructBLIP currently does not support both images and videos as input")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it ever going to support both? Otherwise this message can be misleading

# Ignore copy
def preprocess(
self,
images: ImageInput = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're adding an image processor for InstructBlip, and it can process images, then it should process images consistently with how they were processed for the previous image processor (blip).

Whilst the processing steps are the same i.e. the processed image is consistent, the output shape won't be because this will add an extra axis to the output images i.e. they become (batch_size, num_frames, num_channels, height, width). Instead, we should keep the same output shape for images, this will allow us to add this to the image_processing_auto.py mapping for instruct blip

@zucchini-nlp
Copy link
Member Author

@amyeroberts
I see, but I am not sure how to make a new model and keep it easy for users to load and for us to maintain. For example: we can make a separate processing for video modality, which will be further enhanced with new features. Call it InstructBlipVideoProcessing and another class for video-based generation, InstructBlipForConditionalVideoGeneration? For the case of blip, which will never work with interleaving two vision modalities at the same time it might work. But I also want to make sure there will be consistency in how video+image modality LLMs are handled in transformers.

What if:

  1. Image processor returns one of two possible tensors "pixel_values_image" (4-dim tensor) or "pixel_values_video" (5-dim tensor). This is the way I made it for VideoLlava, which actually can generate from interleaving both at the same time.
    1.1 Or maybe make Image and Video processor as separate classes, and call the appropriate while processing.
  2. Modeling stays the same, but we add one line which expands image dimensionality so that it is (batch_size, 1, num_channels, height, width) and keep the rest of code as if it were a 1-frame video passed. All the conditions then can be removed, if we treat vision inputs as video-like but 1-frame for image and 4-frame for video clip.

@zucchini-nlp
Copy link
Member Author

@amyeroberts ping

@huggingface huggingface deleted a comment from github-actions bot May 13, 2024
@molbap
Copy link
Contributor

molbap commented May 13, 2024

@zucchini-nlp one thing related that I'll merge this week, regarding processors: #30511 I added a VideosKwargsbut no VideosInput yet. Personal opinion from having spent too much time around Processors, I think a separate model would be actually easier to maintain rather than patching a previous one, because what the model does to which modality is easier to understand rather than mixed modalities. I don't have the final say on this, so just a comment :)

@amyeroberts
Copy link
Collaborator

@zucchini-nlp Sorry for the late reply here.

Modeling stays the same, but we add one line which expands image dimensionality so that it is (batch_size, 1, num_channels, height, width) and keep the rest of code as if it were a 1-frame video passed. All the conditions then can be removed, if we treat vision inputs as video-like but 1-frame for image and 4-frame for video clip.

This is a neat solution. My main question would be, what does this do to the shapes of the returned tensors e.g. hidden_states? If they remain the same, then this is a nice easy way to enable this within the modeling file.

If they're not the same, then we'll need to add a new class in the modeling file e.g. InstructBlipForVideoConditionalGeneration or perhaps, have a new modeling file src/transformers/models/instructblip/modeling_instructblip_video.py. The latter would add a new model type e.f. VideoInstructBlip / InstructBlipVideo which would enable the correct auto mapping in image_processing_auto.py. It's not unheard of for us to add new models which can load existing checkpoints under a new architecture name.

In terms of the image processor:

  • If we can use the same class and just expand within the forward pass, then the trick is to correctly batch the inputs such that for a batch of images, the shape of the output pixel values remains (batch_size, num_channels, height, width) and for videos it's (batch_size, num_frames, num_channels, height, width). Both should be output as pixel_values. That is, the output of the image processor for images should remain unchanged from the previous behaviour: same shape, and same key names in the BatchFeature output
  • If we can't use the same class, but add a single module in modeling_instructblip you can do as above
  • If we can't use the same class, but add a new modeling file, then we can add a new separate image processor.

My vote would be for a new model. Having to handle video and images within the same image processor is a good indication they should be separated

@zucchini-nlp
Copy link
Member Author

Thanks for detailed explanations! At this point it should be possible to go with the first option, I will have to go back and check. If not, making a separate model sounds good since BLIP will never work with both modalities at the same time.

@molbap i see, probably having separate VideoProcessor files can be a better solutions for mutli-modal models, instead of packing it all in ImageProcessor

@zucchini-nlp
Copy link
Member Author

I went for the "let's keep one processing and modeling files" way. Currently the following changes are applied:

  1. Processor can accept as arg either images or videos
  2. Image processor returns pixel values of shape (b, c, h, w) for images by squeezing the extra dim and adds and extra frame dimension for videos.
  3. Modeling file unsqueezed back the frame dimension for images, and continues running as if the inputs are all videos. Finally the frame dimension is merged back to embeddings' sequence length.

Slow tests are passing locally + added a few tests for the video generation.

@zucchini-nlp zucchini-nlp requested a review from amyeroberts May 20, 2024 07:49
@amyeroberts
Copy link
Collaborator

Thanks for the continued work on this and apologies for the delay in my review. Skimming over the PR, and thinking about this more I think we should split this up such that we have a video image processor and a video model. We want to avoid conditional outputs from our base processing objects as well as conditional logic within our model's forward passes. As we won't interleave images and videos i.e. we don't expect a user to both be using video and images at the same time then we don't need to be able to handle these all at once

@zucchini-nlp
Copy link
Member Author

@amyeroberts oke, I made Video InstructBlip its own model with its separate modeling and image processing files. Added some tests and indicated in the model-doc that it's the same model as InstructBlip except for the video processing capability. Ready for review

@ArthurZucker I made more changes into diff converter in this PR, as it didn't work in some cases. Specifically:

  • Some models apply CamelCase without an underscore that splits subwords, like InstructBlip, Diff converter cannot infer correct model name in this case, so I added a possibility to indicate model-names by passing it as an arg for converter
  • In case if we have a config that inherits from another config in the library, the auto-generated config doesn't get all the imports from its parent class causing errors. I added a special visited_module_config to store config-specific classes. I could use the same visited_module that is already defined but it will result in configs being imported into the config file
  • Sometimes configs need globally defined vars, e.g. in case of the logger. So there's a new line that ports "SimpleStatements" to configs. If they are not used, everything will be cleaned up by ruff anyway
  • Properties with their setters weren't being ported because they have identical naming, so the fix is to iterate "node.body" directly w/o saving it in a dict

Plus these are the changes in LLaVa-NeXT-Video PR, just duplicating to have all written down somewhere :)

  • Sometimes we want to add new methods in a class, and still retain all methods from the parent. This case wasn't covered, I added a few lines to fix it
  • Inferring model-name from file-name didn't consider long model names with an underscore, fixed by modifying the regex

Still have to fix cases when only "init" docstring is changed, I couldn't make it work yet

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great - thanks for splitting this up and adding this model!

Just a few comment and some questions about the changes to the diff_converter. Main comment is about making sure generate is properly tested

src/transformers/image_utils.py Outdated Show resolved Hide resolved
class InstructBlipVideoForConditionalGenerationDecoderOnlyTest(
ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
):
all_model_classes = (InstructBlipVideoForConditionalGeneration,) if is_torch_available() else ()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't we need to define all_generative_models here too to properly test the generation mixin?

Copy link
Member Author

@zucchini-nlp zucchini-nlp Jun 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GenerationMixin still can't work for VLMs, and I am planning to properly add it after some unification of VLM processors. Otherwise we'll have so many conditional checks inside testing

AFAIK all VLMs currently are tested in IntegrationTests for that

# fmt: off
self.python_module = python_module # we store the original module to use `code_for_node`
self.transformers_imports = {} # maps the imports name like "from transformers.models.xxx" to the parsed AST module
self.imported_mapping = {} # stores the name of the imported classes, with their source {"LlamaModel":"transformers.model.llama.modeling_llama"}
self.visited_module = {} # modules visited like "transformers.models.llama.modeling_llama"
self.visited_module_config = {} # modules visited like "transformers.models.llama.modeling_llama" in config file, needed to not mix config vs modeling imports
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment isn't completely clear to me as transformers.models.llama.modeling_llama" is a modeling import and not a config import and the instructblip configuration file doesn't import transformers.models.llama.modeling_llama

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, maybe needs another example. This was needed for me because the auto-generated config wasn't getting imports from its "super config". In other words, it was only copying imports from diff files and removing unused ones by ruff.

One solution may be to indicate all imports in diff, but if they aren't used ruff removes them eventually. In my case PreTrainedConfig wasn't being imported for ex

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep this is also something I noticed, mixing import was not super well done

@@ -457,13 +463,18 @@ def leave_ClassDef(self, original_node, updated_node):
f"Tried parsing the name of the imported package from {super_file_name}, could not extract the model name"
)

if super_file_name not in self.visited_module: # only extract classes once
visited_module = self.visited_module_config if "Config" in class_name else self.visited_module
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How well does this extend to other files we might have under the model folder e.g. do we need flags for visited_module_processor etc?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if processors work for diff. I tried actually to add image-processor to diff but it messed up, so I believe it doesn't yet support that

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not yet supported but planned for sure.

Comment on lines +568 to +577
parser.add_argument(
"--old_model_name",
required=False,
help="The name of the model from which the copying is done in CamelCase. If not provided is inferred from diff-file",
)
parser.add_argument(
"--new_model_name",
required=False,
help="The name of the new model being added in CamelCase. If not provided is inferred from diff-file",
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this for models with composite config files e.g. CLIPVisionConfig which we don't want to infer as CLIPVision?

Copy link
Member Author

@zucchini-nlp zucchini-nlp Jun 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for models that do camel case without underscores, like InstructBlip (instructblip) vs LlavaNext (llava_next). In second case we can infer where to make a capital letter, while in former it's impossible so I decided to give users freedom passing model names

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that is a great addition indeed! WOuld even add this comment in the help 😉

@@ -474,7 +485,7 @@ def leave_ClassDef(self, original_node, updated_node):
start_insert_idx = self.global_scope_index
for dependency, _ in list_dependencies:
node = class_finder.global_nodes.get(dependency, None)
if node is not None:
if node is not None and "Config" not in class_name:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same q here - do we need to account for all other classes e.g. "Processor" not in class_name?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is added because diff importing configs to the configuration files, even though they are defined as a class a few lines below

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would be good to get a second review of the changes here in this file from @ArthurZucker

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overall good to me!
Separating config imports is the way to go, and further separating process import later on will be needed as well

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

diff converter changes look nice IMO, but we should not need to import all the classes. New diff converter is able to parse dependencies so tell me if this is not the case!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am late to the party here, bug normally you should not have to import all these classes. The diff converter will automatically detect dependencies and copy classes that are required. !
Unless this is an edge case?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We discussed this on Slack and decided we shouldn't have separate imports for each file, and let ruff clean-out unnecessary ones. So I'm manually filtering the issue with configs and adding all imports. That worked for InstructBlip

# fmt: off
self.python_module = python_module # we store the original module to use `code_for_node`
self.transformers_imports = {} # maps the imports name like "from transformers.models.xxx" to the parsed AST module
self.imported_mapping = {} # stores the name of the imported classes, with their source {"LlamaModel":"transformers.model.llama.modeling_llama"}
self.visited_module = {} # modules visited like "transformers.models.llama.modeling_llama"
self.visited_module_config = {} # modules visited like "transformers.models.llama.modeling_llama" in config file, needed to not mix config vs modeling imports
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep this is also something I noticed, mixing import was not super well done

@@ -457,13 +463,18 @@ def leave_ClassDef(self, original_node, updated_node):
f"Tried parsing the name of the imported package from {super_file_name}, could not extract the model name"
)

if super_file_name not in self.visited_module: # only extract classes once
visited_module = self.visited_module_config if "Config" in class_name else self.visited_module
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not yet supported but planned for sure.

Comment on lines +568 to +577
parser.add_argument(
"--old_model_name",
required=False,
help="The name of the model from which the copying is done in CamelCase. If not provided is inferred from diff-file",
)
parser.add_argument(
"--new_model_name",
required=False,
help="The name of the new model being added in CamelCase. If not provided is inferred from diff-file",
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that is a great addition indeed! WOuld even add this comment in the help 😉

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overall good to me!
Separating config imports is the way to go, and further separating process import later on will be needed as well

@zucchini-nlp
Copy link
Member Author

@amyeroberts this one is ready for the final review I guess

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another great addition - thanks!

As with Llava Next Video, we just need a final run of the slow tests and we're good to merge ❤️

@zucchini-nlp
Copy link
Member Author

Got the CI green, including slow tests. Will merge the PR

@zucchini-nlp zucchini-nlp merged commit fc689d7 into huggingface:main Jun 25, 2024
26 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants