Skip to content

Commit

Permalink
Add ColPali to 🤗 transformers (#33736)
Browse files Browse the repository at this point in the history
* feat: run `add-new-model-like`

* feat: add paligemma code with "copied from"

* feat: add ColPaliProcessor

* feat: add ColPaliModel

* feat: add ColPaliConfig

* feat: rename `ColPaliForConditionalGeneration` to `ColPaliModel`

* fixup modeling colpali

* fix: fix root import shortcuts

* fix: fix `modeling_auto` dict

* feat: comment out ColPali test file

* fix: fix typos from `add-new-model-like`

* feat: explicit the forward input args

* feat: move everything to `modular_colpali.py`

* fix: put back ColPaliProcesor

* feat: add auto-generated files

* fix: run `fix-copies`

* fix: remove DOCStRING constants to make modular converter work

* fix: fix typo + modular converter

* fix: add missing imports

* feat: no more errors when loading ColPaliModel

* fix: remove unused args in forward + tweak doc

* feat: rename `ColPaliModel` to `ColPaliForRetrieval`

* fix: apply `fix-copies`

* feat: add ColPaliProcessor to `modular_colpali`

* fix: run make quality + make style

* fix: remove duplicate line in configuration_auto

* feat: make ColPaliModel inehrit from PaliGemmaForConditionalGeneration

* fix: tweak and use ColPaliConfig

* feat: rename `score` to `post_process_retrieval`

* build: run modular formatter + make style

* feat: convert colpali weights + fixes

* feat: remove old weight converter file

* feat: add and validate tests

* feat: replace harcoded path to "vidore/colpali-v1.2-hf" in tests

* fix: add bfloat16 conversion in weight converter

* feat: replace pytest with unittest in modeling colpali test

* feat: add sanity check for weight conversion (doesn't work yet)

* feat: add shape sanity check in weigth converter

* feat: make ColPaliProcessor args explicit

* doc: add doc for ColPali

* fix: trying to fix output mismatch

* feat: tweaks

* fix: ColPaliModelOutput inherits from ModelOutput instead of PaliGemmaCausalLMOutputWithPast

* fix: address comments on PR

* fix: adapt tests to the Hf norm

* wip: try things

* feat: add `__call__` method to `ColPaliProcessor`

* feat: remove need for dummy image in `process_queries`

* build: run new modular converter

* fix: fix incorrect method override

* Fix tests, processing, modular, convert

* fix tokenization auto

* hotfix: manually fix processor -> fixme once convert modular is fixed

* fix: convert weights working

* feat: rename and improve convert weight script

* feat: tweaks

* fest: remove `device` input for `post_process_retrieval`

* refactor: remove unused `get_torch_device`

* Fix all tests

* docs: update ColPali model doc

* wip: fix convert weights to hf

* fix logging modular

* docs: add acknowledgements in model doc

* docs: add missing docstring to ColPaliProcessor

* docs: tweak

* docs: add doc for `ColPaliForRetrievalOutput.forward`

* feat: add modifications from colpali-engine v0.3.2 in ColPaliProcessor

* fix: fix and upload colapli hf weights

* refactor: rename `post_process_retrieval` to `score_retrieval`

* fix: fix wrong typing for `score_retrieval`

* test: add integration test for ColPali

* chore: rerun convert modular

* build: fix root imports

* Update docs/source/en/index.md

Co-authored-by: Yoni Gozlan <[email protected]>

* fix: address PR comments

* wip: reduce the prediction gap in weight conversion

* docs: add comment in weight conversion script

* docs: add example for `ColPaliForRetrieval.forward`

* tests: change dataset path to the new one in hf-internal

* fix: colpali weight conversion works

* test: add fine-grained check for ColPali integration test

* fix: fix typos in convert weight script

* docs: move input docstring in a variable

* fix: remove hardcoded torch device in test

* fix: run the new modular refactor

* docs: fix python example for ColPali

* feat: add option to choose `score_retrieval`'s output dtype and device

* docs: update doc for `score_retrieval`

* feat: add `patch_size` property in ColPali model

* chore: run `make fix-copies`

* docs: update description for ColPali cookbooks

* fix: remove `ignore_index` methods

* feat: remove non-transformers specific methods

* feat: update `__init__.py` to new hf format

* fix: fix root imports in transformers

* feat: remove ColPali's inheritance from PaliGemma

* Fix CI issues

* nit remove prints

* feat: remove ColPali config and model from `modular_colpali.py`

* feat: add `ColPaliPreTrainedModel` and update modeling and configuration code

* fix: fix auto-removed imports in root `__init__.py`

* fix: various fixes

* fix: fix `_init_weight`

* temp: comment `AutoModel.from_config` for experiments

* fix: add missing `output_attentions` arg in ColPali's forward

* fix: fix `resize_token_embeddings`

* fix: make `input_ids` optional in forward

* feat: rename `projection_layer` to `embedding_proj_layer`

* wip: fix convert colpali weight script

* fix tests and convert weights from original repo

* fix unprotected import

* fix unprotected torch import

* fix style

* change vlm_backbone_config to vlm_config

* fix unprotected import in modular this time

* fix: load config from Hub + tweaks in convert weight script

* docs: move example usage from model docstring to model markdown

* docs: fix input docstring for ColPali's forward method

* fix: use `sub_configs` for ColPaliConfig

* fix: remove non-needed sanity checks in weight conversion script + tweaks

* fix: fix issue with `replace_return_docstrings` in ColPali's `forward`

* docs: update docstring for `ColPaliConfig`

* test: change model path in ColPali test

* fix: fix ColPaliConfig

* fix: fix weight conversion script

* test: fix expected weights for ColPali model

* docs: update ColPali markdown

* docs: fix minor typo in ColPaliProcessor

* Fix tests and add _no_split_modules

* add text_config to colpali config

* [run slow] colpali

* move inputs to torch_device in integration test

* skip test_model_parallelism

* docs: clarify quickstart snippet in ColPali's model card

* docs: update ColPali's model card

---------

Co-authored-by: yonigozlan <[email protected]>
Co-authored-by: Yoni Gozlan <[email protected]>
  • Loading branch information
3 people authored Dec 17, 2024
1 parent a7f5479 commit f33a0ce
Show file tree
Hide file tree
Showing 22 changed files with 2,204 additions and 2 deletions.
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -834,6 +834,8 @@
title: CLIPSeg
- local: model_doc/clvp
title: CLVP
- local: model_doc/colpali
title: ColPali
- local: model_doc/data2vec
title: Data2Vec
- local: model_doc/deplot
Expand Down
1 change: 1 addition & 0 deletions docs/source/en/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ Flax), PyTorch, and/or TensorFlow.
| [CodeLlama](model_doc/code_llama) ||||
| [Cohere](model_doc/cohere) ||||
| [Cohere2](model_doc/cohere2) ||||
| [ColPali](model_doc/colpali) ||||
| [Conditional DETR](model_doc/conditional_detr) ||||
| [ConvBERT](model_doc/convbert) ||||
| [ConvNeXT](model_doc/convnext) ||||
Expand Down
95 changes: 95 additions & 0 deletions docs/source/en/model_doc/colpali.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->

# ColPali

## Overview

The ColPali model was proposed in [ColPali: Efficient Document Retrieval with Vision Language Models](https://doi.org/10.48550/arXiv.2407.01449) by **Manuel Faysse***, **Hugues Sibille***, **Tony Wu***, Bilel Omrani, Gautier Viaud, Céline Hudelot, Pierre Colombo (* denotes equal contribution).

With our new model *ColPali*, we propose to leverage VLMs to construct efficient multi-vector embeddings in the visual space for document retrieval. By feeding the ViT output patches from PaliGemma-3B to a linear projection, we create a multi-vector representation of documents. We train the model to maximize the similarity between these document embeddings and the query embeddings, following the ColBERT method.

Using ColPali removes the need for potentially complex and brittle layout recognition and OCR pipelines with a single model that can take into account both the textual and visual content (layout, charts, ...) of a document. ColPali is also highly interpretable: similarity maps can be obtained between patches and query tokens. These maps highlight ColPali’s strong OCR capabilities and chart understanding.

**Paper abstract:**

> Documents are visually rich structures that convey information through text, but also figures, page layouts, tables, or even fonts. Since modern retrieval systems mainly rely on the textual information they extract from document pages to index documents -often through lengthy and brittle processes-, they struggle to exploit key visual cues efficiently. This limits their capabilities in many practical document retrieval applications such as Retrieval Augmented Generation (RAG). To benchmark current systems on visually rich document retrieval, we introduce the Visual Document Retrieval Benchmark *ViDoRe*, composed of various page-level retrieval tasks spanning multiple domains, languages, and practical settings. The inherent complexity and performance shortcomings of modern systems motivate a new concept; doing document retrieval by directly embedding the images of the document pages. We release *ColPali*, a Vision Language Model trained to produce high-quality multi-vector embeddings from images of document pages. Combined with a late interaction matching mechanism, *ColPali* largely outperforms modern document retrieval pipelines while being drastically simpler, faster and end-to-end trainable.
>
> We release models, data, code and benchmarks under open licenses at [https://huggingface.co/vidore](https://huggingface.co/vidore).
## Resources

- The official blog post detailing ColPali can be found [here](https://huggingface.co/blog/manu/colpali). 📝
- The original model implementation code for the ColPali model and for the `colpali-engine` package can be found [here](https://github.com/illuin-tech/colpali). 🌎
- Cookbooks for learning to use the transformers-native version of ColPali, fine-tuning, and similarity maps generation can be found [here](https://github.com/tonywu71/colpali-cookbooks). 📚

This model was contributed by [@tonywu71](https://huggingface.co/tonywu71) and [@yonigozlan](https://huggingface.co/yonigozlan).

## Usage

This example demonstrates how to use ColPali to embed both queries and images, calculate their similarity scores, and identify the most relevant matches. For a specific query, you can retrieve the top-k most similar images by selecting the ones with the highest similarity scores.

```python
import torch
from PIL import Image

from transformers import ColPaliForRetrieval, ColPaliProcessor

model_name = "vidore/colpali-v1.2-hf"

model = ColPaliForRetrieval.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="cuda:0", # or "mps" if on Apple Silicon
).eval()

processor = ColPaliProcessor.from_pretrained(model_name)

# Your inputs (replace dummy images with screenshots of your documents)
images = [
Image.new("RGB", (32, 32), color="white"),
Image.new("RGB", (16, 16), color="black"),
]
queries = [
"What is the organizational structure for our R&D department?",
"Can you provide a breakdown of last year’s financial performance?",
]

# Process the inputs
batch_images = processor(images=images).to(model.device)
batch_queries = processor(text=queries).to(model.device)

# Forward pass
with torch.no_grad():
image_embeddings = model(**batch_images)
query_embeddings = model(**batch_queries)

# Score the queries against the images
scores = processor.score_retrieval(query_embeddings, image_embeddings)
```

## ColPaliConfig

[[autodoc]] ColPaliConfig

## ColPaliProcessor

[[autodoc]] ColPaliProcessor

## ColPaliForRetrieval

[[autodoc]] ColPaliForRetrieval
- forward
20 changes: 20 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,10 @@
],
"models.cohere": ["CohereConfig"],
"models.cohere2": ["Cohere2Config"],
"models.colpali": [
"ColPaliConfig",
"ColPaliProcessor",
],
"models.conditional_detr": ["ConditionalDetrConfig"],
"models.convbert": [
"ConvBertConfig",
Expand Down Expand Up @@ -1468,6 +1472,7 @@
"MODEL_FOR_OBJECT_DETECTION_MAPPING",
"MODEL_FOR_PRETRAINING_MAPPING",
"MODEL_FOR_QUESTION_ANSWERING_MAPPING",
"MODEL_FOR_RETRIEVAL_MAPPING",
"MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING",
"MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING",
"MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
Expand Down Expand Up @@ -1789,6 +1794,12 @@
)
_import_structure["models.cohere"].extend(["CohereForCausalLM", "CohereModel", "CoherePreTrainedModel"])
_import_structure["models.cohere2"].extend(["Cohere2ForCausalLM", "Cohere2Model", "Cohere2PreTrainedModel"])
_import_structure["models.colpali"].extend(
[
"ColPaliForRetrieval",
"ColPaliPreTrainedModel",
]
)
_import_structure["models.conditional_detr"].extend(
[
"ConditionalDetrForObjectDetection",
Expand Down Expand Up @@ -5207,6 +5218,10 @@
)
from .models.cohere import CohereConfig
from .models.cohere2 import Cohere2Config
from .models.colpali import (
ColPaliConfig,
ColPaliProcessor,
)
from .models.conditional_detr import (
ConditionalDetrConfig,
)
Expand Down Expand Up @@ -6413,6 +6428,7 @@
MODEL_FOR_OBJECT_DETECTION_MAPPING,
MODEL_FOR_PRETRAINING_MAPPING,
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_RETRIEVAL_MAPPING,
MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
Expand Down Expand Up @@ -6689,6 +6705,10 @@
Cohere2Model,
Cohere2PreTrainedModel,
)
from .models.colpali import (
ColPaliForRetrieval,
ColPaliPreTrainedModel,
)
from .models.conditional_detr import (
ConditionalDetrForObjectDetection,
ConditionalDetrForSegmentation,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
codegen,
cohere,
cohere2,
colpali,
conditional_detr,
convbert,
convnext,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
"MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING",
"MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING",
"MODEL_FOR_VISION_2_SEQ_MAPPING",
"MODEL_FOR_RETRIEVAL_MAPPING",
"MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING",
"MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING",
"MODEL_MAPPING",
Expand Down Expand Up @@ -252,6 +253,7 @@
MODEL_FOR_OBJECT_DETECTION_MAPPING,
MODEL_FOR_PRETRAINING_MAPPING,
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_RETRIEVAL_MAPPING,
MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
("codegen", "CodeGenConfig"),
("cohere", "CohereConfig"),
("cohere2", "Cohere2Config"),
("colpali", "ColPaliConfig"),
("conditional_detr", "ConditionalDetrConfig"),
("convbert", "ConvBertConfig"),
("convnext", "ConvNextConfig"),
Expand Down Expand Up @@ -373,6 +374,7 @@
("codegen", "CodeGen"),
("cohere", "Cohere"),
("cohere2", "Cohere2"),
("colpali", "ColPali"),
("conditional_detr", "Conditional DETR"),
("convbert", "ConvBERT"),
("convnext", "ConvNeXT"),
Expand Down
8 changes: 8 additions & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,7 @@
("big_bird", "BigBirdForPreTraining"),
("bloom", "BloomForCausalLM"),
("camembert", "CamembertForMaskedLM"),
("colpali", "ColPaliForRetrieval"),
("ctrl", "CTRLLMHeadModel"),
("data2vec-text", "Data2VecTextForMaskedLM"),
("deberta", "DebertaForMaskedLM"),
Expand Down Expand Up @@ -775,6 +776,12 @@
]
)

MODEL_FOR_RETRIEVAL_MAPPING_NAMES = OrderedDict(
[
("colpali", "ColPaliForRetrieval"),
]
)

MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = OrderedDict(
[
("aria", "AriaForConditionalGeneration"),
Expand Down Expand Up @@ -1473,6 +1480,7 @@
MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES
)
MODEL_FOR_RETRIEVAL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_RETRIEVAL_MAPPING_NAMES)
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES
)
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/processing_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
("clip", "CLIPProcessor"),
("clipseg", "CLIPSegProcessor"),
("clvp", "ClvpProcessor"),
("colpali", "ColPaliProcessor"),
("flava", "FlavaProcessor"),
("fuyu", "FuyuProcessor"),
("git", "GitProcessor"),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@
("codegen", ("CodeGenTokenizer", "CodeGenTokenizerFast" if is_tokenizers_available() else None)),
("cohere", (None, "CohereTokenizerFast" if is_tokenizers_available() else None)),
("cohere2", (None, "CohereTokenizerFast" if is_tokenizers_available() else None)),
("colpali", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
("convbert", ("ConvBertTokenizer", "ConvBertTokenizerFast" if is_tokenizers_available() else None)),
(
"cpm",
Expand Down
28 changes: 28 additions & 0 deletions src/transformers/models/colpali/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING

from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure


if TYPE_CHECKING:
from .configuration_colpali import *
from .modeling_colpali import *
from .processing_colpali import *
else:
import sys

_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
106 changes: 106 additions & 0 deletions src/transformers/models/colpali/configuration_colpali.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ColPali model configuration"""

import logging
from copy import deepcopy

from ...configuration_utils import PretrainedConfig
from ..auto import CONFIG_MAPPING, AutoConfig


logger = logging.getLogger(__name__)


class ColPaliConfig(PretrainedConfig):
r"""
Configuration class to store the configuration of a [`ColPaliForRetrieval`]. It is used to instantiate an instance
of `ColPaliForRetrieval` according to the specified arguments, defining the model architecture following the methodology
from the "ColPali: Efficient Document Retrieval with Vision Language Models" paper.
Creating a configuration with the default settings will result in a configuration where the VLM backbone is set to the
default PaliGemma configuration, i.e the one from [vidore/colpali-v1.2](https://huggingface.co/vidore/colpali-v1.2).
The ColPali config is very similar to [`PaligemmaConfig`], but with an extra attribute defining the embedding dimension.
Note that contrarily to what the class name suggests (actually the name refers to the ColPali **methodology**), you can
use a different VLM backbone model than PaliGemma by passing the corresponding VLM configuration to the class constructor.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vlm_config (`PretrainedConfig`, *optional*):
Configuration of the VLM backbone model.
text_config (`PretrainedConfig`, *optional*):
Configuration of the text backbone model. Overrides the `text_config` attribute of the `vlm_config` if provided.
embedding_dim (`int`, *optional*, defaults to 128):
Dimension of the multi-vector embeddings produced by the model.
Example:
```python
from transformers.models.colpali import ColPaliConfig, ColPaliForRetrieval
config = ColPaliConfig()
model = ColPaliForRetrieval(config)
```
"""

model_type = "colpali"
sub_configs = {"vlm_config": PretrainedConfig, "text_config": AutoConfig}

def __init__(
self,
vlm_config=None,
text_config=None,
embedding_dim: int = 128,
**kwargs,
):
if vlm_config is None:
vlm_config = CONFIG_MAPPING["paligemma"]()
logger.info(
"`vlm_config` is `None`. Initializing `vlm_config` with the `PaliGemmaConfig` with default values."
)
elif isinstance(vlm_config, dict):
vlm_config = deepcopy(vlm_config)
if "model_type" not in vlm_config:
raise KeyError(
"The `model_type` key is missing in the `vlm_config` dictionary. Please provide the model type."
)
elif vlm_config["model_type"] not in CONFIG_MAPPING:
raise ValueError(
f"The model type `{vlm_config['model_type']}` is not supported. Please provide a valid model type."
)
vlm_config = CONFIG_MAPPING[vlm_config["model_type"]](**vlm_config)
elif isinstance(vlm_config, PretrainedConfig):
vlm_config = vlm_config
else:
raise TypeError(
f"Invalid type for `vlm_config`. Expected `PretrainedConfig`, `dict`, or `None`, but got {type(vlm_config)}."
)

self.vlm_config = vlm_config
self.text_config = text_config = text_config if text_config is not None else vlm_config.text_config
if isinstance(self.text_config, dict):
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "gemma"
self.text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)

self.embedding_dim = embedding_dim

super().__init__(**kwargs)


__all__ = ["ColPaliConfig"]
Loading

0 comments on commit f33a0ce

Please sign in to comment.