diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index 0d2b752d5ad93f..3521d4ccfed894 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -657,6 +657,8 @@
title: GLPN
- local: model_doc/hiera
title: Hiera
+ - local: model_doc/ijepa
+ title: I-JEPA
- local: model_doc/imagegpt
title: ImageGPT
- local: model_doc/levit
diff --git a/docs/source/en/index.md b/docs/source/en/index.md
index 8a9ccf45b69c26..3cad4e663f23fd 100644
--- a/docs/source/en/index.md
+++ b/docs/source/en/index.md
@@ -168,6 +168,7 @@ Flax), PyTorch, and/or TensorFlow.
| [Hiera](model_doc/hiera) | ✅ | ❌ | ❌ |
| [Hubert](model_doc/hubert) | ✅ | ✅ | ❌ |
| [I-BERT](model_doc/ibert) | ✅ | ❌ | ❌ |
+| [I-JEPA](model_doc/ijepa) | ✅ | ❌ | ❌ |
| [IDEFICS](model_doc/idefics) | ✅ | ✅ | ❌ |
| [Idefics2](model_doc/idefics2) | ✅ | ❌ | ❌ |
| [Idefics3](model_doc/idefics3) | ✅ | ❌ | ❌ |
diff --git a/docs/source/en/model_doc/ijepa.md b/docs/source/en/model_doc/ijepa.md
new file mode 100644
index 00000000000000..9a0cd368a8188f
--- /dev/null
+++ b/docs/source/en/model_doc/ijepa.md
@@ -0,0 +1,78 @@
+
+
+# I-JEPA
+
+## Overview
+
+The I-JEPA model was proposed in [Image-based Joint-Embedding Predictive Architecture](https://arxiv.org/pdf/2301.08243.pdf) by Mahmoud Assran, Quentin Duval, Ishan Misra, Piotr Bojanowski, Pascal Vincent, Michael Rabbat, Yann LeCun, Nicolas Ballas.
+I-JEPA is a self-supervised learning method that predicts the representations of one part of an image based on other parts of the same image. This approach focuses on learning semantic features without relying on pre-defined invariances from hand-crafted data transformations, which can bias specific tasks, or on filling in pixel-level details, which often leads to less meaningful representations.
+
+The abstract from the paper is the following:
+
+This paper demonstrates an approach for learning highly semantic image representations without relying on hand-crafted data-augmentations. We introduce the Image- based Joint-Embedding Predictive Architecture (I-JEPA), a non-generative approach for self-supervised learning from images. The idea behind I-JEPA is simple: from a single context block, predict the representations of various target blocks in the same image. A core design choice to guide I-JEPA towards producing semantic representations is the masking strategy; specifically, it is crucial to (a) sample tar- get blocks with sufficiently large scale (semantic), and to (b) use a sufficiently informative (spatially distributed) context block. Empirically, when combined with Vision Transform- ers, we find I-JEPA to be highly scalable. For instance, we train a ViT-Huge/14 on ImageNet using 16 A100 GPUs in under 72 hours to achieve strong downstream performance across a wide range of tasks, from linear classification to object counting and depth prediction.
+
+This model was contributed by [jmtzt](https://huggingface.co/jmtzt).
+The original code can be found [here](https://github.com/facebookresearch/ijepa).
+
+## How to use
+
+Here is how to use this model for image feature extraction:
+
+```python
+import requests
+import torch
+from PIL import Image
+from torch.nn.functional import cosine_similarity
+
+from transformers import AutoModel, AutoProcessor
+
+url_1 = "http://images.cocodataset.org/val2017/000000039769.jpg"
+url_2 = "http://images.cocodataset.org/val2017/000000219578.jpg"
+image_1 = Image.open(requests.get(url_1, stream=True).raw)
+image_2 = Image.open(requests.get(url_2, stream=True).raw)
+
+model_id = "jmtzt/ijepa_vith14_1k"
+processor = AutoProcessor.from_pretrained(model_id)
+model = AutoModel.from_pretrained(model_id)
+
+@torch.no_grad()
+def infer(image):
+ inputs = processor(image, return_tensors="pt")
+ outputs = model(**inputs)
+ return outputs.last_hidden_state.mean(dim=1)
+
+
+embed_1 = infer(image_1)
+embed_2 = infer(image_2)
+
+similarity = cosine_similarity(embed_1, embed_2)
+print(similarity)
+```
+
+## IJepaConfig
+
+[[autodoc]] IJepaConfig
+
+## IJepaModel
+
+[[autodoc]] IJepaModel
+ - forward
+
+## IJepaForImageClassification
+
+[[autodoc]] IJepaForImageClassification
+ - forward
diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md
index 12f492ff29a5ee..ec8dea2735b531 100644
--- a/docs/source/en/perf_infer_gpu_one.md
+++ b/docs/source/en/perf_infer_gpu_one.md
@@ -235,6 +235,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
* [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel)
* [Gemma2](https://huggingface.co/docs/transformers/model_doc/gemma2#transformers.Gemma2Model)
+* [Granite](https://huggingface.co/docs/transformers/model_doc/granite#transformers.GraniteModel)
* [GPT2](https://huggingface.co/docs/transformers/model_doc/gpt2)
* [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel)
* [GPTNeoX](https://huggingface.co/docs/transformers/model_doc/gpt_neox#transformers.GPTNeoXModel)
@@ -242,7 +243,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [Idefics](https://huggingface.co/docs/transformers/model_doc/idefics#transformers.IdeficsModel)
* [Idefics2](https://huggingface.co/docs/transformers/model_doc/idefics2#transformers.Idefics2Model)
* [Idefics3](https://huggingface.co/docs/transformers/model_doc/idefics3#transformers.Idefics3Model)
-* [Granite](https://huggingface.co/docs/transformers/model_doc/granite#transformers.GraniteModel)
+* [I-JEPA](https://huggingface.co/docs/transformers/model_doc/ijepa#transformers.IJepaModel)
* [GraniteMoe](https://huggingface.co/docs/transformers/model_doc/granitemoe#transformers.GraniteMoeModel)
* [JetMoe](https://huggingface.co/docs/transformers/model_doc/jetmoe#transformers.JetMoeModel)
* [Jamba](https://huggingface.co/docs/transformers/model_doc/jamba#transformers.JambaModel)
diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py
index e1ca1956807318..625936a45869c8 100755
--- a/src/transformers/__init__.py
+++ b/src/transformers/__init__.py
@@ -485,6 +485,7 @@
"models.idefics": ["IdeficsConfig"],
"models.idefics2": ["Idefics2Config"],
"models.idefics3": ["Idefics3Config"],
+ "models.ijepa": ["IJepaConfig"],
"models.imagegpt": ["ImageGPTConfig"],
"models.informer": ["InformerConfig"],
"models.instructblip": [
@@ -2462,6 +2463,13 @@
"Idefics3Processor",
]
)
+ _import_structure["models.ijepa"].extend(
+ [
+ "IJepaForImageClassification",
+ "IJepaModel",
+ "IJepaPreTrainedModel",
+ ]
+ )
_import_structure["models.imagegpt"].extend(
[
"ImageGPTForCausalImageModeling",
@@ -5368,6 +5376,7 @@
)
from .models.idefics2 import Idefics2Config
from .models.idefics3 import Idefics3Config
+ from .models.ijepa import IJepaConfig
from .models.imagegpt import ImageGPTConfig
from .models.informer import InformerConfig
from .models.instructblip import (
@@ -7181,6 +7190,11 @@
Idefics3PreTrainedModel,
Idefics3Processor,
)
+ from .models.ijepa import (
+ IJepaForImageClassification,
+ IJepaModel,
+ IJepaPreTrainedModel,
+ )
from .models.imagegpt import (
ImageGPTForCausalImageModeling,
ImageGPTForImageClassification,
diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py
index 2d2a3b41d4378b..e957d802d80e71 100644
--- a/src/transformers/models/__init__.py
+++ b/src/transformers/models/__init__.py
@@ -117,6 +117,7 @@
idefics,
idefics2,
idefics3,
+ ijepa,
imagegpt,
informer,
instructblip,
diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py
index 4ab6d392282657..c1f2d689df7095 100644
--- a/src/transformers/models/auto/configuration_auto.py
+++ b/src/transformers/models/auto/configuration_auto.py
@@ -135,6 +135,7 @@
("idefics", "IdeficsConfig"),
("idefics2", "Idefics2Config"),
("idefics3", "Idefics3Config"),
+ ("ijepa", "IJepaConfig"),
("imagegpt", "ImageGPTConfig"),
("informer", "InformerConfig"),
("instructblip", "InstructBlipConfig"),
@@ -440,6 +441,7 @@
("idefics", "IDEFICS"),
("idefics2", "Idefics2"),
("idefics3", "Idefics3"),
+ ("ijepa", "I-JEPA"),
("imagegpt", "ImageGPT"),
("informer", "Informer"),
("instructblip", "InstructBLIP"),
diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py
index 11ae15ca461e79..e19c8efd205552 100644
--- a/src/transformers/models/auto/image_processing_auto.py
+++ b/src/transformers/models/auto/image_processing_auto.py
@@ -90,6 +90,7 @@
("idefics", ("IdeficsImageProcessor",)),
("idefics2", ("Idefics2ImageProcessor",)),
("idefics3", ("Idefics3ImageProcessor",)),
+ ("ijepa", ("ViTImageProcessor", "ViTImageProcessorFast")),
("imagegpt", ("ImageGPTImageProcessor",)),
("instructblip", ("BlipImageProcessor",)),
("instructblipvideo", ("InstructBlipVideoImageProcessor",)),
@@ -433,7 +434,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
if image_processor_class is None and image_processor_auto_map is None:
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(
- pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
+ pretrained_model_name_or_path,
+ trust_remote_code=trust_remote_code,
+ **kwargs,
)
# It could be in `config.image_processor_type``
image_processor_class = getattr(config, "image_processor_type", None)
diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py
index 2c519a7dc42ca5..7a7cd9d475884c 100644
--- a/src/transformers/models/auto/modeling_auto.py
+++ b/src/transformers/models/auto/modeling_auto.py
@@ -132,6 +132,7 @@
("idefics", "IdeficsModel"),
("idefics2", "Idefics2Model"),
("idefics3", "Idefics3Model"),
+ ("ijepa", "IJepaModel"),
("imagegpt", "ImageGPTModel"),
("informer", "InformerModel"),
("jamba", "JambaModel"),
@@ -578,6 +579,7 @@
("focalnet", "FocalNetModel"),
("glpn", "GLPNModel"),
("hiera", "HieraModel"),
+ ("ijepa", "IJepaModel"),
("imagegpt", "ImageGPTModel"),
("levit", "LevitModel"),
("mllama", "MllamaVisionModel"),
@@ -655,6 +657,7 @@
("efficientnet", "EfficientNetForImageClassification"),
("focalnet", "FocalNetForImageClassification"),
("hiera", "HieraForImageClassification"),
+ ("ijepa", "IJepaForImageClassification"),
("imagegpt", "ImageGPTForImageClassification"),
(
"levit",
diff --git a/src/transformers/models/ijepa/__init__.py b/src/transformers/models/ijepa/__init__.py
new file mode 100644
index 00000000000000..efc8c90b17628d
--- /dev/null
+++ b/src/transformers/models/ijepa/__init__.py
@@ -0,0 +1,55 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_torch_available,
+)
+
+
+_import_structure = {"configuration_ijepa": ["IJepaConfig"]}
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_ijepa"] = [
+ "IJepaForImageClassification",
+ "IJepaModel",
+ "IJepaPreTrainedModel",
+ ]
+
+if TYPE_CHECKING:
+ from .configuration_ijepa import IJepaConfig
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_ijepa import (
+ IJepaForImageClassification,
+ IJepaModel,
+ IJepaPreTrainedModel,
+ )
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/ijepa/configuration_ijepa.py b/src/transformers/models/ijepa/configuration_ijepa.py
new file mode 100644
index 00000000000000..26378e6e81d9ce
--- /dev/null
+++ b/src/transformers/models/ijepa/configuration_ijepa.py
@@ -0,0 +1,108 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""I-JEPA model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+
+
+class IJepaConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`IJepaModel`]. It is used to instantiate an IJEPA
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the I-JEPA
+ [google/ijepa-base-patch16-224](https://huggingface.co/google/ijepa-base-patch16-224) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ intermediate_size (`int`, *optional*, defaults to 3072):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+ The epsilon used by the layer normalization layers.
+ image_size (`int`, *optional*, defaults to 224):
+ The size (resolution) of each image.
+ patch_size (`int`, *optional*, defaults to 16):
+ The size (resolution) of each patch.
+ num_channels (`int`, *optional*, defaults to 3):
+ The number of input channels.
+ qkv_bias (`bool`, *optional*, defaults to `True`):
+ Whether to add a bias to the queries, keys and values.
+
+ Example:
+
+ ```python
+ >>> from transformers import IJepaConfig, IJepaModel
+
+ >>> # Initializing a IJEPA ijepa-base-patch16-224 style configuration
+ >>> configuration = IJepaConfig()
+
+ >>> # Initializing a model (with random weights) from the ijepa-base-patch16-224 style configuration
+ >>> model = IJepaModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "ijepa"
+
+ def __init__(
+ self,
+ hidden_size=768,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ intermediate_size=3072,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.0,
+ attention_probs_dropout_prob=0.0,
+ initializer_range=0.02,
+ layer_norm_eps=1e-12,
+ image_size=224,
+ patch_size=16,
+ num_channels=3,
+ qkv_bias=True,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.qkv_bias = qkv_bias
diff --git a/src/transformers/models/ijepa/convert_ijepa_to_hf.py b/src/transformers/models/ijepa/convert_ijepa_to_hf.py
new file mode 100644
index 00000000000000..5c15a72ff88847
--- /dev/null
+++ b/src/transformers/models/ijepa/convert_ijepa_to_hf.py
@@ -0,0 +1,267 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Convert IJEPA checkpoints from the original repository.
+
+URL: https://github.com/facebookresearch/ijepa
+"""
+
+import argparse
+import gc
+import re
+from pathlib import Path
+
+import requests
+import torch
+from PIL import Image
+
+from transformers import (
+ IJepaConfig,
+ IJepaModel,
+ ViTImageProcessor,
+)
+from transformers.utils import logging
+
+
+logging.set_verbosity_info()
+logger = logging.get_logger(__name__)
+
+# fmt: off
+ORIGINAL_TO_CONVERTED_KEY_MAPPING = {
+ # Projection layer + position embeddings
+ r"pos_embed": r"embeddings.position_embeddings",
+ r"patch_embed.proj.weight": r"embeddings.patch_embeddings.projection.weight",
+ r"patch_embed.proj.bias": r"embeddings.patch_embeddings.projection.bias",
+
+ # Encoder layers: Layernorms, Attention, Feedforward layers
+ r"blocks.(\d+).norm1.weight": r"encoder.layer.\1.layernorm_before.weight",
+ r"blocks.(\d+).norm1.bias": r"encoder.layer.\1.layernorm_before.bias",
+ r"blocks.(\d+).attn.proj.weight": r"encoder.layer.\1.attention.output.dense.weight",
+ r"blocks.(\d+).attn.proj.bias": r"encoder.layer.\1.attention.output.dense.bias",
+ r"blocks.(\d+).norm2.weight": r"encoder.layer.\1.layernorm_after.weight",
+ r"blocks.(\d+).norm2.bias": r"encoder.layer.\1.layernorm_after.bias",
+ r"blocks.(\d+).mlp.fc1.weight": r"encoder.layer.\1.intermediate.dense.weight",
+ r"blocks.(\d+).mlp.fc1.bias": r"encoder.layer.\1.intermediate.dense.bias",
+ r"blocks.(\d+).mlp.fc2.weight": r"encoder.layer.\1.output.dense.weight",
+ r"blocks.(\d+).mlp.fc2.bias": r"encoder.layer.\1.output.dense.bias",
+
+ # Layernorm + pooler
+ r"norm.weight": r"layernorm.weight",
+ r"norm.bias": r"layernorm.bias",
+}
+# fmt: on
+
+
+def convert_old_keys_to_new_keys(state_dict_keys: dict = None):
+ """
+ Converts old keys to new keys using the mapping and dynamically removes the 'ijepa.' prefix if necessary.
+
+ Args:
+ state_dict_keys (dict): The keys from the state_dict to convert.
+
+ Returns:
+ dict: A mapping from old keys to new keys.
+ """
+ output_dict = {}
+ if state_dict_keys is not None:
+ old_text = "\n".join(state_dict_keys)
+ new_text = old_text
+
+ # Apply regex-based mapping
+ for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING.items():
+ if replacement is None:
+ new_text = re.sub(pattern, "", new_text) # Skip the key
+ continue
+ new_text = re.sub(pattern, replacement, new_text)
+
+ output_dict = dict(zip(old_text.split("\n"), new_text.split("\n")))
+
+ return output_dict
+
+
+# we split up the matrix of each encoder layer into queries, keys and values
+def read_in_q_k_v(state_dict, config):
+ for i in range(config.num_hidden_layers):
+ # read in weights + bias of input projection layer (in timm, this is a single matrix + bias)
+ in_proj_weight = state_dict.pop(f"blocks.{i}.attn.qkv.weight")
+ in_proj_bias = state_dict.pop(f"blocks.{i}.attn.qkv.bias")
+ # next, add query, keys and values (in that order) to the state dict
+ state_dict[f"encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[: config.hidden_size, :]
+ state_dict[f"encoder.layer.{i}.attention.attention.query.bias"] = in_proj_bias[: config.hidden_size]
+ state_dict[f"encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
+ config.hidden_size : config.hidden_size * 2, :
+ ]
+ state_dict[f"encoder.layer.{i}.attention.attention.key.bias"] = in_proj_bias[
+ config.hidden_size : config.hidden_size * 2
+ ]
+ state_dict[f"encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[-config.hidden_size :, :]
+ state_dict[f"encoder.layer.{i}.attention.attention.value.bias"] = in_proj_bias[-config.hidden_size :]
+
+
+def rename_key(dct, old, new):
+ val = dct.pop(old)
+ dct[new] = val
+
+
+# We will verify our results on an image of cute cats
+def prepare_img():
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ im = Image.open(requests.get(url, stream=True).raw)
+ return im
+
+
+def get_ijepa_config(model_name):
+ patch_size = int(model_name.split("_")[1][4:])
+ config = IJepaConfig(patch_size=patch_size)
+ if "vith" in model_name:
+ config.hidden_size = 1280
+ config.num_hidden_layers = 32
+ config.num_attention_heads = 16
+ config.layer_norm_eps = 1e-6
+ config.mlp_ratio = 4
+ config.intermediate_size = 5120
+ if model_name == "ijepa_vith16_1k":
+ config.image_size = 448
+ elif "vitg" in model_name:
+ config.hidden_size = 1408
+ config.num_hidden_layers = 40
+ config.num_attention_heads = 16
+ config.layer_norm_eps = 1e-6
+ config.mlp_ratio = 48 / 11
+ config.intermediate_size = 6144
+ else:
+ raise ValueError("Model not supported, only supports huge and giant models.")
+ return config
+
+
+@torch.no_grad()
+def write_model(model_name, output_dir, safe_serialization, push_to_hub, verify_logits):
+ """
+ Copy/paste/tweak model's weights to our IJEPA structure.
+ """
+
+ # define default IJEPA configuration
+ config = get_ijepa_config(model_name)
+
+ checkpoint_mapping = {
+ "ijepa_vith14_1k": "https://dl.fbaipublicfiles.com/ijepa/IN1K-vit.h.14-300e.pth.tar",
+ "ijepa_vith14_22k": "https://dl.fbaipublicfiles.com/ijepa/IN22K-vit.h.14-900e.pth.tar",
+ "ijepa_vith16_1k": "https://dl.fbaipublicfiles.com/ijepa/IN1K-vit.h.16-448px-300e.pth.tar",
+ "ijepa_vitg16_22k": "https://dl.fbaipublicfiles.com/ijepa/IN22K-vit.g.16-600e.pth.tar",
+ }
+
+ # Load original checkpoint
+ checkpoint_url = checkpoint_mapping[model_name]
+ original_state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu")["encoder"]
+ original_state_dict = {k.replace("module.", ""): v for k, v in original_state_dict.items()}
+
+ # Rename keys
+ state_dict = original_state_dict.copy()
+ new_keys = convert_old_keys_to_new_keys(state_dict.keys())
+ for old_key, new_key in new_keys.items():
+ rename_key(state_dict, old_key, new_key)
+ read_in_q_k_v(state_dict, config)
+
+ # load HuggingFace model
+ model = IJepaModel(config, add_pooling_layer=False).eval()
+ model.load_state_dict(state_dict)
+ size = {"height": config.image_size, "width": config.image_size}
+ image_processor = ViTImageProcessor(size=size)
+
+ if verify_logits:
+ # Check outputs on an image, prepared by ViTImageProcessor
+ encoding = image_processor(images=prepare_img(), return_tensors="pt")
+ pixel_values = encoding["pixel_values"]
+ with torch.no_grad():
+ outputs = model(pixel_values)
+
+ expected_slices = {
+ "ijepa_vith14_1k": torch.Tensor(
+ [[-0.0621, -0.0054, -2.7513], [-0.1952, 0.0909, -3.9536], [0.0942, -0.0331, -1.2833]]
+ ),
+ "ijepa_vith14_22k": torch.Tensor(
+ [[0.0358, -0.0045, -0.2154], [0.0418, -0.0246, 0.0108], [0.2529, -0.0345, -0.0246]]
+ ),
+ "ijepa_vith16_1k": torch.Tensor(
+ [[0.5145, -0.1259, 0.0615], [0.1132, 0.0028, -0.0496], [1.1586, -0.0056, -0.0387]]
+ ),
+ "ijepa_vitg16_22k": torch.Tensor(
+ [[0.0512, -0.0510, -0.0649], [0.1972, 0.0380, -0.0790], [0.1667, -0.0834, -0.1240]]
+ ),
+ }
+
+ assert torch.allclose(
+ expected_slices[model_name],
+ outputs.last_hidden_state[0, :3, :3],
+ atol=1e-4,
+ )
+
+ if output_dir:
+ Path(output_dir).mkdir(exist_ok=True)
+ print(f"Saving model {model_name} to {output_dir}")
+ image_processor.save_pretrained(output_dir, safe_serialization=safe_serialization)
+ model.save_pretrained(output_dir, safe_serialization=safe_serialization)
+
+ if push_to_hub:
+ image_processor.push_to_hub(repo_id=f"jmtzt/{model_name}", safe_serialization=safe_serialization)
+ model.push_to_hub(repo_id=f"jmtzt/{model_name}", safe_serialization=safe_serialization)
+
+ if output_dir:
+ del model, state_dict
+ gc.collect()
+ print("Reloading the model to check if it's saved correctly.")
+ IJepaModel.from_pretrained(output_dir, device_map="auto")
+ print("Model reloaded successfully.")
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ # Required parameters
+ parser.add_argument(
+ "--model_name",
+ default="ijepa_vith14_1k",
+ type=str,
+ choices=[
+ "ijepa_vith14_1k",
+ "ijepa_vith14_22k",
+ "ijepa_vith16_1k",
+ "ijepa_vitg16_22k",
+ ],
+ help="Name of the model you'd like to convert.",
+ )
+ parser.add_argument(
+ "--output_dir",
+ default=None,
+ type=str,
+ help="Path to the output PyTorch model directory.",
+ )
+ parser.add_argument(
+ "--safe_serialization", default=True, type=bool, help="Whether or not to save using `safetensors`."
+ )
+ parser.add_argument(
+ "--push_to_hub",
+ action="store_true",
+ help="Whether or not to push the model to the 🤗 Hub.",
+ )
+ parser.add_argument(
+ "--verify_logits", action="store_false", help="Whether or not to verify logits after conversion."
+ )
+
+ parser.set_defaults()
+ args = parser.parse_args()
+ write_model(args.model_name, args.output_dir, args.safe_serialization, args.push_to_hub, args.verify_logits)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/transformers/models/ijepa/modeling_ijepa.py b/src/transformers/models/ijepa/modeling_ijepa.py
new file mode 100644
index 00000000000000..df254455bad5ab
--- /dev/null
+++ b/src/transformers/models/ijepa/modeling_ijepa.py
@@ -0,0 +1,751 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/ijepa/modular_ijepa.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_ijepa.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+import collections.abc
+import math
+from typing import Dict, List, Optional, Set, Tuple, Union
+
+import torch
+import torch.nn as nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from ...activations import ACT2FN
+from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
+from ...modeling_utils import PreTrainedModel
+from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
+from ...utils import (
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ torch_int,
+)
+from .configuration_ijepa import IJepaConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+_CHECKPOINT_FOR_DOC = "facebook/ijepa_vith14_1k"
+
+# General docstring
+_CONFIG_FOR_DOC = "IJepaConfig"
+
+
+class IJepaPatchEmbeddings(nn.Module):
+ """
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
+ Transformer.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ image_size, patch_size = config.image_size, config.patch_size
+ num_channels, hidden_size = config.num_channels, config.hidden_size
+
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.num_patches = num_patches
+
+ self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
+
+ def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
+ batch_size, num_channels, height, width = pixel_values.shape
+ if num_channels != self.num_channels:
+ raise ValueError(
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+ f" Expected {self.num_channels} but got {num_channels}."
+ )
+ if not interpolate_pos_encoding:
+ if height != self.image_size[0] or width != self.image_size[1]:
+ raise ValueError(
+ f"Input image size ({height}*{width}) doesn't match model"
+ f" ({self.image_size[0]}*{self.image_size[1]})."
+ )
+ embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
+ return embeddings
+
+
+class IJepaEmbeddings(nn.Module):
+ """
+ Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
+ """
+
+ def __init__(self, config: IJepaConfig, use_mask_token: bool = False) -> None:
+ super().__init__()
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
+ self.patch_embeddings = IJepaPatchEmbeddings(config)
+ num_patches = self.patch_embeddings.num_patches
+ self.position_embeddings = nn.Parameter(torch.randn(1, num_patches, config.hidden_size))
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.patch_size = config.patch_size
+ self.config = config
+
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
+ """
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
+ images. This method is also adapted to support torch.jit tracing.
+
+ Adapted from:
+ - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
+ - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
+ """
+
+ num_patches = embeddings.shape[1]
+ num_positions = self.position_embeddings.shape[1]
+
+ # always interpolate when tracing to ensure the exported model works for dynamic input shapes
+ if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
+ return self.position_embeddings
+
+ patch_pos_embed = self.position_embeddings
+
+ dim = embeddings.shape[-1]
+
+ new_height = height // self.patch_size
+ new_width = width // self.patch_size
+
+ sqrt_num_positions = torch_int(num_positions**0.5)
+ patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
+
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed,
+ size=(new_height, new_width),
+ mode="bicubic",
+ align_corners=False,
+ )
+
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+
+ return patch_pos_embed
+
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
+ interpolate_pos_encoding: bool = False,
+ ) -> torch.Tensor:
+ batch_size, _, height, width = pixel_values.shape
+ embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
+
+ if bool_masked_pos is not None:
+ seq_length = embeddings.shape[1]
+ mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
+ # replace the masked visual tokens by mask_tokens
+ mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
+ embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
+
+ # add positional encoding to each token
+ if interpolate_pos_encoding:
+ embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
+ else:
+ embeddings = embeddings + self.position_embeddings
+
+ embeddings = self.dropout(embeddings)
+
+ return embeddings
+
+
+class IJepaPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = IJepaConfig
+ base_model_prefix = "ijepa"
+ main_input_name = "pixel_values"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["IJepaEmbeddings", "IJepaLayer"]
+ _supports_sdpa = True
+
+ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
+ # `trunc_normal_cpu` not implemented in `half` issues
+ module.weight.data = nn.init.trunc_normal_(
+ module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
+ ).to(module.weight.dtype)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, IJepaEmbeddings):
+ module.position_embeddings.data = nn.init.trunc_normal_(
+ module.position_embeddings.data.to(torch.float32),
+ mean=0.0,
+ std=self.config.initializer_range,
+ ).to(module.position_embeddings.dtype)
+
+
+class IJepaSelfAttention(nn.Module):
+ def __init__(self, config: IJepaConfig) -> None:
+ super().__init__()
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+ raise ValueError(
+ f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
+ f"heads {config.num_attention_heads}."
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+ self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+ x = x.view(new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward(
+ self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
+ mixed_query_layer = self.query(hidden_states)
+
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
+
+ context_layer = torch.matmul(attention_probs, value_layer)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(new_context_layer_shape)
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+ return outputs
+
+
+class IJepaSdpaSelfAttention(IJepaSelfAttention):
+ def __init__(self, config: IJepaConfig) -> None:
+ super().__init__(config)
+ self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
+ if output_attentions or head_mask is not None:
+ logger.warning_once(
+ "`IJepaSdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
+ "`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but "
+ "specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
+ 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ return super().forward(
+ hidden_states=hidden_states,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ )
+
+ mixed_query_layer = self.query(hidden_states)
+
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ context_layer = torch.nn.functional.scaled_dot_product_attention(
+ query_layer,
+ key_layer,
+ value_layer,
+ head_mask,
+ self.attention_probs_dropout_prob if self.training else 0.0,
+ is_causal=False,
+ scale=None,
+ )
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(new_context_layer_shape)
+
+ return context_layer, None
+
+
+class IJepaSelfOutput(nn.Module):
+ """
+ The residual connection is defined in IJepaLayer instead of here (as is the case with other models), due to the
+ layernorm applied before each block.
+ """
+
+ def __init__(self, config: IJepaConfig) -> None:
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+
+ return hidden_states
+
+
+class IJepaAttention(nn.Module):
+ def __init__(self, config: IJepaConfig) -> None:
+ super().__init__()
+ self.attention = IJepaSelfAttention(config)
+ self.output = IJepaSelfOutput(config)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads: Set[int]) -> None:
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.attention.query = prune_linear_layer(self.attention.query, index)
+ self.attention.key = prune_linear_layer(self.attention.key, index)
+ self.attention.value = prune_linear_layer(self.attention.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
+ self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
+ self_outputs = self.attention(hidden_states, head_mask, output_attentions)
+
+ attention_output = self.output(self_outputs[0], hidden_states)
+
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+
+class IJepaSdpaAttention(IJepaAttention):
+ def __init__(self, config: IJepaConfig) -> None:
+ super().__init__(config)
+ self.attention = IJepaSdpaSelfAttention(config)
+
+
+class IJepaIntermediate(nn.Module):
+ def __init__(self, config: IJepaConfig) -> None:
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+
+ return hidden_states
+
+
+class IJepaOutput(nn.Module):
+ def __init__(self, config: IJepaConfig) -> None:
+ super().__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+
+ hidden_states = hidden_states + input_tensor
+
+ return hidden_states
+
+
+IJEPA_ATTENTION_CLASSES = {
+ "eager": IJepaAttention,
+ "sdpa": IJepaSdpaAttention,
+}
+
+
+class IJepaLayer(nn.Module):
+ """This corresponds to the Block class in the timm implementation."""
+
+ def __init__(self, config: IJepaConfig) -> None:
+ super().__init__()
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.seq_len_dim = 1
+ self.attention = IJEPA_ATTENTION_CLASSES[config._attn_implementation](config)
+ self.intermediate = IJepaIntermediate(config)
+ self.output = IJepaOutput(config)
+ self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
+ self_attention_outputs = self.attention(
+ self.layernorm_before(hidden_states), # in IJepa, layernorm is applied before self-attention
+ head_mask,
+ output_attentions=output_attentions,
+ )
+ attention_output = self_attention_outputs[0]
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
+
+ # first residual connection
+ hidden_states = attention_output + hidden_states
+
+ # in IJepa, layernorm is also applied after self-attention
+ layer_output = self.layernorm_after(hidden_states)
+ layer_output = self.intermediate(layer_output)
+
+ # second residual connection is done here
+ layer_output = self.output(layer_output, hidden_states)
+
+ outputs = (layer_output,) + outputs
+
+ return outputs
+
+
+class IJepaEncoder(nn.Module):
+ def __init__(self, config: IJepaConfig) -> None:
+ super().__init__()
+ self.config = config
+ self.layer = nn.ModuleList([IJepaLayer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ) -> Union[tuple, BaseModelOutput]:
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+
+ for i, layer_module in enumerate(self.layer):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ layer_module.__call__,
+ hidden_states,
+ layer_head_mask,
+ output_attentions,
+ )
+ else:
+ layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+
+class IJepaPooler(nn.Module):
+ def __init__(self, config: IJepaConfig):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.activation = nn.Tanh()
+
+ def forward(self, hidden_states):
+ # We "pool" the model by simply taking the hidden state corresponding
+ # to the first token.
+ first_token_tensor = hidden_states[:, 0]
+ pooled_output = self.dense(first_token_tensor)
+ pooled_output = self.activation(pooled_output)
+ return pooled_output
+
+
+IJEPA_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`IJepaImageProcessor.__call__`]
+ for details.
+
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ interpolate_pos_encoding (`bool`, *optional*):
+ Whether to interpolate the pre-trained position encodings.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+_EXPECTED_OUTPUT_SHAPE = [1, 197, 768]
+
+
+IJEPA_START_DOCSTRING = r"""
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`IJepaConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+@add_start_docstrings(
+ "The bare IJepa Model transformer outputting raw hidden-states without any specific head on top.",
+ IJEPA_START_DOCSTRING,
+)
+class IJepaModel(IJepaPreTrainedModel):
+ def __init__(self, config: IJepaConfig, add_pooling_layer: bool = False, use_mask_token: bool = False):
+ super().__init__(config)
+ self.config = config
+ self.embeddings = IJepaEmbeddings(config, use_mask_token=use_mask_token)
+ self.encoder = IJepaEncoder(config)
+
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.pooler = IJepaPooler(config) if add_pooling_layer else None
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self) -> IJepaPatchEmbeddings:
+ return self.embeddings.patch_embeddings
+
+ def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ @add_start_docstrings_to_model_forward(IJEPA_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=BaseModelOutputWithPooling,
+ config_class=_CONFIG_FOR_DOC,
+ modality="vision",
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
+ )
+ def forward(
+ self,
+ pixel_values: Optional[torch.Tensor] = None,
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ interpolate_pos_encoding: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
+ r"""
+ bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
+ Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ # TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?)
+ expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype
+ if pixel_values.dtype != expected_dtype:
+ pixel_values = pixel_values.to(expected_dtype)
+
+ embedding_output = self.embeddings(
+ pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
+ )
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = encoder_outputs[0]
+ sequence_output = self.layernorm(sequence_output)
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
+
+ if not return_dict:
+ head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
+ return head_outputs + encoder_outputs[1:]
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+# Image classification docstring
+_IMAGE_CLASS_CHECKPOINT = "google/ijepa-base-patch16-224"
+_IMAGE_CLASS_EXPECTED_OUTPUT = "Egyptian cat"
+
+
+@add_start_docstrings(
+ """
+ IJepa Model transformer with an image classification head on top (a linear layer on top of the final hidden states)
+ e.g. for ImageNet.
+
+
+
+ Note that it's possible to fine-tune IJepa on higher resolution images than the ones it has been trained on, by
+ setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
+ position embeddings to the higher resolution.
+
+
+ """,
+ IJEPA_START_DOCSTRING,
+)
+class IJepaForImageClassification(IJepaPreTrainedModel):
+ def __init__(self, config: IJepaConfig) -> None:
+ super().__init__(config)
+
+ self.num_labels = config.num_labels
+ self.ijepa = IJepaModel(config, add_pooling_layer=False)
+
+ # Classifier head
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(IJEPA_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
+ output_type=ImageClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
+ )
+ def forward(
+ self,
+ pixel_values: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ interpolate_pos_encoding: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, ImageClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.ijepa(
+ pixel_values,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ logits = self.classifier(sequence_output.mean(dim=1))
+
+ loss = None
+ if labels is not None:
+ # move labels to correct device to enable model parallelism
+ labels = labels.to(logits.device)
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return ImageClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+__all__ = ["IJepaPreTrainedModel", "IJepaModel", "IJepaForImageClassification"]
diff --git a/src/transformers/models/ijepa/modular_ijepa.py b/src/transformers/models/ijepa/modular_ijepa.py
new file mode 100644
index 00000000000000..efbd71d91342fd
--- /dev/null
+++ b/src/transformers/models/ijepa/modular_ijepa.py
@@ -0,0 +1,255 @@
+from typing import Optional, Union
+
+import torch
+import torch.nn as nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from transformers.models.ijepa.configuration_ijepa import IJepaConfig
+
+from ...modeling_outputs import ImageClassifierOutput
+from ...modeling_utils import PreTrainedModel
+from ...utils import (
+ add_start_docstrings,
+ torch_int,
+)
+from ..vit.modeling_vit import (
+ ViTEmbeddings,
+ ViTForImageClassification,
+ ViTModel,
+)
+
+
+_CHECKPOINT_FOR_DOC = "facebook/ijepa_vith14_1k"
+
+
+class IJepaEmbeddings(ViTEmbeddings):
+ def __init__(self, config: IJepaConfig, use_mask_token: bool = False) -> None:
+ super().__init__(config, use_mask_token)
+ # Remove cls_token from IJepaEmbeddings, as it is not used in the model
+ del self.cls_token
+ num_patches = self.patch_embeddings.num_patches
+ self.position_embeddings = nn.Parameter(torch.randn(1, num_patches, config.hidden_size))
+
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
+ """
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
+ images. This method is also adapted to support torch.jit tracing.
+
+ Adapted from:
+ - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
+ - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
+ """
+
+ num_patches = embeddings.shape[1]
+ num_positions = self.position_embeddings.shape[1]
+
+ # always interpolate when tracing to ensure the exported model works for dynamic input shapes
+ if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
+ return self.position_embeddings
+
+ patch_pos_embed = self.position_embeddings
+
+ dim = embeddings.shape[-1]
+
+ new_height = height // self.patch_size
+ new_width = width // self.patch_size
+
+ sqrt_num_positions = torch_int(num_positions**0.5)
+ patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
+
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed,
+ size=(new_height, new_width),
+ mode="bicubic",
+ align_corners=False,
+ )
+
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+
+ return patch_pos_embed
+
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
+ interpolate_pos_encoding: bool = False,
+ ) -> torch.Tensor:
+ batch_size, _, height, width = pixel_values.shape
+ embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
+
+ if bool_masked_pos is not None:
+ seq_length = embeddings.shape[1]
+ mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
+ # replace the masked visual tokens by mask_tokens
+ mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
+ embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
+
+ # add positional encoding to each token
+ if interpolate_pos_encoding:
+ embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
+ else:
+ embeddings = embeddings + self.position_embeddings
+
+ embeddings = self.dropout(embeddings)
+
+ return embeddings
+
+
+class IJepaPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = IJepaConfig
+ base_model_prefix = "ijepa"
+ main_input_name = "pixel_values"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["IJepaEmbeddings", "IJepaLayer"]
+ _supports_sdpa = True
+
+ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
+ # `trunc_normal_cpu` not implemented in `half` issues
+ module.weight.data = nn.init.trunc_normal_(
+ module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
+ ).to(module.weight.dtype)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, IJepaEmbeddings):
+ module.position_embeddings.data = nn.init.trunc_normal_(
+ module.position_embeddings.data.to(torch.float32),
+ mean=0.0,
+ std=self.config.initializer_range,
+ ).to(module.position_embeddings.dtype)
+
+
+_EXPECTED_OUTPUT_SHAPE = [1, 256, 1280]
+
+IJEPA_START_DOCSTRING = r"""
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`IJepaConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+@add_start_docstrings(
+ "The bare IJepa Model transformer outputting raw hidden-states without any specific head on top.",
+ IJEPA_START_DOCSTRING,
+)
+class IJepaModel(IJepaPreTrainedModel, ViTModel):
+ def __init__(self, config: IJepaConfig, add_pooling_layer: bool = False, use_mask_token: bool = False):
+ super().__init__(config)
+ self.config = config
+ self.embeddings = IJepaEmbeddings(config, use_mask_token=use_mask_token)
+
+
+_IMAGE_CLASS_CHECKPOINT = "jmtzt/ijepa_vith14_1k"
+_IMAGE_CLASS_EXPECTED_OUTPUT = "Egyptian cat"
+
+
+@add_start_docstrings(
+ """
+ IJepa Model transformer with an image classification head on top (a linear layer on top of the final hidden states)
+ e.g. for ImageNet.
+
+
+
+ Note that it's possible to fine-tune IJepa on higher resolution images than the ones it has been trained on, by
+ setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
+ position embeddings to the higher resolution.
+
+
+ """,
+ IJEPA_START_DOCSTRING,
+)
+class IJepaForImageClassification(IJepaPreTrainedModel, ViTForImageClassification):
+ def __init__(self, config: IJepaConfig):
+ super().__init__(config)
+ self.ijepa = IJepaModel(config, add_pooling_layer=False)
+ self.post_init()
+
+ def forward(
+ self,
+ pixel_values: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ interpolate_pos_encoding: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, ImageClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.ijepa(
+ pixel_values,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ logits = self.classifier(sequence_output.mean(dim=1))
+
+ loss = None
+ if labels is not None:
+ # move labels to correct device to enable model parallelism
+ labels = labels.to(logits.device)
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return ImageClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+__all__ = [
+ "IJepaPreTrainedModel",
+ "IJepaModel",
+ "IJepaForImageClassification",
+]
diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py
index 1238f058783c18..d770b83df935a5 100644
--- a/src/transformers/utils/dummy_pt_objects.py
+++ b/src/transformers/utils/dummy_pt_objects.py
@@ -4978,6 +4978,27 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
+class IJepaForImageClassification(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class IJepaModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class IJepaPreTrainedModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
class ImageGPTForCausalImageModeling(metaclass=DummyObject):
_backends = ["torch"]
diff --git a/src/transformers/utils/fx.py b/src/transformers/utils/fx.py
index 3764f1ee4cef76..101b34182a7309 100755
--- a/src/transformers/utils/fx.py
+++ b/src/transformers/utils/fx.py
@@ -140,6 +140,7 @@ def _generate_supported_model_class_names(
"gptj",
"hiera",
"hubert",
+ "ijepa",
"layoutlm",
"llama",
"cohere",
diff --git a/tests/models/ijepa/__init__.py b/tests/models/ijepa/__init__.py
new file mode 100644
index 00000000000000..e69de29bb2d1d6
diff --git a/tests/models/ijepa/test_modeling_ijepa.py b/tests/models/ijepa/test_modeling_ijepa.py
new file mode 100644
index 00000000000000..27a79bc6724285
--- /dev/null
+++ b/tests/models/ijepa/test_modeling_ijepa.py
@@ -0,0 +1,341 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Testing suite for the PyTorch IJEPA model."""
+
+import unittest
+
+from transformers import IJepaConfig
+from transformers.testing_utils import (
+ require_accelerate,
+ require_torch,
+ require_torch_accelerator,
+ require_torch_fp16,
+ require_vision,
+ slow,
+ torch_device,
+)
+from transformers.utils import (
+ cached_property,
+ is_torch_available,
+ is_vision_available,
+)
+
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
+from ...test_pipeline_mixin import PipelineTesterMixin
+
+
+if is_torch_available():
+ import torch
+ from torch import nn
+
+ from transformers import IJepaForImageClassification, IJepaModel
+
+
+if is_vision_available():
+ from PIL import Image
+
+ from transformers import ViTImageProcessor
+
+
+class IJepaModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=13,
+ image_size=30,
+ patch_size=2,
+ num_channels=3,
+ is_training=True,
+ use_labels=True,
+ hidden_size=32,
+ num_hidden_layers=2,
+ num_attention_heads=4,
+ intermediate_size=37,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ type_sequence_label_size=10,
+ initializer_range=0.02,
+ scope=None,
+ encoder_stride=2,
+ mask_ratio=0.5,
+ attn_implementation="eager",
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.is_training = is_training
+ self.use_labels = use_labels
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.type_sequence_label_size = type_sequence_label_size
+ self.initializer_range = initializer_range
+ self.scope = scope
+ self.encoder_stride = encoder_stride
+ self.attn_implementation = attn_implementation
+
+ # in IJEPA, the seq length equals the number of patches (we don't add 1 for the [CLS] token)
+ num_patches = (image_size // patch_size) ** 2
+ self.seq_length = num_patches
+ self.mask_ratio = mask_ratio
+ self.num_masks = int(mask_ratio * self.seq_length)
+ self.mask_length = num_patches
+
+ def prepare_config_and_inputs(self):
+ pixel_values = floats_tensor(
+ [
+ self.batch_size,
+ self.num_channels,
+ self.image_size,
+ self.image_size,
+ ]
+ )
+
+ labels = None
+ if self.use_labels:
+ labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
+
+ config = self.get_config()
+
+ return config, pixel_values, labels
+
+ def get_config(self):
+ return IJepaConfig(
+ image_size=self.image_size,
+ patch_size=self.patch_size,
+ num_channels=self.num_channels,
+ hidden_size=self.hidden_size,
+ num_hidden_layers=self.num_hidden_layers,
+ num_attention_heads=self.num_attention_heads,
+ intermediate_size=self.intermediate_size,
+ hidden_act=self.hidden_act,
+ hidden_dropout_prob=self.hidden_dropout_prob,
+ attention_probs_dropout_prob=self.attention_probs_dropout_prob,
+ is_decoder=False,
+ initializer_range=self.initializer_range,
+ encoder_stride=self.encoder_stride,
+ attn_implementation=self.attn_implementation,
+ )
+
+ def create_and_check_model(self, config, pixel_values, labels):
+ model = IJepaModel(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(pixel_values)
+ self.parent.assertEqual(
+ result.last_hidden_state.shape,
+ (self.batch_size, self.seq_length, self.hidden_size),
+ )
+
+ def create_and_check_for_image_classification(self, config, pixel_values, labels):
+ config.num_labels = self.type_sequence_label_size
+ model = IJepaForImageClassification(config)
+ model.to(torch_device)
+ model.eval()
+ result = model(pixel_values, labels=labels)
+ self.parent.assertEqual(
+ result.logits.shape,
+ (self.batch_size, self.type_sequence_label_size),
+ )
+
+ # test greyscale images
+ config.num_channels = 1
+ model = IJepaForImageClassification(config)
+ model.to(torch_device)
+ model.eval()
+
+ pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
+ result = model(pixel_values)
+ self.parent.assertEqual(
+ result.logits.shape,
+ (self.batch_size, self.type_sequence_label_size),
+ )
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ (
+ config,
+ pixel_values,
+ labels,
+ ) = config_and_inputs
+ inputs_dict = {"pixel_values": pixel_values}
+ return config, inputs_dict
+
+
+@require_torch
+class IJepaModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
+ """
+ Here we also overwrite some of the tests of test_modeling_common.py, as IJEPA does not use input_ids, inputs_embeds,
+ attention_mask and seq_length.
+ """
+
+ all_model_classes = (
+ (
+ IJepaModel,
+ IJepaForImageClassification,
+ )
+ if is_torch_available()
+ else ()
+ )
+ pipeline_model_mapping = (
+ {"image-feature-extraction": IJepaModel, "image-classification": IJepaForImageClassification}
+ if is_torch_available()
+ else {}
+ )
+ fx_compatible = True
+
+ test_pruning = False
+ test_resize_embeddings = False
+ test_head_masking = False
+
+ def setUp(self):
+ self.model_tester = IJepaModelTester(self)
+ self.config_tester = ConfigTester(
+ self,
+ config_class=IJepaConfig,
+ has_text_modality=False,
+ hidden_size=37,
+ )
+
+ @unittest.skip(
+ "Since `torch==2.3+cu121`, although this test passes, many subsequent tests have `CUDA error: misaligned address`."
+ "If `nvidia-xxx-cu118` are also installed, no failure (even with `torch==2.3+cu121`)."
+ )
+ def test_multi_gpu_data_parallel_forward(self):
+ super().test_multi_gpu_data_parallel_forward()
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ @unittest.skip(reason="IJEPA does not use inputs_embeds")
+ def test_inputs_embeds(self):
+ pass
+
+ def test_model_get_set_embeddings(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
+ x = model.get_output_embeddings()
+ self.assertTrue(x is None or isinstance(x, nn.Linear))
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_for_image_classification(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
+
+ @slow
+ def test_model_from_pretrained(self):
+ model_name = "jmtzt/ijepa_vith14_1k"
+ model = IJepaModel.from_pretrained(model_name)
+ self.assertIsNotNone(model)
+
+
+# We will verify our results on an image of cute cats
+def prepare_img():
+ image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
+ return image
+
+
+@require_torch
+@require_vision
+class IJepaModelIntegrationTest(unittest.TestCase):
+ @cached_property
+ def default_image_processor(self):
+ return ViTImageProcessor.from_pretrained("jmtzt/ijepa_vith14_1k") if is_vision_available() else None
+
+ @slow
+ def test_inference_no_head(self):
+ model = IJepaModel.from_pretrained("jmtzt/ijepa_vith14_1k").to(torch_device)
+
+ image_processor = self.default_image_processor
+ image = prepare_img()
+ inputs = image_processor(images=image, return_tensors="pt").to(torch_device)
+
+ # forward pass
+ with torch.no_grad():
+ outputs = model(**inputs)
+
+ # verify the last hidden state
+ expected_shape = torch.Size((1, 256, 1280))
+ self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
+
+ expected_slice = torch.Tensor(
+ [[-0.0621, -0.0054, -2.7513], [-0.1952, 0.0909, -3.9536], [0.0942, -0.0331, -1.2833]]
+ ).to(torch_device)
+
+ self.assertTrue(torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4))
+
+ @slow
+ @require_accelerate
+ @require_torch_accelerator
+ @require_torch_fp16
+ def test_inference_fp16(self):
+ r"""
+ A small test to make sure that inference work in half precision without any problem.
+ """
+ model = IJepaModel.from_pretrained(
+ "jmtzt/ijepa_vith14_1k",
+ torch_dtype=torch.float16,
+ device_map="auto",
+ )
+ image_processor = self.default_image_processor
+
+ image = prepare_img()
+ inputs = image_processor(images=image, return_tensors="pt")
+ pixel_values = inputs.pixel_values.to(torch_device)
+
+ # forward pass to make sure inference works in fp16
+ with torch.no_grad():
+ _ = model(pixel_values)
+
+ @slow
+ def test_inference_interpolate_pos_encoding(self):
+ # I-JEPA, similar to ViT models have an `interpolate_pos_encoding` argument in their forward method,
+ # allowing to interpolate the pre-trained position embeddings in order to use
+ # the model on higher resolutions. The DINO model by Facebook AI leverages this
+ # to visualize self-attention on higher resolution images.
+ model = IJepaModel.from_pretrained("jmtzt/ijepa_vith14_1k").to(torch_device)
+
+ image_processor = self.default_image_processor
+ image = prepare_img()
+ inputs = image_processor(images=image, return_tensors="pt")
+ pixel_values = inputs.pixel_values.to(torch_device)
+
+ # forward pass
+ with torch.no_grad():
+ outputs = model(pixel_values, interpolate_pos_encoding=True)
+
+ # verify the logits
+ expected_shape = torch.Size((1, 256, 1280))
+ self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
+
+ expected_slice = torch.tensor(
+ [[-0.0621, -0.0054, -2.7513], [-0.1952, 0.0909, -3.9536], [0.0942, -0.0331, -1.2833]]
+ ).to(torch_device)
+
+ self.assertTrue(torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4))
diff --git a/utils/check_docstrings.py b/utils/check_docstrings.py
index 0be960f4a33e6d..a2ea05edce8063 100644
--- a/utils/check_docstrings.py
+++ b/utils/check_docstrings.py
@@ -331,6 +331,7 @@
"IBertModel",
"IdeficsConfig",
"IdeficsProcessor",
+ "IJepaModel",
"ImageClassificationPipeline",
"ImageFeatureExtractionPipeline",
"ImageGPTConfig",