From 8ee6712f5ef56bc4a3d20f17adea1483ef10580d Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 14 Dec 2023 17:51:46 +0000 Subject: [PATCH] add new model-like --- docs/source/en/_toctree.yml | 2 + docs/source/en/model_doc/efficientsam.md | 57 + src/transformers/__init__.py | 28 + src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 3 + .../models/auto/image_processing_auto.py | 1 + src/transformers/models/auto/modeling_auto.py | 2 + .../models/auto/processing_auto.py | 1 + .../models/efficientsam/__init__.py | 65 + .../configuration_efficientsam.py | 311 ++++ ...vert_efficientsam_original_to_hf_format.py | 206 +++ .../efficientsam/modeling_efficientsam.py | 1436 +++++++++++++++++ .../test_modeling_efficientsam.py | 735 +++++++++ 13 files changed, 2848 insertions(+) create mode 100644 docs/source/en/model_doc/efficientsam.md create mode 100644 src/transformers/models/efficientsam/__init__.py create mode 100644 src/transformers/models/efficientsam/configuration_efficientsam.py create mode 100644 src/transformers/models/efficientsam/convert_efficientsam_original_to_hf_format.py create mode 100644 src/transformers/models/efficientsam/modeling_efficientsam.py create mode 100644 tests/models/efficientsam/test_modeling_efficientsam.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 09210a471e3acd..2d1fa495ebbf85 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -685,6 +685,8 @@ title: DePlot - local: model_doc/donut title: Donut + - local: model_doc/efficientsam + title: EfficientSam - local: model_doc/flava title: FLAVA - local: model_doc/git diff --git a/docs/source/en/model_doc/efficientsam.md b/docs/source/en/model_doc/efficientsam.md new file mode 100644 index 00000000000000..895a1f9fafe6eb --- /dev/null +++ b/docs/source/en/model_doc/efficientsam.md @@ -0,0 +1,57 @@ + + +# EfficientSam + +## Overview + +The EfficientSam model was proposed in []() by . + + +The abstract from the paper is the following: + +** + +Tips: + + + +This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/). +The original code can be found [here](). + + +## EfficientSamConfig + +[[autodoc]] EfficientSamConfig + +## EfficientSamVisionConfig + +[[autodoc]] EfficientSamVisionConfig + +## EfficientSamMaskDecoderConfig + +[[autodoc]] EfficientSamMaskDecoderConfig + +## EfficientSamPromptEncoderConfig + +[[autodoc]] EfficientSamPromptEncoderConfig + + +## EfficientSamModel + +[[autodoc]] EfficientSamModel + - forward + diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 98139511d239c5..7b9e7dbeed1edb 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -743,6 +743,14 @@ "SamPromptEncoderConfig", "SamVisionConfig", ], + "models.efficientsam": [ + "EFFICIENTSAM_PRETRAINED_CONFIG_ARCHIVE_MAP", + "EfficientSamConfig", + "EfficientSamMaskDecoderConfig", + + "EfficientSamPromptEncoderConfig", + "EfficientSamVisionConfig", + ], "models.seamless_m4t": [ "SEAMLESS_M4T_PRETRAINED_CONFIG_ARCHIVE_MAP", "SeamlessM4TConfig", @@ -3085,6 +3093,13 @@ "SamPreTrainedModel", ] ) + _import_structure["models.efficientsam"].extend( + [ + "EFFICIENTSAM_PRETRAINED_MODEL_ARCHIVE_LIST", + "EfficientSamModel", + "EfficientSamPreTrainedModel", + ] + ) _import_structure["models.seamless_m4t"].extend( [ "SEAMLESS_M4T_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -5404,6 +5419,14 @@ SamPromptEncoderConfig, SamVisionConfig, ) + from .models.efficientsam import ( + EFFICIENTSAM_PRETRAINED_CONFIG_ARCHIVE_MAP, + EfficientSamConfig, + EfficientSamMaskDecoderConfig, + + EfficientSamPromptEncoderConfig, + EfficientSamVisionConfig, + ) from .models.seamless_m4t import ( SEAMLESS_M4T_PRETRAINED_CONFIG_ARCHIVE_MAP, SeamlessM4TConfig, @@ -7433,6 +7456,11 @@ SamModel, SamPreTrainedModel, ) + from .models.efficientsam import ( + EFFICIENTSAM_PRETRAINED_MODEL_ARCHIVE_LIST, + EfficientSamModel, + EfficientSamPreTrainedModel, + ) # PyTorch model imports from .models.seamless_m4t import ( diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 319c8499319a3f..02826b5cef2eb5 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -187,6 +187,7 @@ roformer, rwkv, sam, + efficientsam, seamless_m4t, seamless_m4t_v2, segformer, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index b91226ac877897..c22d14657e0fdd 100755 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -194,6 +194,7 @@ ("roformer", "RoFormerConfig"), ("rwkv", "RwkvConfig"), ("sam", "SamConfig"), + ("efficientsam", "EfficientSamConfig"), ("seamless_m4t", "SeamlessM4TConfig"), ("seamless_m4t_v2", "SeamlessM4Tv2Config"), ("segformer", "SegformerConfig"), @@ -412,6 +413,7 @@ ("roformer", "ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("rwkv", "RWKV_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("sam", "SAM_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("efficientsam", "EFFICIENTSAM_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("seamless_m4t", "SEAMLESS_M4T_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("seamless_m4t_v2", "SEAMLESS_M4T_V2_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("segformer", "SEGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"), @@ -656,6 +658,7 @@ ("roformer", "RoFormer"), ("rwkv", "RWKV"), ("sam", "SAM"), + ("efficientsam", "EfficientEfficientSam"), ("seamless_m4t", "SeamlessM4T"), ("seamless_m4t_v2", "SeamlessM4Tv2"), ("segformer", "SegFormer"), diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index 446c9adf1b6dc3..1b645f30c01264 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -96,6 +96,7 @@ ("regnet", "ConvNextImageProcessor"), ("resnet", "ConvNextImageProcessor"), ("sam", "SamImageProcessor"), + ("efficientsam", "EfficientSamImageProcessor"), ("segformer", "SegformerImageProcessor"), ("swiftformer", "ViTImageProcessor"), ("swin", "ViTImageProcessor"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index e562bd28bdb3f3..57d5234b8ee304 100755 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -187,6 +187,7 @@ ("roformer", "RoFormerModel"), ("rwkv", "RwkvModel"), ("sam", "SamModel"), + ("efficientsam", "EfficientSamModel"), ("seamless_m4t", "SeamlessM4TModel"), ("seamless_m4t_v2", "SeamlessM4Tv2Model"), ("segformer", "SegformerModel"), @@ -1124,6 +1125,7 @@ MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict( [ ("sam", "SamModel"), + ("efficientsam", "EfficientSamModel"), ] ) diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index 93dc6ab6050bb9..a4a5c82eb1d78b 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -74,6 +74,7 @@ ("pix2struct", "Pix2StructProcessor"), ("pop2piano", "Pop2PianoProcessor"), ("sam", "SamProcessor"), + ("efficientsam", "EfficientSamProcessor"), ("seamless_m4t", "SeamlessM4TProcessor"), ("sew", "Wav2Vec2Processor"), ("sew-d", "Wav2Vec2Processor"), diff --git a/src/transformers/models/efficientsam/__init__.py b/src/transformers/models/efficientsam/__init__.py new file mode 100644 index 00000000000000..78989fd7d074cd --- /dev/null +++ b/src/transformers/models/efficientsam/__init__.py @@ -0,0 +1,65 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import ( + OptionalDependencyNotAvailable, + _LazyModule, + is_torch_available, +) + + +_import_structure = { + "configuration_efficientsam": [ + "EFFICIENTSAM_PRETRAINED_CONFIG_ARCHIVE_MAP", + "EfficientSamConfig", + "EfficientSamMaskDecoderConfig", + "EfficientSamPromptEncoderConfig", + "EfficientSamVisionConfig", + ], +} + + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_efficientsam"] = [ + "EFFICIENTSAM_PRETRAINED_MODEL_ARCHIVE_LIST", + "EfficientSamModel", + "EfficientSamPreTrainedModel", + ] +if TYPE_CHECKING: + from .configuration_efficientsam import ( + EFFICIENTSAM_PRETRAINED_CONFIG_ARCHIVE_MAP, + EfficientSamConfig, + EfficientSamMaskDecoderConfig, + EfficientSamPromptEncoderConfig, + EfficientSamVisionConfig, + ) + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_efficientsam import EFFICIENTSAM_PRETRAINED_MODEL_ARCHIVE_LIST, EfficientSamModel, EfficientSamPreTrainedModel + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/transformers/models/efficientsam/configuration_efficientsam.py b/src/transformers/models/efficientsam/configuration_efficientsam.py new file mode 100644 index 00000000000000..f2270f28b5998d --- /dev/null +++ b/src/transformers/models/efficientsam/configuration_efficientsam.py @@ -0,0 +1,311 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" EFFICIENTSAM model configuration""" + + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + +EFFICIENTSAM_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "efficientsam/efficientsam-ti": "https://huggingface.co/efficientsam/efficientsam-ti/resolve/main/config.json", +} + + + +class EfficientSamPromptEncoderConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`EfficientSamPromptEncoder`]. The [`EfficientSamPromptEncoder`] + module is used to encode the input 2D points and bounding boxes. Instantiating a configuration defaults will yield + a similar configuration to that of the EFFICIENTSAM-vit-h + [efficientsam/efficientsam-ti](https://huggingface.co/efficientsam/efficientsam-ti) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 256): + Dimensionality of the hidden states. + image_size (`int`, *optional*, defaults to 1024): + The expected output resolution of the image. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + mask_input_channels (`int`, *optional*, defaults to 16): + The number of channels to be fed to the `MaskDecoder` module. + num_point_embeddings (`int`, *optional*, defaults to 4): + The number of point embeddings to be used. + hidden_act (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function in the encoder and pooler. + """ + + def __init__( + self, + hidden_size=256, + image_size=1024, + patch_size=16, + mask_input_channels=16, + num_point_embeddings=4, + hidden_act="gelu", + layer_norm_eps=1e-6, + **kwargs, + ): + super().__init__(**kwargs) + self.hidden_size = hidden_size + self.image_size = image_size + self.patch_size = patch_size + self.image_embedding_size = image_size // patch_size + self.mask_input_channels = mask_input_channels + self.num_point_embeddings = num_point_embeddings + self.hidden_act = hidden_act + self.layer_norm_eps = layer_norm_eps + + +class EfficientSamMaskDecoderConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`EfficientSamMaskDecoder`]. It is used to instantiate a EFFICIENTSAM + mask decoder to the specified arguments, defining the model architecture. Instantiating a configuration defaults + will yield a similar configuration to that of the EFFICIENTSAM-vit-h + [efficientsam/efficientsam-ti](https://huggingface.co/efficientsam/efficientsam-ti) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 256): + Dimensionality of the hidden states. + hidden_act (`str`, *optional*, defaults to `"relu"`): + The non-linear activation function used inside the `EfficientSamMaskDecoder` module. + mlp_dim (`int`, *optional*, defaults to 2048): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 2): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer encoder. + attention_downefficientsample_rate (`int`, *optional*, defaults to 2): + The downefficientsampling rate of the attention layer. + num_multimask_outputs (`int`, *optional*, defaults to 3): + The number of outputs from the `EfficientSamMaskDecoder` module. In the Segment Anything paper, this is set to 3. + iou_head_depth (`int`, *optional*, defaults to 3): + The number of layers in the IoU head module. + iou_head_hidden_dim (`int`, *optional*, defaults to 256): + The dimensionality of the hidden states in the IoU head module. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + + """ + + def __init__( + self, + hidden_size=256, + hidden_act="relu", + mlp_dim=2048, + num_hidden_layers=2, + num_attention_heads=8, + attention_downefficientsample_rate=2, + num_multimask_outputs=3, + iou_head_depth=3, + iou_head_hidden_dim=256, + layer_norm_eps=1e-6, + **kwargs, + ): + super().__init__(**kwargs) + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.mlp_dim = mlp_dim + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.attention_downefficientsample_rate = attention_downefficientsample_rate + self.num_multimask_outputs = num_multimask_outputs + self.iou_head_depth = iou_head_depth + self.iou_head_hidden_dim = iou_head_hidden_dim + self.layer_norm_eps = layer_norm_eps + + +class EfficientSamVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`EfficientSamVisionModel`]. It is used to instantiate a EFFICIENTSAM + vision encoder according to the specified arguments, defining the model architecture. Instantiating a configuration + defaults will yield a similar configuration to that of the EFFICIENTSAM ViT-h + [efficientsam/efficientsam-ti](https://huggingface.co/efficientsam/efficientsam-ti) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + output_channels (`int`, *optional*, defaults to 256): + Dimensionality of the output channels in the Patch Encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + num_channels (`int`, *optional*, defaults to 3): + Number of channels in the input image. + image_size (`int`, *optional*, defaults to 1024): + Expected resolution. Target size of the resized input image. + patch_size (`int`, *optional*, defaults to 16): + Size of the patches to be extracted from the input image. + hidden_act (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 1e-10): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + qkv_bias (`bool`, *optional*, defaults to `True`): + Whether to add a bias to query, key, value projections. + mlp_ratio (`float`, *optional*, defaults to 4.0): + Ratio of mlp hidden dim to embedding dim. + use_abs_pos (`bool`, *optional*, defaults to `True`): + Whether to use absolute position embedding. + use_rel_pos (`bool`, *optional*, defaults to `True`): + Whether to use relative position embedding. + window_size (`int`, *optional*, defaults to 14): + Window size for relative position. + global_attn_indexes (`List[int]`, *optional*, defaults to `[2, 5, 8, 11]`): + The indexes of the global attention layers. + num_pos_feats (`int`, *optional*, defaults to 128): + The dimensionality of the position embedding. + mlp_dim (`int`, *optional*): + The dimensionality of the MLP layer in the Transformer encoder. If `None`, defaults to `mlp_ratio * + hidden_size`. + """ + + def __init__( + self, + hidden_size=768, + output_channels=256, + num_hidden_layers=12, + num_attention_heads=12, + num_channels=3, + image_size=1024, + patch_size=16, + hidden_act="gelu", + layer_norm_eps=1e-06, + attention_dropout=0.0, + initializer_range=1e-10, + qkv_bias=True, + mlp_ratio=4.0, + use_abs_pos=True, + use_rel_pos=True, + window_size=14, + global_attn_indexes=[2, 5, 8, 11], + num_pos_feats=128, + mlp_dim=None, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.output_channels = output_channels + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.image_size = image_size + self.patch_size = patch_size + self.hidden_act = hidden_act + self.layer_norm_eps = layer_norm_eps + self.attention_dropout = attention_dropout + self.initializer_range = initializer_range + self.qkv_bias = qkv_bias + self.mlp_ratio = mlp_ratio + self.use_abs_pos = use_abs_pos + self.use_rel_pos = use_rel_pos + self.window_size = window_size + self.global_attn_indexes = global_attn_indexes + self.num_pos_feats = num_pos_feats + self.mlp_dim = int(hidden_size * mlp_ratio) if mlp_dim is None else mlp_dim + + +class EfficientSamConfig(PretrainedConfig): + r""" + [`EfficientSamConfig`] is the configuration class to store the configuration of a [`EfficientSamModel`]. It is used to instantiate a + EFFICIENTSAM model according to the specified arguments, defining the vision model, prompt-encoder model and mask decoder + configs. Instantiating a configuration with the defaults will yield a similar configuration to that of the + EFFICIENTSAM-ViT-H [efficientsam/efficientsam-ti](https://huggingface.co/efficientsam/efficientsam-ti) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vision_config (Union[`dict`, `EfficientSamVisionConfig`], *optional*): + Dictionary of configuration options used to initialize [`EfficientSamVisionConfig`]. + prompt_encoder_config (Union[`dict`, `EfficientSamPromptEncoderConfig`], *optional*): + Dictionary of configuration options used to initialize [`EfficientSamPromptEncoderConfig`]. + mask_decoder_config (Union[`dict`, `EfficientSamMaskDecoderConfig`], *optional*): + Dictionary of configuration options used to initialize [`EfficientSamMaskDecoderConfig`]. + + kwargs (*optional*): + Dictionary of keyword arguments. + + Example: + + ```python + >>> from transformers import ( + ... EfficientSamVisionConfig, + ... EfficientSamPromptEncoderConfig, + ... EfficientSamMaskDecoderConfig, + ... EfficientSamModel, + ... ) + + >>> # Initializing a EfficientSamConfig with `"efficientsam/efficientsam-ti"` style configuration + >>> configuration = EfficientSamConfig() + + >>> # Initializing a EfficientSamModel (with random weights) from the `"efficientsam/efficientsam-ti"` style configuration + >>> model = EfficientSamModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + + >>> # We can also initialize a EfficientSamConfig from a EfficientSamVisionConfig, EfficientSamPromptEncoderConfig, and EfficientSamMaskDecoderConfig + + >>> # Initializing EFFICIENTSAM vision, EFFICIENTSAM Q-Former and language model configurations + >>> vision_config = EfficientSamVisionConfig() + >>> prompt_encoder_config = EfficientSamPromptEncoderConfig() + >>> mask_decoder_config = EfficientSamMaskDecoderConfig() + + >>> config = EfficientSamConfig(vision_config, prompt_encoder_config, mask_decoder_config) + ```""" + + model_type = "efficientsam" + + def __init__( + self, + vision_config=None, + prompt_encoder_config=None, + mask_decoder_config=None, + initializer_range=0.02, + **kwargs, + ): + super().__init__(**kwargs) + vision_config = vision_config if vision_config is not None else {} + prompt_encoder_config = prompt_encoder_config if prompt_encoder_config is not None else {} + mask_decoder_config = mask_decoder_config if mask_decoder_config is not None else {} + + if isinstance(vision_config, EfficientSamVisionConfig): + vision_config = vision_config.to_dict() + if isinstance(prompt_encoder_config, EfficientSamPromptEncoderConfig): + prompt_encoder_config = prompt_encoder_config.to_dict() + if isinstance(mask_decoder_config, EfficientSamMaskDecoderConfig): + mask_decoder_config = mask_decoder_config.to_dict() + + self.vision_config = EfficientSamVisionConfig(**vision_config) + self.prompt_encoder_config = EfficientSamPromptEncoderConfig(**prompt_encoder_config) + self.mask_decoder_config = EfficientSamMaskDecoderConfig(**mask_decoder_config) + self.initializer_range = initializer_range diff --git a/src/transformers/models/efficientsam/convert_efficientsam_original_to_hf_format.py b/src/transformers/models/efficientsam/convert_efficientsam_original_to_hf_format.py new file mode 100644 index 00000000000000..bd4ea70daa3fcf --- /dev/null +++ b/src/transformers/models/efficientsam/convert_efficientsam_original_to_hf_format.py @@ -0,0 +1,206 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Convert EFFICIENTSAM checkpoints from the original repository. +""" +import argparse +import re + +import numpy as np +import requests +import torch +from huggingface_hub import hf_hub_download +from PIL import Image + +from transformers import ( + EfficientSamConfig, + SamImageProcessor, + EfficientSamModel, + SamProcessor, + EfficientSamVisionConfig, +) + + +KEYS_TO_MODIFY_MAPPING = { + "iou_prediction_head.layers.0": "iou_prediction_head.proj_in", + "iou_prediction_head.layers.1": "iou_prediction_head.layers.0", + "iou_prediction_head.layers.2": "iou_prediction_head.proj_out", + "mask_decoder.output_upscaling.0": "mask_decoder.upscale_conv1", + "mask_decoder.output_upscaling.1": "mask_decoder.upscale_layer_norm", + "mask_decoder.output_upscaling.3": "mask_decoder.upscale_conv2", + "mask_downscaling.0": "mask_embed.conv1", + "mask_downscaling.1": "mask_embed.layer_norm1", + "mask_downscaling.3": "mask_embed.conv2", + "mask_downscaling.4": "mask_embed.layer_norm2", + "mask_downscaling.6": "mask_embed.conv3", + "point_embeddings": "point_embed", + "pe_layer.positional_encoding_gaussian_matrix": "shared_embedding.positional_embedding", + "image_encoder": "vision_encoder", + "neck.0": "neck.conv1", + "neck.1": "neck.layer_norm1", + "neck.2": "neck.conv2", + "neck.3": "neck.layer_norm2", + "patch_embed.proj": "patch_embed.projection", + ".norm": ".layer_norm", + "blocks": "layers", +} + + +def replace_keys(state_dict): + model_state_dict = {} + state_dict.pop("pixel_mean", None) + state_dict.pop("pixel_std", None) + + output_hypernetworks_mlps_pattern = r".*.output_hypernetworks_mlps.(\d+).layers.(\d+).*" + + for key, value in state_dict.items(): + for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items(): + if key_to_modify in key: + key = key.replace(key_to_modify, new_key) + + if re.match(output_hypernetworks_mlps_pattern, key): + layer_nb = int(re.match(output_hypernetworks_mlps_pattern, key).group(2)) + if layer_nb == 0: + key = key.replace("layers.0", "proj_in") + elif layer_nb == 1: + key = key.replace("layers.1", "layers.0") + elif layer_nb == 2: + key = key.replace("layers.2", "proj_out") + + model_state_dict[key] = value + + model_state_dict["shared_image_embedding.positional_embedding"] = model_state_dict[ + "prompt_encoder.shared_embedding.positional_embedding" + ] + + return model_state_dict + + +def convert_efficientsam_checkpoint(model_name, pytorch_dump_folder, push_to_hub, model_hub_id="ybelkada/segment-anything"): + checkpoint_path = hf_hub_download(model_hub_id, f"checkpoints/{model_name}.pth") + + if "efficientsam_vit_b" in model_name: + config = EfficientSamConfig() + elif "efficientsam_vit_l" in model_name: + vision_config = EfficientSamVisionConfig( + hidden_size=1024, + num_hidden_layers=24, + num_attention_heads=16, + global_attn_indexes=[5, 11, 17, 23], + ) + + config = EfficientSamConfig( + vision_config=vision_config, + ) + elif "efficientsam_vit_h" in model_name: + vision_config = EfficientSamVisionConfig( + hidden_size=1280, + num_hidden_layers=32, + num_attention_heads=16, + global_attn_indexes=[7, 15, 23, 31], + ) + + config = EfficientSamConfig( + vision_config=vision_config, + ) + + state_dict = torch.load(checkpoint_path, map_location="cpu") + state_dict = replace_keys(state_dict) + + image_processor = SamImageProcessor() + + processor = SamProcessor(image_processor=image_processor) + hf_model = EfficientSamModel(config) + + hf_model.load_state_dict(state_dict) + hf_model = hf_model.to("cuda") + + img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png" + raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") + + input_points = [[[400, 650]]] + input_labels = [[1]] + + inputs = processor(images=np.array(raw_image), return_tensors="pt").to("cuda") + + with torch.no_grad(): + output = hf_model(**inputs) + scores = output.iou_scores.squeeze() + + if model_name == "efficientsam_vit_h_4b8939": + assert scores[-1].item() == 0.579890251159668 + + inputs = processor( + images=np.array(raw_image), input_points=input_points, input_labels=input_labels, return_tensors="pt" + ).to("cuda") + + with torch.no_grad(): + output = hf_model(**inputs) + scores = output.iou_scores.squeeze() + + assert scores[-1].item() == 0.9712603092193604 + + input_boxes = ((75, 275, 1725, 850),) + + inputs = processor(images=np.array(raw_image), input_boxes=input_boxes, return_tensors="pt").to("cuda") + + with torch.no_grad(): + output = hf_model(**inputs) + scores = output.iou_scores.squeeze() + + assert scores[-1].item() == 0.8686015605926514 + + # Test with 2 points and 1 image. + input_points = [[[400, 650], [800, 650]]] + input_labels = [[1, 1]] + + inputs = processor( + images=np.array(raw_image), input_points=input_points, input_labels=input_labels, return_tensors="pt" + ).to("cuda") + + with torch.no_grad(): + output = hf_model(**inputs) + scores = output.iou_scores.squeeze() + + assert scores[-1].item() == 0.9936047792434692 + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + choices = ["efficientsam_vit_b_01ec64", "efficientsam_vit_h_4b8939", "efficientsam_vit_l_0b3195"] + parser.add_argument( + "--model_name", + default="efficientsam_vit_h_4b8939", + choices=choices, + type=str, + help="Path to hf config.json of model to convert", + ) + parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.") + parser.add_argument( + "--push_to_hub", + action="store_true", + help="Whether to push the model and processor to the hub after converting", + ) + parser.add_argument( + "--model_hub_id", + default="ybelkada/segment-anything", + choices=choices, + type=str, + help="Path to hf config.json of model to convert", + ) + + args = parser.parse_args() + + convert_efficientsam_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub, args.model_hub_id) diff --git a/src/transformers/models/efficientsam/modeling_efficientsam.py b/src/transformers/models/efficientsam/modeling_efficientsam.py new file mode 100644 index 00000000000000..8d6104ea31a365 --- /dev/null +++ b/src/transformers/models/efficientsam/modeling_efficientsam.py @@ -0,0 +1,1436 @@ +# coding=utf-8 +# Copyright 2023 The Meta AI Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch EFFICIENTSAM model.""" + +import collections +import math +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import Tensor, nn + +from ...activations import ACT2FN +from ...modeling_outputs import BaseModelOutput +from ...modeling_utils import PreTrainedModel +from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging +from .configuration_efficientsam import EfficientSamConfig, EfficientSamMaskDecoderConfig, EfficientSamPromptEncoderConfig, EfficientSamVisionConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "EfficientSamConfig" +_CHECKPOINT_FOR_DOC = "efficientsam/efficientsam-ti" + +EFFICIENTSAM_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "efficientsam/efficientsam-ti", + # See all EfficientSam models at https://huggingface.co/models?filter=efficientsam +] + + + +@dataclass +# Copied from transformers.models.sam.modeling_sam.SamVisionEncoderOutput with Sam->EfficientSam,sam->efficientsam +class EfficientSamVisionEncoderOutput(ModelOutput): + """ + Base class for efficientsam vision model's outputs that also contains image embeddings obtained by applying the projection + layer to the pooler_output. + + Args: + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The image embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + image_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +# Copied from transformers.models.sam.modeling_sam.SamImageSegmentationOutput with Sam->EfficientSam +class EfficientSamImageSegmentationOutput(ModelOutput): + """ + Base class for Segment-Anything model's output + + Args: + iou_scores (`torch.FloatTensor` of shape `(batch_size, num_masks)`): + The iou scores of the predicted masks. + pred_masks (`torch.FloatTensor` of shape `(batch_size, num_masks, height, width)`): + The predicted low resolutions masks. Needs to be post-processed by the processor + vision_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the vision model at the output of each layer plus the optional initial embedding outputs. + vision_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + mask_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + iou_scores: torch.FloatTensor = None + pred_masks: torch.FloatTensor = None + vision_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + vision_attentions: Optional[Tuple[torch.FloatTensor]] = None + mask_decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + + +# Copied from transformers.models.sam.modeling_sam.SamPatchEmbeddings with Sam->EfficientSam +class EfficientSamPatchEmbeddings(nn.Module): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config): + super().__init__() + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + + self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) + + def forward(self, pixel_values): + batch_size, num_channels, height, width = pixel_values.shape + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + if height != self.image_size[0] or width != self.image_size[1]: + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})." + ) + embeddings = self.projection(pixel_values).permute(0, 2, 3, 1) + return embeddings + + +# Copied from transformers.models.sam.modeling_sam.SamMLPBlock with Sam->EfficientSam +class EfficientSamMLPBlock(nn.Module): + def __init__(self, config): + super().__init__() + self.lin1 = nn.Linear(config.hidden_size, config.mlp_dim) + self.lin2 = nn.Linear(config.mlp_dim, config.hidden_size) + self.act = ACT2FN[config.hidden_act] + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.lin1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.lin2(hidden_states) + return hidden_states + + +# Copied from transformers.models.convnext.modeling_convnext.ConvNextLayerNorm with ConvNext->EfficientSam +class EfficientSamLayerNorm(nn.Module): + r"""LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height, + width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width). + """ + + def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): + super().__init__() + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.eps = eps + self.data_format = data_format + if self.data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError(f"Unsupported data format: {self.data_format}") + self.normalized_shape = (normalized_shape,) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.data_format == "channels_last": + x = torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + elif self.data_format == "channels_first": + input_dtype = x.dtype + x = x.float() + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = x.to(dtype=input_dtype) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +# Copied from transformers.models.sam.modeling_sam.SamAttention with SAM->EFFICIENTSAM,Sam->EfficientSam,sam->efficientsam +class EfficientSamAttention(nn.Module): + """ + EFFICIENTSAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and + values. + """ + + def __init__(self, config, downefficientsample_rate=None): + super().__init__() + self.hidden_size = config.hidden_size + + downefficientsample_rate = config.attention_downefficientsample_rate if downefficientsample_rate is None else downefficientsample_rate + + self.internal_dim = config.hidden_size // downefficientsample_rate + self.num_attention_heads = config.num_attention_heads + if self.internal_dim % config.num_attention_heads != 0: + raise ValueError("num_attention_heads must divide hidden_size.") + + self.q_proj = nn.Linear(self.hidden_size, self.internal_dim) + self.k_proj = nn.Linear(self.hidden_size, self.internal_dim) + self.v_proj = nn.Linear(self.hidden_size, self.internal_dim) + self.out_proj = nn.Linear(self.internal_dim, self.hidden_size) + + def _separate_heads(self, hidden_states: Tensor, num_attention_heads: int) -> Tensor: + batch, point_batch_size, n_tokens, channel = hidden_states.shape + c_per_head = channel // num_attention_heads + hidden_states = hidden_states.reshape(batch * point_batch_size, n_tokens, num_attention_heads, c_per_head) + return hidden_states.transpose(1, 2) + + def _recombine_heads(self, hidden_states: Tensor, point_batch_size: int) -> Tensor: + batch, n_heads, n_tokens, c_per_head = hidden_states.shape + hidden_states = hidden_states.transpose(1, 2) + return hidden_states.reshape(batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head) + + def forward(self, query: Tensor, key: Tensor, value: Tensor, attention_similarity: Tensor = None) -> Tensor: + # Input projections + query = self.q_proj(query) + key = self.k_proj(key) + value = self.v_proj(value) + + point_batch_size = query.shape[1] + # Separate into heads + query = self._separate_heads(query, self.num_attention_heads) + key = self._separate_heads(key, self.num_attention_heads) + value = self._separate_heads(value, self.num_attention_heads) + + # EfficientSamAttention + _, _, _, c_per_head = query.shape + attn = query @ key.permute(0, 1, 3, 2) # batch_size * point_batch_size x N_heads x N_tokens x N_tokens + attn = attn / math.sqrt(c_per_head) + attn = torch.softmax(attn, dim=-1) + + if attention_similarity is not None: + attn = attn + attention_similarity + attn = torch.softmax(attn, dim=-1) + + # Get output + out = attn @ value + out = self._recombine_heads(out, point_batch_size) + out = self.out_proj(out) + + return out + + +# Copied from transformers.models.sam.modeling_sam.SamTwoWayAttentionBlock with Sam->EfficientSam,sam->efficientsam +class EfficientSamTwoWayAttentionBlock(nn.Module): + def __init__(self, config, attention_downefficientsample_rate: int = 2, skip_first_layer_pe: bool = False): + """ + A transformer block with four layers: + (1) self-attention of sparse inputs (2) cross attention of sparse inputs -> dense inputs (3) mlp block on + sparse inputs (4) cross attention of dense inputs -> sparse inputs + + Arguments: + config (`EfficientSamMaskDecoderConfig`): + The configuration file used to instantiate the block + attention_downefficientsample_rate (*optionalk*, int, defaults to 2): + The downefficientsample ratio of the block used to reduce the inner dim of the attention. + skip_first_layer_pe (*optional*, bool, defaults to `False`): + Whether or not to skip the addition of the query_point_embedding on the first layer. + """ + super().__init__() + + self.hidden_size = config.hidden_size + self.layer_norm_eps = config.layer_norm_eps + + self.self_attn = EfficientSamAttention(config, downefficientsample_rate=1) + self.layer_norm1 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) + + self.cross_attn_token_to_image = EfficientSamAttention(config, downefficientsample_rate=attention_downefficientsample_rate) + self.layer_norm2 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) + + self.mlp = EfficientSamMLPBlock(config) + self.layer_norm3 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) + + self.layer_norm4 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) + self.cross_attn_image_to_token = EfficientSamAttention(config, downefficientsample_rate=attention_downefficientsample_rate) + + self.skip_first_layer_pe = skip_first_layer_pe + + def forward( + self, + queries: Tensor, + keys: Tensor, + query_point_embedding: Tensor, + key_point_embedding: Tensor, + attention_similarity: Tensor, + output_attentions: bool = False, + ): + # Self attention block + if self.skip_first_layer_pe: + queries = self.self_attn(query=queries, key=queries, value=queries) + else: + query = queries + query_point_embedding + attn_out = self.self_attn(query=query, key=query, value=queries) + queries = queries + attn_out + queries = self.layer_norm1(queries) + + # Cross attention block, tokens attending to image embedding + query = queries + query_point_embedding + key = keys + key_point_embedding + + attn_out = self.cross_attn_token_to_image( + query=query, key=key, value=keys, attention_similarity=attention_similarity + ) + queries = queries + attn_out + + queries = self.layer_norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.layer_norm3(queries) + + # Cross attention block, image embedding attending to tokens + query = queries + query_point_embedding + key = keys + key_point_embedding + + attn_out = self.cross_attn_image_to_token(query=key, key=query, value=queries) + keys = keys + attn_out + + keys = self.layer_norm4(keys) + + outputs = (queries, keys) + + if output_attentions: + outputs = outputs + (attn_out,) + else: + outputs = outputs + (None,) + + return outputs + + +# Copied from transformers.models.sam.modeling_sam.SamTwoWayTransformer with Sam->EfficientSam +class EfficientSamTwoWayTransformer(nn.Module): + def __init__(self, config: EfficientSamMaskDecoderConfig): + super().__init__() + self.config = config + + self.num_hidden_layers = config.num_hidden_layers + self.layers = nn.ModuleList() + + for i in range(self.num_hidden_layers): + self.layers.append(EfficientSamTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0))) + + self.final_attn_token_to_image = EfficientSamAttention(config) + self.layer_norm_final_attn = nn.LayerNorm(config.hidden_size) + + def forward( + self, + point_embeddings: Tensor, + image_embeddings: Tensor, + image_positional_embeddings: Tensor, + attention_similarity: Tensor, + target_embedding=None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + all_attentions = () + + if image_embeddings is None: + raise ValueError("You have to specify an image_embedding") + + image_embeddings = image_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1) + image_positional_embeddings = image_positional_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1) + + # Prepare queries + queries = point_embeddings + keys = image_embeddings + + # Apply transformer blocks and final layernorm + for layer in self.layers: + if target_embedding is not None: + queries += target_embedding + + queries, keys, attention_outputs = layer( + queries=queries, + keys=keys, + query_point_embedding=point_embeddings, + key_point_embedding=image_positional_embeddings, + attention_similarity=attention_similarity, + output_attentions=output_attentions, + ) + + if output_attentions: + all_attentions = all_attentions + (attention_outputs,) + + # Apply the final attenion layer from the points to the image + query = queries + point_embeddings + key = keys + image_positional_embeddings + + attn_out = self.final_attn_token_to_image(query=query, key=key, value=keys) + + queries = queries + attn_out + queries = self.layer_norm_final_attn(queries) + return queries, keys, all_attentions + + +# Copied from transformers.models.sam.modeling_sam.SamFeedForward with Sam->EfficientSam +class EfficientSamFeedForward(nn.Module): + def __init__( + self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, sigmoid_output: bool = False + ): + super().__init__() + self.num_layers = num_layers + self.activation = nn.ReLU() + self.proj_in = nn.Linear(input_dim, hidden_dim) + self.proj_out = nn.Linear(hidden_dim, output_dim) + self.layers = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers - 2)]) + self.sigmoid_output = sigmoid_output + + def forward(self, hidden_states): + hidden_states = self.proj_in(hidden_states) + hidden_states = self.activation(hidden_states) + for layer in self.layers: + hidden_states = self.activation(layer(hidden_states)) + + hidden_states = self.proj_out(hidden_states) + if self.sigmoid_output: + hidden_states = F.sigmoid(hidden_states) + return hidden_states + + +# Copied from transformers.models.sam.modeling_sam.SamMaskDecoder with Sam->EfficientSam +class EfficientSamMaskDecoder(nn.Module): + def __init__(self, config: EfficientSamMaskDecoderConfig): + super().__init__() + + self.hidden_size = config.hidden_size + + self.num_multimask_outputs = config.num_multimask_outputs + self.num_mask_tokens = config.num_multimask_outputs + 1 + + self.iou_token = nn.Embedding(1, self.hidden_size) + self.mask_tokens = nn.Embedding(self.num_mask_tokens, self.hidden_size) + + self.transformer = EfficientSamTwoWayTransformer(config) + + # should we create a new class for this? + self.upscale_conv1 = nn.ConvTranspose2d(self.hidden_size, self.hidden_size // 4, kernel_size=2, stride=2) + self.upscale_conv2 = nn.ConvTranspose2d(self.hidden_size // 4, self.hidden_size // 8, kernel_size=2, stride=2) + self.upscale_layer_norm = EfficientSamLayerNorm(self.hidden_size // 4, data_format="channels_first") + self.activation = nn.GELU() + + mlps_list = [] + for _ in range(self.num_mask_tokens): + mlps_list += [EfficientSamFeedForward(self.hidden_size, self.hidden_size, self.hidden_size // 8, 3)] + self.output_hypernetworks_mlps = nn.ModuleList(mlps_list) + + self.iou_prediction_head = EfficientSamFeedForward( + self.hidden_size, config.iou_head_hidden_dim, self.num_mask_tokens, config.iou_head_depth + ) + + def forward( + self, + image_embeddings: torch.Tensor, + image_positional_embeddings: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + multimask_output: bool, + output_attentions: Optional[bool] = None, + attention_similarity: torch.Tensor = None, + target_embedding: torch.Tensor = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predict masks given image and prompt embeddings. + + Args: + image_embeddings (`torch.Tensor`): + the embeddings from the image encoder + image_positional_embedding (`torch.Tensor`): + positional encoding with the shape of image_embeddings + sparse_prompt_embeddings (`torch.Tensor`): + The embeddings of the points and boxes + dense_prompt_embeddings (`torch.Tensor`): + the embeddings of the mask inputs + multimask_output (bool): + Whether to return multiple masks or a single mask. + output_attentions (bool, *optional*): + Whether or not to return the attentions tensors of all attention layers. + """ + batch_size, num_channels, height, width = image_embeddings.shape + point_batch_size = sparse_prompt_embeddings.shape[1] + # Concatenate output tokens + output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) + output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1) + + if sparse_prompt_embeddings.sum().item() != 0: + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=2) + else: + tokens = output_tokens + point_embeddings = tokens.to(self.iou_token.weight.dtype) + + # Expand per-image data in batch direction to be per-point + image_embeddings = image_embeddings + dense_prompt_embeddings + image_embeddings = image_embeddings.repeat_interleave(point_batch_size, 0) + image_positional_embeddings = image_positional_embeddings.repeat_interleave(point_batch_size, 0) + + # Run the transformer, image_positional_embedding are consumed + point_embedding, image_embeddings, attentions = self.transformer( + point_embeddings=point_embeddings, + image_embeddings=image_embeddings, + image_positional_embeddings=image_positional_embeddings, + attention_similarity=attention_similarity, + target_embedding=target_embedding, + output_attentions=output_attentions, + ) + iou_token_out = point_embedding[:, :, 0, :] + mask_tokens_out = point_embedding[:, :, 1 : (1 + self.num_mask_tokens), :] + + # Upscale mask embeddings and predict masks using the mask tokens + image_embeddings = image_embeddings.transpose(2, 3).reshape( + batch_size * point_batch_size, num_channels, height, width + ) + + upscaled_embedding = self.upscale_conv1(image_embeddings) + upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding)) + upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding)) + + hyper_in_list = [] + for i in range(self.num_mask_tokens): + current_mlp = self.output_hypernetworks_mlps[i] + hyper_in_list += [current_mlp(mask_tokens_out[:, :, i, :])] + hyper_in = torch.stack(hyper_in_list, dim=2) + + _, num_channels, height, width = upscaled_embedding.shape + upscaled_embedding = upscaled_embedding.reshape(batch_size, point_batch_size, num_channels, height * width) + masks = (hyper_in @ upscaled_embedding).reshape(batch_size, point_batch_size, -1, height, width) + + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + + # Select the correct mask or masks for output + if multimask_output: + mask_slice = slice(1, None) + else: + mask_slice = slice(0, 1) + masks = masks[:, :, mask_slice, :, :] + iou_pred = iou_pred[:, :, mask_slice] + + outputs = (masks, iou_pred) + + if output_attentions: + outputs = outputs + (attentions,) + else: + outputs = outputs + (None,) + + return outputs + + +# Copied from transformers.models.sam.modeling_sam.SamPositionalEmbedding with Sam->EfficientSam +class EfficientSamPositionalEmbedding(nn.Module): + def __init__(self, config): + super().__init__() + self.scale = config.hidden_size // 2 + self.register_buffer("positional_embedding", self.scale * torch.randn((2, config.num_pos_feats))) + + def forward(self, input_coords, input_shape=None): + """Positionally encode points that are normalized to [0,1].""" + coordinates = input_coords.clone() + + if input_shape is not None: + coordinates[:, :, :, 0] = coordinates[:, :, :, 0] / input_shape[1] + coordinates[:, :, :, 1] = coordinates[:, :, :, 1] / input_shape[0] + + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coordinates = 2 * coordinates - 1 + coordinates = coordinates.to(self.positional_embedding.dtype) + coordinates = coordinates @ self.positional_embedding + coordinates = 2 * np.pi * coordinates + # outputs d_1 x ... x d_n x channel shape + return torch.cat([torch.sin(coordinates), torch.cos(coordinates)], dim=-1) + + +# Copied from transformers.models.sam.modeling_sam.SamMaskEmbedding with Sam->EfficientSam +class EfficientSamMaskEmbedding(nn.Module): + def __init__(self, config: EfficientSamPromptEncoderConfig): + super().__init__() + self.mask_input_channels = config.mask_input_channels // 4 + self.activation = ACT2FN[config.hidden_act] + self.conv1 = nn.Conv2d(1, self.mask_input_channels, kernel_size=2, stride=2) + self.conv2 = nn.Conv2d(self.mask_input_channels, config.mask_input_channels, kernel_size=2, stride=2) + self.conv3 = nn.Conv2d(config.mask_input_channels, config.hidden_size, kernel_size=1) + self.layer_norm1 = EfficientSamLayerNorm( + self.mask_input_channels, eps=config.layer_norm_eps, data_format="channels_first" + ) + self.layer_norm2 = EfficientSamLayerNorm( + self.mask_input_channels * 4, eps=config.layer_norm_eps, data_format="channels_first" + ) + + def forward(self, masks): + hidden_states = self.conv1(masks) + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.activation(hidden_states) + + hidden_states = self.conv2(hidden_states) + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.activation(hidden_states) + dense_embeddings = self.conv3(hidden_states) + return dense_embeddings + + +# Copied from transformers.models.sam.modeling_sam.SamPromptEncoder with Sam->EfficientSam +class EfficientSamPromptEncoder(nn.Module): + def __init__(self, config: EfficientSamPromptEncoderConfig, shared_patch_embedding): + super().__init__() + self.shared_embedding = shared_patch_embedding + self.mask_embed = EfficientSamMaskEmbedding(config) + self.no_mask_embed = nn.Embedding(1, config.hidden_size) + + self.image_embedding_size = (config.image_embedding_size, config.image_embedding_size) + self.input_image_size = config.image_size + + self.point_embed = nn.ModuleList( + [nn.Embedding(1, config.hidden_size) for i in range(config.num_point_embeddings)] + ) + self.hidden_size = config.hidden_size + self.not_a_point_embed = nn.Embedding(1, config.hidden_size) + + def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -> torch.Tensor: + """Embeds point prompts.""" + points = points + 0.5 # Shift to center of pixel + if pad: + target_point_shape = (points.shape[0], points.shape[1], 1, points.shape[-1]) + target_labels_shape = (points.shape[0], points.shape[1], 1) + padding_point = torch.zeros(target_point_shape, device=points.device) + padding_label = -torch.ones(target_labels_shape, device=labels.device) + points = torch.cat([points, padding_point], dim=2) + labels = torch.cat([labels, padding_label], dim=2) + input_shape = (self.input_image_size, self.input_image_size) + point_embedding = self.shared_embedding(points, input_shape) + + # torch.where and expanding the labels tensor is required by the ONNX export + point_embedding = torch.where(labels[..., None] == -1, self.not_a_point_embed.weight, point_embedding) + + # This is required for the ONNX export. The dtype, device need to be explicitely + # specificed as otherwise torch.onnx.export interprets as double + point_embedding = torch.where( + labels[..., None] != -10, + point_embedding, + torch.tensor(0.0, dtype=point_embedding.dtype, device=point_embedding.device), + ) + + point_embedding = torch.where( + (labels == 0)[:, :, :, None], + point_embedding + self.point_embed[0].weight[None, None, :, :], + point_embedding, + ) + + point_embedding = torch.where( + (labels == 1)[:, :, :, None], + point_embedding + self.point_embed[1].weight[None, None, :, :], + point_embedding, + ) + + return point_embedding + + def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: + """Embeds box prompts.""" + boxes = boxes + 0.5 # Shift to center of pixel + batch_size, nb_boxes = boxes.shape[:2] + coords = boxes.reshape(batch_size, nb_boxes, 2, 2) + input_shape = (self.input_image_size, self.input_image_size) + corner_embedding = self.shared_embedding(coords, input_shape) + corner_embedding[:, :, 0, :] += self.point_embed[2].weight + corner_embedding[:, :, 1, :] += self.point_embed[3].weight + return corner_embedding + + def forward( + self, + input_points: Optional[Tuple[torch.Tensor, torch.Tensor]], + input_labels: Optional[torch.Tensor], + input_boxes: Optional[torch.Tensor], + input_masks: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Embeds different types of prompts, returning both sparse and dense embeddings. + + Args: + points (`torch.Tensor`, *optional*): + point coordinates and labels to embed. + boxes (`torch.Tensor`, *optional*): + boxes to embed + masks (`torch.Tensor`, *optional*): + masks to embed + """ + sparse_embeddings = None + batch_size = 1 + target_device = self.shared_embedding.positional_embedding.device + if input_points is not None: + batch_size, point_batch_size = input_points.shape[:2] + if input_labels is None: + raise ValueError("If points are provided, labels must also be provided.") + point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None)) + sparse_embeddings = point_embeddings + if input_boxes is not None: + batch_size = input_boxes.shape[0] + box_embeddings = self._embed_boxes(input_boxes) + if sparse_embeddings is None: + sparse_embeddings = box_embeddings + else: + sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=2) + if input_masks is not None: + dense_embeddings = self.mask_embed(input_masks) + else: + dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + batch_size, -1, self.image_embedding_size[0], self.image_embedding_size[1] + ) + + if sparse_embeddings is None: + sparse_embeddings = torch.zeros((batch_size, 1, 1, self.hidden_size), device=target_device) + + return sparse_embeddings, dense_embeddings + + +# Copied from transformers.models.sam.modeling_sam.SamVisionAttention with Sam->EfficientSam +class EfficientSamVisionAttention(nn.Module): + """Multi-head Attention block with relative position embeddings.""" + + def __init__(self, config, window_size): + super().__init__() + input_size = ( + (config.image_size // config.patch_size, config.image_size // config.patch_size) + if window_size == 0 + else (window_size, window_size) + ) + + self.num_attention_heads = config.num_attention_heads + head_dim = config.hidden_size // config.num_attention_heads + self.scale = head_dim**-0.5 + self.dropout = config.attention_dropout + + self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.qkv_bias) + self.proj = nn.Linear(config.hidden_size, config.hidden_size) + + self.use_rel_pos = config.use_rel_pos + if self.use_rel_pos: + if input_size is None: + raise ValueError("Input size must be provided if using relative positional encoding.") + + # initialize relative positional embeddings + self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) + + def get_rel_pos(self, q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + + Args: + q_size (int): + size of the query. + k_size (int): + size of key k. + rel_pos (`torch.Tensor`): + relative position embeddings (L, channel). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode="linear", + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + def add_decomposed_rel_pos( + self, + attn: torch.Tensor, + query: torch.Tensor, + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + q_size: Tuple[int, int], + k_size: Tuple[int, int], + ) -> torch.Tensor: + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py + + Args: + attn (`torch.Tensor`): + attention map. + query (`torch.Tensor`): + query q in the attention layer with shape (batch_size, query_height * query_width, channel). + rel_pos_h (`torch.Tensor`): + relative position embeddings (Lh, channel) for height axis. + rel_pos_w (`torch.Tensor`): + relative position embeddings (Lw, channel) for width axis. + q_size (tuple): + spatial sequence size of query q with (query_height, query_width). + k_size (tuple): + spatial sequence size of key k with (key_height, key_width). + + Returns: + attn (`torch.Tensor`): + attention map with added relative positional embeddings. + """ + query_height, query_width = q_size + key_height, key_width = k_size + relative_position_height = self.get_rel_pos(query_height, key_height, rel_pos_h) + relative_position_width = self.get_rel_pos(query_width, key_width, rel_pos_w) + + batch_size, _, dim = query.shape + reshaped_query = query.reshape(batch_size, query_height, query_width, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height) + rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width) + attn = attn.reshape(batch_size, query_height, query_width, key_height, key_width) + attn = attn + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] + attn = attn.reshape(batch_size, query_height * query_width, key_height * key_width) + return attn + + def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor: + batch_size, height, width, _ = hidden_states.shape + # qkv with shape (3, batch_size, nHead, height * width, channel) + qkv = ( + self.qkv(hidden_states) + .reshape(batch_size, height * width, 3, self.num_attention_heads, -1) + .permute(2, 0, 3, 1, 4) + ) + # q, k, v with shape (batch_size * nHead, height * width, channel) + query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0) + + attn_weights = (query * self.scale) @ key.transpose(-2, -1) + + if self.use_rel_pos: + attn_weights = self.add_decomposed_rel_pos( + attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width) + ) + + attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype) + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1) + attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1) + + attn_output = self.proj(attn_output) + + if output_attentions: + outputs = (attn_output, attn_weights) + else: + outputs = (attn_output, None) + + return outputs + + +# Copied from transformers.models.sam.modeling_sam.SamVisionLayer with Sam->EfficientSam +class EfficientSamVisionLayer(nn.Module): + def __init__(self, config, window_size): + super().__init__() + self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.attn = EfficientSamVisionAttention(config, window_size) + self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.mlp = EfficientSamMLPBlock(config) + self.window_size = window_size + + def window_partition(self, hidden_states: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: + """ + Args: + Partition into non-overlapping windows with padding if needed. + hidden_states (tensor): input tokens with [batch_size, height, width, channel]. window_size (int): window + size. + + Returns: + windows: windows after partition with [batch_size * num_windows, window_size, window_size, channel]. + (pad_height, pad_width): padded height and width before partition + """ + batch_size, height, width, channel = hidden_states.shape + + pad_h = (window_size - height % window_size) % window_size + pad_w = (window_size - width % window_size) % window_size + hidden_states = F.pad(hidden_states, (0, 0, 0, pad_w, 0, pad_h)) + pad_height, pad_width = height + pad_h, width + pad_w + + hidden_states = hidden_states.reshape( + batch_size, pad_height // window_size, window_size, pad_width // window_size, window_size, channel + ) + windows = hidden_states.permute(0, 1, 3, 2, 4, 5).contiguous().reshape(-1, window_size, window_size, channel) + return windows, (pad_height, pad_width) + + def window_unpartition( + self, windows: torch.Tensor, window_size: int, padding_shape: Tuple[int, int], original_shape: Tuple[int, int] + ) -> torch.Tensor: + """ + Args: + Window unpartition into original sequences and removing padding. + hidden_states (tensor): + input tokens with [batch_size * num_windows, window_size, window_size, channel]. + window_size (int): + window size. + padding_shape (Tuple): + padded height and width (pad_height, pad_width). + original_shape (Tuple): original height and width (height, width) before padding. + + Returns: + hidden_states: unpartitioned sequences with [batch_size, height, width, channel]. + """ + pad_height, pad_width = padding_shape + height, width = original_shape + batch_size = windows.shape[0] // (pad_height * pad_width // window_size // window_size) + hidden_states = windows.reshape( + batch_size, pad_height // window_size, pad_width // window_size, window_size, window_size, -1 + ) + hidden_states = ( + hidden_states.permute(0, 1, 3, 2, 4, 5).contiguous().reshape(batch_size, pad_height, pad_width, -1) + ) + + hidden_states = hidden_states[:, :height, :width, :].contiguous() + return hidden_states + + def forward( + self, + hidden_states: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor]: + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + # Window partition + if self.window_size > 0: + height, width = hidden_states.shape[1], hidden_states.shape[2] + hidden_states, padding_shape = self.window_partition(hidden_states, self.window_size) + + hidden_states, attn_weights = self.attn( + hidden_states=hidden_states, + output_attentions=output_attentions, + ) + # Reverse window partition + if self.window_size > 0: + hidden_states = self.window_unpartition(hidden_states, self.window_size, padding_shape, (height, width)) + + hidden_states = residual + hidden_states + layernorm_output = self.layer_norm2(hidden_states) + hidden_states = hidden_states + self.mlp(layernorm_output) + + outputs = (hidden_states,) + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +# Copied from transformers.models.sam.modeling_sam.SamVisionNeck with Sam->EfficientSam +class EfficientSamVisionNeck(nn.Module): + def __init__(self, config: EfficientSamVisionConfig): + super().__init__() + self.config = config + + self.conv1 = nn.Conv2d(config.hidden_size, config.output_channels, kernel_size=1, bias=False) + self.layer_norm1 = EfficientSamLayerNorm(config.output_channels, data_format="channels_first") + self.conv2 = nn.Conv2d(config.output_channels, config.output_channels, kernel_size=3, padding=1, bias=False) + self.layer_norm2 = EfficientSamLayerNorm(config.output_channels, data_format="channels_first") + + def forward(self, hidden_states): + hidden_states = hidden_states.permute(0, 3, 1, 2) + hidden_states = self.conv1(hidden_states) + hidden_states = self.layer_norm1(hidden_states) + + hidden_states = self.conv2(hidden_states) + hidden_states = self.layer_norm2(hidden_states) + return hidden_states + + +# Copied from transformers.models.sam.modeling_sam.SamVisionEncoder with Sam->EfficientSam +class EfficientSamVisionEncoder(nn.Module): + def __init__(self, config: EfficientSamVisionConfig): + super().__init__() + self.config = config + self.image_size = config.image_size + + self.patch_embed = EfficientSamPatchEmbeddings(config) + + self.pos_embed = None + if config.use_abs_pos: + # Initialize absolute positional embedding with pretrain image size. + self.pos_embed = nn.Parameter( + torch.zeros( + 1, + config.image_size // config.patch_size, + config.image_size // config.patch_size, + config.hidden_size, + ) + ) + + self.layers = nn.ModuleList() + for i in range(config.num_hidden_layers): + layer = EfficientSamVisionLayer( + config, + window_size=config.window_size if i not in config.global_attn_indexes else 0, + ) + self.layers.append(layer) + + self.neck = EfficientSamVisionNeck(config) + + self.gradient_checkpointing = False + + def get_input_embeddings(self): + return self.patch_embed + + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, EfficientSamVisionEncoderOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.patch_embed(pixel_values) + if self.pos_embed is not None: + hidden_states = hidden_states + self.pos_embed + + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + ) + else: + layer_outputs = layer_module(hidden_states, output_attentions=output_attentions) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + hidden_states = self.neck(hidden_states) + + if not return_dict: + outputs = (hidden_states,) + if output_hidden_states: + outputs = outputs + (all_hidden_states,) + if output_attentions: + outputs = outputs + (all_self_attentions,) + return outputs + + return EfficientSamVisionEncoderOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +# Copied from transformers.models.sam.modeling_sam.SamPreTrainedModel with Sam->EfficientSam,sam->efficientsam +class EfficientSamPreTrainedModel(PreTrainedModel): + config_class = EfficientSamConfig + base_model_prefix = "efficientsam" + main_input_name = "pixel_values" + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +EFFICIENTSAM_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`EfficientSamConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +EFFICIENTSAM_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`SamProcessor`]. See [`SamProcessor.__call__`] for + details. + input_points (`torch.FloatTensor` of shape `(batch_size, num_points, 2)`): + Input 2D spatial points, this is used by the prompt encoder to encode the prompt. Generally yields to much + better results. The points can be obtained by passing a list of list of list to the processor that will + create corresponding `torch` tensors of dimension 4. The first dimension is the image batch size, the + second dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict + per input point), the third dimension is the number of points per segmentation mask (it is possible to pass + multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal) + coordinates of the point. If a different number of points is passed either for each image, or for each + mask, the processor will create "PAD" points that will correspond to the (0, 0) coordinate, and the + computation of the embedding will be skipped for these points using the labels. + input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points)`): + Input labels for the points, this is used by the prompt encoder to encode the prompt. According to the + official implementation, there are 3 types of labels + + - `1`: the point is a point that contains the object of interest + - `0`: the point is a point that does not contain the object of interest + - `-1`: the point corresponds to the background + + We added the label: + + - `-10`: the point is a padding point, thus should be ignored by the prompt encoder + + The padding labels should be automatically done by the processor. + input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes, 4)`): + Input boxes for the points, this is used by the prompt encoder to encode the prompt. Generally yields to + much better generated masks. The boxes can be obtained by passing a list of list of list to the processor, + that will generate a `torch` tensor, with each dimension corresponding respectively to the image batch + size, the number of boxes per image and the coordinates of the top left and botton right point of the box. + In the order (`x1`, `y1`, `x2`, `y2`): + + - `x1`: the x coordinate of the top left point of the input box + - `y1`: the y coordinate of the top left point of the input box + - `x2`: the x coordinate of the bottom right point of the input box + - `y2`: the y coordinate of the bottom right point of the input box + + input_masks (`torch.FloatTensor` of shape `(batch_size, image_size, image_size)`): + EFFICIENTSAM model also accepts segmentation masks as input. The mask will be embedded by the prompt encoder to + generate a corresponding embedding, that will be fed later on to the mask decoder. These masks needs to be + manually fed by the user, and they need to be of shape (`batch_size`, `image_size`, `image_size`). + + image_embeddings (`torch.FloatTensor` of shape `(batch_size, output_channels, window_size, window_size)`): + Image embeddings, this is used by the mask decder to generate masks and iou scores. For more memory + efficient computation, users can first retrieve the image embeddings using the `get_image_embeddings` + method, and then feed them to the `forward` method instead of feeding the `pixel_values`. + multimask_output (`bool`, *optional*): + In the original implementation and paper, the model always outputs 3 masks per image (or per point / per + bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the + "best" mask, by specifying `multimask_output=False`. + attention_similarity (`torch.FloatTensor`, *optional*): + Attention similarity tensor, to be provided to the mask decoder for target-guided attention in case the + model is used for personalization as introduced in [PerEFFICIENTSAM](https://arxiv.org/abs/2305.03048). + target_embedding (`torch.FloatTensor`, *optional*): + Embedding of the target concept, to be provided to the mask decoder for target-semantic prompting in case + the model is used for personalization as introduced in [PerEFFICIENTSAM](https://arxiv.org/abs/2305.03048). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "Segment Anything Model (EFFICIENTSAM) for generating segmentation masks, given an input image and ", + " optional 2D location and bounding boxes.", + EFFICIENTSAM_START_DOCSTRING, +) +# Copied from transformers.models.sam.modeling_sam.SamModel with SAM->EFFICIENTSAM,Sam->EfficientSam,sam->efficientsam +class EfficientSamModel(EfficientSamPreTrainedModel): + _tied_weights_keys = ["prompt_encoder.shared_embedding.positional_embedding"] + + def __init__(self, config): + super().__init__(config) + self.shared_image_embedding = EfficientSamPositionalEmbedding(config.vision_config) + + self.vision_encoder = EfficientSamVisionEncoder(config.vision_config) + self.prompt_encoder = EfficientSamPromptEncoder(config.prompt_encoder_config, self.shared_image_embedding) + self.mask_decoder = EfficientSamMaskDecoder(config.mask_decoder_config) + + self.post_init() + + def get_input_embeddings(self): + return self.vision_encoder.get_input_embeddings() + + def get_image_wide_positional_embeddings(self): + size = self.config.prompt_encoder_config.image_embedding_size + target_device = self.shared_image_embedding.positional_embedding.device + target_dtype = self.shared_image_embedding.positional_embedding.dtype + grid = torch.ones((size, size), device=target_device, dtype=target_dtype) + y_embed = grid.cumsum(dim=0) - 0.5 + x_embed = grid.cumsum(dim=1) - 0.5 + y_embed = y_embed / size + x_embed = x_embed / size + + positional_embedding = self.shared_image_embedding(torch.stack([x_embed, y_embed], dim=-1)) + return positional_embedding.permute(2, 0, 1).unsqueeze(0) # channel x height x width + + @torch.no_grad() + def get_image_embeddings( + self, + pixel_values, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + r""" + Returns the image embeddings by passing the pixel values through the vision encoder. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Input pixel values + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + + """ + vision_output = self.vision_encoder( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + image_embeddings = vision_output[0] + return image_embeddings + + @torch.no_grad() + def get_prompt_embeddings( + self, + input_points: Optional[torch.FloatTensor] = None, + input_labels: Optional[torch.LongTensor] = None, + input_boxes: Optional[torch.FloatTensor] = None, + input_masks: Optional[torch.LongTensor] = None, + ): + r""" + Returns the prompt embeddings by passing the input points, labels, boxes and masks through the prompt encoder. + + Args: + input_points (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_points_per_image, 2)`): + Optional input points for the prompt encoder. The padding of the point is automatically done by the + processor. `point_batch_size` refers to the number of masks that we want the model to predict per + point. The model will output `point_batch_size` times 3 masks in total. + input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points_per_image)`): + Optional input labels for the prompt encoder. The padding of the labels is automatically done by the + processor, or can be fed by the user. + input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes_per_image, 4)`): + Optional input boxes for the prompt encoder. The padding of the boxes is automatically done by the + processor. users can also pass manually the input boxes. + input_masks (`torch.LongTensor` of shape `(batch_size, image_size, image_size)`): + Optional input masks for the prompt encoder. + """ + prompt_output = self.prompt_encoder( + input_points=input_points, + input_labels=input_labels, + input_boxes=input_boxes, + input_masks=input_masks, + ) + return prompt_output + + @add_start_docstrings_to_model_forward(EFFICIENTSAM_INPUTS_DOCSTRING) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + input_points: Optional[torch.FloatTensor] = None, + input_labels: Optional[torch.LongTensor] = None, + input_boxes: Optional[torch.FloatTensor] = None, + input_masks: Optional[torch.LongTensor] = None, + image_embeddings: Optional[torch.FloatTensor] = None, + multimask_output: bool = True, + attention_similarity: Optional[torch.FloatTensor] = None, + target_embedding: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> List[Dict[str, torch.Tensor]]: + r""" + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoModel, AutoProcessor + + >>> model = AutoModel.from_pretrained("facebook/efficientsam-vit-base") + >>> processor = AutoProcessor.from_pretrained("facebook/efficientsam-vit-base") + + >>> img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/efficientsam-car.png" + >>> raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") + >>> input_points = [[[400, 650]]] # 2D location of a window on the car + >>> inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt") + + >>> # Get segmentation mask + >>> outputs = model(**inputs) + + >>> # Postprocess masks + >>> masks = processor.post_process_masks( + ... outputs.pred_masks, inputs["original_sizes"], inputs["reshaped_input_sizes"] + ... ) + ``` + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None and image_embeddings is None: + raise ValueError("Either pixel_values or image_embeddings must be provided.") + + if pixel_values is not None and image_embeddings is not None: + raise ValueError("Only one of pixel_values and image_embeddings can be provided.") + + if input_points is not None and len(input_points.shape) != 4: + raise ValueError( + "The input_points must be a 4D tensor. Of shape `batch_size`, `point_batch_size`, `nb_points_per_image`, `2`.", + " got {}.".format(input_points.shape), + ) + if input_boxes is not None and len(input_boxes.shape) != 3: + raise ValueError( + "The input_points must be a 3D tensor. Of shape `batch_size`, `nb_boxes`, `4`.", + " got {}.".format(input_boxes.shape), + ) + if input_points is not None and input_boxes is not None: + point_batch_size = input_points.shape[1] + box_batch_size = input_boxes.shape[1] + if point_batch_size != box_batch_size: + raise ValueError( + "You should provide as many bounding boxes as input points per box. Got {} and {}.".format( + point_batch_size, box_batch_size + ) + ) + + image_positional_embeddings = self.get_image_wide_positional_embeddings() + # repeat with batch size + batch_size = pixel_values.shape[0] if pixel_values is not None else image_embeddings.shape[0] + image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1) + + vision_attentions = None + vision_hidden_states = None + + if pixel_values is not None: + vision_outputs = self.vision_encoder( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + image_embeddings = vision_outputs[0] + + if output_hidden_states: + vision_hidden_states = vision_outputs[1] + if output_attentions: + vision_attentions = vision_outputs[-1] + + if input_points is not None and input_labels is None: + input_labels = torch.ones_like(input_points[:, :, :, 0], dtype=torch.int, device=input_points.device) + + if input_points is not None and image_embeddings.shape[0] != input_points.shape[0]: + raise ValueError( + "The batch size of the image embeddings and the input points must be the efficientsame. ", + "Got {} and {} respectively.".format(image_embeddings.shape[0], input_points.shape[0]), + " if you want to pass multiple points for the efficientsame image, make sure that you passed ", + " input_points of shape (batch_size, point_batch_size, num_points_per_image, 3) and ", + " input_labels of shape (batch_size, point_batch_size, num_points_per_image)", + ) + + sparse_embeddings, dense_embeddings = self.prompt_encoder( + input_points=input_points, + input_labels=input_labels, + input_boxes=input_boxes, + input_masks=input_masks, + ) + + low_res_masks, iou_predictions, mask_decoder_attentions = self.mask_decoder( + image_embeddings=image_embeddings, + image_positional_embeddings=image_positional_embeddings, + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + attention_similarity=attention_similarity, + target_embedding=target_embedding, + output_attentions=output_attentions, + ) + + if not return_dict: + output = (iou_predictions, low_res_masks) + if output_hidden_states: + output = output + (vision_hidden_states,) + + if output_attentions: + output = output + (vision_attentions, mask_decoder_attentions) + return output + + return EfficientSamImageSegmentationOutput( + iou_scores=iou_predictions, + pred_masks=low_res_masks, + vision_hidden_states=vision_hidden_states, + vision_attentions=vision_attentions, + mask_decoder_attentions=mask_decoder_attentions, + ) diff --git a/tests/models/efficientsam/test_modeling_efficientsam.py b/tests/models/efficientsam/test_modeling_efficientsam.py new file mode 100644 index 00000000000000..c39596f5552d77 --- /dev/null +++ b/tests/models/efficientsam/test_modeling_efficientsam.py @@ -0,0 +1,735 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Testing suite for the PyTorch EFFICIENTSAM model. """ + + +import gc +import unittest + +import requests + +from transformers import EfficientSamConfig, EfficientSamMaskDecoderConfig, EfficientSamPromptEncoderConfig, EfficientSamVisionConfig, pipeline +from transformers.testing_utils import backend_empty_cache, require_torch, slow, torch_device +from transformers.utils import is_torch_available, is_vision_available + +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, floats_tensor +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_torch_available(): + import torch + from torch import nn + + from transformers import EfficientSamModel, SamProcessor + from transformers.models.efficientsam.modeling_efficientsam import EFFICIENTSAM_PRETRAINED_MODEL_ARCHIVE_LIST + + +if is_vision_available(): + from PIL import Image + + +class EfficientSamPromptEncoderTester: + def __init__( + self, + hidden_size=32, + input_image_size=24, + patch_size=2, + mask_input_channels=4, + num_point_embeddings=4, + hidden_act="gelu", + ): + self.hidden_size = hidden_size + self.input_image_size = input_image_size + self.patch_size = patch_size + self.mask_input_channels = mask_input_channels + self.num_point_embeddings = num_point_embeddings + self.hidden_act = hidden_act + + def get_config(self): + return EfficientSamPromptEncoderConfig( + image_size=self.input_image_size, + patch_size=self.patch_size, + mask_input_channels=self.mask_input_channels, + hidden_size=self.hidden_size, + num_point_embeddings=self.num_point_embeddings, + hidden_act=self.hidden_act, + ) + + def prepare_config_and_inputs(self): + dummy_points = floats_tensor([self.batch_size, 3, 2]) + config = self.get_config() + + return config, dummy_points + + +class EfficientSamMaskDecoderTester: + def __init__( + self, + hidden_size=32, + hidden_act="relu", + mlp_dim=64, + num_hidden_layers=2, + num_attention_heads=4, + attention_downefficientsample_rate=2, + num_multimask_outputs=3, + iou_head_depth=3, + iou_head_hidden_dim=32, + layer_norm_eps=1e-6, + ): + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.mlp_dim = mlp_dim + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.attention_downefficientsample_rate = attention_downefficientsample_rate + self.num_multimask_outputs = num_multimask_outputs + self.iou_head_depth = iou_head_depth + self.iou_head_hidden_dim = iou_head_hidden_dim + self.layer_norm_eps = layer_norm_eps + + def get_config(self): + return EfficientSamMaskDecoderConfig( + hidden_size=self.hidden_size, + hidden_act=self.hidden_act, + mlp_dim=self.mlp_dim, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + attention_downefficientsample_rate=self.attention_downefficientsample_rate, + num_multimask_outputs=self.num_multimask_outputs, + iou_head_depth=self.iou_head_depth, + iou_head_hidden_dim=self.iou_head_hidden_dim, + layer_norm_eps=self.layer_norm_eps, + ) + + def prepare_config_and_inputs(self): + config = self.get_config() + + dummy_inputs = { + "image_embedding": floats_tensor([self.batch_size, self.hidden_size]), + } + + return config, dummy_inputs + + +class EfficientSamModelTester: + def __init__( + self, + parent, + hidden_size=36, + intermediate_size=72, + projection_dim=62, + output_channels=32, + num_hidden_layers=2, + num_attention_heads=4, + num_channels=3, + image_size=24, + patch_size=2, + hidden_act="gelu", + layer_norm_eps=1e-06, + dropout=0.0, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + qkv_bias=True, + mlp_ratio=4.0, + use_abs_pos=True, + use_rel_pos=True, + rel_pos_zero_init=False, + window_size=14, + global_attn_indexes=[2, 5, 8, 11], + num_pos_feats=16, + mlp_dim=None, + batch_size=2, + ): + self.parent = parent + self.image_size = image_size + self.patch_size = patch_size + self.output_channels = output_channels + self.num_channels = num_channels + self.hidden_size = hidden_size + self.projection_dim = projection_dim + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.dropout = dropout + self.attention_dropout = attention_dropout + self.initializer_range = initializer_range + self.initializer_factor = initializer_factor + self.hidden_act = hidden_act + self.layer_norm_eps = layer_norm_eps + self.qkv_bias = qkv_bias + self.mlp_ratio = mlp_ratio + self.use_abs_pos = use_abs_pos + self.use_rel_pos = use_rel_pos + self.rel_pos_zero_init = rel_pos_zero_init + self.window_size = window_size + self.global_attn_indexes = global_attn_indexes + self.num_pos_feats = num_pos_feats + self.mlp_dim = mlp_dim + self.batch_size = batch_size + + # in ViT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token) + num_patches = (image_size // patch_size) ** 2 + self.seq_length = num_patches + 1 + + self.prompt_encoder_tester = EfficientSamPromptEncoderTester() + self.mask_decoder_tester = EfficientSamMaskDecoderTester() + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + config = self.get_config() + + return config, pixel_values + + def get_config(self): + vision_config = EfficientSamVisionConfig( + image_size=self.image_size, + patch_size=self.patch_size, + num_channels=self.num_channels, + hidden_size=self.hidden_size, + projection_dim=self.projection_dim, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + dropout=self.dropout, + attention_dropout=self.attention_dropout, + initializer_range=self.initializer_range, + initializer_factor=self.initializer_factor, + output_channels=self.output_channels, + qkv_bias=self.qkv_bias, + mlp_ratio=self.mlp_ratio, + use_abs_pos=self.use_abs_pos, + use_rel_pos=self.use_rel_pos, + rel_pos_zero_init=self.rel_pos_zero_init, + window_size=self.window_size, + global_attn_indexes=self.global_attn_indexes, + num_pos_feats=self.num_pos_feats, + mlp_dim=self.mlp_dim, + ) + + prompt_encoder_config = self.prompt_encoder_tester.get_config() + + mask_decoder_config = self.mask_decoder_tester.get_config() + + return EfficientSamConfig( + vision_config=vision_config, + prompt_encoder_config=prompt_encoder_config, + mask_decoder_config=mask_decoder_config, + ) + + def create_and_check_model(self, config, pixel_values): + model = EfficientSamModel(config=config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + result = model(pixel_values) + self.parent.assertEqual(result.iou_scores.shape, (self.batch_size, 1, 3)) + self.parent.assertEqual(result.pred_masks.shape[:3], (self.batch_size, 1, 3)) + + def create_and_check_get_image_features(self, config, pixel_values): + model = EfficientSamModel(config=config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + result = model.get_image_embeddings(pixel_values) + self.parent.assertEqual(result[0].shape, (self.output_channels, 12, 12)) + + def create_and_check_get_image_hidden_states(self, config, pixel_values): + model = EfficientSamModel(config=config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + result = model.vision_encoder( + pixel_values, + output_hidden_states=True, + return_dict=True, + ) + + # after computing the convolutional features + expected_hidden_states_shape = (self.batch_size, 12, 12, 36) + self.parent.assertEqual(len(result[1]), self.num_hidden_layers + 1) + self.parent.assertEqual(result[1][0].shape, expected_hidden_states_shape) + + with torch.no_grad(): + result = model.vision_encoder( + pixel_values, + output_hidden_states=True, + return_dict=False, + ) + + # after computing the convolutional features + expected_hidden_states_shape = (self.batch_size, 12, 12, 36) + self.parent.assertEqual(len(result[1]), self.num_hidden_layers + 1) + self.parent.assertEqual(result[1][0].shape, expected_hidden_states_shape) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values = config_and_inputs + inputs_dict = {"pixel_values": pixel_values} + return config, inputs_dict + + +@require_torch +class EfficientSamModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): + """ + Here we also overwrite some of the tests of test_modeling_common.py, as EFFICIENTSAM's vision encoder does not use input_ids, inputs_embeds, + attention_mask and seq_length. + """ + + all_model_classes = (EfficientSamModel,) if is_torch_available() else () + fx_compatible = False + test_pruning = False + test_resize_embeddings = False + test_head_masking = False + test_torchscript = False + + @unittest.skip(reason="EFFICIENTSAM's vision encoder does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + def test_model_common_attributes(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + self.assertIsInstance(model.get_input_embeddings(), (nn.Module)) + x = model.get_output_embeddings() + self.assertTrue(x is None or isinstance(x, nn.Linear)) + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_get_image_features(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_get_image_features(*config_and_inputs) + + def test_image_hidden_states(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_get_image_hidden_states(*config_and_inputs) + + def test_attention_outputs(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + expected_vision_attention_shape = ( + self.model_tester.batch_size * self.model_tester.num_attention_heads, + 196, + 196, + ) + expected_mask_decoder_attention_shape = (self.model_tester.batch_size, 1, 144, 32) + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False + config.return_dict = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + vision_attentions = outputs.vision_attentions + self.assertEqual(len(vision_attentions), self.model_tester.num_hidden_layers) + + mask_decoder_attentions = outputs.mask_decoder_attentions + self.assertEqual(len(mask_decoder_attentions), self.model_tester.mask_decoder_tester.num_hidden_layers) + + # check that output_attentions also work using config + del inputs_dict["output_attentions"] + config.output_attentions = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + vision_attentions = outputs.vision_attentions + self.assertEqual(len(vision_attentions), self.model_tester.num_hidden_layers) + + mask_decoder_attentions = outputs.mask_decoder_attentions + self.assertEqual(len(mask_decoder_attentions), self.model_tester.mask_decoder_tester.num_hidden_layers) + + self.assertListEqual( + list(vision_attentions[0].shape[-4:]), + list(expected_vision_attention_shape), + ) + + self.assertListEqual( + list(mask_decoder_attentions[0].shape[-4:]), + list(expected_mask_decoder_attention_shape), + ) + + @unittest.skip(reason="EfficientSamModel does not support training") + def test_training(self): + pass + + @unittest.skip(reason="EfficientSamModel does not support training") + def test_training_gradient_checkpointing(self): + pass + + @unittest.skip( + reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" + ) + def test_training_gradient_checkpointing_use_reentrant(self): + pass + + @unittest.skip( + reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" + ) + def test_training_gradient_checkpointing_use_reentrant_false(self): + pass + + @unittest.skip(reason="EfficientSamModel has no base class and is not available in MODEL_MAPPING") + def test_save_load_fast_init_from_base(self): + pass + + @unittest.skip(reason="EfficientSamModel has no base class and is not available in MODEL_MAPPING") + def test_save_load_fast_init_to_base(self): + pass + + @unittest.skip(reason="EfficientSamModel does not support training") + def test_retain_grad_hidden_states_attentions(self): + pass + + @unittest.skip(reason="Hidden_states is tested in create_and_check_model tests") + def test_hidden_states_output(self): + pass + + def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=5e-5, name="outputs", attributes=None): + # Use a slightly higher default tol to make the tests non-flaky + super().check_pt_tf_outputs(tf_outputs, pt_outputs, model_class, tol=tol, name=name, attributes=attributes) + + @slow + def test_model_from_pretrained(self): + for model_name in EFFICIENTSAM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: + model = EfficientSamModel.from_pretrained(model_name) + self.assertIsNotNone(model) + + +def prepare_image(): + img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png" + raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") + return raw_image + + +def prepare_dog_img(): + img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/dog-efficientsam.png" + raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") + return raw_image + + +@slow +class EfficientSamModelIntegrationTest(unittest.TestCase): + def tearDown(self): + super().tearDown() + # clean-up as much as possible GPU memory occupied by PyTorch + gc.collect() + backend_empty_cache(torch_device) + + def test_inference_mask_generation_no_point(self): + model = EfficientSamModel.from_pretrained("facebook/efficientsam-vit-base") + processor = SamProcessor.from_pretrained("facebook/efficientsam-vit-base") + + model.to(torch_device) + model.eval() + + raw_image = prepare_image() + inputs = processor(images=raw_image, return_tensors="pt").to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs) + scores = outputs.iou_scores.squeeze() + masks = outputs.pred_masks[0, 0, 0, 0, :3] + self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.4515), atol=2e-4)) + self.assertTrue(torch.allclose(masks, torch.tensor([-4.1800, -3.4948, -3.4481]).to(torch_device), atol=2e-4)) + + def test_inference_mask_generation_one_point_one_bb(self): + model = EfficientSamModel.from_pretrained("facebook/efficientsam-vit-base") + processor = SamProcessor.from_pretrained("facebook/efficientsam-vit-base") + + model.to(torch_device) + model.eval() + + raw_image = prepare_image() + input_boxes = [[[650, 900, 1000, 1250]]] + input_points = [[[820, 1080]]] + + inputs = processor( + images=raw_image, input_boxes=input_boxes, input_points=input_points, return_tensors="pt" + ).to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs) + scores = outputs.iou_scores.squeeze() + masks = outputs.pred_masks[0, 0, 0, 0, :3] + self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9566), atol=2e-4)) + self.assertTrue( + torch.allclose(masks, torch.tensor([-12.7729, -12.3665, -12.6061]).to(torch_device), atol=2e-4) + ) + + def test_inference_mask_generation_batched_points_batched_images(self): + model = EfficientSamModel.from_pretrained("facebook/efficientsam-vit-base") + processor = SamProcessor.from_pretrained("facebook/efficientsam-vit-base") + + model.to(torch_device) + model.eval() + + raw_image = prepare_image() + input_points = [ + [[[820, 1080]], [[820, 1080]], [[820, 1080]], [[820, 1080]]], + [[[510, 1080]], [[820, 1080]], [[820, 1080]], [[820, 1080]]], + ] + + inputs = processor(images=[raw_image, raw_image], input_points=input_points, return_tensors="pt").to( + torch_device + ) + + with torch.no_grad(): + outputs = model(**inputs) + scores = outputs.iou_scores.squeeze().cpu() + masks = outputs.pred_masks[0, 0, 0, 0, :3].cpu() + + EXPECTED_SCORES = torch.tensor( + [ + [ + [0.6765, 0.9379, 0.8803], + [0.6765, 0.9379, 0.8803], + [0.6765, 0.9379, 0.8803], + [0.6765, 0.9379, 0.8803], + ], + [ + [0.3317, 0.7264, 0.7646], + [0.6765, 0.9379, 0.8803], + [0.6765, 0.9379, 0.8803], + [0.6765, 0.9379, 0.8803], + ], + ] + ) + EXPECTED_MASKS = torch.tensor([-2.8550, -2.7988, -2.9625]) + self.assertTrue(torch.allclose(scores, EXPECTED_SCORES, atol=1e-3)) + self.assertTrue(torch.allclose(masks, EXPECTED_MASKS, atol=1e-3)) + + def test_inference_mask_generation_one_point_one_bb_zero(self): + model = EfficientSamModel.from_pretrained("facebook/efficientsam-vit-base") + processor = SamProcessor.from_pretrained("facebook/efficientsam-vit-base") + + model.to(torch_device) + model.eval() + + raw_image = prepare_image() + input_boxes = [[[620, 900, 1000, 1255]]] + input_points = [[[820, 1080]]] + labels = [[0]] + + inputs = processor( + images=raw_image, + input_boxes=input_boxes, + input_points=input_points, + input_labels=labels, + return_tensors="pt", + ).to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs) + scores = outputs.iou_scores.squeeze() + + self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.7894), atol=1e-4)) + + def test_inference_mask_generation_one_point(self): + model = EfficientSamModel.from_pretrained("facebook/efficientsam-vit-base") + processor = SamProcessor.from_pretrained("facebook/efficientsam-vit-base") + + model.to(torch_device) + model.eval() + + raw_image = prepare_image() + + input_points = [[[400, 650]]] + input_labels = [[1]] + + inputs = processor( + images=raw_image, input_points=input_points, input_labels=input_labels, return_tensors="pt" + ).to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs) + scores = outputs.iou_scores.squeeze() + self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9675), atol=1e-4)) + + # With no label + input_points = [[[400, 650]]] + + inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt").to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs) + scores = outputs.iou_scores.squeeze() + self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9675), atol=1e-4)) + + def test_inference_mask_generation_two_points(self): + model = EfficientSamModel.from_pretrained("facebook/efficientsam-vit-base") + processor = SamProcessor.from_pretrained("facebook/efficientsam-vit-base") + + model.to(torch_device) + model.eval() + + raw_image = prepare_image() + + input_points = [[[400, 650], [800, 650]]] + input_labels = [[1, 1]] + + inputs = processor( + images=raw_image, input_points=input_points, input_labels=input_labels, return_tensors="pt" + ).to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs) + scores = outputs.iou_scores.squeeze() + self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9762), atol=1e-4)) + + # no labels + inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt").to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs) + scores = outputs.iou_scores.squeeze() + + self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9762), atol=1e-4)) + + def test_inference_mask_generation_two_points_batched(self): + model = EfficientSamModel.from_pretrained("facebook/efficientsam-vit-base") + processor = SamProcessor.from_pretrained("facebook/efficientsam-vit-base") + + model.to(torch_device) + model.eval() + + raw_image = prepare_image() + + input_points = [[[400, 650], [800, 650]], [[400, 650]]] + input_labels = [[1, 1], [1]] + + inputs = processor( + images=[raw_image, raw_image], input_points=input_points, input_labels=input_labels, return_tensors="pt" + ).to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs) + scores = outputs.iou_scores.squeeze() + self.assertTrue(torch.allclose(scores[0][-1], torch.tensor(0.9762), atol=1e-4)) + self.assertTrue(torch.allclose(scores[1][-1], torch.tensor(0.9637), atol=1e-4)) + + def test_inference_mask_generation_one_box(self): + model = EfficientSamModel.from_pretrained("facebook/efficientsam-vit-base") + processor = SamProcessor.from_pretrained("facebook/efficientsam-vit-base") + + model.to(torch_device) + model.eval() + + raw_image = prepare_image() + + input_boxes = [[[75, 275, 1725, 850]]] + + inputs = processor(images=raw_image, input_boxes=input_boxes, return_tensors="pt").to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs) + scores = outputs.iou_scores.squeeze() + self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.7937), atol=1e-4)) + + def test_inference_mask_generation_batched_image_one_point(self): + model = EfficientSamModel.from_pretrained("facebook/efficientsam-vit-base") + processor = SamProcessor.from_pretrained("facebook/efficientsam-vit-base") + + model.to(torch_device) + model.eval() + + raw_image = prepare_image() + raw_dog_image = prepare_dog_img() + + input_points = [[[820, 1080]], [[220, 470]]] + + inputs = processor(images=[raw_image, raw_dog_image], input_points=input_points, return_tensors="pt").to( + torch_device + ) + + with torch.no_grad(): + outputs = model(**inputs) + scores_batched = outputs.iou_scores.squeeze() + + input_points = [[[220, 470]]] + + inputs = processor(images=raw_dog_image, input_points=input_points, return_tensors="pt").to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs) + scores_single = outputs.iou_scores.squeeze() + self.assertTrue(torch.allclose(scores_batched[1, :], scores_single, atol=1e-4)) + + def test_inference_mask_generation_two_points_point_batch(self): + model = EfficientSamModel.from_pretrained("facebook/efficientsam-vit-base") + processor = SamProcessor.from_pretrained("facebook/efficientsam-vit-base") + + model.to(torch_device) + model.eval() + + raw_image = prepare_image() + + input_points = torch.Tensor([[[400, 650]], [[220, 470]]]).cpu() # fmt: skip + + input_points = input_points.unsqueeze(0) + + inputs = processor(raw_image, input_points=input_points, return_tensors="pt").to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs) + + iou_scores = outputs.iou_scores.cpu() + self.assertTrue(iou_scores.shape == (1, 2, 3)) + torch.testing.assert_allclose( + iou_scores, torch.tensor([[[0.9105, 0.9825, 0.9675], [0.7646, 0.7943, 0.7774]]]), atol=1e-4, rtol=1e-4 + ) + + def test_inference_mask_generation_three_boxes_point_batch(self): + model = EfficientSamModel.from_pretrained("facebook/efficientsam-vit-base") + processor = SamProcessor.from_pretrained("facebook/efficientsam-vit-base") + + model.to(torch_device) + model.eval() + + raw_image = prepare_image() + + # fmt: off + input_boxes = torch.Tensor([[[620, 900, 1000, 1255]], [[75, 275, 1725, 850]], [[75, 275, 1725, 850]]]).cpu() + EXPECTED_IOU = torch.tensor([[[0.9773, 0.9881, 0.9522], + [0.5996, 0.7661, 0.7937], + [0.5996, 0.7661, 0.7937]]]) + # fmt: on + input_boxes = input_boxes.unsqueeze(0) + + inputs = processor(raw_image, input_boxes=input_boxes, return_tensors="pt").to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs) + + iou_scores = outputs.iou_scores.cpu() + self.assertTrue(iou_scores.shape == (1, 3, 3)) + torch.testing.assert_allclose(iou_scores, EXPECTED_IOU, atol=1e-4, rtol=1e-4) + + def test_dummy_pipeline_generation(self): + generator = pipeline("mask-generation", model="facebook/efficientsam-vit-base", device=torch_device) + raw_image = prepare_image() + + _ = generator(raw_image, points_per_batch=64)