diff --git a/docs/source/en/index.md b/docs/source/en/index.md
index 967049d89cbe12..130775f6420f9e 100644
--- a/docs/source/en/index.md
+++ b/docs/source/en/index.md
@@ -122,6 +122,7 @@ Flax), PyTorch, and/or TensorFlow.
| [DeiT](model_doc/deit) | ✅ | ✅ | ❌ |
| [DePlot](model_doc/deplot) | ✅ | ❌ | ❌ |
| [Depth Anything](model_doc/depth_anything) | ✅ | ❌ | ❌ |
+| [DepthPro](model_doc/depth_pro) | ✅ | ❌ | ❌ |
| [DETA](model_doc/deta) | ✅ | ❌ | ❌ |
| [DETR](model_doc/detr) | ✅ | ❌ | ❌ |
| [DialoGPT](model_doc/dialogpt) | ✅ | ✅ | ✅ |
diff --git a/docs/source/en/model_doc/depth_pro.md b/docs/source/en/model_doc/depth_pro.md
new file mode 100644
index 00000000000000..9019547434af84
--- /dev/null
+++ b/docs/source/en/model_doc/depth_pro.md
@@ -0,0 +1,123 @@
+
+
+# DepthPro
+
+## Overview
+
+The DepthPro model was proposed in [Depth Pro: Sharp Monocular Metric Depth in Less Than a Second](https://arxiv.org/abs/2410.02073) by Aleksei Bochkovskii, Amaël Delaunoy, Hugo Germain, Marcel Santos, Yichao Zhou, Stephan R. Richter, Vladlen Koltun.
+
+It leverages a multi-scale [Vision Transformer (ViT)](vit) optimized for dense predictions. It downsamples an image at several scales. At each scale, it is split into patches, which are processed by a ViT-based [Dinov2](dinov2) patch encoder, with weights shared across scales. Patches are merged into feature maps, upsampled, and fused via a [DPT](dpt) like decoder.
+
+The abstract from the paper is the following:
+
+*We present a foundation model for zero-shot metric monocular depth estimation. Our model, Depth Pro, synthesizes high-resolution depth maps with unparalleled sharpness and high-frequency details. The predictions are metric, with absolute scale, without relying on the availability of metadata such as camera intrinsics. And the model is fast, producing a 2.25-megapixel depth map in 0.3 seconds on a standard GPU. These characteristics are enabled by a number of technical contributions, including an efficient multi-scale vision transformer for dense prediction, a training protocol that combines real and synthetic datasets to achieve high metric accuracy alongside fine boundary tracing, dedicated evaluation metrics for boundary accuracy in estimated depth maps, and state-of-the-art focal length estimation from a single image. Extensive experiments analyze specific design choices and demonstrate that Depth Pro outperforms prior work along multiple dimensions.*
+
+
+
+ DepthPro architecture. Taken from the original paper.
+
+This model was contributed by [geetu040](https://github.com/geetu040). The original code can be found [here](https://github.com/apple/ml-depth-pro).
+
+
+
+## Usage tips
+
+```python
+from transformers import DepthProConfig, DepthProForDepthEstimation
+
+config = DepthProConfig()
+model = DPTForDepthEstimation(config=config)
+```
+
+- By default model takes an input image of size `1536`, this can be changed via config, however the model is compatible with images of different width and height.
+- Input image is scaled with different ratios, as specified in `scaled_images_ratios`, then each of the scaled image is patched to `patch_size` with an overlap ratio of `scaled_images_overlap_ratios`.
+- These patches go through `DinoV2 (ViT)` based encoders and are reassembled via a `DPT` based decoder.
+- `DepthProForDepthEstimation` can also predict the `FOV (Field of View)` if `use_fov_model` is set to `True` in the config.
+- `DepthProImageProcessor` can be used for preprocessing the inputs and postprocessing the outputs. `DepthProImageProcessor.post_process_depth_estimation` interpolates the `predicted_depth` back to match the input image size.
+- To generate `predicted_depth` of the same size as input image, make sure the config is created such that
+```
+image_size / 2**(n_fusion_blocks+1) == patch_size / patch_embeddings_size
+
+where
+n_fusion_blocks = len(intermediate_hook_ids) + len(scaled_images_ratios)
+```
+
+
+### Using Scaled Dot Product Attention (SDPA)
+
+PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
+encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
+[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
+or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
+page for more information.
+
+SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
+`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
+
+```py
+from transformers import DepthProForDepthEstimation
+model = DepthProForDepthEstimation.from_pretrained("geetu040/DepthPro", attn_implementation="sdpa", torch_dtype=torch.float16)
+...
+```
+
+For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).
+
+On a local benchmark (A100-40GB, PyTorch 2.3.0, OS Ubuntu 22.04) with `float32` and `google/vit-base-patch16-224` model, we saw the following speedups during inference.
+
+| Batch size | Average inference time (ms), eager mode | Average inference time (ms), sdpa model | Speed up, Sdpa / Eager (x) |
+|--------------|-------------------------------------------|-------------------------------------------|------------------------------|
+| 1 | 7 | 6 | 1.17 |
+| 2 | 8 | 6 | 1.33 |
+| 4 | 8 | 6 | 1.33 |
+| 8 | 8 | 6 | 1.33 |
+
+## Resources
+
+- Research Paper: [Depth Pro: Sharp Monocular Metric Depth in Less Than a Second](https://arxiv.org/pdf/2410.02073)
+
+- Official Implementation: [apple/ml-depth-pro](https://github.com/apple/ml-depth-pro)
+
+
+
+If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
+
+## DepthProConfig
+
+[[autodoc]] DepthProConfig
+
+## DepthProImageProcessor
+
+[[autodoc]] DepthProImageProcessor
+ - preprocess
+ - post_process_depth_estimation
+
+## DepthProImageProcessorFast
+
+[[autodoc]] DepthProImageProcessorFast
+ - preprocess
+ - post_process_depth_estimation
+
+## DepthProModel
+
+[[autodoc]] DepthProModel
+ - forward
+
+## DepthProForDepthEstimation
+
+[[autodoc]] DepthProForDepthEstimation
+ - forward
diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md
index 930f41b6fefba7..9ead5ac276693f 100644
--- a/docs/source/en/perf_infer_gpu_one.md
+++ b/docs/source/en/perf_infer_gpu_one.md
@@ -237,6 +237,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [data2vec_vision](https://huggingface.co/docs/transformers/main/en/model_doc/data2vec#transformers.Data2VecVisionModel)
* [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel)
* [DeiT](https://huggingface.co/docs/transformers/model_doc/deit#transformers.DeiTModel)
+* [DepthPro](https://huggingface.co/docs/transformers/model_doc/depth_pro#transformers.DepthProModel)
* [Dinov2](https://huggingface.co/docs/transformers/en/model_doc/dinov2)
* [DistilBert](https://huggingface.co/docs/transformers/model_doc/distilbert#transformers.DistilBertModel)
* [Dpr](https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DprReader)
diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py
index ef140cc6d3a843..ebcfe53848502e 100755
--- a/src/transformers/__init__.py
+++ b/src/transformers/__init__.py
@@ -400,6 +400,7 @@
"models.deprecated.vit_hybrid": ["ViTHybridConfig"],
"models.deprecated.xlm_prophetnet": ["XLMProphetNetConfig"],
"models.depth_anything": ["DepthAnythingConfig"],
+ "models.depth_pro": ["DepthProConfig"],
"models.detr": ["DetrConfig"],
"models.dialogpt": [],
"models.dinat": ["DinatConfig"],
@@ -1212,6 +1213,7 @@
_import_structure["models.deprecated.efficientformer"].append("EfficientFormerImageProcessor")
_import_structure["models.deprecated.tvlt"].append("TvltImageProcessor")
_import_structure["models.deprecated.vit_hybrid"].extend(["ViTHybridImageProcessor"])
+ _import_structure["models.depth_pro"].extend(["DepthProImageProcessor", "DepthProImageProcessorFast"])
_import_structure["models.detr"].extend(["DetrFeatureExtractor", "DetrImageProcessor"])
_import_structure["models.donut"].extend(["DonutFeatureExtractor", "DonutImageProcessor"])
_import_structure["models.dpt"].extend(["DPTFeatureExtractor", "DPTImageProcessor"])
@@ -2136,6 +2138,13 @@
"DepthAnythingPreTrainedModel",
]
)
+ _import_structure["models.depth_pro"].extend(
+ [
+ "DepthProForDepthEstimation",
+ "DepthProModel",
+ "DepthProPreTrainedModel",
+ ]
+ )
_import_structure["models.detr"].extend(
[
"DetrForObjectDetection",
@@ -5359,6 +5368,7 @@
XLMProphetNetConfig,
)
from .models.depth_anything import DepthAnythingConfig
+ from .models.depth_pro import DepthProConfig
from .models.detr import DetrConfig
from .models.dinat import DinatConfig
from .models.dinov2 import Dinov2Config
@@ -6207,6 +6217,7 @@
from .models.deprecated.efficientformer import EfficientFormerImageProcessor
from .models.deprecated.tvlt import TvltImageProcessor
from .models.deprecated.vit_hybrid import ViTHybridImageProcessor
+ from .models.depth_pro import DepthProImageProcessor, DepthProImageProcessorFast
from .models.detr import DetrFeatureExtractor, DetrImageProcessor
from .models.donut import DonutFeatureExtractor, DonutImageProcessor
from .models.dpt import DPTFeatureExtractor, DPTImageProcessor
@@ -7001,6 +7012,11 @@
DepthAnythingForDepthEstimation,
DepthAnythingPreTrainedModel,
)
+ from .models.depth_pro import (
+ DepthProForDepthEstimation,
+ DepthProModel,
+ DepthProPreTrainedModel,
+ )
from .models.detr import (
DetrForObjectDetection,
DetrForSegmentation,
diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py
index 7fcaddde704cf7..9030f178ab0cfa 100644
--- a/src/transformers/models/__init__.py
+++ b/src/transformers/models/__init__.py
@@ -73,6 +73,7 @@
deit,
deprecated,
depth_anything,
+ depth_pro,
detr,
dialogpt,
dinat,
diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py
index 69ce8efa10c76c..b4b920b6d5886f 100644
--- a/src/transformers/models/auto/configuration_auto.py
+++ b/src/transformers/models/auto/configuration_auto.py
@@ -90,6 +90,7 @@
("deformable_detr", "DeformableDetrConfig"),
("deit", "DeiTConfig"),
("depth_anything", "DepthAnythingConfig"),
+ ("depth_pro", "DepthProConfig"),
("deta", "DetaConfig"),
("detr", "DetrConfig"),
("dinat", "DinatConfig"),
@@ -399,6 +400,7 @@
("deplot", "DePlot"),
("depth_anything", "Depth Anything"),
("depth_anything_v2", "Depth Anything V2"),
+ ("depth_pro", "DepthPro"),
("deta", "DETA"),
("detr", "DETR"),
("dialogpt", "DialoGPT"),
diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py
index db25591eaa3544..41dbb4ab9e5be5 100644
--- a/src/transformers/models/auto/image_processing_auto.py
+++ b/src/transformers/models/auto/image_processing_auto.py
@@ -74,6 +74,7 @@
("deformable_detr", ("DeformableDetrImageProcessor", "DeformableDetrImageProcessorFast")),
("deit", ("DeiTImageProcessor",)),
("depth_anything", ("DPTImageProcessor",)),
+ ("depth_pro", ("DepthProImageProcessor", "DepthProImageProcessorFast")),
("deta", ("DetaImageProcessor",)),
("detr", ("DetrImageProcessor", "DetrImageProcessorFast")),
("dinat", ("ViTImageProcessor", "ViTImageProcessorFast")),
diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py
index e8a2dece432476..dcf3ee974d2bf1 100644
--- a/src/transformers/models/auto/modeling_auto.py
+++ b/src/transformers/models/auto/modeling_auto.py
@@ -88,6 +88,7 @@
("decision_transformer", "DecisionTransformerModel"),
("deformable_detr", "DeformableDetrModel"),
("deit", "DeiTModel"),
+ ("depth_pro", "DepthProModel"),
("deta", "DetaModel"),
("detr", "DetrModel"),
("dinat", "DinatModel"),
@@ -580,6 +581,7 @@
("data2vec-vision", "Data2VecVisionModel"),
("deformable_detr", "DeformableDetrModel"),
("deit", "DeiTModel"),
+ ("depth_pro", "DepthProModel"),
("deta", "DetaModel"),
("detr", "DetrModel"),
("dinat", "DinatModel"),
@@ -891,6 +893,7 @@
[
# Model for depth estimation mapping
("depth_anything", "DepthAnythingForDepthEstimation"),
+ ("depth_pro", "DepthProForDepthEstimation"),
("dpt", "DPTForDepthEstimation"),
("glpn", "GLPNForDepthEstimation"),
("zoedepth", "ZoeDepthForDepthEstimation"),
diff --git a/src/transformers/models/depth_pro/__init__.py b/src/transformers/models/depth_pro/__init__.py
new file mode 100644
index 00000000000000..6fa380d6420834
--- /dev/null
+++ b/src/transformers/models/depth_pro/__init__.py
@@ -0,0 +1,72 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...file_utils import _LazyModule, is_torch_available, is_vision_available
+from ...utils import OptionalDependencyNotAvailable
+
+
+_import_structure = {"configuration_depth_pro": ["DepthProConfig"]}
+
+try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["image_processing_depth_pro"] = ["DepthProImageProcessor"]
+ _import_structure["image_processing_depth_pro_fast"] = ["DepthProImageProcessorFast"]
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_depth_pro"] = [
+ "DepthProForDepthEstimation",
+ "DepthProModel",
+ "DepthProPreTrainedModel",
+ ]
+
+
+if TYPE_CHECKING:
+ from .configuration_depth_pro import DepthProConfig
+
+ try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .image_processing_depth_pro import DepthProImageProcessor
+ from .image_processing_depth_pro_fast import DepthProImageProcessorFast
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_depth_pro import (
+ DepthProForDepthEstimation,
+ DepthProModel,
+ DepthProPreTrainedModel,
+ )
+
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/depth_pro/configuration_depth_pro.py b/src/transformers/models/depth_pro/configuration_depth_pro.py
new file mode 100644
index 00000000000000..206c01eff191bd
--- /dev/null
+++ b/src/transformers/models/depth_pro/configuration_depth_pro.py
@@ -0,0 +1,203 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""DepthPro model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class DepthProConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`DepthProModel`]. It is used to instantiate a
+ DepthPro model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the DepthPro
+ [apple/DepthPro](https://huggingface.co/apple/DepthPro) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ hidden_size (`int`, *optional*, defaults to 1024):
+ Dimensionality of the encoder layers and the pooler layer.
+ fusion_hidden_size (`int`, *optional*, defaults to 256):
+ The number of channels before fusion.
+ num_hidden_layers (`int`, *optional*, defaults to 24):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ mlp_ratio (`int`, *optional*, defaults to 4):
+ Ratio of the hidden size of the MLPs relative to the `hidden_size`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the layer normalization layers.
+ patch_size (`int`, *optional*, defaults to 384):
+ The size (resolution) of each patch.
+ num_channels (`int`, *optional*, defaults to 3):
+ The number of input channels.
+ patch_embeddings_size (`int`, *optional*, defaults to 16):
+ kernel_size and stride for convolution in PatchEmbeddings.
+ qkv_bias (`bool`, *optional*, defaults to `True`):
+ Whether to add a bias to the queries, keys and values.
+ layerscale_value (`float`, *optional*, defaults to 1.0):
+ Initial value to use for layer scale.
+ drop_path_rate (`float`, *optional*, defaults to 0.0):
+ Stochastic depth rate per sample (when applied in the main path of residual layers).
+ use_swiglu_ffn (`bool`, *optional*, defaults to `False`):
+ Whether to use the SwiGLU feedforward neural network.
+ intermediate_hook_ids (`List[int]`, *optional*, defaults to `[11, 5]`):
+ Indices of the intermediate hidden states from the patch encoder to use for fusion.
+ intermediate_feature_dims (`List[int]`, *optional*, defaults to `[256, 256]`):
+ Hidden state dimensions during upsampling for each intermediate hidden state in `intermediate_hook_ids`.
+ scaled_images_ratios (`List[float]`, *optional*, defaults to `[0.25, 0.5, 1]`):
+ Ratios of scaled images to be used by the patch encoder.
+ scaled_images_overlap_ratios (`List[float]`, *optional*, defaults to `[0.0, 0.5, 0.25]`):
+ Overlap ratios between patches for each scaled image in `scaled_images_ratios`.
+ scaled_images_feature_dims (`List[int]`, *optional*, defaults to `[1024, 1024, 512]`):
+ Hidden state dimensions during upsampling for each scaled image in `scaled_images_ratios`.
+ use_batch_norm_in_fusion_residual (`bool`, *optional*, defaults to `False`):
+ Whether to use batch normalization in the pre-activate residual units of the fusion blocks.
+ use_bias_in_fusion_residual (`bool`, *optional*, defaults to `True`):
+ Whether to use bias in the pre-activate residual units of the fusion blocks.
+ use_fov_model (`bool`, *optional*, defaults to `True`):
+ Whether to use `DepthProFOVModel` to generate the field of view.
+ num_fov_head_layers (`int`, *optional*, defaults to 2):
+ Number of convolution layers in the head of `DepthProFOVModel`.
+
+ Example:
+
+ ```python
+ >>> from transformers import DepthProConfig, DepthProModel
+
+ >>> # Initializing a DepthPro apple/DepthPro style configuration
+ >>> configuration = DepthProConfig()
+
+ >>> # Initializing a model (with random weights) from the apple/DepthPro style configuration
+ >>> model = DepthProModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "depth_pro"
+
+ def __init__(
+ self,
+ hidden_size=1024,
+ fusion_hidden_size=256,
+ num_hidden_layers=24,
+ num_attention_heads=16,
+ mlp_ratio=4,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.0,
+ attention_probs_dropout_prob=0.0,
+ initializer_range=0.02,
+ layer_norm_eps=1e-6,
+ patch_size=384,
+ num_channels=3,
+ patch_embeddings_size=16,
+ qkv_bias=True,
+ layerscale_value=1.0,
+ drop_path_rate=0.0,
+ use_swiglu_ffn=False,
+ intermediate_hook_ids=[11, 5],
+ intermediate_feature_dims=[256, 256],
+ scaled_images_ratios=[0.25, 0.5, 1],
+ scaled_images_overlap_ratios=[0.0, 0.5, 0.25],
+ scaled_images_feature_dims=[1024, 1024, 512],
+ use_batch_norm_in_fusion_residual=False,
+ use_bias_in_fusion_residual=True,
+ use_fov_model=True,
+ num_fov_head_layers=2,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ # scaled_images_ratios is sorted
+ if scaled_images_ratios != sorted(scaled_images_ratios):
+ raise ValueError(
+ f"Values in scaled_images_ratios={scaled_images_ratios} " "should be sorted from low to high"
+ )
+
+ # patch_size should be a divisible by patch_embeddings_size
+ # else it raises an exception in DepthProViTPatchEmbeddings
+ if patch_size % patch_embeddings_size != 0:
+ raise ValueError(
+ f"patch_size={patch_size} should be divisible " f"by patch_embeddings_size={patch_embeddings_size}."
+ )
+
+ # scaled_images_ratios, scaled_images_overlap_ratios, scaled_images_feature_dims should be consistent
+ if not (len(scaled_images_ratios) == len(scaled_images_overlap_ratios) == len(scaled_images_feature_dims)):
+ raise ValueError(
+ f"len(scaled_images_ratios)={len(scaled_images_ratios)} and "
+ f"len(scaled_images_overlap_ratios)={len(scaled_images_overlap_ratios)} and "
+ f"len(scaled_images_feature_dims)={len(scaled_images_feature_dims)}, "
+ f"should match in config."
+ )
+
+ # intermediate_hook_ids, intermediate_feature_dims should be consistent
+ if not (len(intermediate_hook_ids) == len(intermediate_feature_dims)):
+ raise ValueError(
+ f"len(intermediate_hook_ids)={len(intermediate_hook_ids)} and "
+ f"len(intermediate_feature_dims)={len(intermediate_feature_dims)}, "
+ f"should match in config."
+ )
+
+ # fusion_hidden_size should be consistent with num_fov_head_layers
+ if fusion_hidden_size // 2**num_fov_head_layers == 0:
+ raise ValueError(
+ f"fusion_hidden_size={fusion_hidden_size} should be consistent with num_fov_head_layers={num_fov_head_layers} "
+ "i.e fusion_hidden_size // 2**num_fov_head_layers > 0"
+ )
+
+ self.hidden_size = hidden_size
+ self.fusion_hidden_size = fusion_hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.mlp_ratio = mlp_ratio
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.patch_embeddings_size = patch_embeddings_size
+ self.qkv_bias = qkv_bias
+ self.layerscale_value = layerscale_value
+ self.drop_path_rate = drop_path_rate
+ self.use_swiglu_ffn = use_swiglu_ffn
+ self.use_batch_norm_in_fusion_residual = use_batch_norm_in_fusion_residual
+ self.use_bias_in_fusion_residual = use_bias_in_fusion_residual
+ self.use_fov_model = use_fov_model
+ self.num_fov_head_layers = num_fov_head_layers
+ self.intermediate_hook_ids = intermediate_hook_ids
+ self.intermediate_feature_dims = intermediate_feature_dims
+ self.scaled_images_ratios = scaled_images_ratios
+ self.scaled_images_overlap_ratios = scaled_images_overlap_ratios
+ self.scaled_images_feature_dims = scaled_images_feature_dims
+
+
+__all__ = ["DepthProConfig"]
diff --git a/src/transformers/models/depth_pro/convert_depth_pro_weights_to_hf.py b/src/transformers/models/depth_pro/convert_depth_pro_weights_to_hf.py
new file mode 100644
index 00000000000000..cca89f6a8b8cac
--- /dev/null
+++ b/src/transformers/models/depth_pro/convert_depth_pro_weights_to_hf.py
@@ -0,0 +1,281 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import gc
+import os
+
+import regex as re
+import torch
+from huggingface_hub import hf_hub_download
+
+from transformers import (
+ DepthProConfig,
+ DepthProForDepthEstimation,
+ DepthProImageProcessorFast,
+)
+from transformers.image_utils import PILImageResampling
+
+
+# fmt: off
+ORIGINAL_TO_CONVERTED_KEY_MAPPING = {
+
+ # encoder and head
+ r"encoder.(patch|image)_encoder.cls_token": r"depth_pro.encoder.\1_encoder.embeddings.cls_token",
+ r"encoder.(patch|image)_encoder.pos_embed": r"depth_pro.encoder.\1_encoder.embeddings.position_embeddings",
+ r"encoder.(patch|image)_encoder.patch_embed.proj.(weight|bias)": r"depth_pro.encoder.\1_encoder.embeddings.patch_embeddings.projection.\2",
+ r"encoder.(patch|image)_encoder.blocks.(\d+).norm(\d+).(weight|bias)": r"depth_pro.encoder.\1_encoder.encoder.layer.\2.norm\3.\4",
+ r"encoder.(patch|image)_encoder.blocks.(\d+).attn.qkv.(weight|bias)": r"depth_pro.encoder.\1_encoder.encoder.layer.\2.attention.attention.(query|key|value).\3",
+ r"encoder.(patch|image)_encoder.blocks.(\d+).attn.proj.(weight|bias)": r"depth_pro.encoder.\1_encoder.encoder.layer.\2.attention.output.dense.\3",
+ r"encoder.(patch|image)_encoder.blocks.(\d+).ls(\d+).gamma": r"depth_pro.encoder.\1_encoder.encoder.layer.\2.layer_scale\3.lambda1",
+ r"encoder.(patch|image)_encoder.blocks.(\d+).mlp.fc(\d+).(weight|bias)": r"depth_pro.encoder.\1_encoder.encoder.layer.\2.mlp.fc\3.\4",
+ r"encoder.(patch|image)_encoder.norm.(weight|bias)": r"depth_pro.encoder.\1_encoder.layernorm.\2",
+ r"encoder.fuse_lowres.(weight|bias)": r"depth_pro.encoder.fuse_image_with_low_res.\1",
+ r"head.(\d+).(weight|bias)": r"head.head.\1.\2",
+
+ # fov
+ r"fov.encoder.0.cls_token": r"fov_model.encoder.embeddings.cls_token",
+ r"fov.encoder.0.pos_embed": r"fov_model.encoder.embeddings.position_embeddings",
+ r"fov.encoder.0.patch_embed.proj.(weight|bias)": r"fov_model.encoder.embeddings.patch_embeddings.projection.\1",
+ r"fov.encoder.0.blocks.(\d+).norm(\d+).(weight|bias)": r"fov_model.encoder.encoder.layer.\1.norm\2.\3",
+ r"fov.encoder.0.blocks.(\d+).attn.qkv.(weight|bias)": r"fov_model.encoder.encoder.layer.\1.attention.attention.(query|key|value).\2",
+ r"fov.encoder.0.blocks.(\d+).attn.proj.(weight|bias)": r"fov_model.encoder.encoder.layer.\1.attention.output.dense.\2",
+ r"fov.encoder.0.blocks.(\d+).ls(\d+).gamma": r"fov_model.encoder.encoder.layer.\1.layer_scale\2.lambda1",
+ r"fov.encoder.0.blocks.(\d+).mlp.fc(\d+).(weight|bias)": r"fov_model.encoder.encoder.layer.\1.mlp.fc\2.\3",
+ r"fov.encoder.0.norm.(weight|bias)": r"fov_model.encoder.layernorm.\1",
+ r"fov.downsample.(\d+).(weight|bias)": r"fov_model.global_neck.\1.\2",
+ r"fov.encoder.1.(weight|bias)": r"fov_model.encoder_neck.\1",
+ r"fov.head.head.(\d+).(weight|bias)": r"fov_model.head.\1.\2",
+
+ # upsamples (hard coded; regex is not very feasible here)
+ "encoder.upsample_latent0.0.weight": "depth_pro.encoder.feature_upsample.upsample_blocks.5.0.weight",
+ "encoder.upsample_latent0.1.weight": "depth_pro.encoder.feature_upsample.upsample_blocks.5.1.weight",
+ "encoder.upsample_latent0.2.weight": "depth_pro.encoder.feature_upsample.upsample_blocks.5.2.weight",
+ "encoder.upsample_latent0.3.weight": "depth_pro.encoder.feature_upsample.upsample_blocks.5.3.weight",
+ "encoder.upsample_latent1.0.weight": "depth_pro.encoder.feature_upsample.upsample_blocks.4.0.weight",
+ "encoder.upsample_latent1.1.weight": "depth_pro.encoder.feature_upsample.upsample_blocks.4.1.weight",
+ "encoder.upsample_latent1.2.weight": "depth_pro.encoder.feature_upsample.upsample_blocks.4.2.weight",
+ "encoder.upsample0.0.weight": "depth_pro.encoder.feature_upsample.upsample_blocks.3.0.weight",
+ "encoder.upsample0.1.weight": "depth_pro.encoder.feature_upsample.upsample_blocks.3.1.weight",
+ "encoder.upsample1.0.weight": "depth_pro.encoder.feature_upsample.upsample_blocks.2.0.weight",
+ "encoder.upsample1.1.weight": "depth_pro.encoder.feature_upsample.upsample_blocks.2.1.weight",
+ "encoder.upsample2.0.weight": "depth_pro.encoder.feature_upsample.upsample_blocks.1.0.weight",
+ "encoder.upsample2.1.weight": "depth_pro.encoder.feature_upsample.upsample_blocks.1.1.weight",
+ "encoder.upsample_lowres.weight": "depth_pro.encoder.feature_upsample.upsample_blocks.0.0.weight",
+ "encoder.upsample_lowres.bias": "depth_pro.encoder.feature_upsample.upsample_blocks.0.0.bias",
+
+ # projections between encoder and fusion
+ r"decoder.convs.(\d+).weight": lambda match: (
+ f"depth_pro.encoder.feature_projection.projections.{4-int(match.group(1))}.weight"
+ ),
+
+ # fusion stage
+ r"decoder.fusions.(\d+).resnet(\d+).residual.(\d+).(weight|bias)": lambda match: (
+ f"fusion_stage.layers.{4-int(match.group(1))}.residual_layer{match.group(2)}.convolution{(int(match.group(3))+1)//2}.{match.group(4)}"
+ ),
+ r"decoder.fusions.(\d+).out_conv.(weight|bias)": lambda match: (
+ f"fusion_stage.layers.{4-int(match.group(1))}.projection.{match.group(2)}"
+ ),
+ r"decoder.fusions.(\d+).deconv.(weight|bias)": lambda match: (
+ f"fusion_stage.layers.{4-int(match.group(1))}.deconv.{match.group(2)}"
+ ),
+}
+# fmt: on
+
+
+def convert_old_keys_to_new_keys(state_dict_keys: dict = None):
+ output_dict = {}
+ if state_dict_keys is not None:
+ old_text = "\n".join(state_dict_keys)
+ new_text = old_text
+ for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING.items():
+ if replacement is None:
+ new_text = re.sub(pattern, "", new_text) # an empty line
+ continue
+ new_text = re.sub(pattern, replacement, new_text)
+ output_dict = dict(zip(old_text.split("\n"), new_text.split("\n")))
+ return output_dict
+
+
+def get_qkv_state_dict(key, parameter):
+ """
+ new key which looks like this
+ xxxx.(q|k|v).xxx (m, n)
+
+ is converted to
+ xxxx.q.xxxx (m//3, n)
+ xxxx.k.xxxx (m//3, n)
+ xxxx.v.xxxx (m//3, n)
+ """
+ qkv_state_dict = {}
+ placeholder = re.search(r"(\(.*?\))", key).group(1) # finds "(query|key|value)"
+ replacements_keys = placeholder[1:-1].split("|") # creates ['query', 'key', 'value']
+ replacements_vals = torch.split(
+ parameter, split_size_or_sections=parameter.size(0) // len(replacements_keys), dim=0
+ )
+ for replacement_key, replacement_val in zip(replacements_keys, replacements_vals):
+ qkv_state_dict[key.replace(placeholder, replacement_key)] = replacement_val
+ return qkv_state_dict
+
+
+def write_model(
+ hf_repo_id: str,
+ output_dir: str,
+ safe_serialization: bool = True,
+):
+ os.makedirs(output_dir, exist_ok=True)
+
+ # ------------------------------------------------------------
+ # Create and save config
+ # ------------------------------------------------------------
+
+ # create config
+ config = DepthProConfig(
+ # this config is same as the default config and used for pre-trained weights
+ hidden_size=1024,
+ fusion_hidden_size=256,
+ num_hidden_layers=24,
+ num_attention_heads=16,
+ mlp_ratio=4,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.0,
+ attention_probs_dropout_prob=0.0,
+ initializer_range=0.02,
+ layer_norm_eps=1e-6,
+ patch_size=384,
+ num_channels=3,
+ patch_embeddings_size=16,
+ qkv_bias=True,
+ layerscale_value=1.0,
+ drop_path_rate=0.0,
+ use_swiglu_ffn=False,
+ apply_layernorm=True,
+ reshape_hidden_states=True,
+ intermediate_hook_ids=[11, 5],
+ intermediate_feature_dims=[256, 256],
+ scaled_images_ratios=[0.25, 0.5, 1],
+ scaled_images_overlap_ratios=[0.0, 0.5, 0.25],
+ scaled_images_feature_dims=[1024, 1024, 512],
+ use_batch_norm_in_fusion_residual=False,
+ use_bias_in_fusion_residual=True,
+ use_fov_model=True,
+ num_fov_head_layers=2,
+ )
+
+ # save config
+ config.save_pretrained(output_dir)
+ print("Model config saved successfully...")
+
+ # ------------------------------------------------------------
+ # Convert weights
+ # ------------------------------------------------------------
+
+ # download and load state_dict from hf repo
+ file_path = hf_hub_download(hf_repo_id, "depth_pro.pt")
+ # file_path = "/home/geetu/work/hf/depth_pro/depth_pro.pt" # when you already have the files locally
+ loaded = torch.load(file_path, weights_only=True)
+
+ print("Converting model...")
+ all_keys = list(loaded.keys())
+ new_keys = convert_old_keys_to_new_keys(all_keys)
+
+ state_dict = {}
+ for key in all_keys:
+ new_key = new_keys[key]
+ current_parameter = loaded.pop(key)
+
+ if "qkv" in key:
+ qkv_state_dict = get_qkv_state_dict(new_key, current_parameter)
+ state_dict.update(qkv_state_dict)
+ else:
+ state_dict[new_key] = current_parameter
+
+ print("Loading the checkpoint in a DepthPro model.")
+ model = DepthProForDepthEstimation(config)
+ model.load_state_dict(state_dict, strict=True, assign=True)
+ print("Checkpoint loaded successfully.")
+
+ print("Saving the model.")
+ model.save_pretrained(output_dir, safe_serialization=safe_serialization)
+ del state_dict, model
+
+ # Safety check: reload the converted model
+ gc.collect()
+ print("Reloading the model to check if it's saved correctly.")
+ model = DepthProForDepthEstimation.from_pretrained(output_dir, device_map="auto")
+ print("Model reloaded successfully.")
+ return model
+
+
+def write_image_processor(output_dir: str):
+ image_processor = DepthProImageProcessorFast(
+ do_resize=True,
+ size={"height": 1536, "width": 1536},
+ resample=PILImageResampling.BILINEAR,
+ antialias=False,
+ do_rescale=True,
+ rescale_factor=1 / 255,
+ do_normalize=True,
+ image_mean=0.5,
+ image_std=0.5,
+ )
+ image_processor.save_pretrained(output_dir)
+ return image_processor
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--hf_repo_id",
+ default="apple/DepthPro",
+ help="Location of official weights from apple on HF",
+ )
+ parser.add_argument(
+ "--output_dir",
+ default="apple_DepthPro",
+ help="Location to write the converted model and processor",
+ )
+ parser.add_argument(
+ "--safe_serialization", default=True, type=bool, help="Whether or not to save using `safetensors`."
+ )
+ parser.add_argument(
+ "--push_to_hub",
+ action=argparse.BooleanOptionalAction,
+ help="Whether or not to push the converted model to the huggingface hub.",
+ )
+ parser.add_argument(
+ "--hub_repo_id",
+ default="geetu040/DepthPro",
+ help="Huggingface hub repo to write the converted model and processor",
+ )
+ args = parser.parse_args()
+
+ model = write_model(
+ hf_repo_id=args.hf_repo_id,
+ output_dir=args.output_dir,
+ safe_serialization=args.safe_serialization,
+ )
+
+ image_processor = write_image_processor(
+ output_dir=args.output_dir,
+ )
+
+ if args.push_to_hub:
+ print("Pushing to hub...")
+ model.push_to_hub(args.hub_repo_id)
+ image_processor.push_to_hub(args.hub_repo_id)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/transformers/models/depth_pro/image_processing_depth_pro.py b/src/transformers/models/depth_pro/image_processing_depth_pro.py
new file mode 100644
index 00000000000000..76a12577dd6330
--- /dev/null
+++ b/src/transformers/models/depth_pro/image_processing_depth_pro.py
@@ -0,0 +1,406 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Image processor class for DepthPro."""
+
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+
+
+if TYPE_CHECKING:
+ from ...modeling_outputs import DepthProDepthEstimatorOutput
+
+from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
+from ...image_transforms import to_channel_dimension_format
+from ...image_utils import (
+ IMAGENET_STANDARD_MEAN,
+ IMAGENET_STANDARD_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ infer_channel_dimension_format,
+ is_scaled_image,
+ is_torch_available,
+ make_list_of_images,
+ pil_torch_interpolation_mapping,
+ to_numpy_array,
+ valid_images,
+)
+from ...utils import (
+ TensorType,
+ filter_out_non_signature_kwargs,
+ logging,
+ requires_backends,
+)
+
+
+if is_torch_available():
+ import torch
+
+
+logger = logging.get_logger(__name__)
+
+
+class DepthProImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a DepthPro image processor.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the image's (height, width) dimensions to the specified `(size["height"],
+ size["width"])`. Can be overridden by the `do_resize` parameter in the `preprocess` method.
+ size (`dict`, *optional*, defaults to `{"height": 1536, "width": 1536}`):
+ Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
+ method.
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
+ Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
+ `preprocess` method.
+ antialias (`bool`, *optional*, defaults to `False`):
+ Whether to apply an anti-aliasing filter when resizing the image. It only affects tensors with
+ bilinear or bicubic modes and it is ignored otherwise.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
+ parameter in the `preprocess` method.
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+ Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
+ `preprocess` method.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
+ method.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
+ image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
+ """
+
+ model_input_names = ["pixel_values"]
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ size: Optional[Dict[str, int]] = None,
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
+ antialias: bool = False,
+ do_rescale: bool = True,
+ rescale_factor: Union[int, float] = 1 / 255,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ size = size if size is not None else {"height": 1536, "width": 1536}
+ size = get_size_dict(size)
+ self.do_resize = do_resize
+ self.do_rescale = do_rescale
+ self.do_normalize = do_normalize
+ self.size = size
+ self.resample = resample
+ self.antialias = antialias
+ self.rescale_factor = rescale_factor
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
+ self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
+
+ def resize(
+ self,
+ image: np.ndarray,
+ size: Dict[str, int],
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
+ antialias: bool = False,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> np.ndarray:
+ """
+ Resize an image to `(size["height"], size["width"])`.
+
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`Dict[str, int]`):
+ Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
+ `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
+ antialias (`bool`, *optional*, defaults to `False`):
+ Whether to apply an anti-aliasing filter when resizing the image. It only affects tensors with
+ bilinear or bicubic modes and it is ignored otherwise.
+ data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
+ image is used. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+
+ Returns:
+ `np.ndarray`: The resized images.
+ """
+ requires_backends(self, "torch")
+
+ size = get_size_dict(size)
+ if "height" not in size or "width" not in size:
+ raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
+ output_size = (size["height"], size["width"])
+
+ # we use torch interpolation instead of image.resize because DepthProImageProcessor
+ # rescales, then normalizes, which may cause some values to become negative, before resizing the image.
+ # image.resize expects all values to be in range [0, 1] or [0, 255] and throws an exception otherwise,
+ # however pytorch interpolation works with negative values.
+ # relevant issue here: https://github.com/huggingface/transformers/issues/34920
+ return (
+ torch.nn.functional.interpolate(
+ # input should be (B, C, H, W)
+ input=torch.from_numpy(image).unsqueeze(0),
+ size=output_size,
+ mode=pil_torch_interpolation_mapping[resample].value,
+ antialias=antialias,
+ )
+ .squeeze(0)
+ .numpy()
+ )
+
+ def _validate_input_arguments(
+ self,
+ do_resize: bool,
+ size: Dict[str, int],
+ resample: PILImageResampling,
+ antialias: bool,
+ do_rescale: bool,
+ rescale_factor: float,
+ do_normalize: bool,
+ image_mean: Union[float, List[float]],
+ image_std: Union[float, List[float]],
+ data_format: Union[str, ChannelDimension],
+ ):
+ if do_resize and None in (size, resample, antialias):
+ raise ValueError("Size, resample and antialias must be specified if do_resize is True.")
+
+ if do_rescale and rescale_factor is None:
+ raise ValueError("Rescale factor must be specified if do_rescale is True.")
+
+ if do_normalize and None in (image_mean, image_std):
+ raise ValueError("Image mean and standard deviation must be specified if do_normalize is True.")
+
+ @filter_out_non_signature_kwargs()
+ def preprocess(
+ self,
+ images: ImageInput,
+ do_resize: Optional[bool] = None,
+ size: Optional[Dict[str, int]] = None,
+ resample: Optional[PILImageResampling] = None,
+ antialias: Optional[bool] = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ):
+ """
+ Preprocess an image or batch of images.
+
+ Args:
+ images (`ImageInput`):
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
+ Dictionary in the format `{"height": h, "width": w}` specifying the size of the output image after
+ resizing.
+ resample (`PILImageResampling` filter, *optional*, defaults to `self.resample`):
+ `PILImageResampling` filter to use if resizing the image e.g. `PILImageResampling.BILINEAR`. Only has
+ an effect if `do_resize` is set to `True`.
+ antialias (`bool`, *optional*, defaults to `False`):
+ Whether to apply an anti-aliasing filter when resizing the image. It only affects tensors with
+ bilinear or bicubic modes and it is ignored otherwise.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image values between [0 - 1].
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
+ Image mean to use if `do_normalize` is set to `True`.
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
+ Image standard deviation to use if `do_normalize` is set to `True`.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - Unset: Use the channel dimension format of the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ """
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ resample = resample if resample is not None else self.resample
+ antialias = antialias if antialias is not None else self.antialias
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+
+ size = size if size is not None else self.size
+
+ images = make_list_of_images(images)
+
+ if not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+ self._validate_input_arguments(
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ antialias=antialias,
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ data_format=data_format,
+ )
+
+ # All transformations expect numpy arrays.
+ images = [to_numpy_array(image) for image in images]
+
+ if is_scaled_image(images[0]) and do_rescale:
+ logger.warning_once(
+ "It looks like you are trying to rescale already rescaled images. If the input"
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
+ )
+
+ if input_data_format is None:
+ # We assume that all images have the same channel dimension format.
+ input_data_format = infer_channel_dimension_format(images[0])
+
+ all_images = []
+ for image in images:
+ if do_rescale:
+ image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
+
+ if do_normalize:
+ image = self.normalize(
+ image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
+ )
+
+ # depth-pro rescales and normalizes the image before resizing it
+ # uses torch interpolation which requires ChannelDimension.FIRST
+ if do_resize:
+ image = to_channel_dimension_format(image, ChannelDimension.FIRST, input_channel_dim=input_data_format)
+ image = self.resize(image=image, size=size, resample=resample, antialias=antialias)
+ image = to_channel_dimension_format(image, data_format, input_channel_dim=ChannelDimension.FIRST)
+ else:
+ image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
+
+ all_images.append(image)
+
+ data = {"pixel_values": all_images}
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+ def post_process_depth_estimation(
+ self,
+ outputs: "DepthProDepthEstimatorOutput",
+ target_sizes: Optional[Union[TensorType, List[Tuple[int, int]], None]] = None,
+ ) -> Dict[str, List[TensorType]]:
+ """
+ Post-processes the raw depth predictions from the model to generate final depth predictions and optionally
+ resizes them to specified target sizes. This function supports scaling based on the field of view (FoV)
+ and adjusts depth values accordingly.
+
+ Args:
+ outputs ([`DepthProDepthEstimatorOutput`]):
+ Raw outputs of the model.
+ target_sizes (`Optional[Union[TensorType, List[Tuple[int, int]], None]]`, *optional*, defaults to `None`):
+ Target sizes to resize the depth predictions. Can be a tensor of shape `(batch_size, 2)`
+ or a list of tuples `(height, width)` for each image in the batch. If `None`, no resizing
+ is performed.
+
+ Returns:
+ `List[Dict[str, TensorType]]`: A list of dictionaries of tensors representing the processed depth
+ predictions.
+
+ Raises:
+ `ValueError`:
+ If the lengths of `predicted_depths`, `fovs`, or `target_sizes` are mismatched.
+ """
+ requires_backends(self, "torch")
+
+ predicted_depth = outputs.predicted_depth
+ fov = outputs.fov
+
+ batch_size = len(predicted_depth)
+
+ if target_sizes is not None and batch_size != len(target_sizes):
+ raise ValueError(
+ "Make sure that you pass in as many fov values as the batch dimension of the predicted depth"
+ )
+
+ results = []
+ fov = [None] * batch_size if fov is None else fov
+ target_sizes = [None] * batch_size if target_sizes is None else target_sizes
+ for depth, fov_value, target_size in zip(predicted_depth, fov, target_sizes):
+ if target_size is not None:
+ # scale image w.r.t fov
+ if fov_value is not None:
+ width = target_size[1]
+ fov_value = 0.5 * width / torch.tan(0.5 * torch.deg2rad(fov_value))
+ depth = depth * width / fov_value
+
+ # interpolate
+ depth = torch.nn.functional.interpolate(
+ # input should be (B, C, H, W)
+ input=depth.unsqueeze(0).unsqueeze(1),
+ size=target_size,
+ mode=pil_torch_interpolation_mapping[self.resample].value,
+ antialias=self.antialias,
+ ).squeeze()
+
+ # inverse the depth
+ depth = 1.0 / torch.clamp(depth, min=1e-4, max=1e4)
+
+ results.append(
+ {
+ "predicted_depth": depth,
+ "fov": fov_value,
+ }
+ )
+
+ return results
+
+
+__all__ = ["DepthProImageProcessor"]
diff --git a/src/transformers/models/depth_pro/image_processing_depth_pro_fast.py b/src/transformers/models/depth_pro/image_processing_depth_pro_fast.py
new file mode 100644
index 00000000000000..521e5b8a06282e
--- /dev/null
+++ b/src/transformers/models/depth_pro/image_processing_depth_pro_fast.py
@@ -0,0 +1,386 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Fast Image processor class for DepthPro."""
+
+import functools
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
+
+
+if TYPE_CHECKING:
+ from ...modeling_outputs import DepthProDepthEstimatorOutput
+
+from ...image_processing_base import BatchFeature
+from ...image_processing_utils import get_size_dict
+from ...image_processing_utils_fast import BaseImageProcessorFast, SizeDict
+from ...image_transforms import FusedRescaleNormalize, NumpyToTensor, Rescale
+from ...image_utils import (
+ IMAGENET_STANDARD_MEAN,
+ IMAGENET_STANDARD_STD,
+ ChannelDimension,
+ ImageInput,
+ ImageType,
+ PILImageResampling,
+ get_image_type,
+ make_list_of_images,
+ pil_torch_interpolation_mapping,
+)
+from ...utils import TensorType, logging, requires_backends
+from ...utils.import_utils import is_torch_available, is_torchvision_available
+
+
+logger = logging.get_logger(__name__)
+
+
+if is_torch_available():
+ import torch
+
+
+if is_torchvision_available():
+ from torchvision.transforms import Compose, Normalize, PILToTensor, Resize
+
+
+class DepthProImageProcessorFast(BaseImageProcessorFast):
+ r"""
+ Constructs a DepthPro image processor.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the image's (height, width) dimensions to the specified `(size["height"],
+ size["width"])`. Can be overridden by the `do_resize` parameter in the `preprocess` method.
+ size (`dict`, *optional*, defaults to `{"height": 1536, "width": 1536}`):
+ Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
+ method.
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
+ Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
+ `preprocess` method.
+ antialias (`bool`, *optional*, defaults to `False`):
+ Whether to apply an anti-aliasing filter when resizing the image. It only affects tensors with
+ bilinear or bicubic modes and it is ignored otherwise.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
+ parameter in the `preprocess` method.
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+ Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
+ `preprocess` method.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
+ method.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
+ image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
+ """
+
+ model_input_names = ["pixel_values"]
+ _transform_params = [
+ "do_resize",
+ "do_rescale",
+ "do_normalize",
+ "size",
+ "resample",
+ "antialias",
+ "rescale_factor",
+ "image_mean",
+ "image_std",
+ "image_type",
+ ]
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ size: Optional[Dict[str, int]] = None,
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
+ antialias: bool = False,
+ do_rescale: bool = True,
+ rescale_factor: Union[int, float] = 1 / 255,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ size = size if size is not None else {"height": 1536, "width": 1536}
+ size = get_size_dict(size)
+ self.do_resize = do_resize
+ self.do_rescale = do_rescale
+ self.do_normalize = do_normalize
+ self.size = size
+ self.resample = resample
+ self.antialias = antialias
+ self.rescale_factor = rescale_factor
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
+ self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
+
+ def _build_transforms(
+ self,
+ do_resize: bool,
+ size: Dict[str, int],
+ resample: PILImageResampling,
+ antialias: bool,
+ do_rescale: bool,
+ rescale_factor: float,
+ do_normalize: bool,
+ image_mean: Union[float, List[float]],
+ image_std: Union[float, List[float]],
+ image_type: ImageType,
+ ) -> "Compose":
+ """
+ Given the input settings build the image transforms using `torchvision.transforms.Compose`.
+ """
+ transforms = []
+
+ # All PIL and numpy values need to be converted to a torch tensor
+ # to keep cross compatibility with slow image processors
+ if image_type == ImageType.PIL:
+ transforms.append(PILToTensor())
+
+ elif image_type == ImageType.NUMPY:
+ transforms.append(NumpyToTensor())
+
+ # We can combine rescale and normalize into a single operation for speed
+ if do_rescale and do_normalize:
+ transforms.append(FusedRescaleNormalize(image_mean, image_std, rescale_factor=rescale_factor))
+ elif do_rescale:
+ transforms.append(Rescale(rescale_factor=rescale_factor))
+ elif do_normalize:
+ transforms.append(Normalize(image_mean, image_std))
+
+ # depth-pro scales the image before resizing it
+ if do_resize:
+ transforms.append(
+ Resize(
+ (size["height"], size["width"]),
+ interpolation=pil_torch_interpolation_mapping[resample],
+ antialias=antialias,
+ )
+ )
+
+ return Compose(transforms)
+
+ @functools.lru_cache(maxsize=1)
+ def _validate_input_arguments(
+ self,
+ return_tensors: Union[str, TensorType],
+ do_resize: bool,
+ size: Dict[str, int],
+ resample: PILImageResampling,
+ antialias: bool,
+ do_rescale: bool,
+ rescale_factor: float,
+ do_normalize: bool,
+ image_mean: Union[float, List[float]],
+ image_std: Union[float, List[float]],
+ data_format: Union[str, ChannelDimension],
+ image_type: ImageType,
+ ):
+ if return_tensors != "pt":
+ raise ValueError("Only returning PyTorch tensors is currently supported.")
+
+ if data_format != ChannelDimension.FIRST:
+ raise ValueError("Only channel first data format is currently supported.")
+
+ if do_resize and None in (size, resample, antialias):
+ raise ValueError("Size, resample and antialias must be specified if do_resize is True.")
+
+ if do_rescale and rescale_factor is None:
+ raise ValueError("Rescale factor must be specified if do_rescale is True.")
+
+ if do_normalize and None in (image_mean, image_std):
+ raise ValueError("Image mean and standard deviation must be specified if do_normalize is True.")
+
+ def preprocess(
+ self,
+ images: ImageInput,
+ do_resize: Optional[bool] = None,
+ size: Optional[Dict[str, int]] = None,
+ resample: Optional[PILImageResampling] = None,
+ antialias: Optional[bool] = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ return_tensors: Optional[Union[str, TensorType]] = "pt",
+ data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ):
+ """
+ Preprocess an image or batch of images.
+
+ Args:
+ images (`ImageInput`):
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
+ Dictionary in the format `{"height": h, "width": w}` specifying the size of the output image after
+ resizing.
+ resample (`PILImageResampling` filter, *optional*, defaults to `self.resample`):
+ `PILImageResampling` filter to use if resizing the image e.g. `PILImageResampling.BILINEAR`. Only has
+ an effect if `do_resize` is set to `True`.
+ antialias (`bool`, *optional*, defaults to `False`):
+ Whether to apply an anti-aliasing filter when resizing the image. It only affects tensors with
+ bilinear or bicubic modes and it is ignored otherwise.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image values between [0 - 1].
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
+ Image mean to use if `do_normalize` is set to `True`.
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
+ Image standard deviation to use if `do_normalize` is set to `True`.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Only "pt" is supported
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. The following formats are currently supported:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ """
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ resample = resample if resample is not None else self.resample
+ antialias = antialias if antialias is not None else self.antialias
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+ size = size if size is not None else self.size
+ # Make hashable for cache
+ size = SizeDict(**size)
+ image_mean = tuple(image_mean) if isinstance(image_mean, list) else image_mean
+ image_std = tuple(image_std) if isinstance(image_std, list) else image_std
+
+ images = make_list_of_images(images)
+ image_type = get_image_type(images[0])
+
+ if image_type not in [ImageType.PIL, ImageType.TORCH, ImageType.NUMPY]:
+ raise ValueError(f"Unsupported input image type {image_type}")
+
+ self._validate_input_arguments(
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ antialias=antialias,
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ return_tensors=return_tensors,
+ data_format=data_format,
+ image_type=image_type,
+ )
+
+ transforms = self.get_transforms(
+ do_resize=do_resize,
+ do_rescale=do_rescale,
+ do_normalize=do_normalize,
+ size=size,
+ resample=resample,
+ antialias=antialias,
+ rescale_factor=rescale_factor,
+ image_mean=image_mean,
+ image_std=image_std,
+ image_type=image_type,
+ )
+ transformed_images = [transforms(image) for image in images]
+
+ data = {"pixel_values": torch.stack(transformed_images, dim=0)}
+ return BatchFeature(data, tensor_type=return_tensors)
+
+ # Copied from transformers.models.depth_pro.image_processing_depth_pro.DepthProImageProcessor.post_process_depth_estimation
+ def post_process_depth_estimation(
+ self,
+ outputs: "DepthProDepthEstimatorOutput",
+ target_sizes: Optional[Union[TensorType, List[Tuple[int, int]], None]] = None,
+ ) -> Dict[str, List[TensorType]]:
+ """
+ Post-processes the raw depth predictions from the model to generate final depth predictions and optionally
+ resizes them to specified target sizes. This function supports scaling based on the field of view (FoV)
+ and adjusts depth values accordingly.
+
+ Args:
+ outputs ([`DepthProDepthEstimatorOutput`]):
+ Raw outputs of the model.
+ target_sizes (`Optional[Union[TensorType, List[Tuple[int, int]], None]]`, *optional*, defaults to `None`):
+ Target sizes to resize the depth predictions. Can be a tensor of shape `(batch_size, 2)`
+ or a list of tuples `(height, width)` for each image in the batch. If `None`, no resizing
+ is performed.
+
+ Returns:
+ `List[Dict[str, TensorType]]`: A list of dictionaries of tensors representing the processed depth
+ predictions.
+
+ Raises:
+ `ValueError`:
+ If the lengths of `predicted_depths`, `fovs`, or `target_sizes` are mismatched.
+ """
+ requires_backends(self, "torch")
+
+ predicted_depth = outputs.predicted_depth
+ fov = outputs.fov
+
+ batch_size = len(predicted_depth)
+
+ if target_sizes is not None and batch_size != len(target_sizes):
+ raise ValueError(
+ "Make sure that you pass in as many fov values as the batch dimension of the predicted depth"
+ )
+
+ results = []
+ fov = [None] * batch_size if fov is None else fov
+ target_sizes = [None] * batch_size if target_sizes is None else target_sizes
+ for depth, fov_value, target_size in zip(predicted_depth, fov, target_sizes):
+ if target_size is not None:
+ # scale image w.r.t fov
+ if fov_value is not None:
+ width = target_size[1]
+ fov_value = 0.5 * width / torch.tan(0.5 * torch.deg2rad(fov_value))
+ depth = depth * width / fov_value
+
+ # interpolate
+ depth = torch.nn.functional.interpolate(
+ # input should be (B, C, H, W)
+ input=depth.unsqueeze(0).unsqueeze(1),
+ size=target_size,
+ mode=pil_torch_interpolation_mapping[self.resample].value,
+ antialias=self.antialias,
+ ).squeeze()
+
+ # inverse the depth
+ depth = 1.0 / torch.clamp(depth, min=1e-4, max=1e4)
+
+ results.append(
+ {
+ "predicted_depth": depth,
+ "fov": fov_value,
+ }
+ )
+
+ return results
+
+
+__all__ = ["DepthProImageProcessorFast"]
diff --git a/src/transformers/models/depth_pro/modeling_depth_pro.py b/src/transformers/models/depth_pro/modeling_depth_pro.py
new file mode 100644
index 00000000000000..633d765b49f3f0
--- /dev/null
+++ b/src/transformers/models/depth_pro/modeling_depth_pro.py
@@ -0,0 +1,1686 @@
+# coding=utf-8
+# Copyright 2024 The Apple Research Team Authors and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch DepthPro model."""
+
+import math
+from dataclasses import dataclass
+from typing import List, Optional, Set, Tuple, Union
+
+import torch
+from torch import nn
+
+from ...activations import ACT2FN
+from ...modeling_outputs import BaseModelOutput
+from ...modeling_utils import PreTrainedModel
+from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
+from ...utils import (
+ ModelOutput,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+ torch_int,
+)
+from .configuration_depth_pro import DepthProConfig
+
+
+logger = logging.get_logger(__name__)
+
+# General docstring
+_CONFIG_FOR_DOC = "DepthProConfig"
+
+
+DEPTH_PRO_START_DOCSTRING = r"""
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`DepthProConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+DEPTH_PRO_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`DPTImageProcessor.__call__`]
+ for details.
+
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
+"""
+
+DEPTH_PRO_FOR_DEPTH_ESTIMATION_START_DOCSTRING = r"""
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`DepthProConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+ use_fov_model (`bool`, *optional*, defaults to `True`):
+ Whether to use `DepthProFOVModel` to generate the field of view.
+"""
+
+
+@dataclass
+class DepthProOutput(ModelOutput):
+ """
+ Base class for DepthPro's outputs.
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, n_patches_per_batch, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ features (`List[torch.FloatTensor]`, *optional*:
+ Features from scaled images and hidden_states.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, n_patches_per_batch, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer and the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, n_patches_per_batch, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ last_hidden_state: torch.FloatTensor = None
+ features: Optional[List[torch.FloatTensor]] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class DepthProDepthEstimatorOutput(ModelOutput):
+ """
+ Base class for DepthProForDepthEstimation's output.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Classification (or regression if config.num_labels==1) loss.
+ predicted_depth (`torch.FloatTensor` of shape `(batch_size, height, width)`):
+ Predicted depth for each pixel.
+ fov (`torch.FloatTensor` of shape `(batch_size,)`, *optional*, returned when `use_fov_model` is provided):
+ Field of View Scaler.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, n_patches_per_batch, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer and the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, n_patches_per_batch, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ predicted_depth: torch.FloatTensor = None
+ fov: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+
+def patch_to_batch(data: torch.Tensor, batch_size: int) -> torch.Tensor:
+ """
+ Converts tensor from shape:
+ (num_patches, seq_len, hidden_size) -> (batch_size, n_patches_per_batch, seq_len, hidden_size)
+ """
+ data = data.reshape(-1, batch_size, *data.shape[1:])
+ data = data.transpose(0, 1)
+ return data
+
+
+def batch_to_patch(data: torch.Tensor) -> torch.Tensor:
+ """
+ Converts tensor from shape:
+ (batch_size, n_patches_per_batch, seq_len, hidden_size) -> (num_patches, seq_len, hidden_size)
+ """
+ data = data.transpose(0, 1)
+ data = data.reshape(-1, *data.shape[2:])
+ return data
+
+
+# Copied from transformers.models.dinov2.modeling_dinov2.Dinov2PatchEmbeddings with Dinov2->DepthProViT
+class DepthProViTPatchEmbeddings(nn.Module):
+ """
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
+ Transformer.
+ """
+
+ # Ignore copy
+ # addition of config parameter patch_embeddings_size
+ def __init__(self, config):
+ super().__init__()
+
+ self.config = config
+ self.in_channels = config.num_channels
+ self.out_channels = config.hidden_size
+ self.patch_embeddings_size = config.patch_embeddings_size
+ self.num_channels = config.num_channels
+
+ self.projection = nn.Conv2d(
+ self.in_channels,
+ self.out_channels,
+ kernel_size=(self.patch_embeddings_size, self.patch_embeddings_size),
+ stride=(self.patch_embeddings_size, self.patch_embeddings_size),
+ )
+
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
+ num_channels = pixel_values.shape[1]
+ if num_channels != self.num_channels:
+ raise ValueError(
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+ f" Expected {self.num_channels} but got {num_channels}."
+ )
+ embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
+ return embeddings
+
+
+class DepthProViTEmbeddings(nn.Module):
+ """
+ Copied from transformers.models.dinov2.modeling_dinov2.Dinov2Embeddings
+ except antialias=True in interpolation and removal of mask_token
+ and enabling dynamic embeddings.
+ """
+
+ def __init__(self, config: DepthProConfig):
+ super().__init__()
+
+ self.config = config
+ self.seq_len = (config.patch_size // config.patch_embeddings_size) ** 2
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
+ self.patch_embeddings = DepthProViTPatchEmbeddings(config)
+ self.position_embeddings = nn.Parameter(torch.zeros(1, self.seq_len + 1, config.hidden_size))
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
+ """
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
+ images. This method is also adapted to support torch.jit tracing and interpolation at torch.float32 precision.
+
+ Adapted from:
+ - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
+ - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
+ """
+
+ num_positions = embeddings.shape[1] - 1
+
+ # always interpolate when tracing to ensure the exported model works for dynamic input shapes
+ if not torch.jit.is_tracing() and self.seq_len == num_positions and height == width:
+ return self.position_embeddings
+
+ class_pos_embed = self.position_embeddings[:, :1]
+ patch_pos_embed = self.position_embeddings[:, 1:]
+
+ dim = embeddings.shape[-1]
+
+ new_height = height // self.config.patch_embeddings_size
+ new_width = width // self.config.patch_embeddings_size
+
+ patch_pos_embed_size = torch_int(patch_pos_embed.shape[1] ** 0.5)
+ patch_pos_embed = patch_pos_embed.reshape(1, patch_pos_embed_size, patch_pos_embed_size, dim)
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
+ target_dtype = patch_pos_embed.dtype
+
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed.to(torch.float32),
+ size=(new_height, new_width),
+ mode="bicubic",
+ align_corners=False,
+ antialias=True, # except for this, the class is same as transformers.models.dinov2.modeling_dinov2.DepthProViTPatchEmbeddings
+ ).to(dtype=target_dtype)
+
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+
+ return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
+
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ batch_size: Optional[int] = None,
+ ) -> torch.Tensor:
+ n, _, height, width = pixel_values.shape
+ target_dtype = self.patch_embeddings.projection.weight.dtype
+ embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype))
+
+ # add the [CLS] token to the embedded patch tokens
+ cls_tokens = self.cls_token.expand(n, -1, -1)
+ embeddings = torch.cat((cls_tokens, embeddings), dim=1)
+
+ # add positional encoding to each token
+ embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
+
+ embeddings = self.dropout(embeddings)
+
+ if batch_size is not None:
+ embeddings = patch_to_batch(embeddings, batch_size)
+
+ return embeddings
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViTConfig->DepthProConfig, ViT->DepthProViT
+class DepthProViTSelfAttention(nn.Module):
+ def __init__(self, config: DepthProConfig) -> None:
+ super().__init__()
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+ raise ValueError(
+ f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
+ f"heads {config.num_attention_heads}."
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+ self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+ x = x.view(new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ # Ignore copy
+ # addition of parameter batch_size
+ def forward(
+ self,
+ hidden_states,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ batch_size: Optional[int] = None,
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
+ mixed_query_layer = self.query(hidden_states)
+
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
+
+ if batch_size is not None:
+ attention_probs_batched = patch_to_batch(attention_probs, batch_size)
+ attention_probs_patched = batch_to_patch(attention_probs_batched)
+ else:
+ attention_probs_patched = attention_probs_batched = attention_probs
+
+ context_layer = torch.matmul(attention_probs_patched, value_layer)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(new_context_layer_shape)
+
+ outputs = (context_layer, attention_probs_batched) if output_attentions else (context_layer,)
+
+ return outputs
+
+
+# Copied from transformers.models.dinov2.modeling_dinov2.Dinov2SdpaSelfAttention with Dinov2Config->DepthProConfig, Dinov2->DepthProViT
+class DepthProViTSdpaSelfAttention(DepthProViTSelfAttention):
+ def __init__(self, config: DepthProConfig) -> None:
+ super().__init__(config)
+ self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
+
+ # Ignore copy
+ # addition of `batch_size`
+ def forward(
+ self,
+ hidden_states,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ batch_size: Optional[int] = None,
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
+ if output_attentions:
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
+ logger.warning_once(
+ "DepthProViTModel is using DepthProViTSdpaSelfAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ return super().forward(
+ hidden_states=hidden_states,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ batch_size=batch_size,
+ )
+
+ mixed_query_layer = self.query(hidden_states)
+
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ context_layer = torch.nn.functional.scaled_dot_product_attention(
+ query_layer,
+ key_layer,
+ value_layer,
+ head_mask,
+ self.attention_probs_dropout_prob if self.training else 0.0,
+ is_causal=False,
+ scale=None,
+ )
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(new_context_layer_shape)
+
+ return context_layer, None
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViTConfig->DepthProConfig, ViT->DepthProViT
+class DepthProViTSelfOutput(nn.Module):
+ """
+ The residual connection is defined in DepthProViTLayer instead of here (as is the case with other models), due to the
+ layernorm applied before each block.
+ """
+
+ def __init__(self, config: DepthProConfig) -> None:
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+
+ return hidden_states
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViTConfig->DepthProConfig, ViT->DepthProViT
+class DepthProViTAttention(nn.Module):
+ def __init__(self, config: DepthProConfig) -> None:
+ super().__init__()
+ self.attention = DepthProViTSelfAttention(config)
+ self.output = DepthProViTSelfOutput(config)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads: Set[int]) -> None:
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.attention.query = prune_linear_layer(self.attention.query, index)
+ self.attention.key = prune_linear_layer(self.attention.key, index)
+ self.attention.value = prune_linear_layer(self.attention.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
+ self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ # Ignore copy
+ # addition of `batch_size`
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ batch_size: Optional[int] = None,
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
+ self_outputs = self.attention(hidden_states, head_mask, output_attentions, batch_size)
+
+ attention_output = self.output(self_outputs[0], hidden_states)
+
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViTConfig->DepthProConfig, ViT->DepthProViT
+class DepthProViTSdpaAttention(DepthProViTAttention):
+ def __init__(self, config: DepthProConfig) -> None:
+ super().__init__(config)
+ self.attention = DepthProViTSdpaSelfAttention(config)
+
+
+# Copied from transformers.models.dinov2.modeling_dinov2.Dinov2LayerScale with Dinov2Config->DepthProConfig, Dinov2->DepthProViT
+class DepthProViTLayerScale(nn.Module):
+ def __init__(self, config) -> None:
+ super().__init__()
+ self.lambda1 = nn.Parameter(config.layerscale_value * torch.ones(config.hidden_size))
+
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
+ return hidden_state * self.lambda1
+
+
+# Copied from transformers.models.beit.modeling_beit.drop_path
+def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
+ """
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+ Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
+ however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
+ layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
+ argument.
+ """
+ if drop_prob == 0.0 or not training:
+ return input
+ keep_prob = 1 - drop_prob
+ shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
+ random_tensor.floor_() # binarize
+ output = input.div(keep_prob) * random_tensor
+ return output
+
+
+# Copied from transformers.models.beit.modeling_beit.BeitDropPath
+class DepthProViTDropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob: Optional[float] = None) -> None:
+ super().__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ return drop_path(hidden_states, self.drop_prob, self.training)
+
+ def extra_repr(self) -> str:
+ return "p={}".format(self.drop_prob)
+
+
+# Copied from transformers.models.dinov2.modeling_dinov2.Dinov2MLP with Dinov2->DepthPro
+class DepthProViTMLP(nn.Module):
+ def __init__(self, config) -> None:
+ super().__init__()
+ in_features = out_features = config.hidden_size
+ hidden_features = int(config.hidden_size * config.mlp_ratio)
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=True)
+ if isinstance(config.hidden_act, str):
+ self.activation = ACT2FN[config.hidden_act]
+ else:
+ self.activation = config.hidden_act
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=True)
+
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
+ hidden_state = self.fc1(hidden_state)
+ hidden_state = self.activation(hidden_state)
+ hidden_state = self.fc2(hidden_state)
+ return hidden_state
+
+
+# Copied from transformers.models.dinov2.modeling_dinov2.Dinov2SwiGLUFFN with Dinov2->DepthPro
+class DepthProViTSwiGLUFFN(nn.Module):
+ def __init__(self, config) -> None:
+ super().__init__()
+ in_features = out_features = config.hidden_size
+ hidden_features = int(config.hidden_size * config.mlp_ratio)
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
+
+ self.weights_in = nn.Linear(in_features, 2 * hidden_features, bias=True)
+ self.weights_out = nn.Linear(hidden_features, out_features, bias=True)
+
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
+ hidden_state = self.weights_in(hidden_state)
+ x1, x2 = hidden_state.chunk(2, dim=-1)
+ hidden = nn.functional.silu(x1) * x2
+ return self.weights_out(hidden)
+
+
+DEPTHPROVIT_ATTENTION_CLASSES = {
+ "eager": DepthProViTAttention,
+ "sdpa": DepthProViTSdpaAttention,
+}
+
+
+# Copied from transformers.models.dinov2.modeling_dinov2.Dinov2Layer with Dinov2Config->DepthProConfig, Dinov2->DepthProViT all-casing
+class DepthProViTLayer(nn.Module):
+ """This corresponds to the Block class in the original implementation."""
+
+ def __init__(self, config: DepthProConfig) -> None:
+ super().__init__()
+
+ self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.attention = DEPTHPROVIT_ATTENTION_CLASSES[config._attn_implementation](config)
+ self.layer_scale1 = DepthProViTLayerScale(config)
+ self.drop_path = DepthProViTDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
+
+ self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ if config.use_swiglu_ffn:
+ self.mlp = DepthProViTSwiGLUFFN(config)
+ else:
+ self.mlp = DepthProViTMLP(config)
+ self.layer_scale2 = DepthProViTLayerScale(config)
+
+ # Ignore copy
+ # addition of `batch_size`
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ batch_size: Optional[int] = None,
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
+ if batch_size is not None:
+ hidden_states = batch_to_patch(hidden_states)
+
+ self_attention_outputs = self.attention(
+ self.norm1(hidden_states), # in DepthProViT, layernorm is applied before self-attention
+ head_mask,
+ output_attentions=output_attentions,
+ batch_size=batch_size,
+ )
+ attention_output = self_attention_outputs[0]
+
+ attention_output = self.layer_scale1(attention_output)
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
+
+ # first residual connection
+ hidden_states = self.drop_path(attention_output) + hidden_states
+
+ # in DepthProViT, layernorm is also applied after self-attention
+ layer_output = self.norm2(hidden_states)
+ layer_output = self.mlp(layer_output)
+ layer_output = self.layer_scale2(layer_output)
+
+ # second residual connection
+ layer_output = self.drop_path(layer_output) + hidden_states
+
+ if batch_size is not None:
+ layer_output = patch_to_batch(layer_output, batch_size)
+
+ outputs = (layer_output,) + outputs
+
+ return outputs
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViTConfig->DepthProConfig, ViT->DepthProViT
+class DepthProViTEncoder(nn.Module):
+ def __init__(self, config: DepthProConfig) -> None:
+ super().__init__()
+ self.config = config
+ self.layer = nn.ModuleList([DepthProViTLayer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ # Ignore copy
+ # addition of `batch_size`
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ batch_size: Optional[int] = None,
+ ) -> Union[tuple, BaseModelOutput]:
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+
+ for i, layer_module in enumerate(self.layer):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ layer_module.__call__,
+ hidden_states,
+ layer_head_mask,
+ output_attentions,
+ batch_size,
+ )
+ else:
+ layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions, batch_size)
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+
+class DepthProViT(nn.Module):
+ def __init__(self, config: DepthProConfig):
+ super().__init__()
+ self.config = config
+
+ self.embeddings = DepthProViTEmbeddings(config)
+ self.encoder = DepthProViTEncoder(config)
+
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ def forward(
+ self,
+ pixel_values: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ batch_size: Optional[int] = None,
+ ) -> Union[Tuple, BaseModelOutput]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ embedding_output = self.embeddings(pixel_values, batch_size=batch_size)
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ batch_size=batch_size,
+ )
+ sequence_output = encoder_outputs[0]
+ sequence_output = self.layernorm(sequence_output)
+
+ if not return_dict:
+ head_outputs = (sequence_output,)
+ return head_outputs + encoder_outputs[1:]
+
+ return BaseModelOutput(
+ last_hidden_state=sequence_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+class DepthProFeatureUpsample(nn.Module):
+ def __init__(self, config: DepthProConfig):
+ super().__init__()
+ self.config = config
+
+ self.upsample_blocks = nn.ModuleList()
+
+ # for image_features
+ self.upsample_blocks.append(
+ self._create_upsample_block(
+ input_dims=config.hidden_size,
+ intermediate_dims=config.hidden_size,
+ output_dims=config.scaled_images_feature_dims[0],
+ n_upsample_layers=1,
+ use_proj=False,
+ bias=True,
+ )
+ )
+
+ # for scaled_images_features
+ for i, feature_dims in enumerate(config.scaled_images_feature_dims):
+ upsample_block = self._create_upsample_block(
+ input_dims=config.hidden_size,
+ intermediate_dims=feature_dims,
+ output_dims=feature_dims,
+ n_upsample_layers=1,
+ )
+ self.upsample_blocks.append(upsample_block)
+
+ # for intermediate_features
+ for i, feature_dims in enumerate(config.intermediate_feature_dims):
+ intermediate_dims = config.fusion_hidden_size if i == 0 else feature_dims
+ upsample_block = self._create_upsample_block(
+ input_dims=config.hidden_size,
+ intermediate_dims=intermediate_dims,
+ output_dims=feature_dims,
+ n_upsample_layers=2 + i,
+ )
+ self.upsample_blocks.append(upsample_block)
+
+ def _create_upsample_block(
+ self,
+ input_dims: int,
+ intermediate_dims: int,
+ output_dims: int,
+ n_upsample_layers: int,
+ use_proj: bool = True,
+ bias: bool = False,
+ ) -> nn.Module:
+ upsample_block = nn.Sequential()
+
+ # create first projection layer
+ if use_proj:
+ proj = nn.Conv2d(
+ in_channels=input_dims,
+ out_channels=intermediate_dims,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=bias,
+ )
+ upsample_block.append(proj)
+
+ # create following upsample layers
+ for i in range(n_upsample_layers):
+ in_channels = intermediate_dims if i == 0 else output_dims
+ layer = nn.ConvTranspose2d(
+ in_channels=in_channels,
+ out_channels=output_dims,
+ kernel_size=2,
+ stride=2,
+ padding=0,
+ bias=bias,
+ )
+ upsample_block.append(layer)
+
+ return upsample_block
+
+ def forward(self, features: List[torch.Tensor]) -> List[torch.Tensor]:
+ upsampled_features = []
+ for i, upsample_block in enumerate(self.upsample_blocks):
+ upsampled_feature = upsample_block(features[i])
+ upsampled_features.append(upsampled_feature)
+ return upsampled_features
+
+
+class DepthProFeatureProjection(nn.Module):
+ def __init__(self, config: DepthProConfig):
+ super().__init__()
+ self.config = config
+
+ combined_feature_dims = config.scaled_images_feature_dims + config.intermediate_feature_dims
+ self.projections = nn.ModuleList()
+ for i, in_channels in enumerate(combined_feature_dims):
+ if i == len(combined_feature_dims) - 1 and in_channels == config.fusion_hidden_size:
+ # projection for last layer can be ignored if input and output channels already match
+ self.projections.append(nn.Identity())
+ else:
+ self.projections.append(
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=config.fusion_hidden_size,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False,
+ )
+ )
+
+ def forward(self, features: List[torch.Tensor]) -> List[torch.Tensor]:
+ projected_features = []
+ for i, projection in enumerate(self.projections):
+ upsampled_feature = projection(features[i])
+ projected_features.append(upsampled_feature)
+ return projected_features
+
+
+def interpolate(
+ pixel_values: torch.Tensor, size: Optional[int] = None, scale_factor: Optional[List[float]] = None
+) -> torch.Tensor:
+ return nn.functional.interpolate(
+ pixel_values,
+ size=size,
+ scale_factor=scale_factor,
+ mode="bilinear",
+ align_corners=False,
+ )
+
+
+def patch(pixel_values: torch.Tensor, patch_size: int, overlap_ratio: float) -> torch.Tensor:
+ """Creates Patches from Batch."""
+ batch_size, num_channels, height, width = pixel_values.shape
+
+ if height == width == patch_size:
+ # create patches only if scaled image is not already equal to patch size
+ return pixel_values
+
+ stride = int(patch_size * (1 - overlap_ratio))
+
+ # (batch_size, num_channels, height, width)
+ patches = torch.nn.functional.unfold(pixel_values, kernel_size=(patch_size, patch_size), stride=(stride, stride))
+ # patches.shape (batch_size, patch_size**2 * num_channels, n_patches_per_batch)
+ patches = patches.permute(2, 0, 1)
+ # patches.shape (n_patches_per_batch, batch_size, patch_size**2 * C)
+ patches = patches.reshape(-1, num_channels, patch_size, patch_size)
+ # patches.shape (n_patches, num_channels, patch_size, patch_size)
+
+ return patches
+
+
+def reshape_feature(hidden_states: torch.Tensor) -> torch.Tensor:
+ """Discard class token and reshape 1D feature map to a 2D grid."""
+ n_samples, seq_len, hidden_size = hidden_states.shape
+ size = int(math.sqrt(seq_len))
+
+ # (n_samples, seq_len, hidden_size)
+ hidden_states = hidden_states[:, 1:, :] # remove class token
+ # (n_samples, seq_len, hidden_size)
+ hidden_states = hidden_states.reshape(n_samples, size, size, hidden_size)
+ # (n_samples, size, size, hideden_size)
+ hidden_states = hidden_states.permute(0, 3, 1, 2)
+ # (n_samples, hideden_size, size, size)
+ return hidden_states
+
+
+def merge(patches: torch.Tensor, batch_size: int, merge_out_size: int) -> torch.Tensor:
+ n_patches, hidden_size, out_size, out_size = patches.shape
+ n_patches_per_batch = n_patches // batch_size
+ sqrt_n_patches_per_batch = int(math.sqrt(n_patches_per_batch))
+ new_out_size = sqrt_n_patches_per_batch * out_size
+
+ if n_patches == batch_size:
+ # merge only if the patches were created from scaled image
+ # patches are not created when scaled image size is equal to patch size
+ return patches
+
+ # calculate padding using the formula
+ # merge_out_size = (box_size - 2) * (out_size - 2 * padding) + (2) * (out_size - padding)
+ padding = (sqrt_n_patches_per_batch * out_size - merge_out_size) // (2 * sqrt_n_patches_per_batch - 2)
+
+ # patches.shape (n_patches, hidden_size, out_size, out_size)
+
+ merged = patches.reshape(n_patches_per_batch, batch_size, hidden_size, out_size, out_size)
+ # (n_patches_per_batch, batch_size, hidden_size, out_size, out_size)
+ merged = merged.permute(1, 2, 0, 3, 4)
+ # (batch_size, hidden_size, n_patches_per_batch, out_size, out_size)
+
+ merged = merged[:, :, : sqrt_n_patches_per_batch**2, :, :]
+ # (batch_size, hidden_size, n_patches_per_batch, out_size, out_size)
+
+ merged = merged.reshape(
+ batch_size, hidden_size, sqrt_n_patches_per_batch, sqrt_n_patches_per_batch, out_size, out_size
+ )
+ # (batch_size, hidden_size, sqrt_n_patches_per_batch, sqrt_n_patches_per_batch, out_size, out_size)
+ merged = merged.permute(0, 1, 2, 4, 3, 5)
+ # (batch_size, hidden_size, sqrt_n_patches_per_batch, out_size, sqrt_n_patches_per_batch, out_size)
+ merged = merged.reshape(batch_size, hidden_size, new_out_size, new_out_size)
+ # (batch_size, hidden_size, sqrt_n_patches_per_batch * out_size, sqrt_n_patches_per_batch * out_size)
+
+ if padding != 0:
+ padding_mask = torch.ones((new_out_size, new_out_size), dtype=torch.bool)
+ starting_index = torch.arange(start=out_size - padding, end=new_out_size - padding, step=out_size)
+ for index in starting_index:
+ padding_mask[index : index + padding * 2, :] = False
+ padding_mask[:, index : index + padding * 2] = False
+
+ merged = merged[:, :, padding_mask]
+ final_out_size = int(math.sqrt(merged.shape[-1]))
+ merged = merged.reshape(*merged.shape[:2], final_out_size, final_out_size)
+
+ return merged
+
+
+class DepthProEncoder(nn.Module):
+ def __init__(self, config: DepthProConfig):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.fusion_hidden_size = config.fusion_hidden_size
+
+ self.intermediate_hook_ids = config.intermediate_hook_ids
+ self.intermediate_feature_dims = config.intermediate_feature_dims
+ self.scaled_images_ratios = config.scaled_images_ratios
+ self.scaled_images_overlap_ratios = config.scaled_images_overlap_ratios
+ self.scaled_images_feature_dims = config.scaled_images_feature_dims
+
+ self.n_scaled_images = len(self.scaled_images_ratios)
+ self.n_intermediate_hooks = len(self.intermediate_hook_ids)
+ self.out_size = config.patch_size // config.patch_embeddings_size
+ self.seq_len = self.out_size**2 # each patch is flattened
+
+ # patch encoder
+ self.patch_encoder = DepthProViT(config)
+
+ # image encoder
+ self.image_encoder = DepthProViT(config)
+
+ # upsample features
+ self.feature_upsample = DepthProFeatureUpsample(config)
+
+ # for STEP 7: fuse low_res and image features
+ self.fuse_image_with_low_res = nn.Conv2d(
+ in_channels=config.scaled_images_feature_dims[0] * 2,
+ out_channels=config.scaled_images_feature_dims[0],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=True,
+ )
+
+ # project features
+ self.feature_projection = DepthProFeatureProjection(config)
+
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ) -> Union[tuple, DepthProOutput]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if pixel_values.dim() != 4:
+ raise ValueError("Input tensor must have shape (batch_size, num_channels, height, width).")
+
+ batch_size, num_channels, height, width = pixel_values.shape
+
+ if not (num_channels == self.config.num_channels):
+ raise ValueError(
+ f"Found {num_channels} channels in image, expected number of channels is {self.config.num_channels} from config."
+ )
+
+ if min(self.scaled_images_ratios) * min(height, width) < self.config.patch_size:
+ raise ValueError(
+ f"Image size {height}x{width} is too small to be scaled "
+ f"with scaled_images_ratios={self.scaled_images_ratios} "
+ f"when patch_size={self.config.patch_size}."
+ )
+
+ # pixel_values.shape (batch_size, num_channels, height, width)
+
+ # STEP 1: create 3-level image
+
+ scaled_images = []
+ for ratio in self.scaled_images_ratios:
+ scaled_images.append(interpolate(pixel_values, scale_factor=ratio))
+ # (batch_size, num_channels, height*ratio, width*ratio)
+
+ # STEP 2: create patches
+
+ for i in range(self.n_scaled_images):
+ scaled_images[i] = patch(
+ scaled_images[i],
+ patch_size=self.config.patch_size,
+ overlap_ratio=self.scaled_images_overlap_ratios[i],
+ )
+ # (n_patches_per_scaled_image[i], num_channels, patch_size, patch_size)
+ n_patches_per_scaled_image = [len(i) for i in scaled_images]
+ patches = torch.cat(scaled_images[::-1], dim=0) # -1 as patch encoder expects high res patches first
+ # (n_patches, num_channels, patch_size, patch_size)
+
+ # STEP 3: apply patch and image encoder
+
+ patch_encodings = self.patch_encoder(
+ patches,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ # required for intermediate features
+ output_hidden_states=self.n_intermediate_hooks or output_hidden_states,
+ return_dict=True,
+ batch_size=batch_size,
+ )
+ # patch_encodings.last_hidden_state (batch_size, n_patches/batch_size, seq_len, hidden_size)
+ # patch_encodings.hidden_states[i] (batch_size, n_patches/batch_size, seq_len, hidden_size)
+ # patch_encodings.attentions[i] (batch_size, n_patches/batch_size, num_heads, seq_len, seq_len)
+
+ last_hidden_state = patch_encodings.last_hidden_state
+ # (batch_size, n_patches/batch_size, seq_len, hidden_size)
+ last_hidden_state = batch_to_patch(last_hidden_state)
+ # (n_patches, seq_len, hidden_size)
+ scaled_images_last_hidden_state = torch.split_with_sizes(last_hidden_state, n_patches_per_scaled_image[::-1])
+ # (n_patches_per_scaled_image[i], seq_len, hidden_size)
+ scaled_images_last_hidden_state = scaled_images_last_hidden_state[::-1]
+ # (n_patches_per_scaled_image[i], seq_len, hidden_size)
+ # -1 (reverse list) as patch encoder expects high res patches first
+
+ # scale the image to patch size for image_encoder
+ image_scaled_to_patch_size = interpolate(
+ pixel_values,
+ size=(self.config.patch_size, self.config.patch_size),
+ )
+ image_encodings = self.image_encoder(
+ pixel_values=image_scaled_to_patch_size,
+ head_mask=head_mask,
+ )
+ # image_encodings.last_hidden_state (batch_size, seq_len, hidden_size)
+ # image_encodings.hidden_states[i] (batch_size, seq_len, hidden_size)
+ # image_encodings.attentions[i] (batch_size, num_heads, seq_len, seq_len)
+
+ # STEP 4: get patch features (high_res, med_res, low_res) - (3-5) in diagram
+
+ exponent_value = int(math.log2(width / self.out_size))
+ base_height = height // 2**exponent_value
+ base_width = width // 2**exponent_value
+
+ scaled_images_features = []
+ for i in range(self.n_scaled_images):
+ # a. extract hidden_state
+ hidden_state = scaled_images_last_hidden_state[i]
+ # (n_patches_per_scaled_image[i], seq_len, hidden_size)
+
+ # b. reshape back to image like
+ features = reshape_feature(hidden_state)
+ # (n_patches_per_scaled_image[i], hidden_size, out_size, out_size)
+
+ # c. merge patches back together
+ features = merge(
+ features, batch_size=batch_size, merge_out_size=self.out_size * 2**i
+ ) # (batch_size, hidden_size, out_size*2**i, out_size*2**i)
+
+ # d. interpolate patches to base size
+ features = interpolate(features, size=(base_height * 2**i, base_width * 2**i))
+ # (batch_size, hidden_size, base_height*2**i, base_width*2**i)
+
+ scaled_images_features.append(features)
+
+ # STEP 5: get intermediate features - (1-2) in diagram
+
+ intermediate_features = []
+ for i in range(self.n_intermediate_hooks):
+ # a. extract hidden_state
+ layer_id = (
+ self.intermediate_hook_ids[i] + 1
+ ) # +1 to correct index position as hidden_states contain embedding output as well
+ hidden_state = patch_encodings.hidden_states[layer_id]
+ hidden_state = batch_to_patch(hidden_state)
+ hidden_state = hidden_state[
+ : n_patches_per_scaled_image[-1]
+ ] # number of patches to be of same length as highest resolution
+ # (n_patches_per_scaled_image[-1], seq_len, hidden_size)
+
+ # b. reshape back to image like
+ features = reshape_feature(hidden_state)
+ # (n_patches_per_scaled_image[-1], hidden_size, out_size, out_size)
+
+ # c. merge patches back together
+ features = merge(
+ features,
+ batch_size=batch_size,
+ merge_out_size=self.out_size * 2 ** (self.n_scaled_images - 1),
+ ) # (batch_size, hidden_size, out_size*2**(n_scaled_images-1), out_size*2**(n_scaled_images-1))
+
+ # d. interpolate patches to base size
+ features = interpolate(
+ features,
+ size=(base_height * 2 ** (self.n_scaled_images - 1), base_width * 2 ** (self.n_scaled_images - 1)),
+ )
+ # (batch_size, hidden_size, base_height*2**(n_scaled_images - 1), base_width*2**(n_scaled_images - 1))
+
+ intermediate_features.append(features)
+
+ # STEP 6: get image features - (6) in diagram
+
+ # a. extract hidden_state
+ hidden_state = image_encodings.last_hidden_state # (batch_size, seq_len, hidden_size)
+
+ # b. reshape back to image like
+ image_features = reshape_feature(hidden_state)
+ # (batch_size, hidden_size, out_size, out_size)
+
+ # c. merge patches back together
+ # no merge required for image_features as they are already in batches instead of patches
+
+ # d. interpolate patches to base size
+ image_features = interpolate(image_features, size=(base_height, base_width))
+ # (batch_size, hidden_size, base_height, base_width)
+
+ # STEP 7: combine all features
+ features = [
+ image_features,
+ # (batch_size, scaled_images_feature_dims[0], base_height*2, base_width*2)
+ *scaled_images_features,
+ # (batch_size, scaled_images_feature_dims[i], base_height*2**(i+1), base_width*2**(i+1))
+ *intermediate_features,
+ # (batch_size, intermediate_feature_dims[i], base_height*2**(n_scaled_images+i+1), base_width*2**(n_scaled_images+i+1))
+ ]
+
+ # STEP 8: upsample features
+ features = self.feature_upsample(features)
+
+ # STEP 9: apply fusion
+ # (global features = low res features + image features)
+ # fuses image_features with lowest resolution features as they are of same size
+ global_features = torch.cat((features[1], features[0]), dim=1)
+ global_features = self.fuse_image_with_low_res(global_features)
+ features = [global_features, *features[2:]]
+
+ # STEP 10: project features
+ features = self.feature_projection(features)
+
+ # STEP 11: return output
+
+ last_hidden_state = patch_encodings.last_hidden_state
+ hidden_states = patch_encodings.hidden_states if output_hidden_states else None
+ attentions = patch_encodings.attentions if output_attentions else None
+
+ if not return_dict:
+ return tuple(v for v in [last_hidden_state, features, hidden_states, attentions] if v is not None)
+
+ return DepthProOutput(
+ last_hidden_state=last_hidden_state,
+ features=features,
+ hidden_states=hidden_states,
+ attentions=attentions,
+ )
+
+
+class DepthProPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = DepthProConfig
+ base_model_prefix = "depth_pro"
+ main_input_name = "pixel_values"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["DepthProViTSwiGLUFFN"]
+ _supports_sdpa = True
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+
+@add_start_docstrings(
+ "The bare DepthPro Model transformer outputting raw hidden-states without any specific head on top.",
+ DEPTH_PRO_START_DOCSTRING,
+)
+class DepthProModel(DepthProPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.config = config
+ self.encoder = DepthProEncoder(config)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.encoder.patch_encoder.embeddings.patch_embeddings
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.patch_encoder.encoder.layer[layer].attention.prune_heads(heads)
+ self.encoder.image_encoder.encoder.layer[layer].attention.prune_heads(heads)
+
+ @add_start_docstrings_to_model_forward(DEPTH_PRO_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor,
+ head_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, DepthProOutput]:
+ r"""
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> import torch
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, DepthProModel
+
+ >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> checkpoint = "geetu040/DepthPro"
+ >>> processor = AutoProcessor.from_pretrained(checkpoint)
+ >>> model = DepthProModel.from_pretrained(checkpoint)
+
+ >>> # prepare image for the model
+ >>> inputs = processor(images=image, return_tensors="pt")
+
+ >>> with torch.no_grad():
+ ... output = model(**inputs)
+
+ >>> output.last_hidden_state.shape
+ torch.Size([1, 35, 577, 1024])
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ encodings = self.encoder(
+ pixel_values,
+ head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ return encodings
+
+
+# Copied from transformers.models.dpt.modeling_dpt.DPTPreActResidualLayer DPT->DepthPro
+class DepthProPreActResidualLayer(nn.Module):
+ """
+ ResidualConvUnit, pre-activate residual unit.
+
+ Args:
+ config (`[DepthProConfig]`):
+ Model configuration class defining the model architecture.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+
+ self.use_batch_norm = config.use_batch_norm_in_fusion_residual
+ use_bias_in_fusion_residual = (
+ config.use_bias_in_fusion_residual
+ if config.use_bias_in_fusion_residual is not None
+ else not self.use_batch_norm
+ )
+
+ self.activation1 = nn.ReLU()
+ self.convolution1 = nn.Conv2d(
+ config.fusion_hidden_size,
+ config.fusion_hidden_size,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=use_bias_in_fusion_residual,
+ )
+
+ self.activation2 = nn.ReLU()
+ self.convolution2 = nn.Conv2d(
+ config.fusion_hidden_size,
+ config.fusion_hidden_size,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=use_bias_in_fusion_residual,
+ )
+
+ if self.use_batch_norm:
+ self.batch_norm1 = nn.BatchNorm2d(config.fusion_hidden_size)
+ self.batch_norm2 = nn.BatchNorm2d(config.fusion_hidden_size)
+
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
+ residual = hidden_state
+ hidden_state = self.activation1(hidden_state)
+
+ hidden_state = self.convolution1(hidden_state)
+
+ if self.use_batch_norm:
+ hidden_state = self.batch_norm1(hidden_state)
+
+ hidden_state = self.activation2(hidden_state)
+ hidden_state = self.convolution2(hidden_state)
+
+ if self.use_batch_norm:
+ hidden_state = self.batch_norm2(hidden_state)
+
+ return hidden_state + residual
+
+
+# Taken from transformers.models.dpt.modeling_dpt.DPTFeatureFusionLayer
+# except it uses deconv and skip_add and needs no interpolation
+class DepthProFeatureFusionLayer(nn.Module):
+ def __init__(self, config: DepthProConfig, use_deconv: bool = True):
+ super().__init__()
+ self.config = config
+ self.use_deconv = use_deconv
+
+ self.residual_layer1 = DepthProPreActResidualLayer(config)
+ self.residual_layer2 = DepthProPreActResidualLayer(config)
+
+ if self.use_deconv:
+ self.deconv = nn.ConvTranspose2d(
+ in_channels=config.fusion_hidden_size,
+ out_channels=config.fusion_hidden_size,
+ kernel_size=2,
+ stride=2,
+ padding=0,
+ bias=False,
+ )
+
+ self.projection = nn.Conv2d(config.fusion_hidden_size, config.fusion_hidden_size, kernel_size=1, bias=True)
+ self.skip_add = nn.quantized.FloatFunctional()
+
+ def forward(self, hidden_state: torch.Tensor, residual: Optional[torch.Tensor] = None) -> torch.Tensor:
+ if residual is not None:
+ hidden_state = self.skip_add.add(hidden_state, self.residual_layer1(residual))
+
+ hidden_state = self.residual_layer2(hidden_state)
+ if self.use_deconv:
+ hidden_state = self.deconv(hidden_state)
+ hidden_state = self.projection(hidden_state)
+
+ return hidden_state
+
+
+# Take from transformers.models.dpt.modeling_dpt.DPTFeatureFusionStage with DPT->DepthPro
+# with deconv and reversed layers
+class DepthProFeatureFusionStage(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+
+ self.num_layers = len(config.intermediate_hook_ids) + len(config.scaled_images_ratios)
+ self.layers = nn.ModuleList()
+ for _ in range(self.num_layers - 1):
+ self.layers.append(DepthProFeatureFusionLayer(config))
+ # final layer doesnot require deconvolution
+ self.layers.append(DepthProFeatureFusionLayer(config, use_deconv=False))
+
+ def forward(self, hidden_states: List[torch.Tensor]) -> List[torch.Tensor]:
+ if self.num_layers != len(hidden_states):
+ raise ValueError(
+ f"num_layers={self.num_layers} in DepthProFeatureFusionStage"
+ f"doesnot match len(hidden_states)={len(hidden_states)}"
+ )
+
+ fused_hidden_states = []
+ fused_hidden_state = None
+ for hidden_state, layer in zip(hidden_states, self.layers):
+ if fused_hidden_state is None:
+ # first layer only uses the last hidden_state
+ fused_hidden_state = layer(hidden_state)
+ else:
+ fused_hidden_state = layer(fused_hidden_state, hidden_state)
+ fused_hidden_states.append(fused_hidden_state)
+
+ return fused_hidden_states
+
+
+class DepthProFOVModel(nn.Module):
+ def __init__(self, config: DepthProConfig):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.fusion_hidden_size = config.fusion_hidden_size
+
+ self.out_size = config.patch_size // config.patch_embeddings_size
+
+ self.encoder = DepthProViT(config)
+ self.encoder_neck = nn.Linear(self.hidden_size, self.fusion_hidden_size // 2)
+ self.global_neck = nn.Sequential(
+ nn.Conv2d(self.fusion_hidden_size, self.fusion_hidden_size // 2, kernel_size=3, stride=2, padding=1),
+ nn.ReLU(True),
+ )
+
+ # create initial head layers
+ self.head = nn.Sequential()
+ for i in range(config.num_fov_head_layers):
+ self.head.append(
+ nn.Conv2d(
+ math.ceil(self.fusion_hidden_size / 2 ** (i + 1)),
+ math.ceil(self.fusion_hidden_size / 2 ** (i + 2)),
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ )
+ )
+ self.head.append(nn.ReLU(True))
+ # calculate expected shapes to finally generate a scalar output from final head layer
+ final_in_channels = math.ceil(self.fusion_hidden_size / 2 ** (config.num_fov_head_layers + 1))
+ final_kernal_size = int((self.out_size - 1) / 2**config.num_fov_head_layers + 1)
+ self.head.append(
+ nn.Conv2d(
+ in_channels=final_in_channels, out_channels=1, kernel_size=final_kernal_size, stride=1, padding=0
+ )
+ )
+
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ global_features: torch.Tensor,
+ head_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ batch_size, num_channels, height, width = pixel_values.shape
+
+ # follow the steps same as with image features in DepthProEncoder
+ # except for the extra encoder_neck layer applied
+
+ image_scaled_to_patch_size = interpolate(
+ pixel_values,
+ size=(self.config.patch_size, self.config.patch_size),
+ )
+ encodings = self.encoder(
+ image_scaled_to_patch_size,
+ head_mask=head_mask,
+ )
+
+ # a. extract hidden_state
+ hidden_state = encodings.last_hidden_state # (batch_size, seq_len, hidden_size)
+ # extra step
+ hidden_state = self.encoder_neck(hidden_state)
+ # (batch_size, seq_len, fusion_hidden_size//2)
+
+ # b. reshape back to image like
+ fov_features = reshape_feature(hidden_state)
+ # (batch_size, fusion_hidden_size//2, out_size, out_size)
+
+ # c. merge patches back together
+ # no merge required for fov_features as they are already in batches instead of patches
+
+ # d. interpolate patches to base size
+ # skip; instead interpolate the global features
+
+ global_features = self.global_neck(global_features)
+ global_features = interpolate(global_features, size=(self.out_size, self.out_size))
+
+ fov_features = fov_features + global_features
+ fov_output = self.head(fov_features)
+ fov_output = fov_output.reshape(batch_size)
+
+ return fov_output
+
+
+class DepthProDepthEstimationHead(nn.Module):
+ """
+ The DepthProDepthEstimationHead module serves as the output head for depth estimation tasks.
+ This module comprises a sequence of convolutional and transposed convolutional layers
+ that process the feature map from the fusion to produce a single-channel depth map.
+ Key operations include dimensionality reduction and upsampling to match the input resolution.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+
+ features = config.fusion_hidden_size
+ self.head = nn.Sequential(
+ nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
+ nn.ConvTranspose2d(
+ in_channels=features // 2, out_channels=features // 2, kernel_size=2, stride=2, padding=0, bias=True
+ ),
+ nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(True),
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
+ nn.ReLU(),
+ )
+
+ def forward(self, hidden_states: List[torch.Tensor]) -> torch.Tensor:
+ predicted_depth = self.head(hidden_states)
+ predicted_depth = predicted_depth.squeeze(dim=1)
+ return predicted_depth
+
+
+@add_start_docstrings(
+ """
+ DepthPro Model with a depth estimation head on top (consisting of 3 convolutional layers).
+ """,
+ DEPTH_PRO_FOR_DEPTH_ESTIMATION_START_DOCSTRING,
+)
+class DepthProForDepthEstimation(DepthProPreTrainedModel):
+ def __init__(self, config, use_fov_model=None):
+ super().__init__(config)
+ self.config = config
+ self.use_fov_model = use_fov_model if use_fov_model is not None else self.config.use_fov_model
+
+ # dinov2 (vit) like encoders
+ self.depth_pro = DepthProModel(config)
+
+ # dpt (vit) like fusion stage
+ self.fusion_stage = DepthProFeatureFusionStage(config)
+
+ # depth estimation head
+ self.head = DepthProDepthEstimationHead(config)
+
+ # dinov2 (vit) like encoder
+ self.fov_model = DepthProFOVModel(config) if self.use_fov_model else None
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(DEPTH_PRO_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=DepthProDepthEstimatorOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor,
+ head_mask: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor]]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
+ Ground truth depth estimation maps for computing the loss.
+
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoImageProcessor, DepthProForDepthEstimation
+ >>> import torch
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> checkpoint = "geetu040/DepthPro"
+ >>> processor = AutoImageProcessor.from_pretrained(checkpoint)
+ >>> model = DepthProForDepthEstimation.from_pretrained(checkpoint)
+
+ >>> # prepare image for the model
+ >>> inputs = processor(images=image, return_tensors="pt")
+
+ >>> with torch.no_grad():
+ ... outputs = model(**inputs)
+
+ >>> # interpolate to original size
+ >>> post_processed_output = processor.post_process_depth_estimation(
+ ... outputs, target_sizes=[(image.height, image.width)],
+ ... )
+
+ >>> # visualize the prediction
+ >>> predicted_depth = post_processed_output[0]["predicted_depth"]
+ >>> depth = predicted_depth * 255 / predicted_depth.max()
+ >>> depth = depth.detach().cpu().numpy()
+ >>> depth = Image.fromarray(depth.astype("uint8"))
+ ```"""
+ loss = None
+ if labels is not None:
+ raise NotImplementedError("Training is not implemented yet")
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+
+ depth_pro_outputs = self.depth_pro(
+ pixel_values=pixel_values,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ )
+ features = depth_pro_outputs.features
+ fused_hidden_states = self.fusion_stage(features)
+ predicted_depth = self.head(fused_hidden_states[-1])
+
+ fov = (
+ self.fov_model(
+ pixel_values=pixel_values,
+ # frozon features from encoder are used
+ global_features=features[0].detach(),
+ head_mask=head_mask,
+ )
+ if self.use_fov_model
+ else None
+ )
+
+ if not return_dict:
+ outputs = [loss, predicted_depth, fov, depth_pro_outputs.hidden_states, depth_pro_outputs.attentions]
+ return tuple(v for v in outputs if v is not None)
+
+ return DepthProDepthEstimatorOutput(
+ loss=loss,
+ predicted_depth=predicted_depth,
+ fov=fov,
+ hidden_states=depth_pro_outputs.hidden_states,
+ attentions=depth_pro_outputs.attentions,
+ )
+
+
+__all__ = ["DepthProPreTrainedModel", "DepthProModel", "DepthProForDepthEstimation"]
diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py
index e3463461ea07e5..419bb07665e858 100644
--- a/src/transformers/utils/dummy_pt_objects.py
+++ b/src/transformers/utils/dummy_pt_objects.py
@@ -3551,6 +3551,27 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
+class DepthProForDepthEstimation(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class DepthProModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class DepthProPreTrainedModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
class DetrForObjectDetection(metaclass=DummyObject):
_backends = ["torch"]
diff --git a/src/transformers/utils/dummy_vision_objects.py b/src/transformers/utils/dummy_vision_objects.py
index 3ebda4404aae9c..2b3e46e63cc55b 100644
--- a/src/transformers/utils/dummy_vision_objects.py
+++ b/src/transformers/utils/dummy_vision_objects.py
@@ -184,6 +184,20 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])
+class DepthProImageProcessor(metaclass=DummyObject):
+ _backends = ["vision"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["vision"])
+
+
+class DepthProImageProcessorFast(metaclass=DummyObject):
+ _backends = ["vision"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["vision"])
+
+
class DetrFeatureExtractor(metaclass=DummyObject):
_backends = ["vision"]
diff --git a/tests/models/depth_pro/__init__.py b/tests/models/depth_pro/__init__.py
new file mode 100644
index 00000000000000..e69de29bb2d1d6
diff --git a/tests/models/depth_pro/test_image_processing_depth_pro.py b/tests/models/depth_pro/test_image_processing_depth_pro.py
new file mode 100644
index 00000000000000..e9d94151e145ec
--- /dev/null
+++ b/tests/models/depth_pro/test_image_processing_depth_pro.py
@@ -0,0 +1,113 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import unittest
+
+from transformers.file_utils import is_vision_available
+from transformers.testing_utils import require_torch, require_vision
+
+from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
+
+
+if is_vision_available():
+ from transformers import DepthProImageProcessor, DepthProImageProcessorFast
+
+
+class DepthProImageProcessingTester(unittest.TestCase):
+ def __init__(
+ self,
+ parent,
+ batch_size=7,
+ num_channels=3,
+ image_size=18,
+ min_resolution=30,
+ max_resolution=400,
+ do_resize=True,
+ size=None,
+ do_normalize=True,
+ image_mean=[0.5, 0.5, 0.5],
+ image_std=[0.5, 0.5, 0.5],
+ ):
+ super().__init__()
+ size = size if size is not None else {"height": 18, "width": 18}
+ self.parent = parent
+ self.batch_size = batch_size
+ self.num_channels = num_channels
+ self.image_size = image_size
+ self.min_resolution = min_resolution
+ self.max_resolution = max_resolution
+ self.do_resize = do_resize
+ self.size = size
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean
+ self.image_std = image_std
+
+ def prepare_image_processor_dict(self):
+ return {
+ "image_mean": self.image_mean,
+ "image_std": self.image_std,
+ "do_normalize": self.do_normalize,
+ "do_resize": self.do_resize,
+ "size": self.size,
+ }
+
+ def expected_output_image_shape(self, images):
+ return self.num_channels, self.size["height"], self.size["width"]
+
+ def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):
+ return prepare_image_inputs(
+ batch_size=self.batch_size,
+ num_channels=self.num_channels,
+ min_resolution=self.min_resolution,
+ max_resolution=self.max_resolution,
+ equal_resolution=equal_resolution,
+ numpify=numpify,
+ torchify=torchify,
+ )
+
+
+@require_torch
+@require_vision
+class DepthProImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
+ image_processing_class = DepthProImageProcessor if is_vision_available() else None
+ fast_image_processing_class = DepthProImageProcessorFast if is_vision_available() else None
+
+ def setUp(self):
+ super().setUp()
+ self.image_processor_tester = DepthProImageProcessingTester(self)
+
+ @property
+ def image_processor_dict(self):
+ return self.image_processor_tester.prepare_image_processor_dict()
+
+ def test_image_processor_properties(self):
+ image_processing = self.image_processing_class(**self.image_processor_dict)
+ self.assertTrue(hasattr(image_processing, "image_mean"))
+ self.assertTrue(hasattr(image_processing, "image_std"))
+ self.assertTrue(hasattr(image_processing, "do_normalize"))
+ self.assertTrue(hasattr(image_processing, "do_resize"))
+ self.assertTrue(hasattr(image_processing, "size"))
+ self.assertTrue(hasattr(image_processing, "do_rescale"))
+ self.assertTrue(hasattr(image_processing, "rescale_factor"))
+ self.assertTrue(hasattr(image_processing, "resample"))
+ self.assertTrue(hasattr(image_processing, "antialias"))
+
+ def test_image_processor_from_dict_with_kwargs(self):
+ image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
+ self.assertEqual(image_processor.size, {"height": 18, "width": 18})
+
+ image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42)
+ self.assertEqual(image_processor.size, {"height": 42, "width": 42})
diff --git a/tests/models/depth_pro/test_modeling_depth_pro.py b/tests/models/depth_pro/test_modeling_depth_pro.py
new file mode 100644
index 00000000000000..ad17476c664dcf
--- /dev/null
+++ b/tests/models/depth_pro/test_modeling_depth_pro.py
@@ -0,0 +1,375 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Testing suite for the PyTorch DepthPro model."""
+
+import unittest
+
+from transformers import DepthProConfig
+from transformers.file_utils import is_torch_available, is_vision_available
+from transformers.testing_utils import require_torch, require_vision, slow, torch_device
+
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
+from ...test_pipeline_mixin import PipelineTesterMixin
+
+
+if is_torch_available():
+ import torch
+ from torch import nn
+
+ from transformers import DepthProForDepthEstimation, DepthProModel
+ from transformers.models.auto.modeling_auto import MODEL_MAPPING_NAMES
+
+
+if is_vision_available():
+ from PIL import Image
+
+ from transformers import DepthProImageProcessor
+
+
+class DepthProModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=8,
+ image_size=64,
+ patch_size=8,
+ patch_embeddings_size=4,
+ num_channels=3,
+ is_training=True,
+ use_labels=True,
+ hidden_size=32,
+ fusion_hidden_size=16,
+ intermediate_hook_ids=[0],
+ intermediate_feature_dims=[8],
+ scaled_images_ratios=[0.5, 1.0],
+ scaled_images_overlap_ratios=[0.0, 0.2],
+ scaled_images_feature_dims=[12, 12],
+ num_hidden_layers=1,
+ num_attention_heads=4,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ initializer_range=0.02,
+ use_fov_model=False,
+ num_labels=3,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.patch_embeddings_size = patch_embeddings_size
+ self.num_channels = num_channels
+ self.is_training = is_training
+ self.use_labels = use_labels
+ self.hidden_size = hidden_size
+ self.fusion_hidden_size = fusion_hidden_size
+ self.intermediate_hook_ids = intermediate_hook_ids
+ self.intermediate_feature_dims = intermediate_feature_dims
+ self.scaled_images_ratios = scaled_images_ratios
+ self.scaled_images_overlap_ratios = scaled_images_overlap_ratios
+ self.scaled_images_feature_dims = scaled_images_feature_dims
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.initializer_range = initializer_range
+ self.use_fov_model = use_fov_model
+ self.num_labels = num_labels
+
+ self.num_patches = (patch_size // patch_embeddings_size) ** 2
+ self.seq_length = (patch_size // patch_embeddings_size) ** 2 + 1 # we add 1 for the [CLS] token
+
+ n_fusion_blocks = len(intermediate_hook_ids) + len(scaled_images_ratios)
+ self.expected_depth_size = 2 ** (n_fusion_blocks + 1) * patch_size // patch_embeddings_size
+
+ def prepare_config_and_inputs(self):
+ pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
+
+ labels = None
+ if self.use_labels:
+ labels = ids_tensor([self.batch_size, self.image_size, self.image_size], self.num_labels)
+
+ config = self.get_config()
+
+ return config, pixel_values, labels
+
+ def get_config(self):
+ return DepthProConfig(
+ image_size=self.image_size,
+ patch_size=self.patch_size,
+ patch_embeddings_size=self.patch_embeddings_size,
+ num_channels=self.num_channels,
+ hidden_size=self.hidden_size,
+ fusion_hidden_size=self.fusion_hidden_size,
+ intermediate_hook_ids=self.intermediate_hook_ids,
+ intermediate_feature_dims=self.intermediate_feature_dims,
+ scaled_images_ratios=self.scaled_images_ratios,
+ scaled_images_overlap_ratios=self.scaled_images_overlap_ratios,
+ scaled_images_feature_dims=self.scaled_images_feature_dims,
+ num_hidden_layers=self.num_hidden_layers,
+ num_attention_heads=self.num_attention_heads,
+ hidden_act=self.hidden_act,
+ hidden_dropout_prob=self.hidden_dropout_prob,
+ attention_probs_dropout_prob=self.attention_probs_dropout_prob,
+ initializer_range=self.initializer_range,
+ use_fov_model=self.use_fov_model,
+ )
+
+ def create_and_check_model(self, config, pixel_values, labels):
+ model = DepthProModel(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(pixel_values)
+ num_patches = result.last_hidden_state.shape[1] # num_patches are created dynamically
+ self.parent.assertEqual(
+ result.last_hidden_state.shape, (self.batch_size, num_patches, self.seq_length, self.hidden_size)
+ )
+
+ def create_and_check_for_depth_estimation(self, config, pixel_values, labels):
+ config.num_labels = self.num_labels
+ model = DepthProForDepthEstimation(config)
+ model.to(torch_device)
+ model.eval()
+ result = model(pixel_values)
+ self.parent.assertEqual(
+ result.predicted_depth.shape, (self.batch_size, self.expected_depth_size, self.expected_depth_size)
+ )
+
+ def create_and_check_for_fov(self, config, pixel_values, labels):
+ model = DepthProForDepthEstimation(config, use_fov_model=True)
+ model.to(torch_device)
+ model.eval()
+
+ # check if the fov_model (DinoV2-based encoder) is created
+ self.parent.assertIsNotNone(model.fov_model)
+
+ batched_pixel_values = pixel_values
+ row_pixel_values = pixel_values[:1]
+
+ with torch.no_grad():
+ model_batched_output_fov = model(batched_pixel_values).fov
+ model_row_output_fov = model(row_pixel_values).fov
+
+ # check if fov is returned
+ self.parent.assertIsNotNone(model_batched_output_fov)
+ self.parent.assertIsNotNone(model_row_output_fov)
+
+ # check output shape consistency for fov
+ self.parent.assertEqual(model_batched_output_fov.shape, (self.batch_size,))
+
+ # check equivalence between batched and single row outputs for fov
+ diff = torch.max(torch.abs(model_row_output_fov - model_batched_output_fov[:1]))
+ model_name = model.__class__.__name__
+ self.parent.assertTrue(
+ diff <= 1e-03,
+ msg=(f"Batched and Single row outputs are not equal in {model_name} for fov. " f"Difference={diff}."),
+ )
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ config, pixel_values, labels = config_and_inputs
+ inputs_dict = {"pixel_values": pixel_values}
+ return config, inputs_dict
+
+
+@require_torch
+class DepthProModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
+ """
+ Here we also overwrite some of the tests of test_modeling_common.py, as DepthPro does not use input_ids, inputs_embeds,
+ attention_mask and seq_length.
+ """
+
+ all_model_classes = (DepthProModel, DepthProForDepthEstimation) if is_torch_available() else ()
+ pipeline_model_mapping = (
+ {
+ "depth-estimation": DepthProForDepthEstimation,
+ "image-feature-extraction": DepthProModel,
+ }
+ if is_torch_available()
+ else {}
+ )
+
+ test_pruning = False
+ test_resize_embeddings = False
+ test_head_masking = False
+
+ def setUp(self):
+ self.model_tester = DepthProModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=DepthProConfig, has_text_modality=False, hidden_size=37)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ @unittest.skip(reason="DepthPro does not use inputs_embeds")
+ def test_inputs_embeds(self):
+ pass
+
+ def test_model_get_set_embeddings(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
+ x = model.get_output_embeddings()
+ self.assertTrue(x is None or isinstance(x, nn.Linear))
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_for_depth_estimation(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_depth_estimation(*config_and_inputs)
+
+ def test_for_fov(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_fov(*config_and_inputs)
+
+ def test_training(self):
+ for model_class in self.all_model_classes:
+ if model_class.__name__ == "DepthProForDepthEstimation":
+ continue
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.return_dict = True
+
+ if model_class.__name__ in MODEL_MAPPING_NAMES.values():
+ continue
+
+ model = model_class(config)
+ model.to(torch_device)
+ model.train()
+ inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
+ loss = model(**inputs).loss
+ loss.backward()
+
+ def test_training_gradient_checkpointing(self):
+ for model_class in self.all_model_classes:
+ if model_class.__name__ == "DepthProForDepthEstimation":
+ continue
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.use_cache = False
+ config.return_dict = True
+
+ if model_class.__name__ in MODEL_MAPPING_NAMES.values() or not model_class.supports_gradient_checkpointing:
+ continue
+ model = model_class(config)
+ model.to(torch_device)
+ model.gradient_checkpointing_enable()
+ model.train()
+ inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
+ loss = model(**inputs).loss
+ loss.backward()
+
+ @unittest.skip(
+ reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
+ )
+ def test_training_gradient_checkpointing_use_reentrant(self):
+ pass
+
+ @unittest.skip(
+ reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
+ )
+ def test_training_gradient_checkpointing_use_reentrant_false(self):
+ pass
+
+ def test_initialization(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ configs_no_init = _config_zero_init(config)
+ for model_class in self.all_model_classes:
+ model = model_class(config=configs_no_init)
+ # Skip the check for the backbone
+ backbone_params = []
+ for name, module in model.named_modules():
+ if module.__class__.__name__ == "DepthProViTHybridEmbeddings":
+ backbone_params = [f"{name}.{key}" for key in module.state_dict().keys()]
+ break
+
+ for name, param in model.named_parameters():
+ if param.requires_grad:
+ if name in backbone_params:
+ continue
+ self.assertIn(
+ ((param.data.mean() * 1e9).round() / 1e9).item(),
+ [0.0, 1.0],
+ msg=f"Parameter {name} of model {model_class} seems not properly initialized",
+ )
+
+ @slow
+ def test_model_from_pretrained(self):
+ model_path = "geetu040/DepthPro"
+ model = DepthProModel.from_pretrained(model_path)
+ self.assertIsNotNone(model)
+
+
+# We will verify our results on an image of cute cats
+def prepare_img():
+ image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
+ return image
+
+
+@require_torch
+@require_vision
+@slow
+class DepthProModelIntegrationTest(unittest.TestCase):
+ def test_inference_depth_estimation(self):
+ model_path = "geetu040/DepthPro"
+ image_processor = DepthProImageProcessor.from_pretrained(model_path)
+ model = DepthProForDepthEstimation.from_pretrained(model_path).to(torch_device)
+ config = model.config
+
+ image = prepare_img()
+ inputs = image_processor(images=image, return_tensors="pt").to(torch_device)
+
+ # forward pass
+ with torch.no_grad():
+ outputs = model(**inputs)
+ predicted_depth = outputs.predicted_depth
+
+ # verify the predicted depth
+ n_fusion_blocks = len(config.intermediate_hook_ids) + len(config.scaled_images_ratios)
+ expected_depth_size = 2 ** (n_fusion_blocks + 1) * config.patch_size // config.patch_embeddings_size
+ expected_shape = torch.Size((1, expected_depth_size, expected_depth_size))
+ self.assertEqual(predicted_depth.shape, expected_shape)
+
+ expected_slice = torch.tensor(
+ [[1.0582, 1.1225, 1.1335], [1.1154, 1.1398, 1.1486], [1.1434, 1.1500, 1.1643]]
+ ).to(torch_device)
+
+ self.assertTrue(torch.allclose(outputs.predicted_depth[0, :3, :3], expected_slice, atol=1e-4))
+
+ def test_post_processing_depth_estimation(self):
+ model_path = "geetu040/DepthPro"
+ image_processor = DepthProImageProcessor.from_pretrained(model_path)
+ model = DepthProForDepthEstimation.from_pretrained(model_path)
+
+ image = prepare_img()
+ inputs = image_processor(images=image, return_tensors="pt")
+
+ # forward pass
+ with torch.no_grad():
+ outputs = model(**inputs)
+
+ outputs = image_processor.post_process_depth_estimation(
+ outputs,
+ target_sizes=[[image.height, image.width]],
+ )
+ predicted_depth = outputs[0]["predicted_depth"]
+ expected_shape = torch.Size((image.height, image.width))
+ self.assertTrue(predicted_depth.shape == expected_shape)