From f33a0cebb37454a25af3d0be44832ea53c39733d Mon Sep 17 00:00:00 2001 From: Tony Wu <28306721+tonywu71@users.noreply.github.com> Date: Tue, 17 Dec 2024 11:26:43 +0100 Subject: [PATCH] =?UTF-8?q?Add=20ColPali=20to=20=F0=9F=A4=97=20transformer?= =?UTF-8?q?s=20(#33736)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: run `add-new-model-like` * feat: add paligemma code with "copied from" * feat: add ColPaliProcessor * feat: add ColPaliModel * feat: add ColPaliConfig * feat: rename `ColPaliForConditionalGeneration` to `ColPaliModel` * fixup modeling colpali * fix: fix root import shortcuts * fix: fix `modeling_auto` dict * feat: comment out ColPali test file * fix: fix typos from `add-new-model-like` * feat: explicit the forward input args * feat: move everything to `modular_colpali.py` * fix: put back ColPaliProcesor * feat: add auto-generated files * fix: run `fix-copies` * fix: remove DOCStRING constants to make modular converter work * fix: fix typo + modular converter * fix: add missing imports * feat: no more errors when loading ColPaliModel * fix: remove unused args in forward + tweak doc * feat: rename `ColPaliModel` to `ColPaliForRetrieval` * fix: apply `fix-copies` * feat: add ColPaliProcessor to `modular_colpali` * fix: run make quality + make style * fix: remove duplicate line in configuration_auto * feat: make ColPaliModel inehrit from PaliGemmaForConditionalGeneration * fix: tweak and use ColPaliConfig * feat: rename `score` to `post_process_retrieval` * build: run modular formatter + make style * feat: convert colpali weights + fixes * feat: remove old weight converter file * feat: add and validate tests * feat: replace harcoded path to "vidore/colpali-v1.2-hf" in tests * fix: add bfloat16 conversion in weight converter * feat: replace pytest with unittest in modeling colpali test * feat: add sanity check for weight conversion (doesn't work yet) * feat: add shape sanity check in weigth converter * feat: make ColPaliProcessor args explicit * doc: add doc for ColPali * fix: trying to fix output mismatch * feat: tweaks * fix: ColPaliModelOutput inherits from ModelOutput instead of PaliGemmaCausalLMOutputWithPast * fix: address comments on PR * fix: adapt tests to the Hf norm * wip: try things * feat: add `__call__` method to `ColPaliProcessor` * feat: remove need for dummy image in `process_queries` * build: run new modular converter * fix: fix incorrect method override * Fix tests, processing, modular, convert * fix tokenization auto * hotfix: manually fix processor -> fixme once convert modular is fixed * fix: convert weights working * feat: rename and improve convert weight script * feat: tweaks * fest: remove `device` input for `post_process_retrieval` * refactor: remove unused `get_torch_device` * Fix all tests * docs: update ColPali model doc * wip: fix convert weights to hf * fix logging modular * docs: add acknowledgements in model doc * docs: add missing docstring to ColPaliProcessor * docs: tweak * docs: add doc for `ColPaliForRetrievalOutput.forward` * feat: add modifications from colpali-engine v0.3.2 in ColPaliProcessor * fix: fix and upload colapli hf weights * refactor: rename `post_process_retrieval` to `score_retrieval` * fix: fix wrong typing for `score_retrieval` * test: add integration test for ColPali * chore: rerun convert modular * build: fix root imports * Update docs/source/en/index.md Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com> * fix: address PR comments * wip: reduce the prediction gap in weight conversion * docs: add comment in weight conversion script * docs: add example for `ColPaliForRetrieval.forward` * tests: change dataset path to the new one in hf-internal * fix: colpali weight conversion works * test: add fine-grained check for ColPali integration test * fix: fix typos in convert weight script * docs: move input docstring in a variable * fix: remove hardcoded torch device in test * fix: run the new modular refactor * docs: fix python example for ColPali * feat: add option to choose `score_retrieval`'s output dtype and device * docs: update doc for `score_retrieval` * feat: add `patch_size` property in ColPali model * chore: run `make fix-copies` * docs: update description for ColPali cookbooks * fix: remove `ignore_index` methods * feat: remove non-transformers specific methods * feat: update `__init__.py` to new hf format * fix: fix root imports in transformers * feat: remove ColPali's inheritance from PaliGemma * Fix CI issues * nit remove prints * feat: remove ColPali config and model from `modular_colpali.py` * feat: add `ColPaliPreTrainedModel` and update modeling and configuration code * fix: fix auto-removed imports in root `__init__.py` * fix: various fixes * fix: fix `_init_weight` * temp: comment `AutoModel.from_config` for experiments * fix: add missing `output_attentions` arg in ColPali's forward * fix: fix `resize_token_embeddings` * fix: make `input_ids` optional in forward * feat: rename `projection_layer` to `embedding_proj_layer` * wip: fix convert colpali weight script * fix tests and convert weights from original repo * fix unprotected import * fix unprotected torch import * fix style * change vlm_backbone_config to vlm_config * fix unprotected import in modular this time * fix: load config from Hub + tweaks in convert weight script * docs: move example usage from model docstring to model markdown * docs: fix input docstring for ColPali's forward method * fix: use `sub_configs` for ColPaliConfig * fix: remove non-needed sanity checks in weight conversion script + tweaks * fix: fix issue with `replace_return_docstrings` in ColPali's `forward` * docs: update docstring for `ColPaliConfig` * test: change model path in ColPali test * fix: fix ColPaliConfig * fix: fix weight conversion script * test: fix expected weights for ColPali model * docs: update ColPali markdown * docs: fix minor typo in ColPaliProcessor * Fix tests and add _no_split_modules * add text_config to colpali config * [run slow] colpali * move inputs to torch_device in integration test * skip test_model_parallelism * docs: clarify quickstart snippet in ColPali's model card * docs: update ColPali's model card --------- Co-authored-by: yonigozlan Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com> --- docs/source/en/_toctree.yml | 2 + docs/source/en/index.md | 1 + docs/source/en/model_doc/colpali.md | 95 ++++ src/transformers/__init__.py | 20 + src/transformers/models/__init__.py | 1 + src/transformers/models/auto/__init__.py | 2 + .../models/auto/configuration_auto.py | 2 + src/transformers/models/auto/modeling_auto.py | 8 + .../models/auto/processing_auto.py | 1 + .../models/auto/tokenization_auto.py | 1 + src/transformers/models/colpali/__init__.py | 28 ++ .../models/colpali/configuration_colpali.py | 106 +++++ .../colpali/convert_colpali_weights_to_hf.py | 207 ++++++++ .../models/colpali/modeling_colpali.py | 299 ++++++++++++ .../models/colpali/modular_colpali.py | 354 ++++++++++++++ .../models/colpali/processing_colpali.py | 443 ++++++++++++++++++ src/transformers/utils/dummy_pt_objects.py | 17 + tests/models/colpali/__init__.py | 0 tests/models/colpali/test_modeling_colpali.py | 368 +++++++++++++++ .../models/colpali/test_processing_colpali.py | 247 ++++++++++ utils/check_table.py | 2 +- utils/update_metadata.py | 2 +- 22 files changed, 2204 insertions(+), 2 deletions(-) create mode 100644 docs/source/en/model_doc/colpali.md create mode 100644 src/transformers/models/colpali/__init__.py create mode 100644 src/transformers/models/colpali/configuration_colpali.py create mode 100644 src/transformers/models/colpali/convert_colpali_weights_to_hf.py create mode 100644 src/transformers/models/colpali/modeling_colpali.py create mode 100644 src/transformers/models/colpali/modular_colpali.py create mode 100644 src/transformers/models/colpali/processing_colpali.py create mode 100644 tests/models/colpali/__init__.py create mode 100644 tests/models/colpali/test_modeling_colpali.py create mode 100644 tests/models/colpali/test_processing_colpali.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index c4707d5f20a027..d87906159ce34f 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -834,6 +834,8 @@ title: CLIPSeg - local: model_doc/clvp title: CLVP + - local: model_doc/colpali + title: ColPali - local: model_doc/data2vec title: Data2Vec - local: model_doc/deplot diff --git a/docs/source/en/index.md b/docs/source/en/index.md index 49c44874e320ef..a40bb825463495 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -100,6 +100,7 @@ Flax), PyTorch, and/or TensorFlow. | [CodeLlama](model_doc/code_llama) | ✅ | ❌ | ✅ | | [Cohere](model_doc/cohere) | ✅ | ❌ | ❌ | | [Cohere2](model_doc/cohere2) | ✅ | ❌ | ❌ | +| [ColPali](model_doc/colpali) | ✅ | ❌ | ❌ | | [Conditional DETR](model_doc/conditional_detr) | ✅ | ❌ | ❌ | | [ConvBERT](model_doc/convbert) | ✅ | ✅ | ❌ | | [ConvNeXT](model_doc/convnext) | ✅ | ✅ | ❌ | diff --git a/docs/source/en/model_doc/colpali.md b/docs/source/en/model_doc/colpali.md new file mode 100644 index 00000000000000..d47f0aa072262c --- /dev/null +++ b/docs/source/en/model_doc/colpali.md @@ -0,0 +1,95 @@ + + +# ColPali + +## Overview + +The ColPali model was proposed in [ColPali: Efficient Document Retrieval with Vision Language Models](https://doi.org/10.48550/arXiv.2407.01449) by **Manuel Faysse***, **Hugues Sibille***, **Tony Wu***, Bilel Omrani, Gautier Viaud, Céline Hudelot, Pierre Colombo (* denotes equal contribution). + +With our new model *ColPali*, we propose to leverage VLMs to construct efficient multi-vector embeddings in the visual space for document retrieval. By feeding the ViT output patches from PaliGemma-3B to a linear projection, we create a multi-vector representation of documents. We train the model to maximize the similarity between these document embeddings and the query embeddings, following the ColBERT method. + +Using ColPali removes the need for potentially complex and brittle layout recognition and OCR pipelines with a single model that can take into account both the textual and visual content (layout, charts, ...) of a document. ColPali is also highly interpretable: similarity maps can be obtained between patches and query tokens. These maps highlight ColPali’s strong OCR capabilities and chart understanding. + +**Paper abstract:** + +> Documents are visually rich structures that convey information through text, but also figures, page layouts, tables, or even fonts. Since modern retrieval systems mainly rely on the textual information they extract from document pages to index documents -often through lengthy and brittle processes-, they struggle to exploit key visual cues efficiently. This limits their capabilities in many practical document retrieval applications such as Retrieval Augmented Generation (RAG). To benchmark current systems on visually rich document retrieval, we introduce the Visual Document Retrieval Benchmark *ViDoRe*, composed of various page-level retrieval tasks spanning multiple domains, languages, and practical settings. The inherent complexity and performance shortcomings of modern systems motivate a new concept; doing document retrieval by directly embedding the images of the document pages. We release *ColPali*, a Vision Language Model trained to produce high-quality multi-vector embeddings from images of document pages. Combined with a late interaction matching mechanism, *ColPali* largely outperforms modern document retrieval pipelines while being drastically simpler, faster and end-to-end trainable. +> +> We release models, data, code and benchmarks under open licenses at [https://huggingface.co/vidore](https://huggingface.co/vidore). + +## Resources + +- The official blog post detailing ColPali can be found [here](https://huggingface.co/blog/manu/colpali). 📝 +- The original model implementation code for the ColPali model and for the `colpali-engine` package can be found [here](https://github.com/illuin-tech/colpali). 🌎 +- Cookbooks for learning to use the transformers-native version of ColPali, fine-tuning, and similarity maps generation can be found [here](https://github.com/tonywu71/colpali-cookbooks). 📚 + +This model was contributed by [@tonywu71](https://huggingface.co/tonywu71) and [@yonigozlan](https://huggingface.co/yonigozlan). + +## Usage + +This example demonstrates how to use ColPali to embed both queries and images, calculate their similarity scores, and identify the most relevant matches. For a specific query, you can retrieve the top-k most similar images by selecting the ones with the highest similarity scores. + +```python +import torch +from PIL import Image + +from transformers import ColPaliForRetrieval, ColPaliProcessor + +model_name = "vidore/colpali-v1.2-hf" + +model = ColPaliForRetrieval.from_pretrained( + model_name, + torch_dtype=torch.bfloat16, + device_map="cuda:0", # or "mps" if on Apple Silicon +).eval() + +processor = ColPaliProcessor.from_pretrained(model_name) + +# Your inputs (replace dummy images with screenshots of your documents) +images = [ + Image.new("RGB", (32, 32), color="white"), + Image.new("RGB", (16, 16), color="black"), +] +queries = [ + "What is the organizational structure for our R&D department?", + "Can you provide a breakdown of last year’s financial performance?", +] + +# Process the inputs +batch_images = processor(images=images).to(model.device) +batch_queries = processor(text=queries).to(model.device) + +# Forward pass +with torch.no_grad(): + image_embeddings = model(**batch_images) + query_embeddings = model(**batch_queries) + +# Score the queries against the images +scores = processor.score_retrieval(query_embeddings, image_embeddings) +``` + +## ColPaliConfig + +[[autodoc]] ColPaliConfig + +## ColPaliProcessor + +[[autodoc]] ColPaliProcessor + +## ColPaliForRetrieval + +[[autodoc]] ColPaliForRetrieval + - forward diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 1eb34b48fda856..920dc334dbb2a4 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -306,6 +306,10 @@ ], "models.cohere": ["CohereConfig"], "models.cohere2": ["Cohere2Config"], + "models.colpali": [ + "ColPaliConfig", + "ColPaliProcessor", + ], "models.conditional_detr": ["ConditionalDetrConfig"], "models.convbert": [ "ConvBertConfig", @@ -1468,6 +1472,7 @@ "MODEL_FOR_OBJECT_DETECTION_MAPPING", "MODEL_FOR_PRETRAINING_MAPPING", "MODEL_FOR_QUESTION_ANSWERING_MAPPING", + "MODEL_FOR_RETRIEVAL_MAPPING", "MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING", "MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING", "MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING", @@ -1789,6 +1794,12 @@ ) _import_structure["models.cohere"].extend(["CohereForCausalLM", "CohereModel", "CoherePreTrainedModel"]) _import_structure["models.cohere2"].extend(["Cohere2ForCausalLM", "Cohere2Model", "Cohere2PreTrainedModel"]) + _import_structure["models.colpali"].extend( + [ + "ColPaliForRetrieval", + "ColPaliPreTrainedModel", + ] + ) _import_structure["models.conditional_detr"].extend( [ "ConditionalDetrForObjectDetection", @@ -5207,6 +5218,10 @@ ) from .models.cohere import CohereConfig from .models.cohere2 import Cohere2Config + from .models.colpali import ( + ColPaliConfig, + ColPaliProcessor, + ) from .models.conditional_detr import ( ConditionalDetrConfig, ) @@ -6413,6 +6428,7 @@ MODEL_FOR_OBJECT_DETECTION_MAPPING, MODEL_FOR_PRETRAINING_MAPPING, MODEL_FOR_QUESTION_ANSWERING_MAPPING, + MODEL_FOR_RETRIEVAL_MAPPING, MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, @@ -6689,6 +6705,10 @@ Cohere2Model, Cohere2PreTrainedModel, ) + from .models.colpali import ( + ColPaliForRetrieval, + ColPaliPreTrainedModel, + ) from .models.conditional_detr import ( ConditionalDetrForObjectDetection, ConditionalDetrForSegmentation, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 2e3b48da96e966..5eb74fab5abe71 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -53,6 +53,7 @@ codegen, cohere, cohere2, + colpali, conditional_detr, convbert, convnext, diff --git a/src/transformers/models/auto/__init__.py b/src/transformers/models/auto/__init__.py index 2ee0541a1a71b8..1f626d8c24f42a 100644 --- a/src/transformers/models/auto/__init__.py +++ b/src/transformers/models/auto/__init__.py @@ -74,6 +74,7 @@ "MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING", "MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING", "MODEL_FOR_VISION_2_SEQ_MAPPING", + "MODEL_FOR_RETRIEVAL_MAPPING", "MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING", "MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING", "MODEL_MAPPING", @@ -252,6 +253,7 @@ MODEL_FOR_OBJECT_DETECTION_MAPPING, MODEL_FOR_PRETRAINING_MAPPING, MODEL_FOR_QUESTION_ANSWERING_MAPPING, + MODEL_FOR_RETRIEVAL_MAPPING, MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 1d9db837e8d27c..1fb7464f41116a 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -70,6 +70,7 @@ ("codegen", "CodeGenConfig"), ("cohere", "CohereConfig"), ("cohere2", "Cohere2Config"), + ("colpali", "ColPaliConfig"), ("conditional_detr", "ConditionalDetrConfig"), ("convbert", "ConvBertConfig"), ("convnext", "ConvNextConfig"), @@ -373,6 +374,7 @@ ("codegen", "CodeGen"), ("cohere", "Cohere"), ("cohere2", "Cohere2"), + ("colpali", "ColPali"), ("conditional_detr", "Conditional DETR"), ("convbert", "ConvBERT"), ("convnext", "ConvNeXT"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index bec72a4e7b84ec..5d41ad42beea7e 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -306,6 +306,7 @@ ("big_bird", "BigBirdForPreTraining"), ("bloom", "BloomForCausalLM"), ("camembert", "CamembertForMaskedLM"), + ("colpali", "ColPaliForRetrieval"), ("ctrl", "CTRLLMHeadModel"), ("data2vec-text", "Data2VecTextForMaskedLM"), ("deberta", "DebertaForMaskedLM"), @@ -775,6 +776,12 @@ ] ) +MODEL_FOR_RETRIEVAL_MAPPING_NAMES = OrderedDict( + [ + ("colpali", "ColPaliForRetrieval"), + ] +) + MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = OrderedDict( [ ("aria", "AriaForConditionalGeneration"), @@ -1473,6 +1480,7 @@ MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING = _LazyAutoMapping( CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES ) +MODEL_FOR_RETRIEVAL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_RETRIEVAL_MAPPING_NAMES) MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping( CONFIG_MAPPING_NAMES, MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES ) diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index 3e475b1be211fa..815e2ca755bee3 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -58,6 +58,7 @@ ("clip", "CLIPProcessor"), ("clipseg", "CLIPSegProcessor"), ("clvp", "ClvpProcessor"), + ("colpali", "ColPaliProcessor"), ("flava", "FlavaProcessor"), ("fuyu", "FuyuProcessor"), ("git", "GitProcessor"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 386ca11abedcf4..1cdebde8cd904f 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -148,6 +148,7 @@ ("codegen", ("CodeGenTokenizer", "CodeGenTokenizerFast" if is_tokenizers_available() else None)), ("cohere", (None, "CohereTokenizerFast" if is_tokenizers_available() else None)), ("cohere2", (None, "CohereTokenizerFast" if is_tokenizers_available() else None)), + ("colpali", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)), ("convbert", ("ConvBertTokenizer", "ConvBertTokenizerFast" if is_tokenizers_available() else None)), ( "cpm", diff --git a/src/transformers/models/colpali/__init__.py b/src/transformers/models/colpali/__init__.py new file mode 100644 index 00000000000000..fa1b63fd009803 --- /dev/null +++ b/src/transformers/models/colpali/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_colpali import * + from .modeling_colpali import * + from .processing_colpali import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/colpali/configuration_colpali.py b/src/transformers/models/colpali/configuration_colpali.py new file mode 100644 index 00000000000000..045462adca4e2c --- /dev/null +++ b/src/transformers/models/colpali/configuration_colpali.py @@ -0,0 +1,106 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""ColPali model configuration""" + +import logging +from copy import deepcopy + +from ...configuration_utils import PretrainedConfig +from ..auto import CONFIG_MAPPING, AutoConfig + + +logger = logging.getLogger(__name__) + + +class ColPaliConfig(PretrainedConfig): + r""" + Configuration class to store the configuration of a [`ColPaliForRetrieval`]. It is used to instantiate an instance + of `ColPaliForRetrieval` according to the specified arguments, defining the model architecture following the methodology + from the "ColPali: Efficient Document Retrieval with Vision Language Models" paper. + + Creating a configuration with the default settings will result in a configuration where the VLM backbone is set to the + default PaliGemma configuration, i.e the one from [vidore/colpali-v1.2](https://huggingface.co/vidore/colpali-v1.2). + + The ColPali config is very similar to [`PaligemmaConfig`], but with an extra attribute defining the embedding dimension. + + Note that contrarily to what the class name suggests (actually the name refers to the ColPali **methodology**), you can + use a different VLM backbone model than PaliGemma by passing the corresponding VLM configuration to the class constructor. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vlm_config (`PretrainedConfig`, *optional*): + Configuration of the VLM backbone model. + text_config (`PretrainedConfig`, *optional*): + Configuration of the text backbone model. Overrides the `text_config` attribute of the `vlm_config` if provided. + embedding_dim (`int`, *optional*, defaults to 128): + Dimension of the multi-vector embeddings produced by the model. + + Example: + + ```python + from transformers.models.colpali import ColPaliConfig, ColPaliForRetrieval + + config = ColPaliConfig() + model = ColPaliForRetrieval(config) + ``` + """ + + model_type = "colpali" + sub_configs = {"vlm_config": PretrainedConfig, "text_config": AutoConfig} + + def __init__( + self, + vlm_config=None, + text_config=None, + embedding_dim: int = 128, + **kwargs, + ): + if vlm_config is None: + vlm_config = CONFIG_MAPPING["paligemma"]() + logger.info( + "`vlm_config` is `None`. Initializing `vlm_config` with the `PaliGemmaConfig` with default values." + ) + elif isinstance(vlm_config, dict): + vlm_config = deepcopy(vlm_config) + if "model_type" not in vlm_config: + raise KeyError( + "The `model_type` key is missing in the `vlm_config` dictionary. Please provide the model type." + ) + elif vlm_config["model_type"] not in CONFIG_MAPPING: + raise ValueError( + f"The model type `{vlm_config['model_type']}` is not supported. Please provide a valid model type." + ) + vlm_config = CONFIG_MAPPING[vlm_config["model_type"]](**vlm_config) + elif isinstance(vlm_config, PretrainedConfig): + vlm_config = vlm_config + else: + raise TypeError( + f"Invalid type for `vlm_config`. Expected `PretrainedConfig`, `dict`, or `None`, but got {type(vlm_config)}." + ) + + self.vlm_config = vlm_config + self.text_config = text_config = text_config if text_config is not None else vlm_config.text_config + if isinstance(self.text_config, dict): + text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "gemma" + self.text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) + + self.embedding_dim = embedding_dim + + super().__init__(**kwargs) + + +__all__ = ["ColPaliConfig"] diff --git a/src/transformers/models/colpali/convert_colpali_weights_to_hf.py b/src/transformers/models/colpali/convert_colpali_weights_to_hf.py new file mode 100644 index 00000000000000..595974e0da1c3f --- /dev/null +++ b/src/transformers/models/colpali/convert_colpali_weights_to_hf.py @@ -0,0 +1,207 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Convert ColPali weights from the original repository to the HF model format. + +Original repository: https://github.com/illuin-tech/colpali. + +NOTE: This script was originally run using `torch==2.5.1` and with: + +```bash +python src/transformers/models/colpali/convert_colpali_weights_to_hf.py \ + --model_id vidore/colpali-v1.2-merged \ + --revision 89fd9736194236a1ecb7a9ec9b04f537f6f896af \ + --original_vlm_name_or_path google/paligemma-3b-mix-448 \ + --output_dir vidore/colpali-v1.2-hf-internal \ + --push_to_hub +``` +""" + +import argparse +import glob +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from huggingface_hub import snapshot_download +from safetensors import safe_open + +from transformers import AutoConfig +from transformers.models.colpali import ColPaliForRetrieval +from transformers.models.colpali.configuration_colpali import ColPaliConfig +from transformers.utils import logging + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +ORIGINAL_DTYPE = torch.bfloat16 + + +def rename_state_dict_keys(state_dict: Dict[str, Any]) -> Dict[str, Any]: + new_state_dict = {} + for key, value in state_dict.items(): + new_key = key + if key.startswith("custom_text_proj"): + new_key = key.replace("custom_text_proj", "embedding_proj_layer") + if key.startswith("model."): + new_key = key.replace("model.", "vlm.", 1) + new_state_dict[new_key] = value + return new_state_dict + + +def load_original_state_dict(model_id: str, revision: Optional[str] = None) -> Dict[str, torch.Tensor]: + directory_path = snapshot_download( + repo_id=model_id, + revision=revision, + allow_patterns=["*.safetensors"], + ) + + original_state_dict = {} + for path in glob.glob(f"{directory_path}/*"): + if path.endswith(".safetensors"): + with safe_open(path, framework="pt", device="cpu") as f: + for key in f.keys(): + original_state_dict[key] = f.get_tensor(key) + + # Some weights are tied, so `lm.head`` is not saved. Let's clone to load state dict. + if "lm_head.weight" not in original_state_dict: + original_state_dict["vlm.language_model.lm_head.weight"] = original_state_dict[ + "model.language_model.model.embed_tokens.weight" + ].clone() + + return original_state_dict + + +@torch.no_grad() +def convert_colpali_weights_to_hf( + model_id: str, + output_dir: str, + push_to_hub: bool, + revision: Optional[str] = None, + original_vlm_name_or_path: Optional[str] = None, +): + # Load the original model data + original_config = AutoConfig.from_pretrained( + model_id, + revision=revision, + ) + if original_vlm_name_or_path is not None: + original_config._name_or_path = original_vlm_name_or_path + if hasattr(original_config, "architectures"): + delattr(original_config, "architectures") + + original_state_dict = load_original_state_dict(model_id, revision=revision) + + # Format the state_dict keys + original_state_dict = rename_state_dict_keys(original_state_dict) + + # Create the new config + config = ColPaliConfig( + vlm_config=original_config, + embedding_dim=128, # hardcoded in the original model + ) + config.model_type = "colpali" + config.is_composition = False + + # Load the untrained model + model = ColPaliForRetrieval(config=config).to("cpu").eval() + print("Created model with new config and randomly initialized weights") + + # NOTE: The model was initialized with float32 weights. We need to convert it to the desired precision. + # There are two ways to set the model's dtype: + # - Using `model.from_pretrained(..., torch_dtype=dtype_precision)` doesn't convert the hyperparameters to the desired precision. + # - Using `model.to(dtype_precision)` converts all values - including the hyperparameters - to the desired precision. + # The following snippet allows a fine-grained control over the model's dtype, making sure that all + # the new weights' dtypes match the original model. + for param in model.parameters(): + param.data = param.data.to(ORIGINAL_DTYPE) + print(f"Converted the new model weights to `{ORIGINAL_DTYPE}`") + + # Load the original weights + model.load_state_dict(original_state_dict) + print("Loaded original model weights") + + # Tie the weights (following ColPali's `__init__`` step) + if model.vlm.language_model._tied_weights_keys is not None: + model._tied_weights_keys = [f"vlm.language_model.{k}" for k in model.vlm.language_model._tied_weights_keys] + + # Sanity check: ensure all keys are the same + state_dict_keys_old = set(original_state_dict.keys()) + state_dict_keys_new = set(model.state_dict().keys()) + disjoint_keys = state_dict_keys_old.symmetric_difference(state_dict_keys_new) + if disjoint_keys: + raise ValueError(f"Incompatible keys: {disjoint_keys}") + + # Save the model + if push_to_hub: + model.push_to_hub(output_dir, private=True) + print(f"Model pushed to the hub at `{output_dir}`") + else: + Path(output_dir).mkdir(exist_ok=True, parents=True) + model.save_pretrained(output_dir) + print(f"Model saved to `{output_dir}`") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description=""" + This script converts the original ColPali model to the HF model format. + + Example usage: + ```bash + python src/transformers/models/colpali/convert_colpali_weights_to_hf.py \ + --model_id vidore/colpali-v1.2-merged \ + --revision 89fd9736194236a1ecb7a9ec9b04f537f6f896af \ + --original_vlm_name_or_path google/paligemma-3b-mix-448 \ + --output_dir vidore/colpali-v1.2-hf \ + --push_to_hub + ``` + """ + ) + parser.add_argument( + "--model_id", + help="Model ID of the original model to convert", + ) + parser.add_argument( + "--output_dir", + help="Location to write HF model and tokenizer", + ) + parser.add_argument( + "--push_to_hub", + help="Whether or not to push the model to the hub at `output_dir` instead of saving it locally", + action="store_true", + default=False, + ) + parser.add_argument( + "--revision", + help="Revision of the model to download", + default=None, + ) + parser.add_argument( + "--original_vlm_name_or_path", + help="Name or path of the original VLM backbone model", + default=None, + ) + args = parser.parse_args() + + convert_colpali_weights_to_hf( + model_id=args.model_id, + output_dir=args.output_dir, + push_to_hub=args.push_to_hub, + revision=args.revision, + original_vlm_name_or_path=args.original_vlm_name_or_path, + ) diff --git a/src/transformers/models/colpali/modeling_colpali.py b/src/transformers/models/colpali/modeling_colpali.py new file mode 100644 index 00000000000000..8bfff814c83756 --- /dev/null +++ b/src/transformers/models/colpali/modeling_colpali.py @@ -0,0 +1,299 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch ColPali model""" + +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch +from torch import nn + +from transformers import AutoModelForImageTextToText + +from ...cache_utils import Cache +from ...modeling_utils import PreTrainedModel +from ...utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + replace_return_docstrings, +) +from .configuration_colpali import ColPaliConfig + + +_CONFIG_FOR_DOC = "ColPaliConfig" + +COLPALI_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`ColPaliConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare ColPali model outputting raw hidden-states without any specific head on top.", + COLPALI_START_DOCSTRING, +) +class ColPaliPreTrainedModel(PreTrainedModel): + config_class = ColPaliConfig + base_model_prefix = "model" + _no_split_modules = [] + + def _init_weights(self, module): + std = ( + self.config.initializer_range + if hasattr(self.config, "initializer_range") + else self.config.vlm_config.text_config.initializer_range + ) + + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +@dataclass +class ColPaliForRetrievalOutput(ModelOutput): + """ + Base class for ColPali embeddings output. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + The embeddings of the model. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder after projecting last hidden state. + """ + + loss: Optional[torch.FloatTensor] = None + embeddings: torch.Tensor = None + past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + image_hidden_states: Optional[torch.FloatTensor] = None + + +COLPALI_FOR_RETRIEVAL_INPUT_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + [What are input IDs?](../glossary#input-ids) + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)): + The tensors corresponding to the input images. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`SiglipImageProcessor.__call__`] for details ([]`PaliGemmaProcessor`] uses + [`SiglipImageProcessor`] for processing images). If none, ColPali will only process text (query embeddings). + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + kwargs (`Dict[str, Any]`, *optional*): + Additional key word arguments passed along to the vlm backbone model. +""" + + +@add_start_docstrings( + """ + ColPali leverages Vision Language Models (VLMs) to construct efficient multi-vector embeddings in the visual space for document retrieval. + By feeding the ViT output patches from PaliGemma-3B to a linear projection, we create a multi-vector representation of documents. The model + is trained to maximize the similarity between these document embeddings and the query embeddings, following the ColBERT method. + + Using ColPali removes the need for potentially complex and brittle layout recognition and OCR pipelines with a single model that can take into account + both the textual and visual content (layout, charts, ...) of a document. + + ColPali was introduced in the following paper: [*ColPali: Efficient Document Retrieval with Vision Language Models*](https://arxiv.org/abs/2407.01449). + + Resources: + - A blog post detailing ColPali, a vision retrieval model, can be found [here](https://huggingface.co/blog/manu/colpali). 📝 + - The code for using and training the original ColPali model and for the `colpali-engine` package can be found [here](https://github.com/illuin-tech/colpali). 🌎 + - Cookbooks for learning to use the Hf version of ColPali, fine-tuning, and similarity maps generation can be found [here](https://github.com/tonywu71/colpali-cookbooks). 📚 + """ +) +class ColPaliForRetrieval(ColPaliPreTrainedModel): + def __init__(self, config: ColPaliConfig): + super().__init__(config) + self.config = config + self.vocab_size = config.vlm_config.text_config.vocab_size + + vlm = AutoModelForImageTextToText.from_config(config.vlm_config) + if vlm.language_model._tied_weights_keys is not None: + self._tied_weights_keys = [f"vlm.language_model.{k}" for k in vlm.language_model._tied_weights_keys] + self.vlm = vlm + + self.embedding_dim = self.config.embedding_dim + self.embedding_proj_layer = nn.Linear( + self.config.vlm_config.text_config.hidden_size, + self.embedding_dim, + ) + + self.post_init() + + @add_start_docstrings_to_model_forward(COLPALI_FOR_RETRIEVAL_INPUT_DOCSTRING) + @replace_return_docstrings(output_type=ColPaliForRetrievalOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, ColPaliForRetrievalOutput]: + r""" + Returns: + """ + if "pixel_values" in kwargs: + kwargs["pixel_values"] = kwargs["pixel_values"].to(dtype=self.dtype) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.vlm( + input_ids=input_ids, + attention_mask=attention_mask, + pixel_values=pixel_values, + output_hidden_states=True, + return_dict=return_dict, + output_attentions=output_attentions, + **kwargs, + ) + + last_hidden_states = outputs.hidden_states[-1] # (batch_size, sequence_length, hidden_size) + embeddings = self.embedding_proj_layer(last_hidden_states) # (batch_size, sequence_length, dim) + + # L2 normalization + embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim) + + embeddings = embeddings * attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim) + + loss = None + if not return_dict: + output = (embeddings,) + outputs[2:] + output[2] = output[2] if output_hidden_states is not None else None + output[-1] = (outputs.image_hidden_states if pixel_values is not None else None,) + return (loss,) + output if loss is not None else output + + return ColPaliForRetrievalOutput( + loss=loss, + embeddings=embeddings, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=outputs.attentions, + image_hidden_states=outputs.image_hidden_states if pixel_values is not None else None, + ) + + def get_input_embeddings(self): + return self.vlm.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.vlm.language_model.set_input_embeddings(value) + + def get_output_embeddings(self): + return self.vlm.language_model.get_output_embeddings() + + def set_output_embeddings(self, new_embeddings): + self.vlm.language_model.set_output_embeddings(new_embeddings) + + def set_decoder(self, decoder): + self.vlm.language_model.set_decoder(decoder) + + def get_decoder(self): + return self.vlm.language_model.get_decoder() + + def tie_weights(self): + return self.vlm.language_model.tie_weights() + + def resize_token_embeddings( + self, + new_num_tokens: Optional[int] = None, + pad_to_multiple_of: Optional[int] = None, + mean_resizing: bool = True, + ) -> nn.Embedding: + model_embeds = self.vlm.language_model.resize_token_embeddings( + new_num_tokens=new_num_tokens, + pad_to_multiple_of=pad_to_multiple_of, + mean_resizing=mean_resizing, + ) + + self.config.vlm_config.text_config.vocab_size = model_embeds.num_embeddings + self.config.vlm_config.vocab_size = model_embeds.num_embeddings + self.vlm.vocab_size = model_embeds.num_embeddings + self.vocab_size = model_embeds.num_embeddings + + return model_embeds + + +__all__ = [ + "ColPaliForRetrieval", + "ColPaliForRetrievalOutput", + "ColPaliPreTrainedModel", +] diff --git a/src/transformers/models/colpali/modular_colpali.py b/src/transformers/models/colpali/modular_colpali.py new file mode 100644 index 00000000000000..ceb43e2d66f335 --- /dev/null +++ b/src/transformers/models/colpali/modular_colpali.py @@ -0,0 +1,354 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import ClassVar, List, Optional, Union + +from transformers.models.paligemma.processing_paligemma import ( + IMAGE_TOKEN, + PaliGemmaProcessor, + build_string_from_input, + make_batched_images, +) + +from ...feature_extraction_utils import BatchFeature +from ...image_utils import ImageInput, is_valid_image +from ...processing_utils import ( + ProcessingKwargs, + Unpack, +) +from ...tokenization_utils_base import ( + PreTokenizedInput, + TextInput, +) +from ...utils import ( + is_torch_available, + logging, +) + + +if is_torch_available(): + import torch + + +logger = logging.get_logger(__name__) + + +class ColPaliProcessorKwargs(ProcessingKwargs, total=False): + _defaults = { + "text_kwargs": { + "padding": "longest", + }, + "images_kwargs": { + "data_format": "channels_first", + "do_convert_rgb": True, + }, + "common_kwargs": {"return_tensors": "pt"}, + } + + +class ColPaliProcessor(PaliGemmaProcessor): + r""" + Constructs a ColPali processor which wraps a PaliGemmaProcessor and special methods to process images and queries, as + well as to compute the late-interaction retrieval score. + + [`ColPaliProcessor`] offers all the functionalities of [`PaliGemmaProcessor`]. See the [`~PaliGemmaProcessor.__call__`] + for more information. + + Args: + image_processor ([`SiglipImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`LlamaTokenizerFast`], *optional*): + The tokenizer is a required input. + chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages + in a chat into a tokenizable string. + """ + + visual_prompt_prefix: ClassVar[str] = "Describe the image." + query_prefix: ClassVar[str] = "Question: " + + @property + def query_augmentation_token(self) -> str: + """ + Return the query augmentation token. + + Query augmentation buffers are used as reasoning buffers during inference. + """ + return self.tokenizer.pad_token + + def __call__( + self, + images: ImageInput = None, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + audio=None, + videos=None, + **kwargs: Unpack[ColPaliProcessorKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model either (1) one or several texts, either (2) one or several image(s). This method is custom + wrapper around the PaliGemmaProcessor's [`~PaliGemmaProcessor.__call__`] method adapted for the ColPali model. It cannot process + both text and images at the same time. + + When preparing the the text(s), this method forwards the `text` and `kwargs` arguments to LlamaTokenizerFast's + [`~LlamaTokenizerFast.__call__`]. + When preparing the the image(s), this method forwards the `images` and `kwargs` arguments to SiglipImageProcessor's + [`~SiglipImageProcessor.__call__`]. + Please refer to the doctsring of the above two methods for more information. + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a + number of channels, H and W are image height and width. + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + """ + output_kwargs = self._merge_kwargs( + ColPaliProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + suffix = output_kwargs["text_kwargs"].pop("suffix", None) + + return_token_type_ids = True if suffix is not None else False + + if text is None and images is None: + raise ValueError("Either text or images must be provided") + if text is not None and images is not None: + raise ValueError("Only one of text or images can be processed at a time") + + if images is not None: + if is_valid_image(images): + images = [images] + elif isinstance(images, list) and is_valid_image(images[0]): + pass + elif not (isinstance(images, list) and isinstance(images[0], list) and is_valid_image(images[0][0])): + raise ValueError("images must be an image, list of images or list of list of images") + + texts_doc = [self.visual_prompt_prefix] * len(images) + images = [image.convert("RGB") for image in images] + + input_strings = [ + build_string_from_input( + prompt=prompt, + bos_token=self.tokenizer.bos_token, + image_seq_len=self.image_seq_length, + image_token=IMAGE_TOKEN, + num_images=len(image_list) if isinstance(image_list, list) else 1, + ) + for prompt, image_list in zip(texts_doc, images) + ] + images = make_batched_images(images) + pixel_values = self.image_processor(images, **output_kwargs["images_kwargs"])["pixel_values"] + + # max_length has to account for the image tokens + if output_kwargs["text_kwargs"].get("max_length", None) is not None: + output_kwargs["text_kwargs"]["max_length"] += self.image_seq_length + + inputs = self.tokenizer( + input_strings, + return_token_type_ids=False, + **output_kwargs["text_kwargs"], + ) + + return_data = {**inputs, "pixel_values": pixel_values} + + if return_token_type_ids: + labels = inputs["input_ids"].masked_fill(inputs["token_type_ids"] == 0, -100) + return_data.update({"labels": labels}) + + return BatchFeature(data=return_data) + + elif text is not None: + if isinstance(text, str): + text = [text] + elif not (isinstance(text, list) and isinstance(text[0], str)): + raise ValueError("Text must be a string or a list of strings") + + if suffix is None: + suffix = self.query_augmentation_token * 10 + texts_query: List[str] = [] + + for query in text: + query = self.tokenizer.bos_token + self.query_prefix + query + query += suffix # add suffix (pad tokens) + query += "\n" # make input ISO to PaliGemma's processor + texts_query.append(query) + + output_kwargs["text_kwargs"]["max_length"] = output_kwargs["text_kwargs"].get("max_length", 50) + + batch_query = self.tokenizer( + texts_query, + return_token_type_ids=False, + **output_kwargs["text_kwargs"], + ) + + return batch_query + + def process_images( + self, + images: ImageInput = None, + **kwargs: Unpack[ColPaliProcessorKwargs], + ) -> BatchFeature: + """ + Prepare for the model one or several image(s). This method is a wrapper around the `__call__` method of the ColPaliProcessor's + [`ColPaliProcessor.__call__`]. + + This method forwards the `images` and `kwargs` arguments to SiglipImageProcessor's [`~SiglipImageProcessor.__call__`]. + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a + number of channels, H and W are image height and width. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + """ + return self.__call__(images=images, **kwargs) + + def process_queries( + self, + text: Union[TextInput, List[TextInput]], + **kwargs: Unpack[ColPaliProcessorKwargs], + ) -> BatchFeature: + """ + Prepare for the model one or several texts. This method is a wrapper around the `__call__` method of the ColPaliProcessor's + [`ColPaliProcessor.__call__`]. + + This method forwards the `text` and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`]. + + Args: + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + """ + return self.__call__(text=text, **kwargs) + + def score_retrieval( + self, + query_embeddings: Union["torch.Tensor", List["torch.Tensor"]], + passage_embeddings: Union["torch.Tensor", List["torch.Tensor"]], + batch_size: int = 128, + output_dtype: Optional["torch.dtype"] = None, + output_device: Union["torch.device", str] = "cpu", + ) -> "torch.Tensor": + """ + Compute the late-interaction/MaxSim score (ColBERT-like) for the given multi-vector + query embeddings (`qs`) and passage embeddings (`ps`). For ColPali, a passage is the + image of a document page. + + Because the embedding tensors are multi-vector and can thus have different shapes, they + should be fed as: + (1) a list of tensors, where the i-th tensor is of shape (sequence_length_i, embedding_dim) + (2) a single tensor of shape (n_passages, max_sequence_length, embedding_dim) -> usually + obtained by padding the list of tensors. + + Args: + query_embeddings (`Union[torch.Tensor, List[torch.Tensor]`): Query embeddings. + passage_embeddings (`Union[torch.Tensor, List[torch.Tensor]`): Passage embeddings. + batch_size (`int`, *optional*, defaults to 128): Batch size for computing scores. + output_dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): The dtype of the output tensor. + If `None`, the dtype of the input embeddings is used. + output_device (`torch.device` or `str`, *optional*, defaults to "cpu"): The device of the output tensor. + + Returns: + `torch.Tensor`: A tensor of shape `(n_queries, n_passages)` containing the scores. The score + tensor is saved on the "cpu" device. + """ + + if len(query_embeddings) == 0: + raise ValueError("No queries provided") + if len(passage_embeddings) == 0: + raise ValueError("No passages provided") + + if query_embeddings[0].device != passage_embeddings[0].device: + raise ValueError("Queries and passages must be on the same device") + + if query_embeddings[0].dtype != passage_embeddings[0].dtype: + raise ValueError("Queries and passages must have the same dtype") + + if output_dtype is None: + output_dtype = query_embeddings[0].dtype + + scores: List[torch.Tensor] = [] + + for i in range(0, len(query_embeddings), batch_size): + batch_scores: List[torch.Tensor] = [] + batch_queries = torch.nn.utils.rnn.pad_sequence( + query_embeddings[i : i + batch_size], batch_first=True, padding_value=0 + ) + for j in range(0, len(passage_embeddings), batch_size): + batch_passages = torch.nn.utils.rnn.pad_sequence( + passage_embeddings[j : j + batch_size], batch_first=True, padding_value=0 + ) + batch_scores.append( + torch.einsum("bnd,csd->bcns", batch_queries, batch_passages).max(dim=3)[0].sum(dim=2) + ) + scores.append(torch.cat(batch_scores, dim=1).to(output_dtype).to(output_device)) + + return torch.cat(scores, dim=0) + + +__all__ = [ + "ColPaliProcessor", +] diff --git a/src/transformers/models/colpali/processing_colpali.py b/src/transformers/models/colpali/processing_colpali.py new file mode 100644 index 00000000000000..f8d68675798bc4 --- /dev/null +++ b/src/transformers/models/colpali/processing_colpali.py @@ -0,0 +1,443 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/colpali/modular_colpali.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_colpali.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import ClassVar, List, Optional, Union + +from ...feature_extraction_utils import BatchFeature +from ...image_utils import ImageInput, is_valid_image +from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack +from ...tokenization_utils_base import AddedToken, PreTokenizedInput, TextInput +from ...utils import is_torch_available + + +if is_torch_available(): + import torch + + +class ColPaliProcessorKwargs(ProcessingKwargs, total=False): + _defaults = { + "text_kwargs": { + "padding": "longest", + }, + "images_kwargs": { + "data_format": "channels_first", + "do_convert_rgb": True, + }, + "common_kwargs": {"return_tensors": "pt"}, + } + + +IMAGE_TOKEN = "" +EXTRA_TOKENS = [f"4}>" for i in range(1024)] + [f"3}>" for i in range(128)] + + +def build_string_from_input(prompt, bos_token, image_seq_len, image_token, num_images): + """ + Builds a string from the input prompt and image tokens. + For example, for the call: + build_string_from_input( + prompt="Prefix str" + bos_token="", + image_seq_len=3, + image_token="", + ) + The output will be: + "Initial str" + Args: + prompt (`List[Union[str, ImageInput]]`): The input prompt. + bos_token (`str`): The beginning of sentence token. + image_seq_len (`int`): The length of the image sequence. + image_token (`str`): The image token. + num_images (`int`): Number of images in the prompt. + """ + return f"{image_token * image_seq_len * num_images}{bos_token}{prompt}\n" + + +def make_batched_images(images) -> List[List[ImageInput]]: + """ + Accepts images in list or nested list format, and makes a list of images for preprocessing. + + Args: + images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`): + The input image. + + Returns: + list: A list of images. + """ + if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]): + return [img for img_list in images for img in img_list] + + elif isinstance(images, (list, tuple)) and is_valid_image(images[0]): + return images + + elif is_valid_image(images): + return [images] + + raise ValueError(f"Could not make batched video from {images}") + + +class ColPaliProcessor(ProcessorMixin): + r""" + Constructs a ColPali processor which wraps a PaliGemmaProcessor and special methods to process images and queries, as + well as to compute the late-interaction retrieval score. + + [`ColPaliProcessor`] offers all the functionalities of [`PaliGemmaProcessor`]. See the [`~PaliGemmaProcessor.__call__`] + for more information. + + Args: + image_processor ([`SiglipImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`LlamaTokenizerFast`], *optional*): + The tokenizer is a required input. + chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages + in a chat into a tokenizable string. + """ + + attributes = ["image_processor", "tokenizer"] + valid_kwargs = ["chat_template"] + image_processor_class = "SiglipImageProcessor" + tokenizer_class = ("GemmaTokenizer", "GemmaTokenizerFast") + + visual_prompt_prefix: ClassVar[str] = "Describe the image." + query_prefix: ClassVar[str] = "Question: " + + def __init__( + self, + image_processor=None, + tokenizer=None, + chat_template=None, + **kwargs, + ): + if image_processor is None: + raise ValueError("You need to specify an `image_processor`.") + if tokenizer is None: + raise ValueError("You need to specify a `tokenizer`.") + if not hasattr(image_processor, "image_seq_length"): + raise ValueError("Image processor is missing an `image_seq_length` attribute.") + + self.image_seq_length = image_processor.image_seq_length + + if not hasattr(tokenizer, "image_token"): + image_token = AddedToken(IMAGE_TOKEN, normalized=False, special=True) + tokens_to_add = {"additional_special_tokens": [image_token]} + tokenizer.add_special_tokens(tokens_to_add) + self.image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) + else: + self.image_token_id = tokenizer.image_token_id + + tokenizer.add_tokens(EXTRA_TOKENS) + tokenizer.add_bos_token = False + tokenizer.add_eos_token = False + + super().__init__(image_processor, tokenizer, chat_template=chat_template) + + def __call__( + self, + images: ImageInput = None, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + audio=None, + videos=None, + **kwargs: Unpack[ColPaliProcessorKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model either (1) one or several texts, either (2) one or several image(s). This method is custom + wrapper around the PaliGemmaProcessor's [`~PaliGemmaProcessor.__call__`] method adapted for the ColPali model. It cannot process + both text and images at the same time. + + When preparing the the text(s), this method forwards the `text` and `kwargs` arguments to LlamaTokenizerFast's + [`~LlamaTokenizerFast.__call__`]. + When preparing the the image(s), this method forwards the `images` and `kwargs` arguments to SiglipImageProcessor's + [`~SiglipImageProcessor.__call__`]. + Please refer to the doctsring of the above two methods for more information. + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a + number of channels, H and W are image height and width. + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + """ + output_kwargs = self._merge_kwargs( + ColPaliProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + suffix = output_kwargs["text_kwargs"].pop("suffix", None) + + return_token_type_ids = True if suffix is not None else False + + if text is None and images is None: + raise ValueError("Either text or images must be provided") + if text is not None and images is not None: + raise ValueError("Only one of text or images can be processed at a time") + + if images is not None: + if is_valid_image(images): + images = [images] + elif isinstance(images, list) and is_valid_image(images[0]): + pass + elif not (isinstance(images, list) and isinstance(images[0], list) and is_valid_image(images[0][0])): + raise ValueError("images must be an image, list of images or list of list of images") + + texts_doc = [self.visual_prompt_prefix] * len(images) + images = [image.convert("RGB") for image in images] + + input_strings = [ + build_string_from_input( + prompt=prompt, + bos_token=self.tokenizer.bos_token, + image_seq_len=self.image_seq_length, + image_token=IMAGE_TOKEN, + num_images=len(image_list) if isinstance(image_list, list) else 1, + ) + for prompt, image_list in zip(texts_doc, images) + ] + images = make_batched_images(images) + pixel_values = self.image_processor(images, **output_kwargs["images_kwargs"])["pixel_values"] + + # max_length has to account for the image tokens + if output_kwargs["text_kwargs"].get("max_length", None) is not None: + output_kwargs["text_kwargs"]["max_length"] += self.image_seq_length + + inputs = self.tokenizer( + input_strings, + return_token_type_ids=False, + **output_kwargs["text_kwargs"], + ) + + return_data = {**inputs, "pixel_values": pixel_values} + + if return_token_type_ids: + labels = inputs["input_ids"].masked_fill(inputs["token_type_ids"] == 0, -100) + return_data.update({"labels": labels}) + + return BatchFeature(data=return_data) + + elif text is not None: + if isinstance(text, str): + text = [text] + elif not (isinstance(text, list) and isinstance(text[0], str)): + raise ValueError("Text must be a string or a list of strings") + + if suffix is None: + suffix = self.query_augmentation_token * 10 + texts_query: List[str] = [] + + for query in text: + query = self.tokenizer.bos_token + self.query_prefix + query + query += suffix # add suffix (pad tokens) + query += "\n" # make input ISO to PaliGemma's processor + texts_query.append(query) + + output_kwargs["text_kwargs"]["max_length"] = output_kwargs["text_kwargs"].get("max_length", 50) + + batch_query = self.tokenizer( + texts_query, + return_token_type_ids=False, + **output_kwargs["text_kwargs"], + ) + + return batch_query + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) + + @property + def query_augmentation_token(self) -> str: + """ + Return the query augmentation token. + + Query augmentation buffers are used as reasoning buffers during inference. + """ + return self.tokenizer.pad_token + + def process_images( + self, + images: ImageInput = None, + **kwargs: Unpack[ColPaliProcessorKwargs], + ) -> BatchFeature: + """ + Prepare for the model one or several image(s). This method is a wrapper around the `__call__` method of the ColPaliProcessor's + [`ColPaliProcessor.__call__`]. + + This method forwards the `images` and `kwargs` arguments to SiglipImageProcessor's [`~SiglipImageProcessor.__call__`]. + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a + number of channels, H and W are image height and width. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + """ + return self.__call__(images=images, **kwargs) + + def process_queries( + self, + text: Union[TextInput, List[TextInput]], + **kwargs: Unpack[ColPaliProcessorKwargs], + ) -> BatchFeature: + """ + Prepare for the model one or several texts. This method is a wrapper around the `__call__` method of the ColPaliProcessor's + [`ColPaliProcessor.__call__`]. + + This method forwards the `text` and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`]. + + Args: + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + """ + return self.__call__(text=text, **kwargs) + + def score_retrieval( + self, + query_embeddings: Union["torch.Tensor", List["torch.Tensor"]], + passage_embeddings: Union["torch.Tensor", List["torch.Tensor"]], + batch_size: int = 128, + output_dtype: Optional["torch.dtype"] = None, + output_device: Union["torch.device", str] = "cpu", + ) -> "torch.Tensor": + """ + Compute the late-interaction/MaxSim score (ColBERT-like) for the given multi-vector + query embeddings (`qs`) and passage embeddings (`ps`). For ColPali, a passage is the + image of a document page. + + Because the embedding tensors are multi-vector and can thus have different shapes, they + should be fed as: + (1) a list of tensors, where the i-th tensor is of shape (sequence_length_i, embedding_dim) + (2) a single tensor of shape (n_passages, max_sequence_length, embedding_dim) -> usually + obtained by padding the list of tensors. + + Args: + query_embeddings (`Union[torch.Tensor, List[torch.Tensor]`): Query embeddings. + passage_embeddings (`Union[torch.Tensor, List[torch.Tensor]`): Passage embeddings. + batch_size (`int`, *optional*, defaults to 128): Batch size for computing scores. + output_dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): The dtype of the output tensor. + If `None`, the dtype of the input embeddings is used. + output_device (`torch.device` or `str`, *optional*, defaults to "cpu"): The device of the output tensor. + + Returns: + `torch.Tensor`: A tensor of shape `(n_queries, n_passages)` containing the scores. The score + tensor is saved on the "cpu" device. + """ + + if len(query_embeddings) == 0: + raise ValueError("No queries provided") + if len(passage_embeddings) == 0: + raise ValueError("No passages provided") + + if query_embeddings[0].device != passage_embeddings[0].device: + raise ValueError("Queries and passages must be on the same device") + + if query_embeddings[0].dtype != passage_embeddings[0].dtype: + raise ValueError("Queries and passages must have the same dtype") + + if output_dtype is None: + output_dtype = query_embeddings[0].dtype + + scores: List[torch.Tensor] = [] + + for i in range(0, len(query_embeddings), batch_size): + batch_scores: List[torch.Tensor] = [] + batch_queries = torch.nn.utils.rnn.pad_sequence( + query_embeddings[i : i + batch_size], batch_first=True, padding_value=0 + ) + for j in range(0, len(passage_embeddings), batch_size): + batch_passages = torch.nn.utils.rnn.pad_sequence( + passage_embeddings[j : j + batch_size], batch_first=True, padding_value=0 + ) + batch_scores.append( + torch.einsum("bnd,csd->bcns", batch_queries, batch_passages).max(dim=3)[0].sum(dim=2) + ) + scores.append(torch.cat(batch_scores, dim=1).to(output_dtype).to(output_device)) + + return torch.cat(scores, dim=0) + + +__all__ = ["ColPaliProcessor"] diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index c6057088b7d506..823c51a290713d 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -813,6 +813,9 @@ def __init__(self, *args, **kwargs): MODEL_FOR_QUESTION_ANSWERING_MAPPING = None +MODEL_FOR_RETRIEVAL_MAPPING = None + + MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING = None @@ -2258,6 +2261,20 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class ColPaliForRetrieval(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class ColPaliPreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class ConditionalDetrForObjectDetection(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/models/colpali/__init__.py b/tests/models/colpali/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/tests/models/colpali/test_modeling_colpali.py b/tests/models/colpali/test_modeling_colpali.py new file mode 100644 index 00000000000000..646726ac700ee5 --- /dev/null +++ b/tests/models/colpali/test_modeling_colpali.py @@ -0,0 +1,368 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch ColPali model.""" + +import gc +import unittest +from typing import ClassVar + +import torch +from datasets import load_dataset +from parameterized import parameterized + +from tests.test_configuration_common import ConfigTester +from tests.test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor +from transformers import ( + is_torch_available, + is_vision_available, +) +from transformers.models.colpali.configuration_colpali import ColPaliConfig +from transformers.models.colpali.modeling_colpali import ColPaliForRetrieval, ColPaliForRetrievalOutput +from transformers.models.colpali.processing_colpali import ColPaliProcessor +from transformers.testing_utils import ( + require_torch, + require_torch_sdpa, + require_vision, + slow, + torch_device, +) + + +if is_torch_available(): + import torch + +if is_vision_available(): + pass + + +class ColPaliForRetrievalModelTester: + def __init__( + self, + parent, + ignore_index=-100, + image_token_index=0, + projector_hidden_act="gelu", + seq_length=25, + vision_feature_select_strategy="default", + vision_feature_layer=-1, + projection_dim=32, + text_config={ + "model_type": "gemma", + "seq_length": 128, + "is_training": True, + "use_token_type_ids": False, + "use_labels": True, + "vocab_size": 99, + "hidden_size": 32, + "num_hidden_layers": 2, + "num_attention_heads": 4, + "num_key_value_heads": 1, + "head_dim": 8, + "intermediate_size": 37, + "hidden_activation": "gelu_pytorch_tanh", + "hidden_dropout_prob": 0.1, + "attention_probs_dropout_prob": 0.1, + "max_position_embeddings": 512, + "type_vocab_size": 16, + "type_sequence_label_size": 2, + "initializer_range": 0.02, + "num_labels": 3, + "num_choices": 4, + "pad_token_id": 1, + }, + is_training=False, + vision_config={ + "use_labels": True, + "image_size": 20, + "patch_size": 5, + "num_image_tokens": 4, + "num_channels": 3, + "is_training": True, + "hidden_size": 32, + "projection_dim": 32, + "num_key_value_heads": 1, + "num_hidden_layers": 2, + "num_attention_heads": 4, + "intermediate_size": 37, + "dropout": 0.1, + "attention_dropout": 0.1, + "initializer_range": 0.02, + }, + use_cache=False, + embedding_dim=128, + ): + self.parent = parent + self.ignore_index = ignore_index + # `image_token_index` is set to 0 to pass "resize_embeddings" test, do not modify + self.image_token_index = image_token_index + self.projector_hidden_act = projector_hidden_act + self.vision_feature_select_strategy = vision_feature_select_strategy + self.vision_feature_layer = vision_feature_layer + self.text_config = text_config + self.vision_config = vision_config + self.seq_length = seq_length + self.projection_dim = projection_dim + self.pad_token_id = text_config["pad_token_id"] + + self.num_hidden_layers = text_config["num_hidden_layers"] + self.vocab_size = text_config["vocab_size"] + self.hidden_size = text_config["hidden_size"] + self.num_attention_heads = text_config["num_attention_heads"] + self.is_training = is_training + + self.batch_size = 3 + self.num_channels = vision_config["num_channels"] + self.image_size = vision_config["image_size"] + self.encoder_seq_length = seq_length + self.use_cache = use_cache + + self.embedding_dim = embedding_dim + self.vlm_config = { + "model_type": "paligemma", + "text_config": self.text_config, + "vision_config": self.vision_config, + "ignore_index": self.ignore_index, + "image_token_index": self.image_token_index, + "projector_hidden_act": self.projector_hidden_act, + "projection_dim": self.projection_dim, + "vision_feature_select_strategy": self.vision_feature_select_strategy, + "vision_feature_layer": self.vision_feature_layer, + } + + def get_config(self): + return ColPaliConfig( + vlm_config=self.vlm_config, + embedding_dim=self.embedding_dim, + ) + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor( + [ + self.batch_size, + self.vision_config["num_channels"], + self.vision_config["image_size"], + self.vision_config["image_size"], + ] + ) + config = self.get_config() + + return config, pixel_values + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values = config_and_inputs + input_ids = ids_tensor([self.batch_size, self.seq_length], config.vlm_config.text_config.vocab_size - 1) + 1 + attention_mask = input_ids.ne(1).to(torch_device) + # set the 16 first tokens to be image, and ensure that no other tokens are image tokens + # do not change this unless you modified image size or patch size + input_ids[input_ids == config.vlm_config.image_token_index] = self.pad_token_id + input_ids[:, :16] = config.vlm_config.image_token_index + inputs_dict = { + "pixel_values": pixel_values, + "input_ids": input_ids, + "attention_mask": attention_mask, + "labels": input_ids, + "token_type_ids": torch.zeros_like(input_ids), + } + return config, inputs_dict + + +@require_torch +class ColPaliForRetrievalModelTest(ModelTesterMixin, unittest.TestCase): + """ + Model tester for `ColPaliForRetrieval`. + """ + + all_model_classes = (ColPaliForRetrieval,) if is_torch_available() else () + fx_compatible = False + test_torchscript = False + test_pruning = False + test_resize_embeddings = True + test_head_masking = False + + def setUp(self): + self.model_tester = ColPaliForRetrievalModelTester(self) + self.config_tester = ConfigTester(self, config_class=ColPaliConfig, has_text_modality=False) + + # overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs + + def test_inputs_embeds(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + inputs = self._prepare_for_class(inputs_dict, model_class) + + input_ids = inputs["input_ids"] + del inputs["input_ids"] + del inputs["pixel_values"] + + wte = model.get_input_embeddings() + inputs["inputs_embeds"] = wte(input_ids) + + with torch.no_grad(): + model(**inputs) + + # overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs + # while some other models require pixel_values to be present + def test_inputs_embeds_matches_input_ids(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + inputs = self._prepare_for_class(inputs_dict, model_class) + input_ids = inputs["input_ids"] + del inputs["input_ids"] + del inputs["pixel_values"] + + inputs_embeds = model.get_input_embeddings()(input_ids) + + with torch.no_grad(): + out_ids = model(input_ids=input_ids, **inputs)[0] + out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0] + self.assertTrue(torch.allclose(out_embeds, out_ids)) + + @slow + @require_vision + def test_colpali_forward_inputs(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + inputs = self._prepare_for_class(inputs_dict, model_class) + + with torch.no_grad(): + outputs = model(**inputs, return_dict=True) + + self.assertIsInstance(outputs, ColPaliForRetrievalOutput) + + @unittest.skip( + reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" + ) + def test_training_gradient_checkpointing(self): + pass + + @unittest.skip( + reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" + ) + def test_training_gradient_checkpointing_use_reentrant(self): + pass + + @unittest.skip( + reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" + ) + def test_training_gradient_checkpointing_use_reentrant_false(self): + pass + + @require_torch_sdpa + @slow + @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) + def test_eager_matches_sdpa_inference(self, torch_dtype: str): + self.skipTest( + "Due to custom causal mask, there is a slightly too big difference between eager and sdpa in bfloat16." + ) + + @unittest.skip( + reason="From PaliGemma: Some undefined behavior encountered with test versions of this model. Skip for now." + ) + def test_model_parallelism(self): + pass + + @unittest.skip( + reason="PaliGemmma's SigLip encoder uses the same initialization scheme as the Flax original implementation" + ) + def test_initialization(self): + pass + + # TODO extend valid outputs to include this test @Molbap + @unittest.skip(reason="PaliGemma has currently one output format.") + def test_model_outputs_equivalence(self): + pass + + @unittest.skip(reason="Pass because ColPali requires `attention_mask is not None`") + def test_sdpa_can_dispatch_on_flash(self): + pass + + @unittest.skip(reason="Pass because ColPali requires `attention_mask is not None`") + def test_sdpa_can_compile_dynamic(self): + pass + + +@require_torch +class ColPaliModelIntegrationTest(unittest.TestCase): + model_name: ClassVar[str] = "vidore/colpali-v1.2-hf" + + def setUp(self): + self.processor = ColPaliProcessor.from_pretrained(self.model_name) + + def tearDown(self): + gc.collect() + torch.cuda.empty_cache() + + @slow + def test_model_integration_test(self): + """ + Test if the model is able to retrieve the correct pages for a small and easy dataset. + """ + model = ColPaliForRetrieval.from_pretrained( + self.model_name, + torch_dtype=torch.bfloat16, + device_map=torch_device, + ).eval() + + # Load the test dataset + ds = load_dataset("hf-internal-testing/document-visual-retrieval-test", split="test") + + # Preprocess the examples + batch_images = self.processor(images=ds["image"]).to(torch_device) + batch_queries = self.processor(text=ds["query"]).to(torch_device) + + # Run inference + with torch.inference_mode(): + image_embeddings = model(**batch_images).embeddings + query_embeddings = model(**batch_queries).embeddings + + # Compute retrieval scores + scores = self.processor.score_retrieval( + query_embeddings=query_embeddings, + passage_embeddings=image_embeddings, + ) # (len(qs), len(ps)) + + assert scores.ndim == 2, f"Expected 2D tensor, got {scores.ndim}" + assert scores.shape == (len(ds), len(ds)), f"Expected shape {(len(ds), len(ds))}, got {scores.shape}" + + # Check if the maximum scores per row are in the diagonal of the matrix score + self.assertTrue((scores.argmax(axis=1) == torch.arange(len(ds), device=scores.device)).all()) + + # Further validation: fine-grained check, with a hardcoded score from the original implementation + expected_scores = torch.tensor( + [ + [15.5625, 6.5938, 14.4375], + [12.2500, 16.2500, 11.0000], + [15.0625, 11.7500, 21.0000], + ], + dtype=scores.dtype, + ) + + assert torch.allclose(scores, expected_scores, atol=1), f"Expected scores {expected_scores}, got {scores}" diff --git a/tests/models/colpali/test_processing_colpali.py b/tests/models/colpali/test_processing_colpali.py new file mode 100644 index 00000000000000..42592460fa28ed --- /dev/null +++ b/tests/models/colpali/test_processing_colpali.py @@ -0,0 +1,247 @@ +import shutil +import tempfile +import unittest + +import torch + +from transformers import GemmaTokenizer +from transformers.models.colpali.processing_colpali import ColPaliProcessor +from transformers.testing_utils import get_tests_dir, require_torch, require_vision +from transformers.utils import is_vision_available +from transformers.utils.dummy_vision_objects import SiglipImageProcessor + +from ...test_processing_common import ProcessorTesterMixin + + +if is_vision_available(): + from transformers import ( + ColPaliProcessor, + PaliGemmaProcessor, + SiglipImageProcessor, + ) + +SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model") + + +@require_vision +class ColPaliProcessorTest(ProcessorTesterMixin, unittest.TestCase): + processor_class = ColPaliProcessor + + def setUp(self): + self.tmpdirname = tempfile.mkdtemp() + image_processor = SiglipImageProcessor.from_pretrained("google/siglip-so400m-patch14-384") + image_processor.image_seq_length = 0 + tokenizer = GemmaTokenizer(SAMPLE_VOCAB, keep_accents=True) + processor = PaliGemmaProcessor(image_processor=image_processor, tokenizer=tokenizer) + processor.save_pretrained(self.tmpdirname) + + def tearDown(self): + shutil.rmtree(self.tmpdirname) + + @require_torch + @require_vision + def test_process_images(self): + # Processor configuration + image_input = self.prepare_image_inputs() + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer", max_length=112, padding="max_length") + image_processor.image_seq_length = 14 + + # Get the processor + processor = self.processor_class( + tokenizer=tokenizer, + image_processor=image_processor, + ) + + # Process the image + batch_feature = processor.process_images(images=image_input, return_tensors="pt") + + # Assertions + self.assertIn("pixel_values", batch_feature) + self.assertEqual(batch_feature["pixel_values"].shape, torch.Size([1, 3, 384, 384])) + + @require_torch + @require_vision + def test_process_queries(self): + # Inputs + queries = [ + "Is attention really all you need?", + "Are Benjamin, Antoine, Merve, and Jo best friends?", + ] + + # Processor configuration + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer", max_length=112, padding="max_length") + image_processor.image_seq_length = 14 + + # Get the processor + processor = self.processor_class( + tokenizer=tokenizer, + image_processor=image_processor, + ) + + # Process the image + batch_feature = processor.process_queries(text=queries, return_tensors="pt") + + # Assertions + self.assertIn("input_ids", batch_feature) + self.assertIsInstance(batch_feature["input_ids"], torch.Tensor) + self.assertEqual(batch_feature["input_ids"].shape[0], len(queries)) + + # The following tests are overwritten as ColPaliProcessor can only take one of images or text as input at a time + + def test_tokenizer_defaults_preserved_by_kwargs(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + processor_components = self.prepare_components() + processor_components["tokenizer"] = self.get_component("tokenizer", max_length=117, padding="max_length") + + processor = self.processor_class(**processor_components) + self.skip_processor_without_typed_kwargs(processor) + input_str = self.prepare_text_inputs() + inputs = processor(text=input_str, return_tensors="pt") + self.assertEqual(inputs[self.text_input_name].shape[-1], 117) + + def test_image_processor_defaults_preserved_by_image_kwargs(self): + """ + We use do_rescale=True, rescale_factor=-1 to ensure that image_processor kwargs are preserved in the processor. + We then check that the mean of the pixel_values is less than or equal to 0 after processing. + Since the original pixel_values are in [0, 255], this is a good indicator that the rescale_factor is indeed applied. + """ + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + processor_components = self.prepare_components() + processor_components["image_processor"] = self.get_component( + "image_processor", do_rescale=True, rescale_factor=-1 + ) + processor_components["tokenizer"] = self.get_component("tokenizer", max_length=117, padding="max_length") + + processor = self.processor_class(**processor_components) + self.skip_processor_without_typed_kwargs(processor) + + image_input = self.prepare_image_inputs() + + inputs = processor(images=image_input, return_tensors="pt") + self.assertLessEqual(inputs[self.images_input_name][0][0].mean(), 0) + + def test_kwargs_overrides_default_tokenizer_kwargs(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + processor_components = self.prepare_components() + processor_components["tokenizer"] = self.get_component("tokenizer", padding="longest") + + processor = self.processor_class(**processor_components) + self.skip_processor_without_typed_kwargs(processor) + input_str = self.prepare_text_inputs() + inputs = processor(text=input_str, return_tensors="pt", max_length=112, padding="max_length") + self.assertEqual(inputs[self.text_input_name].shape[-1], 112) + + def test_kwargs_overrides_default_image_processor_kwargs(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + processor_components = self.prepare_components() + processor_components["image_processor"] = self.get_component( + "image_processor", do_rescale=True, rescale_factor=1 + ) + processor_components["tokenizer"] = self.get_component("tokenizer", max_length=117, padding="max_length") + + processor = self.processor_class(**processor_components) + self.skip_processor_without_typed_kwargs(processor) + + image_input = self.prepare_image_inputs() + + inputs = processor(images=image_input, do_rescale=True, rescale_factor=-1, return_tensors="pt") + self.assertLessEqual(inputs[self.images_input_name][0][0].mean(), 0) + + def test_unstructured_kwargs(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + processor_components = self.prepare_components() + processor = self.processor_class(**processor_components) + self.skip_processor_without_typed_kwargs(processor) + + input_str = self.prepare_text_inputs() + inputs = processor( + text=input_str, + return_tensors="pt", + do_rescale=True, + rescale_factor=-1, + padding="max_length", + max_length=76, + ) + + self.assertEqual(inputs[self.text_input_name].shape[-1], 76) + + def test_unstructured_kwargs_batched(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + processor_components = self.prepare_components() + processor = self.processor_class(**processor_components) + self.skip_processor_without_typed_kwargs(processor) + + image_input = self.prepare_image_inputs(batch_size=2) + inputs = processor( + images=image_input, + return_tensors="pt", + do_rescale=True, + rescale_factor=-1, + padding="longest", + max_length=76, + ) + + self.assertLessEqual(inputs[self.images_input_name][0][0].mean(), 0) + + def test_doubly_passed_kwargs(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + processor_components = self.prepare_components() + processor = self.processor_class(**processor_components) + self.skip_processor_without_typed_kwargs(processor) + + image_input = self.prepare_image_inputs() + with self.assertRaises(ValueError): + _ = processor( + images=image_input, + images_kwargs={"do_rescale": True, "rescale_factor": -1}, + do_rescale=True, + return_tensors="pt", + ) + + def test_structured_kwargs_nested(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + processor_components = self.prepare_components() + processor = self.processor_class(**processor_components) + self.skip_processor_without_typed_kwargs(processor) + + input_str = self.prepare_text_inputs() + + # Define the kwargs for each modality + all_kwargs = { + "common_kwargs": {"return_tensors": "pt"}, + "images_kwargs": {"do_rescale": True, "rescale_factor": -1}, + "text_kwargs": {"padding": "max_length", "max_length": 76}, + } + + inputs = processor(text=input_str, **all_kwargs) + self.skip_processor_without_typed_kwargs(processor) + + self.assertEqual(inputs[self.text_input_name].shape[-1], 76) + + def test_structured_kwargs_nested_from_dict(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + processor_components = self.prepare_components() + processor = self.processor_class(**processor_components) + self.skip_processor_without_typed_kwargs(processor) + image_input = self.prepare_image_inputs() + + # Define the kwargs for each modality + all_kwargs = { + "common_kwargs": {"return_tensors": "pt"}, + "images_kwargs": {"do_rescale": True, "rescale_factor": -1}, + "text_kwargs": {"padding": "max_length", "max_length": 76}, + } + + inputs = processor(images=image_input, **all_kwargs) + self.assertEqual(inputs[self.text_input_name].shape[-1], 76) diff --git a/utils/check_table.py b/utils/check_table.py index 5876818449558e..4a392a58fd0500 100644 --- a/utils/check_table.py +++ b/utils/check_table.py @@ -87,7 +87,7 @@ def _find_text_in_file(filename: str, start_prompt: str, end_prompt: str) -> str _re_tf_models = re.compile(r"TF(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)") _re_flax_models = re.compile(r"Flax(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)") # Will match any TF or Flax model too so need to be in an else branch after the two previous regexes. -_re_pt_models = re.compile(r"(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)") +_re_pt_models = re.compile(r"(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration|ForRetrieval)") # This is to make sure the transformers module imported is the one in the repo. diff --git a/utils/update_metadata.py b/utils/update_metadata.py index b6ee1e7c8c13c2..8e4a7e3fe5340e 100755 --- a/utils/update_metadata.py +++ b/utils/update_metadata.py @@ -56,7 +56,7 @@ _re_tf_models = re.compile(r"TF(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)") _re_flax_models = re.compile(r"Flax(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)") # Will match any TF or Flax model too so need to be in an else branch afterthe two previous regexes. -_re_pt_models = re.compile(r"(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)") +_re_pt_models = re.compile(r"(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration|ForRetrieval)") # Fill this with tuples (pipeline_tag, model_mapping, auto_model)