diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 0d2b752d5ad93f..97b64ebdb31b8b 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -902,6 +902,8 @@ title: Qwen2VL - local: model_doc/sam title: Segment Anything + - local: model_doc/sam_hq + title: Segment Anything High Quality - local: model_doc/siglip title: SigLIP - local: model_doc/speech-encoder-decoder diff --git a/docs/source/en/index.md b/docs/source/en/index.md index 8a9ccf45b69c26..6ed17c19a5cfce 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -290,6 +290,7 @@ Flax), PyTorch, and/or TensorFlow. | [RT-DETR-ResNet](model_doc/rt_detr_resnet) | ✅ | ❌ | ❌ | | [RWKV](model_doc/rwkv) | ✅ | ❌ | ❌ | | [SAM](model_doc/sam) | ✅ | ✅ | ❌ | +| [SAM_HQ](model_doc/sam_hq) | ✅ | ❌ | ❌ | | [SeamlessM4T](model_doc/seamless_m4t) | ✅ | ❌ | ❌ | | [SeamlessM4Tv2](model_doc/seamless_m4t_v2) | ✅ | ❌ | ❌ | | [SegFormer](model_doc/segformer) | ✅ | ✅ | ❌ | diff --git a/docs/source/en/model_doc/sam_hq.md b/docs/source/en/model_doc/sam_hq.md new file mode 100644 index 00000000000000..e77072292b2620 --- /dev/null +++ b/docs/source/en/model_doc/sam_hq.md @@ -0,0 +1,122 @@ +# SAM-HQ + +## Overview + +SAM-HQ (High-Quality Segment Anything Model) was proposed in [Segment Anything in High Quality](https://arxiv.org/pdf/2306.01567.pdf) by Lei Ke, Mingqiao Ye, Martin Danelljan, Yifan Liu, Yu-Wing Tai, Chi-Keung Tang, Fisher Yu. + +The model is an enhancement to the original SAM model that produces significantly higher quality segmentation masks while maintaining SAM's original promptable design, efficiency, and zero-shot generalizability. + +![example image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/sam-output.png) + +The abstract from the paper is the following: + +*The recent Segment Anything Model (SAM) represents a big leap in scaling up segmentation models, allowing for powerful zero-shot capabilities and flexible prompting. Despite being trained with 1.1 billion masks, SAM's mask prediction quality falls short in many cases, particularly when dealing with objects that have intricate structures. We propose HQ-SAM, equipping SAM with the ability to accurately segment any object, while maintaining SAM's original promptable design, efficiency, and zero-shot generalizability. Our careful design reuses and preserves the pre-trained model weights of SAM, while only introducing minimal additional parameters and computation. We design a learnable High-Quality Output Token, which is injected into SAM's mask decoder and is responsible for predicting the high-quality mask. Instead of only applying it on mask-decoder features, we first fuse them with early and final ViT features for improved mask details. To train our introduced learnable parameters, we compose a dataset of 44K fine-grained masks from several sources. HQ-SAM is only trained on the introduced dataset of 44k masks, which takes only 4 hours on 8 GPUs.* + +Tips: + +- SAM-HQ produces higher quality masks than the original SAM model, particularly for objects with intricate structures and fine details +- The model predicts binary masks with more accurate boundaries and better handling of thin structures +- Like SAM, the model performs better with input 2D points and/or input bounding boxes +- You can prompt multiple points for the same image and predict a single high-quality mask +- The model maintains SAM's zero-shot generalization capabilities +- SAM-HQ only adds ~0.5% additional parameters compared to SAM +- Fine-tuning the model is not supported yet + +This model was contributed by [sushmanth](https://huggingface.co/sushmanth). +The original code can be found [here](https://github.com/SysCV/SAM-HQ). + +Below is an example on how to run mask generation given an image and a 2D point: + +```python +import torch +from PIL import Image +import requests +from transformers import SamHQModel, SamHQProcessor + +device = "cuda" if torch.cuda.is_available() else "cpu" +model = SamHQModel.from_pretrained("sushmanth/sam_hq_vit_b").to(device) +processor = SamHQProcessor.from_pretrained("sushmanth/sam_hq_vit_b") + +img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png" +raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") +input_points = [[[450, 600]]] # 2D location of a window in the image + +inputs = processor(raw_image, input_points=input_points, return_tensors="pt").to(device) +with torch.no_grad(): + outputs = model(**inputs) + +masks = processor.image_processor.post_process_masks( + outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu() +) +scores = outputs.iou_scores +``` + +You can also process your own masks alongside the input images in the processor to be passed to the model: + +```python +import torch +from PIL import Image +import requests +from transformers import SamHQModel, SamHQProcessor + +device = "cuda" if torch.cuda.is_available() else "cpu" +model = SamHQModel.from_pretrained("sushmanth/sam_hq_vit_b").to(device) +processor = SamHQProcessor.from_pretrained("sushmanth/sam_hq_vit_b") + +img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png" +raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") +mask_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png" +segmentation_map = Image.open(requests.get(mask_url, stream=True).raw).convert("1") +input_points = [[[450, 600]]] # 2D location of a window in the image + +inputs = processor(raw_image, input_points=input_points, segmentation_maps=segmentation_map, return_tensors="pt").to(device) +with torch.no_grad(): + outputs = model(**inputs) + +masks = processor.image_processor.post_process_masks( + outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu() +) +scores = outputs.iou_scores +``` + +## Key Differences from SAM + +SAM-HQ introduces several key improvements over the original SAM model: + +1. High-Quality Output Token: A learnable token injected into SAM's mask decoder for higher quality mask prediction +2. Global-local Feature Fusion: Combines features from different stages of the model for improved mask details +3. Training Data: Uses a carefully curated dataset of 44K high-quality masks instead of SA-1B +4. Efficiency: Adds only 0.5% additional parameters while significantly improving mask quality +5. Zero-shot Capability: Maintains SAM's strong zero-shot performance while improving accuracy + +## Resources + +A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with SAM-HQ: + +- Demo notebook for using the model (coming soon) +- Paper implementation and code: [SAM-HQ GitHub Repository](https://github.com/SysCV/SAM-HQ) + +## SamHQConfig + +[[autodoc]] SamHQConfig + +## SamHQVisionConfig + +[[autodoc]] SamHQVisionConfig + +## SamHQMaskDecoderConfig + +[[autodoc]] SamHQMaskDecoderConfig + +## SamHQPromptEncoderConfig + +[[autodoc]] SamHQPromptEncoderConfig + +## SamHQProcessor + +[[autodoc]] SamHQProcessor + +## SamHQModel + +[[autodoc]] SamHQModel + - forward \ No newline at end of file diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index e1ca1956807318..7387a6d4e1cb5f 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -723,6 +723,13 @@ "SamPromptEncoderConfig", "SamVisionConfig", ], + "models.sam_hq": [ + "SamHQConfig", + "SamHQMaskDecoderConfig", + "SamHQProcessor", + "SamHQVisionConfig", + "SamHQPromptEncoderConfig", + ], "models.seamless_m4t": [ "SeamlessM4TConfig", "SeamlessM4TFeatureExtractor", @@ -1233,6 +1240,7 @@ _import_structure["models.qwen2_vl"].extend(["Qwen2VLImageProcessor"]) _import_structure["models.rt_detr"].extend(["RTDetrImageProcessor"]) _import_structure["models.sam"].extend(["SamImageProcessor"]) + _import_structure["models.sam"].extend(["SamImageProcessor"]) _import_structure["models.segformer"].extend(["SegformerFeatureExtractor", "SegformerImageProcessor"]) _import_structure["models.seggpt"].extend(["SegGptImageProcessor"]) _import_structure["models.siglip"].append("SiglipImageProcessor") @@ -3295,6 +3303,12 @@ "SamPreTrainedModel", ] ) + _import_structure["models.sam_hq"].extend( + [ + "SamHQModel", + "SamHQPreTrainedModel", + ] + ) _import_structure["models.seamless_m4t"].extend( [ "SeamlessM4TCodeHifiGan", @@ -5638,6 +5652,13 @@ SamPromptEncoderConfig, SamVisionConfig, ) + from .models.sam_hq import ( + SamHQConfig, + SamHQProcessor, + SamHQMaskDecoderConfig, + SamHQPromptEncoderConfig, + SamHQVisionConfig, + ) from .models.seamless_m4t import ( SeamlessM4TConfig, SeamlessM4TFeatureExtractor, @@ -6167,6 +6188,7 @@ from .models.qwen2_vl import Qwen2VLImageProcessor from .models.rt_detr import RTDetrImageProcessor from .models.sam import SamImageProcessor + from .models.sam import SamImageProcessor from .models.segformer import SegformerFeatureExtractor, SegformerImageProcessor from .models.seggpt import SegGptImageProcessor from .models.siglip import SiglipImageProcessor @@ -7827,6 +7849,10 @@ SamModel, SamPreTrainedModel, ) + from .models.sam_hq import ( + SamHQModel, + SamHQPreTrainedModel, + ) from .models.seamless_m4t import ( SeamlessM4TCodeHifiGan, SeamlessM4TForSpeechToSpeech, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 2d2a3b41d4378b..b78114df77c45d 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -221,6 +221,7 @@ rt_detr, rwkv, sam, + sam_hq, seamless_m4t, seamless_m4t_v2, segformer, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 4ab6d392282657..50e4a71766d7f6 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -244,6 +244,7 @@ ("rt_detr_resnet", "RTDetrResNetConfig"), ("rwkv", "RwkvConfig"), ("sam", "SamConfig"), + ("sam_hq", "SamHQConfig"), ("seamless_m4t", "SeamlessM4TConfig"), ("seamless_m4t_v2", "SeamlessM4Tv2Config"), ("segformer", "SegformerConfig"), @@ -561,6 +562,7 @@ ("rt_detr_resnet", "RT-DETR-ResNet"), ("rwkv", "RWKV"), ("sam", "SAM"), + ("sam_hq", "SAM-HQ"), ("seamless_m4t", "SeamlessM4T"), ("seamless_m4t_v2", "SeamlessM4Tv2"), ("segformer", "SegFormer"), diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index 11ae15ca461e79..8b68ba05ff6078 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -126,6 +126,7 @@ ("resnet", ("ConvNextImageProcessor",)), ("rt_detr", ("RTDetrImageProcessor", "RTDetrImageProcessorFast")), ("sam", ("SamImageProcessor",)), + ("sam_hq", ("SamImageProcessor",)), ("segformer", ("SegformerImageProcessor",)), ("seggpt", ("SegGptImageProcessor",)), ("siglip", ("SiglipImageProcessor",)), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 2c519a7dc42ca5..d680e28c1edea9 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -226,6 +226,7 @@ ("rt_detr", "RTDetrModel"), ("rwkv", "RwkvModel"), ("sam", "SamModel"), + ("sam_hq", "SamHQModel"), ("seamless_m4t", "SeamlessM4TModel"), ("seamless_m4t_v2", "SeamlessM4Tv2Model"), ("segformer", "SegformerModel"), @@ -1367,6 +1368,12 @@ ] ) +MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict( + [ + ("sam_hq", "SamHQModel"), + ] +) + MODEL_FOR_KEYPOINT_DETECTION_MAPPING_NAMES = OrderedDict( [ diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index c1f23bc1cb3f18..3b53efbfbd8435 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -89,6 +89,7 @@ ("qwen2_audio", "Qwen2AudioProcessor"), ("qwen2_vl", "Qwen2VLProcessor"), ("sam", "SamProcessor"), + ("sam-hq", "SamHQProcessor"), ("seamless_m4t", "SeamlessM4TProcessor"), ("sew", "Wav2Vec2Processor"), ("sew-d", "Wav2Vec2Processor"), diff --git a/src/transformers/models/sam_hq/__init__.py b/src/transformers/models/sam_hq/__init__.py new file mode 100644 index 00000000000000..516aef0219a498 --- /dev/null +++ b/src/transformers/models/sam_hq/__init__.py @@ -0,0 +1,74 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + OptionalDependencyNotAvailable, + _LazyModule, + is_torch_available, + is_vision_available, +) + + +__import_structure = { + "configuration_sam_hq": [ + "SamHQConfig", + "SamHQMaskDecoderConfig", + "SamHQPromptEncoderConfig", + "SamHQVisionConfig", + ], + "processing_samhq": ["SamHQProcessor"], +} + + + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + __import_structure["modeling_sam_hq"] = [ + "SamHQModel", + "SamHQPreTrainedModel", + ] + +try: + if not is_vision_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + __import_structure["image_processing_sam"] = ["SamImageProcessor"] + + +if TYPE_CHECKING: + from .configuration_sam_hq import ( + SamHQConfig, + SamHQMaskDecoderConfig, + SamHQPromptEncoderConfig, + SamHQVisionConfig, + ) + from .processing_samhq import SamHQProcessor + + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_sam_hq import SamHQModel, SamHQPreTrainedModel + + + try: + if not is_vision_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from transformers.models.sam.image_processing_sam import SamImageProcessor + + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], __import_structure, module_spec=__spec__) \ No newline at end of file diff --git a/src/transformers/models/sam_hq/configuration_sam_hq.py b/src/transformers/models/sam_hq/configuration_sam_hq.py new file mode 100644 index 00000000000000..d701493ff3ecda --- /dev/null +++ b/src/transformers/models/sam_hq/configuration_sam_hq.py @@ -0,0 +1,303 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SAM-HQ configuration.""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + +#Copied from transformer.models.sam.configuration_sam.SamPromptEncoderConfig with Sam->SamHQ +class SamHQPromptEncoderConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`SamHQPromptEncoderModel`].The [`SamHQPromptEncoderModel`] + module is used to encode the input 2D points and bounding boxes. Instantiating a configuration defaults will yield a + similar configuration to that of the SAM_HQ model. The configuration is used to store the configuration of the model. + [Uminosachi/sam-hq](https://huggingface.co/Uminosachi/sam-hq) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model's output.Read the documentation from + [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 256): + Dimensionality of the hidden states. + image_size (`int`, *optional*, defaults to 1024): + The expected output resolution of the image. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + mask_input_channels (`int`, *optional*, defaults to 16): + The number of channels to be fed to the `MaskDecoder` module. + num_point_embeddings (`int`, *optional*, defaults to 4): + The number of point embeddings to be used. + hidden_act (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function in the encoder and pooler. + """ + + def __init__( + self, + hidden_size=256, + image_size=1024, + patch_size=16, + mask_input_channels=16, + num_point_embeddings=4, + hidden_act="gelu", + layer_norm_eps=1e-6, + **kwargs, + ): + super().__init__(**kwargs) + self.hidden_size = hidden_size + self.image_size = image_size + self.patch_size = patch_size + self.image_embeddings_size = image_size // patch_size + self.mask_input_channels = mask_input_channels + self.num_point_embeddings = num_point_embeddings + self.hidden_act = hidden_act + self.layer_norm_eps = layer_norm_eps + +# Copied from transformer.models.sam.configuration_sam.SamVisionConfig with Sam->SamHQ +class SamHQVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`SamHQVisionModel`]. It is used to instantiate a + SAM HQ vision encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with defaults will yield a configuration similar to that of the SAM HQ Vision model. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + output_channels (`int`, *optional*, defaults to 256): + Dimensionality of the output channels in the Patch Encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + num_channels (`int`, *optional*, defaults to 3): + Number of channels in the input image. + image_size (`int`, *optional*, defaults to 1024): + Expected resolution. Target size of the resized input image. + patch_size (`int`, *optional*, defaults to 16): + Size of the patches to be extracted from the input image. + hidden_act (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string). + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 1e-10): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + qkv_bias (`bool`, *optional*, defaults to `True`): + Whether to add a bias to query, key, value projections. + mlp_ratio (`float`, *optional*, defaults to 4.0): + Ratio of mlp hidden dim to embedding dim. + use_abs_pos (`bool`, *optional*, defaults to `True`): + Whether to use absolute position embedding. + use_rel_pos (`bool`, *optional*, defaults to `True`): + Whether to use relative position embedding. + window_size (`int`, *optional*, defaults to 14): + Window size for relative position. + global_attn_indexes (`List[int]`, *optional*, defaults to `[2, 5, 8, 11]`): + The indexes of the global attention layers. + num_pos_feats (`int`, *optional*, defaults to 128): + The dimensionality of the position embedding. + mlp_dim (`int`, *optional*): + The dimensionality of the MLP layer in the Transformer encoder. If `None`, defaults to `mlp_ratio * + hidden_size`. + """ + + def __init__( + self, + hidden_size=768, + output_channels=256, + num_hidden_layers=12, + num_attention_heads=12, + num_channels=3, + image_size=1024, + patch_size=16, + hidden_act="gelu", + layer_norm_eps=1e-6, + attention_dropout=0.0, + initializer_range=1e-10, + qkv_bias=True, + mlp_ratio=4.0, + use_abs_pos=True, + use_rel_pos=True, + window_size=14, + global_attn_indexes=[2, 5, 8, 11], + num_pos_feats=128, + mlp_dim=None, + **kwargs, + ): + super().__init__(**kwargs) + self.hidden_size = hidden_size + self.output_channels = output_channels + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.image_size = image_size + self.patch_size = patch_size + self.hidden_act = hidden_act + self.layer_norm_eps = layer_norm_eps + self.attention_dropout = attention_dropout + self.initializer_range = initializer_range + self.qkv_bias = qkv_bias + self.mlp_ratio = mlp_ratio + self.use_abs_pos = use_abs_pos + self.use_rel_pos = use_rel_pos + self.window_size = window_size + self.global_attn_indexes = global_attn_indexes + self.num_pos_feats = num_pos_feats + self.mlp_dim = int(hidden_size * mlp_ratio) if mlp_dim is None else mlp_dim + + +class SamHQMaskDecoderConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`SamHQMaskDecoder`]. It is used to instantiate a + SAM HQ mask decoder with the specified arguments, defining the model architecture. Instantiating a configuration + with defaults will yield a configuration similar to that of the SAM HQ variant. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 256): + Dimensionality of the hidden states. + hidden_act (`str`, *optional*, defaults to `"relu"`): + The non-linear activation function used inside the `SamHQMaskDecoder` module. + mlp_dim (`int`, *optional*, defaults to 2048): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 2): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer encoder. + attention_downsample_rate (`int`, *optional*, defaults to 2): + The downsampling rate of the attention layer. + num_multimask_outputs (`int`, *optional*, defaults to 3): + The number of outputs from the `SamHQMaskDecoder` module. In the Segment Anything paper, this is set to 3. + iou_head_depth (`int`, *optional*, defaults to 3): + The number of layers in the IoU head module. + iou_head_hidden_dim (`int`, *optional*, defaults to 256): + The dimensionality of the hidden states in the IoU head module. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + vit_dim (`int`, *optional*, defaults to 768): + Dimensionality of the Vision Transformer (ViT) used in the `SamHQMaskDecoder` module. + """ + + def __init__( + self, + hidden_size=256, + hidden_act="relu", + mlp_dim=2048, + num_hidden_layers=2, + num_attention_heads=8, + attention_downsample_rate=2, + num_multimask_outputs=3, + iou_head_depth=3, + iou_head_hidden_dim=256, + layer_norm_eps=1e-6, + vit_dim=768, + **kwargs, + ): + super().__init__(**kwargs) + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.mlp_dim = mlp_dim + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.attention_downsample_rate = attention_downsample_rate + self.num_multimask_outputs = num_multimask_outputs + self.iou_head_depth = iou_head_depth + self.iou_head_hidden_dim = iou_head_hidden_dim + self.layer_norm_eps = layer_norm_eps + self.vit_dim = vit_dim + + +# Copied from transformer.models.sam.configuration_sam.SamConfig with Sam->SamHQ +class SamHQConfig(PretrainedConfig): + r""" + [`SamHQConfig`] is the configuration class to store the configuration of a [`SamHQModel`]. It is used to instantiate a + SAM-HQ model according to the specified arguments, defining the vision model, prompt-encoder model and mask decoder + configs. Instantiating a configuration with the defaults will yield a similar configuration to that of the + SAM-HQ-ViT-H [sushmanth/sam_hq_vit_h](https://huggingface.co/sushmanth/sam_hq_vit_h) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vision_config (Union[`dict`, `SamHQVisionConfig`], *optional*): + Dictionary of configuration options used to initialize [`SamHQVisionConfig`]. + prompt_encoder_config (Union[`dict`, `SamHQPromptEncoderConfig`], *optional*): + Dictionary of configuration options used to initialize [`SamHQPromptEncoderConfig`]. + mask_decoder_config (Union[`dict`, `SamHQMaskDecoderConfig`], *optional*): + Dictionary of configuration options used to initialize [`SamHQMaskDecoderConfig`]. + kwargs (*optional*): + Dictionary of keyword arguments. + + Example: + ```python + >>> from transformers import ( + ... SamHQVisionConfig, + ... SamHQPromptEncoderConfig, + ... SamHQMaskDecoderConfig, + ... SamHQModel, + ... ) + + >>> # Initializing a SamHQConfig with `"sushmanth/sam_hq_vit_h"` style configuration + >>> configuration = SamHQConfig() + + >>> # Initializing a SamHQModel (with random weights) from the `"sushmanth/sam_hq_vit_h"` style configuration + >>> model = SamHQModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + + >>> # We can also initialize a SamHQConfig from a SamHQVisionConfig, SamHQPromptEncoderConfig, and SamHQMaskDecoderConfig + >>> # Initializing SAM-HQ vision, prompt encoder and mask decoder configurations + >>> vision_config = SamHQVisionConfig() + >>> prompt_encoder_config = SamHQPromptEncoderConfig() + >>> mask_decoder_config = SamHQMaskDecoderConfig() + >>> config = SamHQConfig(vision_config, prompt_encoder_config, mask_decoder_config) + ```""" + + model_type = "sam-hq" + + def __init__( + self, + vision_config=None, + prompt_encoder_config=None, + mask_decoder_config=None, + initializer_range=0.02, + **kwargs, + ): + super().__init__(**kwargs) + vision_config = vision_config if vision_config is not None else {} + prompt_encoder_config = prompt_encoder_config if prompt_encoder_config is not None else {} + mask_decoder_config = mask_decoder_config if mask_decoder_config is not None else {} + + if isinstance(vision_config, SamHQVisionConfig): + vision_config = vision_config.to_dict() + if isinstance(prompt_encoder_config, SamHQPromptEncoderConfig): + prompt_encoder_config = prompt_encoder_config.to_dict() + if isinstance(mask_decoder_config, SamHQMaskDecoderConfig): + mask_decoder_config = mask_decoder_config.to_dict() + + self.vision_config = SamHQVisionConfig(**vision_config) + self.prompt_encoder_config = SamHQPromptEncoderConfig(**prompt_encoder_config) + self.mask_decoder_config = SamHQMaskDecoderConfig(**mask_decoder_config) + self.initializer_range = initializer_range \ No newline at end of file diff --git a/src/transformers/models/sam_hq/convert_samhq_to_hf.py b/src/transformers/models/sam_hq/convert_samhq_to_hf.py new file mode 100644 index 00000000000000..eea9ceb7f3a33b --- /dev/null +++ b/src/transformers/models/sam_hq/convert_samhq_to_hf.py @@ -0,0 +1,262 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Convert SAM-HQ checkpoints from the original repository. + +URL: https://github.com/SysCV/sam-hq + +""" + +import argparse +import re + +import numpy as np +import requests +import torch +from huggingface_hub import hf_hub_download +from PIL import Image + +from transformers import SamHQConfig, SamHQModel, SamHQVisionConfig, SamHQProcessor, SamImageProcessor + + +def get_config(model_name): + if "sam_hq_vit_b" in model_name: + vision_config = SamHQVisionConfig() + vit_dim = 768 # Base model dimension + elif "sam_hq_vit_l" in model_name: + vision_config = SamHQVisionConfig( + hidden_size=1024, + num_hidden_layers=24, + num_attention_heads=16, + global_attn_indexes=[5, 11, 17, 23], + ) + vit_dim = 1024 # Large model dimension + elif "sam_hq_vit_h" in model_name: + vision_config = SamHQVisionConfig( + hidden_size=1280, + num_hidden_layers=32, + num_attention_heads=16, + global_attn_indexes=[7, 15, 23, 31], + ) + vit_dim = 1280 # Huge model dimension + + # Create mask decoder config with appropriate vit_dim + mask_decoder_config = { + "vit_dim": vit_dim + } + + config = SamHQConfig( + vision_config=vision_config, + mask_decoder_config=mask_decoder_config, + ) + + return config + +KEYS_TO_MODIFY_MAPPING = { + "iou_prediction_head.layers.0": "iou_prediction_head.proj_in", + "iou_prediction_head.layers.1": "iou_prediction_head.layers.0", + "iou_prediction_head.layers.2": "iou_prediction_head.proj_out", + "mask_decoder.output_upscaling.0": "mask_decoder.upscale_conv1", + "mask_decoder.output_upscaling.1": "mask_decoder.upscale_layer_norm", + "mask_decoder.output_upscaling.3": "mask_decoder.upscale_conv2", + "mask_downscaling.0": "mask_embed.conv1", + "mask_downscaling.1": "mask_embed.layer_norm1", + "mask_downscaling.3": "mask_embed.conv2", + "mask_downscaling.4": "mask_embed.layer_norm2", + "mask_downscaling.6": "mask_embed.conv3", + "point_embeddings": "point_embed", + "pe_layer.positional_encoding_gaussian_matrix": "shared_embedding.positional_embedding", + "image_encoder": "vision_encoder", + "neck.0": "neck.conv1", + "neck.1": "neck.layer_norm1", + "neck.2": "neck.conv2", + "neck.3": "neck.layer_norm2", + "patch_embed.proj": "patch_embed.projection", + ".norm": ".layer_norm", + "blocks": "layers", + # HQ-specific mappings + "mask_decoder.hf_token": "mask_decoder.hq_token", + "mask_decoder.compress_vit_feat.0": "mask_decoder.compress_vit_conv1", + "mask_decoder.compress_vit_feat.1": "mask_decoder.compress_vit_norm", + "mask_decoder.compress_vit_feat.3": "mask_decoder.compress_vit_conv2", + "mask_decoder.embedding_encoder.0": "mask_decoder.encoder_conv1", + "mask_decoder.embedding_encoder.1": "mask_decoder.encoder_norm", + "mask_decoder.embedding_encoder.3": "mask_decoder.encoder_conv2", + "mask_decoder.embedding_maskfeature.0": "mask_decoder.mask_conv1", + "mask_decoder.embedding_maskfeature.1": "mask_decoder.mask_norm", + "mask_decoder.embedding_maskfeature.3": "mask_decoder.mask_conv2", +} + +def replace_keys(state_dict): + model_state_dict = {} + state_dict.pop("pixel_mean", None) + state_dict.pop("pixel_std", None) + + output_hypernetworks_mlps_pattern = r".*.output_hypernetworks_mlps.(\d+).layers.(\d+).*" + hf_mlp_layers_pattern = r".*hf_mlp.layers.(\d+).*" + + for key, value in state_dict.items(): + new_key = key + for key_to_modify, replacement in KEYS_TO_MODIFY_MAPPING.items(): + if key_to_modify in new_key: + new_key = new_key.replace(key_to_modify, replacement) + + if re.match(output_hypernetworks_mlps_pattern, new_key): + layer_nb = int(re.match(output_hypernetworks_mlps_pattern, new_key).group(2)) + if layer_nb == 0: + new_key = new_key.replace("layers.0", "proj_in") + elif layer_nb == 1: + new_key = new_key.replace("layers.1", "layers.0") + elif layer_nb == 2: + new_key = new_key.replace("layers.2", "proj_out") + + # Handle HQ-specific MLP layers + if re.match(hf_mlp_layers_pattern, new_key): + layer_nb = int(re.match(hf_mlp_layers_pattern, new_key).group(1)) + if layer_nb == 0: + new_key = new_key.replace("layers.0", "proj_in") + elif layer_nb == 1: + new_key = new_key.replace("layers.1", "layers.0") + elif layer_nb == 2: + new_key = new_key.replace("layers.2", "proj_out") + + model_state_dict[new_key] = value + + model_state_dict["shared_image_embedding.positional_embedding"] = model_state_dict[ + "prompt_encoder.shared_embedding.positional_embedding" + ] + + return model_state_dict + +def convert_sam_hq_checkpoint(model_name, checkpoint_path, pytorch_dump_folder, push_to_hub): + config = get_config(model_name) + + state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True) + state_dict = replace_keys(state_dict) + + image_processor = SamImageProcessor() + processor = SamHQProcessor(image_processor=image_processor) + hf_model = SamHQModel(config) + hf_model.eval() + + device = "cuda" if torch.cuda.is_available() else "cpu" + + hf_model.load_state_dict(state_dict) + + hf_model = hf_model.to(device) + + # Test the model with a sample image + img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png" + raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") + + input_points = [[[500, 375]]] + input_labels = [[1]] + + # Basic test without prompts + inputs = processor(images=np.array(raw_image), return_tensors="pt").to(device) + + + + with torch.no_grad(): + hf_model(**inputs) + + if model_name == "sam_hq_vit_b": + inputs = processor( + images=np.array(raw_image), input_points=input_points, input_labels=input_labels, return_tensors="pt" + ).to(device) + + with torch.no_grad(): + hf_model(**inputs) + + elif model_name == "sam_hq_vit_h": + inputs = processor( + images=np.array(raw_image), input_points=input_points, input_labels=input_labels, return_tensors="pt" + ).to(device) + + with torch.no_grad(): + hf_model(**inputs) + + input_boxes = [[[75.0, 275.0, 1725.0, 850.0]]] + + + inputs = processor(images=np.array(raw_image), input_boxes=input_boxes, return_tensors="pt").to(device) + + with torch.no_grad(): + hf_model(**inputs) + + + input_points = [[[400, 650], [800, 650]]] + input_labels = [[1, 1]] + + inputs = processor( + images=np.array(raw_image), input_points=input_points, input_labels=input_labels, return_tensors="pt" + ).to(device) + + with torch.no_grad(): + output = hf_model(**inputs) + + + if pytorch_dump_folder is not None: + processor.save_pretrained(pytorch_dump_folder) + hf_model.save_pretrained(pytorch_dump_folder) + + if push_to_hub: + repo_id = f"sushmanth/{model_name}" + processor.push_to_hub(repo_id) + hf_model.push_to_hub(repo_id) + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + choices = ["sam_hq_vit_b", "sam_hq_vit_h", "sam_hq_vit_l"] + parser.add_argument( + "--model_name", + choices=choices, + type=str, + required=True, + help="Name of the SAM-HQ model to convert", + ) + parser.add_argument( + "--checkpoint_path", + type=str, + required=False, + help="Path to the SAM-HQ checkpoint (.pth file)", + ) + parser.add_argument( + "--pytorch_dump_folder_path", + type=str, + default=None, + help="Path to save the converted model", + ) + parser.add_argument( + "--push_to_hub", + action="store_true", + help="Whether to push the converted model to the hub", + ) + + args = parser.parse_args() + + checkpoint_path = args.checkpoint_path + if checkpoint_path is None: + checkpoint_path = hf_hub_download("Uminosachi/sam-hq", f"{args.model_name}.pth") + + convert_sam_hq_checkpoint( + args.model_name, + checkpoint_path, + args.pytorch_dump_folder_path, + args.push_to_hub, + ) \ No newline at end of file diff --git a/src/transformers/models/sam_hq/modeling_sam_hq.py b/src/transformers/models/sam_hq/modeling_sam_hq.py new file mode 100644 index 00000000000000..14e8508b254c81 --- /dev/null +++ b/src/transformers/models/sam_hq/modeling_sam_hq.py @@ -0,0 +1,1546 @@ +# coding=utf-8 +# Copyright 2023 The Meta AI Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch SAM-HQ model. """ + +import collections +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple, Union + + +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import Tensor, nn + + +from ...activations import ACT2FN +from ...modeling_outputs import BaseModelOutput +from ...modeling_utils import PreTrainedModel +from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging +from .configuration_sam_hq import SamHQConfig, SamHQVisionConfig, SamHQPromptEncoderConfig, SamHQMaskDecoderConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "SamHQConfig" +_CHECKPOINT_FOR_DOC = "Uminosachi/sam-hq" + + +@dataclass +class SamHQVisionEncoderOutput(ModelOutput): + """ + Base class for SAM-HQ vision model's outputs. + + Args: + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*): + The image embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + intermediate_embeddings (`list(torch.FloatTensor)`, *optional*): + A list of intermediate embeddings collected from certain blocks within the model, typically those without + windowed attention. Each element in the list is of shape `(batch_size, sequence_length, hidden_size)`. + hidden_states (`tuple(torch.FloatTensor)`, *optional*): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attention weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + image_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: torch.FloatTensor = None + intermediate_embeddings: Optional[List[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + +# Copied from transformers.models.sam.modeling_sam.SamImageSegmentationOutput with Sam->SamHQ +@dataclass +class SamHQImageSegmentationOutput(ModelOutput): + """ + Base class for SAM-HQ model's output + + Args: + iou_scores (`torch.FloatTensor` of shape `(batch_size, num_masks)`): + The iou scores of the predicted masks. + pred_masks (`torch.FloatTensor` of shape `(batch_size, num_masks, height, width)`): + The predicted high resolution masks. Needs to be post-processed by the processor. + vision_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + Hidden-states of the vision model at the output of each layer plus the optional initial embedding outputs. + vision_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. + mask_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. + """ + iou_scores: torch.FloatTensor = None + pred_masks: torch.FloatTensor = None + vision_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + vision_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + mask_decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + +# Copied from transformers.models.sam.modeling_sam.SamPatchEmbeddings with Sam->SamHQ +class SamHQPatchEmbeddings(nn.Module): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config): + super().__init__() + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size,image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + + self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) + + + def forward(self, pixel_values): + batch_size, num_channels, height, width = pixel_values.shape + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + if height != self.image_size[0] or width != self.image_size[1]: + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})." + ) + embeddings = self.projection(pixel_values).permute(0, 2, 3, 1) + return embeddings + + +# Copied from transformers.models.sam.modeling_sam.SamMLPBlock with Sam->SamHQ +class SamHQMLPBlock(nn.Module): + def __init__(self, config): + super().__init__() + self.lin1 = nn.Linear(config.hidden_size, config.mlp_dim) + self.lin2 = nn.Linear(config.mlp_dim, config.hidden_size) + self.act = ACT2FN[config.hidden_act] + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.lin1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.lin2(hidden_states) + return hidden_states + +# Copied from transformers.models.sam.modeling_sam.SamVisionAttention with Sam->SamHQ +class SamHQVisionAttention(nn.Module): + """Multi-head Attention block with relative position embeddings.""" + + def __init__(self, config, window_size): + super().__init__() + input_size = ( + (config.image_size // config.patch_size, config.image_size // config.patch_size) + if window_size == 0 + else (window_size, window_size) + ) + + self.num_attention_heads = config.num_attention_heads + head_dim = config.hidden_size // config.num_attention_heads + self.scale = head_dim**-0.5 + self.dropout = config.attention_dropout + + self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.qkv_bias) + self.proj = nn.Linear(config.hidden_size, config.hidden_size) + + self.use_rel_pos = config.use_rel_pos + if self.use_rel_pos: + if input_size is None: + raise ValueError("Input size must be provided if using relative positional encoding.") + + # initialize relative positional embeddings + self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) + + def get_rel_pos(self, q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + + Args: + q_size (int): + size of the query. + k_size (int): + size of key k. + rel_pos (`torch.Tensor`): + relative position embeddings (L, channel). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode="linear", + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + def add_decomposed_rel_pos( + self, + attn: torch.Tensor, + query: torch.Tensor, + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + q_size: Tuple[int, int], + k_size: Tuple[int, int], + ) -> torch.Tensor: + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py + + Args: + attn (`torch.Tensor`): + attention map. + query (`torch.Tensor`): + query q in the attention layer with shape (batch_size, query_height * query_width, channel). + rel_pos_h (`torch.Tensor`): + relative position embeddings (Lh, channel) for height axis. + rel_pos_w (`torch.Tensor`): + relative position embeddings (Lw, channel) for width axis. + q_size (tuple): + spatial sequence size of query q with (query_height, query_width). + k_size (tuple): + spatial sequence size of key k with (key_height, key_width). + + Returns: + attn (`torch.Tensor`): + attention map with added relative positional embeddings. + """ + query_height, query_width = q_size + key_height, key_width = k_size + relative_position_height = self.get_rel_pos(query_height, key_height, rel_pos_h) + relative_position_width = self.get_rel_pos(query_width, key_width, rel_pos_w) + + batch_size, _, dim = query.shape + reshaped_query = query.reshape(batch_size, query_height, query_width, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height) + rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width) + attn = attn.reshape(batch_size, query_height, query_width, key_height, key_width) + attn = attn + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] + attn = attn.reshape(batch_size, query_height * query_width, key_height * key_width) + return attn + + def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor: + batch_size, height, width, _ = hidden_states.shape + # qkv with shape (3, batch_size, nHead, height * width, channel) + qkv = ( + self.qkv(hidden_states) + .reshape(batch_size, height * width, 3, self.num_attention_heads, -1) + .permute(2, 0, 3, 1, 4) + ) + # q, k, v with shape (batch_size * nHead, height * width, channel) + query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0) + + attn_weights = (query * self.scale) @ key.transpose(-2, -1) + + if self.use_rel_pos: + attn_weights = self.add_decomposed_rel_pos( + attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width) + ) + + attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype) + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1) + attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1) + + attn_output = self.proj(attn_output) + + if output_attentions: + outputs = (attn_output, attn_weights) + else: + outputs = (attn_output, None) + + return outputs + + + +# Copied from transformers.models.sam.modeling_sam.SamVisionLayer with Sam->SamHQ +class SamHQVisionLayer(nn.Module): + def __init__(self, config, window_size): + super().__init__() + self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.attn = SamHQVisionAttention(config, window_size) + self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.mlp = SamHQMLPBlock(config) + self.window_size = window_size + + + def window_partition(self, hidden_states: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: + """ + Args: + Partition into non-overlapping windows with padding if needed. + hidden_states (tensor): input tokens with [batch_size, height, width, channel]. window_size (int): window + size. + + Returns: + windows: windows after partition with [batch_size * num_windows, window_size, window_size, channel]. + (pad_height, pad_width): padded height and width before partition + """ + batch_size, height, width, channel = hidden_states.shape + + pad_h = (window_size - height % window_size) % window_size + pad_w = (window_size - width % window_size) % window_size + hidden_states = F.pad(hidden_states, (0, 0, 0, pad_w, 0, pad_h)) + pad_height, pad_width = height + pad_h, width + pad_w + + hidden_states = hidden_states.reshape( + batch_size, pad_height // window_size, window_size, pad_width // window_size, window_size, channel + ) + windows = hidden_states.permute(0, 1, 3, 2, 4, 5).contiguous().reshape(-1, window_size, window_size, channel) + return windows, (pad_height, pad_width) + + + def window_unpartition( + self, windows: torch.Tensor, window_size: int, padding_shape: Tuple[int, int], original_shape: Tuple[int, int] + ) -> torch.Tensor: + """ + Args: + Window unpartition into original sequences and removing padding. + hidden_states (tensor): + input tokens with [batch_size * num_windows, window_size, window_size, channel]. + window_size (int): + window size. + padding_shape (Tuple): + padded height and width (pad_height, pad_width). + original_shape (Tuple): original height and width (height, width) before padding. + + Returns: + hidden_states: unpartitioned sequences with [batch_size, height, width, channel]. + """ + pad_height, pad_width = padding_shape + height, width = original_shape + batch_size = windows.shape[0] // (pad_height * pad_width // window_size // window_size) + hidden_states = windows.reshape( + batch_size, pad_height // window_size, pad_width // window_size, window_size, window_size, -1 + ) + hidden_states = ( + hidden_states.permute(0, 1, 3, 2, 4, 5).contiguous().reshape(batch_size, pad_height, pad_width, -1) + ) + + hidden_states = hidden_states[:, :height, :width, :].contiguous() + return hidden_states + + + def forward( + self, + hidden_states: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor]: + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + # Window partition + if self.window_size > 0: + height, width = hidden_states.shape[1], hidden_states.shape[2] + hidden_states, padding_shape = self.window_partition(hidden_states, self.window_size) + + hidden_states, attn_weights = self.attn( + hidden_states=hidden_states, + output_attentions=output_attentions, + ) + # Reverse window partition + if self.window_size > 0: + hidden_states = self.window_unpartition(hidden_states, self.window_size, padding_shape, (height, width)) + + hidden_states = residual + hidden_states + layernorm_output = self.layer_norm2(hidden_states) + hidden_states = hidden_states + self.mlp(layernorm_output) + + outputs = (hidden_states,) + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +# Copied from transformers.models.convnext.modeling_convnext.ConvNextLayerNorm with ConvNext->SamHQ +class SamHQLayerNorm(nn.Module): + r"""LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height, + width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width). + """ + + def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): + super().__init__() + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.eps = eps + self.data_format = data_format + if self.data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError(f"Unsupported data format: {self.data_format}") + self.normalized_shape = (normalized_shape,) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.data_format == "channels_last": + x = torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + elif self.data_format == "channels_first": + input_dtype = x.dtype + x = x.float() + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = x.to(dtype=input_dtype) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + +# Copied from transformers.models.sam.modeling_sam.SamVisionNeck with Sam -> SamHQ +class SamHQVisionNeck(nn.Module): + def __init__(self, config: SamHQVisionConfig): + super().__init__() + self.config = config + + self.conv1 = nn.Conv2d(config.hidden_size, config.output_channels, kernel_size=1, bias=False) + self.layer_norm1 = SamHQLayerNorm(config.output_channels, data_format="channels_first") + self.conv2 = nn.Conv2d(config.output_channels, config.output_channels, kernel_size=3, padding=1, bias=False) + self.layer_norm2 = SamHQLayerNorm(config.output_channels, data_format="channels_first") + + def forward(self, hidden_states): + hidden_states = hidden_states.permute(0, 3, 1, 2) + hidden_states = self.conv1(hidden_states) + hidden_states = self.layer_norm1(hidden_states) + + hidden_states = self.conv2(hidden_states) + hidden_states = self.layer_norm2(hidden_states) + return hidden_states + +class SamHQVisionEncoder(nn.Module): + def __init__(self, config: SamHQVisionConfig): + super().__init__() + self.config = config + self.image_size = config.image_size + + self.patch_embed = SamHQPatchEmbeddings(config) + + self.pos_embed = None + + if config.use_abs_pos: + self.pos_embed = nn.Parameter( + torch.zeros(1, + config.image_size // config.patch_size, + config.image_size // config.patch_size, + config.hidden_size, + ) + ) + + + self.layers = nn.ModuleList() + for i in range(config.num_hidden_layers): + layer = SamHQVisionLayer( + config, + window_size=config.window_size if i not in config.global_attn_indexes else 0, + ) + self.layers.append(layer) + + self.neck = SamHQVisionNeck(config) + + self.gradient_checkpointing = False + + + def get_input_embeddings(self): + return self.patch_embed + + + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SamHQVisionEncoderOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.patch_embed(pixel_values) + if self.pos_embed is not None: + hidden_states = hidden_states + self.pos_embed + + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + intermediate_embeddings = [] + + for i, layer_module in enumerate(self.layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + ) + else: + layer_outputs = layer_module(hidden_states, output_attentions= output_attentions) + + hidden_states = layer_outputs[0] + + # Collect embeddings from non-windowed blocks + if hasattr(layer_module, 'window_size') and layer_module.window_size == 0: + intermediate_embeddings.append(hidden_states) + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + + hidden_states = self.neck(hidden_states) + + if not return_dict: + outputs = (hidden_states, intermediate_embeddings) + if output_hidden_states: + outputs = outputs + (all_hidden_states,) + if output_attentions: + outputs = outputs + (all_self_attentions,) + return outputs + + return SamHQVisionEncoderOutput( + last_hidden_state=hidden_states, + intermediate_embeddings=intermediate_embeddings, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + +# Copied from transformers.models.sam.modeling_sam.SamMaskEmbedding with Sam->SamHQ +class SamHQMaskEmbedding(nn.Module): + def __init__(self, config: SamHQPromptEncoderConfig): + super().__init__() + self.mask_input_channels = config.mask_input_channels // 4 + self.activation = ACT2FN[config.hidden_act] + self.conv1 = nn.Conv2d(1, self.mask_input_channels, kernel_size=2, stride=2) + self.conv2 = nn.Conv2d(self.mask_input_channels, config.mask_input_channels, kernel_size=2, stride=2) + self.conv3 = nn.Conv2d(config.mask_input_channels, config.hidden_size, kernel_size=1) + self.layer_norm1 = SamHQLayerNorm( + self.mask_input_channels, eps=config.layer_norm_eps, data_format="channels_first" + ) + self.layer_norm2 = SamHQLayerNorm( + self.mask_input_channels * 4, eps=config.layer_norm_eps, data_format="channels_first" + ) + + def forward(self, masks): + hidden_states = self.conv1(masks) + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.activation(hidden_states) + + hidden_states = self.conv2(hidden_states) + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.activation(hidden_states) + dense_embeddings = self.conv3(hidden_states) + return dense_embeddings + +# Copied from transformers.models.sam.modeling_sam.SamPromptEncoder with Sam->SamHQ +class SamHQPromptEncoder(nn.Module): + def __init__(self, config: SamHQPromptEncoderConfig, shared_patch_embedding): + super().__init__() + self.shared_embedding = shared_patch_embedding + self.mask_embed = SamHQMaskEmbedding(config) + self.no_mask_embed = nn.Embedding(1, config.hidden_size) + + self.image_embedding_size = (config.image_embeddings_size, config.image_embeddings_size) + self.input_image_size = config.image_size + + self.point_embed = nn.ModuleList( + [nn.Embedding(1, config.hidden_size) for i in range(config.num_point_embeddings)] + ) + self.hidden_size = config.hidden_size + self.not_a_point_embed = nn.Embedding(1, config.hidden_size) + + + + def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -> torch.Tensor: + """Embeds point prompts.""" + points = points + 0.5 # Shift to center of pixel + if pad: + target_point_shape = (points.shape[0], points.shape[1], 1, points.shape[-1]) + target_labels_shape = (points.shape[0], points.shape[1], 1) + padding_point = torch.zeros(target_point_shape, device=points.device) + padding_label = -torch.ones(target_labels_shape, device=labels.device) + points = torch.cat([points, padding_point], dim=2) + labels = torch.cat([labels, padding_label], dim=2) + input_shape = (self.input_image_size, self.input_image_size) + point_embedding = self.shared_embedding(points, input_shape) + + # torch.where and expanding the labels tensor is required by the ONNX export + point_embedding = torch.where(labels[..., None] == -1, self.not_a_point_embed.weight, point_embedding) + + # This is required for the ONNX export. The dtype, device need to be explicitely + # specificed as otherwise torch.onnx.export interprets as double + point_embedding = torch.where( + labels[..., None] != -10, + point_embedding, + torch.tensor(0.0, dtype=point_embedding.dtype, device=point_embedding.device), + ) + + point_embedding = torch.where( + (labels == 0)[:, :, :, None], + point_embedding + self.point_embed[0].weight[None, None, :, :], + point_embedding, + ) + + point_embedding = torch.where( + (labels == 1)[:, :, :, None], + point_embedding + self.point_embed[1].weight[None, None, :, :], + point_embedding, + ) + + return point_embedding + + + + def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: + """Embeds box prompts.""" + boxes = boxes + 0.5 # Shift to center of pixel + batch_size, nb_boxes = boxes.shape[:2] + coords = boxes.reshape(batch_size, nb_boxes, 2, 2) + input_shape = (self.input_image_size, self.input_image_size) + corner_embedding = self.shared_embedding(coords, input_shape) + corner_embedding[:, :, 0, :] += self.point_embed[2].weight + corner_embedding[:, :, 1, :] += self.point_embed[3].weight + return corner_embedding + + + def forward( + self, + input_points: Optional[Tuple[torch.Tensor, torch.Tensor]], + input_labels: Optional[torch.Tensor], + input_boxes: Optional[torch.Tensor], + input_masks: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Embeds different types of prompts, returning both sparse and dense embeddings. + + Args: + points (`torch.Tensor`, *optional*): + point coordinates and labels to embed. + boxes (`torch.Tensor`, *optional*): + boxes to embed + masks (`torch.Tensor`, *optional*): + masks to embed + """ + sparse_embeddings = None + batch_size = 1 + target_device = self.shared_embedding.positional_embedding.device + if input_points is not None: + batch_size, point_batch_size = input_points.shape[:2] + if input_labels is None: + raise ValueError("If points are provided, labels must also be provided.") + point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None)) + sparse_embeddings = point_embeddings + if input_boxes is not None: + batch_size = input_boxes.shape[0] + box_embeddings = self._embed_boxes(input_boxes) + if sparse_embeddings is None: + sparse_embeddings = box_embeddings + else: + sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=2) + if input_masks is not None: + dense_embeddings = self.mask_embed(input_masks) + else: + dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + batch_size, -1, self.image_embedding_size[0], self.image_embedding_size[1] + ) + + if sparse_embeddings is None: + sparse_embeddings = torch.zeros((batch_size, 1, 1, self.hidden_size), device=target_device) + + return sparse_embeddings, dense_embeddings + +# Copied from transformers.models.sam.modeling_sam.SamAttention with Sam->SamHQ +class SamHQAttention(nn.Module): + """ + SAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and + values. + """ + + def __init__(self, config, downsample_rate=None): + super().__init__() + self.hidden_size = config.hidden_size + + downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate + + self.internal_dim = config.hidden_size // downsample_rate + self.num_attention_heads = config.num_attention_heads + if self.internal_dim % config.num_attention_heads != 0: + raise ValueError("num_attention_heads must divide hidden_size.") + + self.q_proj = nn.Linear(self.hidden_size, self.internal_dim) + self.k_proj = nn.Linear(self.hidden_size, self.internal_dim) + self.v_proj = nn.Linear(self.hidden_size, self.internal_dim) + self.out_proj = nn.Linear(self.internal_dim, self.hidden_size) + + def _separate_heads(self, hidden_states: Tensor, num_attention_heads: int) -> Tensor: + batch, point_batch_size, n_tokens, channel = hidden_states.shape + c_per_head = channel // num_attention_heads + hidden_states = hidden_states.reshape(batch * point_batch_size, n_tokens, num_attention_heads, c_per_head) + return hidden_states.transpose(1, 2) + + def _recombine_heads(self, hidden_states: Tensor, point_batch_size: int) -> Tensor: + batch, n_heads, n_tokens, c_per_head = hidden_states.shape + hidden_states = hidden_states.transpose(1, 2) + return hidden_states.reshape(batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head) + + def forward(self, query: Tensor, key: Tensor, value: Tensor, attention_similarity: Tensor = None) -> Tensor: + # Input projections + query = self.q_proj(query) + key = self.k_proj(key) + value = self.v_proj(value) + + point_batch_size = query.shape[1] + # Separate into heads + query = self._separate_heads(query, self.num_attention_heads) + key = self._separate_heads(key, self.num_attention_heads) + value = self._separate_heads(value, self.num_attention_heads) + + # SamAttention + _, _, _, c_per_head = query.shape + attn = query @ key.permute(0, 1, 3, 2) # batch_size * point_batch_size x N_heads x N_tokens x N_tokens + attn = attn / (c_per_head**0.5) + attn = torch.softmax(attn, dim=-1) + + if attention_similarity is not None: + attn = attn + attention_similarity + attn = torch.softmax(attn, dim=-1) + + # Get output + out = attn @ value + out = self._recombine_heads(out, point_batch_size) + out = self.out_proj(out) + + return out + +# Copied from transformer.models.sam.modeling_sam.SamTwoWayAttentionBlock with Sam->SamHQ +class SamHQTwoWayAttentionBlock(nn.Module): + def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False): + """ + A transformer block with four layers: + (1) self-attention of sparse inputs (2) cross attention of sparse inputs -> dense inputs (3) mlp block on + sparse inputs (4) cross attention of dense inputs -> sparse inputs + + Arguments: + config (`SamMaskDecoderConfig`): + The configuration file used to instantiate the block + attention_downsample_rate (*optionalk*, int, defaults to 2): + The downsample ratio of the block used to reduce the inner dim of the attention. + skip_first_layer_pe (*optional*, bool, defaults to `False`): + Whether or not to skip the addition of the query_point_embedding on the first layer. + """ + super().__init__() + + self.hidden_size = config.hidden_size + self.layer_norm_eps = config.layer_norm_eps + + self.self_attn = SamHQAttention(config, downsample_rate=1) + self.layer_norm1 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) + + self.cross_attn_token_to_image = SamHQAttention(config, downsample_rate=attention_downsample_rate) + self.layer_norm2 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) + + self.mlp = SamHQMLPBlock(config) + self.layer_norm3 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) + + self.layer_norm4 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) + self.cross_attn_image_to_token = SamHQAttention(config, downsample_rate=attention_downsample_rate) + + self.skip_first_layer_pe = skip_first_layer_pe + + def forward( + self, + queries: Tensor, + keys: Tensor, + query_point_embedding: Tensor, + key_point_embedding: Tensor, + attention_similarity: Tensor, + output_attentions: bool = False, + ): + # Self attention block + if self.skip_first_layer_pe: + queries = self.self_attn(query=queries, key=queries, value=queries) + else: + query = queries + query_point_embedding + attn_out = self.self_attn(query=query, key=query, value=queries) + queries = queries + attn_out + queries = self.layer_norm1(queries) + + # Cross attention block, tokens attending to image embedding + query = queries + query_point_embedding + key = keys + key_point_embedding + + attn_out = self.cross_attn_token_to_image( + query=query, key=key, value=keys, attention_similarity=attention_similarity + ) + queries = queries + attn_out + + queries = self.layer_norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.layer_norm3(queries) + + # Cross attention block, image embedding attending to tokens + query = queries + query_point_embedding + key = keys + key_point_embedding + + attn_out = self.cross_attn_image_to_token(query=key, key=query, value=queries) + keys = keys + attn_out + + keys = self.layer_norm4(keys) + + outputs = (queries, keys) + + if output_attentions: + outputs = outputs + (attn_out,) + else: + outputs = outputs + (None,) + + return outputs + + +# Copied from transformers.models.sam.modeling_sam.SamPositionalEmbedding with Sam->SamHQ +class SamHQPositionalEmbedding(nn.Module): + def __init__(self, config): + super().__init__() + self.scale = config.hidden_size // 2 + self.register_buffer("positional_embedding", self.scale * torch.randn((2, config.num_pos_feats))) + + def forward(self, input_coords, input_shape=None): + """Positionally encode points that are normalized to [0,1].""" + coordinates = input_coords.clone() + + if input_shape is not None: + coordinates[:, :, :, 0] = coordinates[:, :, :, 0] / input_shape[1] + coordinates[:, :, :, 1] = coordinates[:, :, :, 1] / input_shape[0] + + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coordinates = 2 * coordinates - 1 + coordinates = coordinates.to(self.positional_embedding.dtype) + coordinates = coordinates @ self.positional_embedding + coordinates = 2 * np.pi * coordinates + # outputs d_1 x ... x d_n x channel shape + return torch.cat([torch.sin(coordinates), torch.cos(coordinates)], dim=-1) + + +# Copied from transformers.models.sam.modeling_sam.SamTwoWayTransformer with Sam->SamHQ +class SamHQTwoWayTransformer(nn.Module): + def __init__(self, config: SamHQMaskDecoderConfig): + super().__init__() + self.config = config + + self.num_hidden_layers = config.num_hidden_layers + self.layers = nn.ModuleList() + + for i in range(self.num_hidden_layers): + self.layers.append(SamHQTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0))) + + self.final_attn_token_to_image = SamHQAttention(config) + self.layer_norm_final_attn = nn.LayerNorm(config.hidden_size) + + def forward( + self, + point_embeddings: Tensor, + image_embeddings: Tensor, + image_positional_embeddings: Tensor, + attention_similarity: Tensor, + target_embedding=None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + all_attentions = () + + if image_embeddings is None: + raise ValueError("You have to specify an image_embedding") + + image_embeddings = image_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1) + image_positional_embeddings = image_positional_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1) + + # Prepare queries + queries = point_embeddings + keys = image_embeddings + + # Apply transformer blocks and final layernorm + for layer in self.layers: + if target_embedding is not None: + queries += target_embedding + + queries, keys, attention_outputs = layer( + queries=queries, + keys=keys, + query_point_embedding=point_embeddings, + key_point_embedding=image_positional_embeddings, + attention_similarity=attention_similarity, + output_attentions=output_attentions, + ) + + if output_attentions: + all_attentions = all_attentions + (attention_outputs,) + + # Apply the final attenion layer from the points to the image + query = queries + point_embeddings + key = keys + image_positional_embeddings + + attn_out = self.final_attn_token_to_image(query=query, key=key, value=keys) + + queries = queries + attn_out + queries = self.layer_norm_final_attn(queries) + return queries, keys, all_attentions + +# Copied from transformers.models.sam.modeling_sam.SamFeedForward with Sam->SamHQ +class SamHQFeedForward(nn.Module): + def __init__( + self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, sigmoid_output: bool = False + ): + super().__init__() + self.num_layers = num_layers + self.activation = nn.ReLU() + self.proj_in = nn.Linear(input_dim, hidden_dim) + self.proj_out = nn.Linear(hidden_dim, output_dim) + self.layers = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers - 2)]) + self.sigmoid_output = sigmoid_output + + def forward(self, hidden_states): + hidden_states = self.proj_in(hidden_states) + hidden_states = self.activation(hidden_states) + for layer in self.layers: + hidden_states = self.activation(layer(hidden_states)) + + hidden_states = self.proj_out(hidden_states) + if self.sigmoid_output: + hidden_states = F.sigmoid(hidden_states) + return hidden_states + +class SamHQMaskDecoder(nn.Module): + def __init__(self, config: SamHQMaskDecoderConfig): + super().__init__() + self.hidden_size = config.hidden_size + + self.num_multimask_outputs = config.num_multimask_outputs + self.num_mask_tokens = config.num_multimask_outputs + 1 + + self.iou_token = nn.Embedding(1, self.hidden_size) + self.mask_tokens = nn.Embedding(self.num_mask_tokens, self.hidden_size) + + + self.transformer = SamHQTwoWayTransformer(config) + + self.upscale_conv1 = nn.ConvTranspose2d(self.hidden_size, self.hidden_size // 4, kernel_size=2, stride=2) + self.upscale_conv2 = nn.ConvTranspose2d(self.hidden_size // 4, self.hidden_size // 8, kernel_size=2, stride=2) + self.upscale_layer_norm = SamHQLayerNorm(self.hidden_size // 4, data_format="channels_first") + self.activation = nn.GELU() + + + mlps_list = [] + for _ in range(self.num_mask_tokens): + mlps_list += [SamHQFeedForward(self.hidden_size, self.hidden_size, self.hidden_size // 8, 3)] + self.output_hypernetworks_mlps = nn.ModuleList(mlps_list) + + self.iou_prediction_head = SamHQFeedForward( + self.hidden_size, config.iou_head_hidden_dim, self.num_mask_tokens, config.iou_head_depth + ) + + + self.hq_token = nn.Embedding(1, self.hidden_size) + self.hf_mlp = SamHQFeedForward(self.hidden_size, self.hidden_size, self.hidden_size // 8, 3) + self.num_mask_tokens = self.num_mask_tokens + 1 + + # Compress ViT features + self.compress_vit_conv1 = nn.ConvTranspose2d(config.vit_dim, self.hidden_size, kernel_size=2, stride=2) + self.compress_vit_norm = SamHQLayerNorm(self.hidden_size, data_format="channels_first") + self.compress_vit_conv2 = nn.ConvTranspose2d(self.hidden_size, self.hidden_size // 8, kernel_size=2, stride=2) + self.activation = nn.GELU() + + # Embedding encoder + self.encoder_conv1 = nn.ConvTranspose2d(self.hidden_size, self.hidden_size // 4, kernel_size=2, stride=2) + self.encoder_norm = SamHQLayerNorm(self.hidden_size // 4, data_format="channels_first") + self.encoder_conv2 = nn.ConvTranspose2d(self.hidden_size // 4, self.hidden_size // 8, kernel_size=2, stride=2) + + # Embedding mask feature + self.mask_conv1 = nn.Conv2d(self.hidden_size // 8, self.hidden_size // 4, kernel_size=3, stride=1, padding=1) + self.mask_norm = SamHQLayerNorm(self.hidden_size // 4, data_format="channels_first") + self.mask_conv2 = nn.Conv2d(self.hidden_size // 4, self.hidden_size // 8, kernel_size=3, stride=1, padding=1) + + + def forward( + self, + image_embeddings: torch.Tensor, + image_positional_embeddings: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + multimask_output: bool, + hq_token_only: bool, + interm_embeddings: torch.Tensor, + output_attentions: Optional[bool] = None, + attention_similarity: torch.Tensor = None, + target_embedding: torch.Tensor = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predict high-quality masks given image and prompt embeddings. + + Args: + image_embeddings (`torch.Tensor`): + The embeddings from the image encoder. + image_positional_embedding (`torch.Tensor`): + Positional encoding with the shape of image_embeddings. + sparse_prompt_embeddings (`torch.Tensor`): + The embeddings of the points and boxes. + dense_prompt_embeddings (`torch.Tensor`): + The embeddings of the mask inputs. + multimask_output (bool): + Whether to return multiple masks or a single mask. + hq_token_only (bool): + Whether to use only the high-quality token output or combine with SAM output. + interm_embeddings (`torch.Tensor`): + Intermediate embeddings from the vision encoder for feature fusion. + output_attentions (bool, *optional*): + Whether or not to return the attentions tensors of all attention layers. + attention_similarity (`torch.Tensor`, *optional*): + Optional tensor for attention similarity computation. + target_embedding (`torch.Tensor`, *optional*): + Optional target embedding for transformer processing. + + Returns: + `Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]`: A tuple of tensors containing: + - A tensor of shape `(batch_size, num_prompts, num_masks, height, width)` containing the output masks. + - A tensor of shape `(batch_size, num_prompts, num_masks)` containing the iou predictions for each mask. + - (Optional) A tuple containing attention tensors if output_attentions is True. + """ + batch_size, num_channels, height, width = image_embeddings.shape + point_batch_size = sparse_prompt_embeddings.shape[1] + vit_features = interm_embeddings[0].permute(0, 3, 1, 2) + + embed_encode = self.encoder_conv1(image_embeddings) + embed_encode = self.activation(self.encoder_norm(embed_encode)) + embed_encode = self.encoder_conv2(embed_encode) + + com_vit_feat = self.compress_vit_conv1(vit_features) + com_vit_feat = self.activation(self.compress_vit_norm(com_vit_feat)) + com_vit_feat = self.compress_vit_conv2(com_vit_feat) + + hq_features = embed_encode + com_vit_feat + + output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight, self.hq_token.weight], dim=0) + output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1) + + if sparse_prompt_embeddings.sum().item() != 0: + tokens = torch.cat([output_tokens, sparse_prompt_embeddings], dim=2) + else: + tokens = output_tokens + point_embeddings = tokens.to(self.iou_token.weight.dtype) + + + image_embeddings = image_embeddings + dense_prompt_embeddings + image_embeddings = image_embeddings.repeat_interleave(point_batch_size, 0) + image_positional_embeddings = image_positional_embeddings.repeat_interleave(point_batch_size, 0) + + point_embedding, image_embeddings, attentions = self.transformer( + point_embeddings=point_embeddings, + image_embeddings=image_embeddings, + image_positional_embeddings=image_positional_embeddings, + attention_similarity=attention_similarity, + target_embedding=target_embedding, + output_attentions=output_attentions, + ) + iou_token_out = point_embedding[:, :, 0, :] + mask_tokens_out = point_embedding[:, :, 1 : (1 + self.num_mask_tokens), :] + + image_embeddings = image_embeddings.transpose(2, 3).reshape( + batch_size * point_batch_size, num_channels, height, width + ) + + upscaled_embedding = self.upscale_conv1(image_embeddings) + upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding)) + upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding)) + + upscaled_embedding_hq = self.mask_conv1(upscaled_embedding) + upscaled_embedding_hq = self.activation(self.mask_norm(upscaled_embedding_hq)) + upscaled_embedding_hq = self.mask_conv2(upscaled_embedding_hq) + + upscaled_embedding_hq = upscaled_embedding_hq + hq_features.repeat(batch_size, 1, 1, 1) + + + hyper_in_list = [] + for i in range(self.num_mask_tokens): + if i < self.num_mask_tokens - 1: + current_mlp = self.output_hypernetworks_mlps[i] + hyper_in_list += [current_mlp(mask_tokens_out[:, :, i, :])] + else: + current_mlp = self.hf_mlp + hyper_in_list += [current_mlp(mask_tokens_out[:, :, i, :])] + + + + hyper_in = torch.stack(hyper_in_list, dim=2) + _, num_channels, height, width = upscaled_embedding.shape + upscaled_embedding = upscaled_embedding.reshape(batch_size, point_batch_size, num_channels, height * width) + upscaled_embedding_hq = upscaled_embedding_hq.reshape(batch_size, point_batch_size, num_channels, height * width) + + masks_sam = (hyper_in[:, :, :self.num_mask_tokens-1] @ upscaled_embedding).reshape(batch_size, point_batch_size, -1, height, width) + masks_hq = (hyper_in[:, :, self.num_mask_tokens-1:] @ upscaled_embedding_hq).reshape(batch_size, point_batch_size, -1, height, width) + masks = torch.cat([masks_sam, masks_hq], dim=2) + + + iou_pred = self.iou_prediction_head(iou_token_out) + + + + if multimask_output: + mask_slice = slice(1,self.num_mask_tokens-1) + iou_pred = iou_pred[:, :, mask_slice] + + + iou_pred, max_iou_idx = torch.max(iou_pred, dim=2) + iou_pred = iou_pred.unsqueeze(2) + + + masks_multi = masks[:, :, mask_slice, :, :] + batch_indices = torch.arange(masks_multi.size(0)).unsqueeze(1).expand(-1, masks_multi.size(1)) + point_indices = torch.arange(masks_multi.size(1)).unsqueeze(0).expand(masks_multi.size(0), -1) + masks_sam = masks_multi[batch_indices, point_indices, max_iou_idx].unsqueeze(2) + else: + mask_slice = slice(0, 1) + iou_pred = iou_pred[:, :, mask_slice] + masks_sam = masks[:, :, mask_slice, :, :] + masks_hq = masks[:, :, slice(self.num_mask_tokens-1, self.num_mask_tokens), :, :] + + + if hq_token_only: + masks = masks_hq + else: + masks = masks_sam + masks_hq + + + outputs = (masks, iou_pred) + if output_attentions: + outputs = outputs + (attentions,) + else: + outputs = outputs + (None,) + + return outputs + + +class SamHQPreTrainedModel(PreTrainedModel): + config_class = SamHQConfig + base_model_prefix = "sam_hq" + main_input_name = "pixel_values" + _no_split_modules = ["SamHQVisionAttention"] + + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, (nn.Linear, nn.Conv2d,nn.ConvTranspose2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module,nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +SAM_HQ_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`SamHQConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +SAM_HQ_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`SamHQProcessor`]. See [`SamHQProcessor.__call__`] for + details. + input_points (`torch.FloatTensor` of shape `(batch_size, num_points, 2)`): + Input 2D spatial points, this is used by the prompt encoder to encode the prompt. Generally yields to much + better results. The points can be obtained by passing a list of list of list to the processor that will + create corresponding `torch` tensors of dimension 4. The first dimension is the image batch size, the + second dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict + per input point), the third dimension is the number of points per segmentation mask (it is possible to pass + multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal) + coordinates of the point. If a different number of points is passed either for each image, or for each + mask, the processor will create "PAD" points that will correspond to the (0, 0) coordinate, and the + computation of the embedding will be skipped for these points using the labels. + input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points)`): + Input labels for the points, this is used by the prompt encoder to encode the prompt. According to the + official implementation, there are 3 types of labels: + + - `1`: the point is a point that contains the object of interest + - `0`: the point is a point that does not contain the object of interest + - `-1`: the point corresponds to the background + + We added the label: + + - `-10`: the point is a padding point, thus should be ignored by the prompt encoder + + The padding labels should be automatically done by the processor. + input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes, 4)`): + Input boxes for the points, this is used by the prompt encoder to encode the prompt. Generally yields to + much better generated masks. The boxes can be obtained by passing a list of list of list to the processor, + that will generate a `torch` tensor, with each dimension corresponding respectively to the image batch + size, the number of boxes per image and the coordinates of the top left and botton right point of the box. + In the order (`x1`, `y1`, `x2`, `y2`): + + - `x1`: the x coordinate of the top left point of the input box + - `y1`: the y coordinate of the top left point of the input box + - `x2`: the x coordinate of the bottom right point of the input box + - `y2`: the y coordinate of the bottom right point of the input box + + input_masks (`torch.FloatTensor` of shape `(batch_size, image_size, image_size)`): + SAM-HQ model also accepts segmentation masks as input. The mask will be embedded by the prompt encoder to + generate a corresponding embedding, that will be fed later on to the mask decoder. These masks needs to be + manually fed by the user, and they need to be of shape (`batch_size`, `image_size`, `image_size`). + + image_embeddings (`torch.FloatTensor` of shape `(batch_size, output_channels, window_size, window_size)`): + Image embeddings, this is used by the mask decoder to generate masks and iou scores. For more memory + efficient computation, users can first retrieve the image embeddings using the `get_image_embeddings` + method, and then feed them to the `forward` method instead of feeding the `pixel_values`. + multimask_output (`bool`, *optional*): + In the original implementation and paper, the model always outputs 3 masks per image (or per point / per + bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the + "best" mask, by specifying `multimask_output=False`. + hq_token_only (`bool`, *optional*, defaults to `False`): + Whether to use only the HQ token path for mask generation. When False, combines both standard and HQ paths. + This is specific to SAM-HQ's architecture. + attention_similarity (`torch.FloatTensor`, *optional*): + Attention similarity tensor, to be provided to the mask decoder for target-guided attention in case the + model is used for personalization as introduced in [PerSAM](https://arxiv.org/abs/2305.03048). + target_embedding (`torch.FloatTensor`, *optional*): + Embedding of the target concept, to be provided to the mask decoder for target-semantic prompting in case + the model is used for personalization as introduced in [PerSAM](https://arxiv.org/abs/2305.03048). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + interm_embeddings (`List[torch.FloatTensor]`, *optional*): + Intermediate embeddings from vision encoder's non-windowed blocks, used by SAM-HQ for enhanced mask quality. + Required when providing pre-computed image_embeddings instead of pixel_values. +""" + +@add_start_docstrings( + "Segment Anything Model HQ (SAM-HQ) for generating masks,given an input image and", + " optional 2D location and bounding boxes.", + SAM_HQ_START_DOCSTRING, +) +class SamHQModel(SamHQPreTrainedModel): + _tied_weights_keys = ["prompt_encoder.shared_embedding.positional_embedding"] + + + def __init__(self, config): + super().__init__(config) + self.shared_image_embedding = SamHQPositionalEmbedding(config.vision_config) + + self.vision_encoder = SamHQVisionEncoder(config.vision_config) + self.prompt_encoder = SamHQPromptEncoder(config.prompt_encoder_config, self.shared_image_embedding) + self.mask_decoder = SamHQMaskDecoder(config.mask_decoder_config) + + self.post_init() + + + def get_input_embeddings(self): + return self.vision_encoder.get_input_embeddings() + + def get_image_wide_positional_embeddings(self): + size = self.config.prompt_encoder_config.image_embeddings_size + target_device = self.shared_image_embedding.positional_embedding.device + target_dtype = self.shared_image_embedding.positional_embedding.dtype + grid = torch.ones((size, size), device=target_device, dtype=target_dtype) + y_embed = grid.cumsum(dim=0) - 0.5 + x_embed = grid.cumsum(dim=1) - 0.5 + y_embed = y_embed / size + x_embed = x_embed / size + + positional_embedding = self.shared_image_embedding(torch.stack([x_embed, y_embed], dim=-1)) + return positional_embedding.permute(2, 0, 1).unsqueeze(0) + + + + @torch.no_grad() + def get_image_embeddings( + self, + pixel_values, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + + ): + r""" + Returns the image embeddings by passing the pixel values through the vision encoder. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Input pixel values + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + + """ + vision_output = self.vision_encoder( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + + if return_dict: + image_embeddings = vision_output.last_hidden_state + intermediate_embeddings = vision_output.intermediate_embeddings + else: + image_embeddings = vision_output[0] + intermediate_embeddings = vision_output[1] + + return image_embeddings, intermediate_embeddings + + + @torch.no_grad() + def get_prompt_embeddings( + self, + input_points: Optional[torch.FloatTensor]= None, + input_labels: Optional[torch.LongTensor]= None, + input_boxes: Optional[torch.FloatTensor]= None, + input_masks: Optional[torch.LongTensor]= None, + ): + r""" + Returns the prompt embeddings by passing the input points, labels, boxes and masks through the prompt encoder. + + Args: + input_points (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_points_per_image, 2)`): + Optional input points for the prompt encoder. The padding of the point is automatically done by the + processor. `point_batch_size` refers to the number of masks that we want the model to predict per + point. The model will output `point_batch_size` times 3 masks in total. + input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points_per_image)`): + Optional input labels for the prompt encoder. The padding of the labels is automatically done by the + processor, or can be fed by the user. + input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes_per_image, 4)`): + Optional input boxes for the prompt encoder. The padding of the boxes is automatically done by the + processor. users can also pass manually the input boxes. + input_masks (`torch.LongTensor` of shape `(batch_size, image_size, image_size)`): + Optional input masks for the prompt encoder. + """ + prompt_output = self.prompt_encoder( + input_points=input_points, + input_labels=input_labels, + input_boxes=input_boxes, + input_masks=input_masks, + ) + return prompt_output + + + @add_start_docstrings_to_model_forward(SAM_HQ_INPUTS_DOCSTRING) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + input_points: Optional[torch.FloatTensor] = None, + input_labels: Optional[torch.LongTensor] = None, + input_boxes: Optional[torch.FloatTensor] = None, + input_masks: Optional[torch.LongTensor] = None, + image_embeddings: Optional[torch.FloatTensor] = None, + multimask_output: bool = True, + hq_token_only: bool = False, + attention_similarity: Optional[torch.FloatTensor] = None, + target_embedding: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + interm_embeddings: Optional[List[torch.FloatTensor]] = None, + **kwargs, + ) -> List[Dict[str, torch.Tensor]]: + r""" + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoModel, AutoProcessor + + >>> model = AutoModel.from_pretrained("sushmanth/sam_hq_vit_b") + >>> processor = AutoProcessor.from_pretrained("sushmanth/sam_hq_vit_b") + + >>> img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/sam-car.png" + >>> raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") + >>> input_points = [[[400, 650]]] # 2D location of a window on the car + >>> inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt") + + >>> # Get high-quality segmentation mask + >>> outputs = model(**inputs) + + >>> # For high-quality mask only + >>> outputs = model(**inputs, hq_token_only=True) + + >>> # Postprocess masks + >>> masks = processor.post_process_masks( + ... outputs.pred_masks, inputs["original_sizes"], inputs["reshaped_input_sizes"] + ... ) + ``` + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None and image_embeddings is None: + raise ValueError("Either pixel_values or image_embeddings must be provided.") + + if pixel_values is not None and image_embeddings is not None: + raise ValueError("Only one of pixel_values and image_embeddings can be provided.") + + if input_points is not None and len(input_points.shape) != 4: + raise ValueError( + "The input_points must be a 4D tensor. Of shape `batch_size`, `point_batch_size`, `nb_points_per_image`, `2`." + f" got {input_points.shape}." + ) + + if input_boxes is not None and len(input_boxes.shape) != 3: + raise ValueError( + "The input_boxes must be a 3D tensor. Of shape `batch_size`, `nb_boxes`, `4`." + f" got {input_boxes.shape}." + ) + + # Add validation for point and box batch sizes + if input_points is not None and input_boxes is not None: + point_batch_size = input_points.shape[1] + box_batch_size = input_boxes.shape[1] + if point_batch_size != box_batch_size: + raise ValueError( + "You should provide as many bounding boxes as input points per box. Got {} and {}.".format( + point_batch_size, box_batch_size + ) + ) + + image_positional_embeddings = self.get_image_wide_positional_embeddings() + # repeat with batch size + batch_size = pixel_values.shape[0] if pixel_values is not None else image_embeddings.shape[0] + image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1) + + vision_attentions = None + vision_hidden_states = None + + if pixel_values is not None: + vision_outputs = self.vision_encoder( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if return_dict: + image_embeddings = vision_outputs.last_hidden_state + interm_embeddings = vision_outputs.intermediate_embeddings + if output_hidden_states: + vision_hidden_states = vision_outputs.hidden_states + if output_attentions: + vision_attentions = vision_outputs.attentions + else: + image_embeddings = vision_outputs[0] + interm_embeddings = vision_outputs[1] + if output_hidden_states: + vision_hidden_states = vision_outputs[2] + if output_attentions: + vision_attentions = vision_outputs[-1] + + if input_points is not None and input_labels is None: + input_labels = torch.ones_like(input_points[:, :, :, 0], dtype=torch.int, device=input_points.device) + + sparse_embeddings, dense_embeddings = self.prompt_encoder( + input_points=input_points, + input_labels=input_labels, + input_boxes=input_boxes, + input_masks=input_masks, + ) + + # Predict masks + low_res_masks, iou_predictions, mask_decoder_attentions = self.mask_decoder( + image_embeddings=image_embeddings, + image_positional_embeddings=image_positional_embeddings, + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + hq_token_only=hq_token_only, + interm_embeddings=interm_embeddings, + attention_similarity=attention_similarity, + target_embedding=target_embedding, + output_attentions=output_attentions, + ) + + if not return_dict: + output = (iou_predictions, low_res_masks) + if output_hidden_states: + output = output + (vision_hidden_states,) + + if output_attentions: + output = output + (vision_attentions, mask_decoder_attentions) + return output + + return SamHQImageSegmentationOutput( + iou_scores=iou_predictions, + pred_masks=low_res_masks, + vision_hidden_states=vision_hidden_states, + vision_attentions=vision_attentions, + mask_decoder_attentions=mask_decoder_attentions, + ) \ No newline at end of file diff --git a/src/transformers/models/sam_hq/processing_samhq.py b/src/transformers/models/sam_hq/processing_samhq.py new file mode 100644 index 00000000000000..852fbe45bc7af1 --- /dev/null +++ b/src/transformers/models/sam_hq/processing_samhq.py @@ -0,0 +1,267 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processor class for SAMHQ +""" + + +from copy import deepcopy +from typing import Optional, Union + +import numpy as np + +from ...processing_utils import ProcessorMixin +from ...tokenization_utils_base import BatchEncoding +from ...utils import TensorType, is_torch_available + + + +if is_torch_available(): + import torch + +#Copied from transformer.models.sam.processing_samhq.SamProcessor with Sam->SamHQ +class SamHQProcessor(ProcessorMixin): + r""" + Constructs a SAM processor which wraps a SAM image processor and an 2D points & Bounding boxes processor into a + single processor. + + [`SamHQProcessor`] offers all the functionalities of [`SamImageProcessor`]. See the docstring of + [`~SamImageProcessor.__call__`] for more information. + + Args: + image_processor (`SamImageProcessor`): + An instance of [`SamImageProcessor`]. The image processor is a required input. + """ + attributes = ["image_processor"] + image_processor_class = "SamImageProcessor" + + def __init__(self, image_processor): + super().__init__(image_processor) + self.current_processor = self.image_processor + self.point_pad_value = -10 + self.target_size = self.image_processor.size["longest_edge"] + + + + def __call__( + self, + images=None, + segmentation_maps=None, + input_points=None, + input_labels=None, + input_boxes=None, + return_tensors: Optional[Union[str, TensorType]] = None, + **kwargs, + ) -> BatchEncoding: + """ + This method uses [`SamImageProcessor.__call__`] method to prepare image(s) for the model. It also prepares 2D + points and bounding boxes for the model if they are provided. + """ + encoding_image_processor = self.image_processor( + images, + segmentation_maps=segmentation_maps, + return_tensors=return_tensors, + **kwargs, + ) + + + # pop arguments that are not used in the foward but used nevertheless + original_sizes = encoding_image_processor["original_sizes"] + + if hasattr(original_sizes, "numpy"): # Checks if Torch or TF tensor + original_sizes = original_sizes.numpy() + + input_points, input_labels, input_boxes = self._check_and_preprocess_points( + input_points=input_points, + input_labels=input_labels, + input_boxes=input_boxes, + ) + + encoding_image_processor = self._normalize_and_convert( + encoding_image_processor, + original_sizes, + input_points=input_points, + input_labels=input_labels, + input_boxes=input_boxes, + return_tensors=return_tensors, + ) + + return encoding_image_processor + + + + def _normalize_and_convert( + self, + encoding_image_processor, + original_sizes, + input_points=None, + input_labels=None, + input_boxes=None, + return_tensors="pt", + ): + if input_points is not None: + if len(original_sizes) != len(input_points): + input_points = [ + self._normalize_coordinates(self.target_size, point, original_sizes[0]) for point in input_points + ] + else: + input_points = [ + self._normalize_coordinates(self.target_size, point, original_size) + for point, original_size in zip(input_points, original_sizes) + ] + # check that all arrays have the same shape + if not all(point.shape == input_points[0].shape for point in input_points): + if input_labels is not None: + input_points, input_labels = self._pad_points_and_labels(input_points, input_labels) + + input_points = np.array(input_points) + + if input_labels is not None: + input_labels = np.array(input_labels) + + if input_boxes is not None: + if len(original_sizes) != len(input_boxes): + input_boxes = [ + self._normalize_coordinates(self.target_size, box, original_sizes[0], is_bounding_box=True) + for box in input_boxes + ] + else: + input_boxes = [ + self._normalize_coordinates(self.target_size, box, original_size, is_bounding_box=True) + for box, original_size in zip(input_boxes, original_sizes) + ] + input_boxes = np.array(input_boxes) + + + + if input_boxes is not None: + if return_tensors == "pt": + input_boxes = torch.from_numpy(input_boxes) + input_boxes = input_boxes.unsqueeze(1) if len(input_boxes.shape) != 3 else input_boxes + encoding_image_processor.update({"input_boxes": input_boxes}) + if input_points is not None: + if return_tensors == "pt": + input_points = torch.from_numpy(input_points) + input_points = input_points.unsqueeze(1) if len(input_points.shape) != 4 else input_points + encoding_image_processor.update({"input_points": input_points}) + if input_labels is not None: + if return_tensors == "pt": + input_labels = torch.from_numpy(input_labels) + input_labels = input_labels.unsqueeze(1) if len(input_labels.shape) != 3 else input_labels + encoding_image_processor.update({"input_labels": input_labels}) + + return encoding_image_processor + + + def _pad_points_and_labels(self, input_points, input_labels): + r""" + The method pads the 2D points and labels to the maximum number of points in the batch. + """ + expected_nb_points = max([point.shape[0] for point in input_points]) + processed_input_points = [] + for i, point in enumerate(input_points): + if point.shape[0] != expected_nb_points: + point = np.concatenate( + [point, np.zeros((expected_nb_points - point.shape[0], 2)) + self.point_pad_value], axis=0 + ) + input_labels[i] = np.append(input_labels[i], [self.point_pad_value]) + processed_input_points.append(point) + input_points = processed_input_points + return input_points, input_labels + + def _normalize_coordinates( + self, target_size: int, coords: np.ndarray, original_size, is_bounding_box=False + ) -> np.ndarray: + """ + Expects a numpy array of length 2 in the final dimension. Requires the original image size in (H, W) format. + """ + old_h, old_w = original_size + new_h, new_w = self.image_processor._get_preprocess_shape(original_size, longest_edge=target_size) + coords = deepcopy(coords).astype(float) + + if is_bounding_box: + coords = coords.reshape(-1, 2, 2) + + coords[..., 0] = coords[..., 0] * (new_w / old_w) + coords[..., 1] = coords[..., 1] * (new_h / old_h) + + if is_bounding_box: + coords = coords.reshape(-1, 4) + + return coords + + def _check_and_preprocess_points( + self, + input_points=None, + input_labels=None, + input_boxes=None, + ): + r""" + Check and preprocesses the 2D points, labels and bounding boxes. It checks if the input is valid and if they + are, it converts the coordinates of the points and bounding boxes. If a user passes directly a `torch.Tensor`, + it is converted to a `numpy.ndarray` and then to a `list`. + """ + if input_points is not None: + if hasattr(input_points, "numpy"): # Checks for TF or Torch tensor + input_points = input_points.numpy().tolist() + + if not isinstance(input_points, list) or not isinstance(input_points[0], list): + raise ValueError("Input points must be a list of list of floating points.") + input_points = [np.array(input_point) for input_point in input_points] + else: + input_points = None + + if input_labels is not None: + if hasattr(input_labels, "numpy"): + input_labels = input_labels.numpy().tolist() + + if not isinstance(input_labels, list) or not isinstance(input_labels[0], list): + raise ValueError("Input labels must be a list of list integers.") + input_labels = [np.array(label) for label in input_labels] + else: + input_labels = None + + if input_boxes is not None: + if hasattr(input_boxes, "numpy"): + input_boxes = input_boxes.numpy().tolist() + + if ( + not isinstance(input_boxes, list) + or not isinstance(input_boxes[0], list) + or not isinstance(input_boxes[0][0], list) + ): + raise ValueError("Input boxes must be a list of list of list of floating points.") + input_boxes = [np.array(box).astype(np.float32) for box in input_boxes] + else: + input_boxes = None + + return input_points, input_labels, input_boxes + + @property + def model_input_names(self): + image_processor_input_names = self.image_processor.model_input_names + return list(dict.fromkeys(image_processor_input_names)) + + def post_process_masks(self, *args, **kwargs): + return self.image_processor.post_process_masks(*args, **kwargs) + + + + + + + + + \ No newline at end of file diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 1238f058783c18..53b3920462d9cb 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -8166,6 +8166,19 @@ class SamPreTrainedModel(metaclass=DummyObject): def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class SamHQModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + +class SamHQPreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class SeamlessM4TCodeHifiGan(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/models/fixtures/test_samples/COCO/000000039769.png b/tests/models/fixtures/test_samples/COCO/000000039769.png new file mode 100644 index 00000000000000..a3b5225fc3cef5 Binary files /dev/null and b/tests/models/fixtures/test_samples/COCO/000000039769.png differ diff --git a/tests/models/sam_hq/__init__.py b/tests/models/sam_hq/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/tests/models/sam_hq/test_modeling_sam_hq.py b/tests/models/sam_hq/test_modeling_sam_hq.py new file mode 100644 index 00000000000000..557ce3a7a04cd7 --- /dev/null +++ b/tests/models/sam_hq/test_modeling_sam_hq.py @@ -0,0 +1,705 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch SAM-HQ model.""" + +import unittest + +import requests + +from transformers import SamHQConfig, SamHQMaskDecoderConfig, SamHQPromptEncoderConfig, SamHQVisionConfig, pipeline +from transformers.testing_utils import require_torch, slow, torch_device +from transformers.utils import is_torch_available, is_vision_available + +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, floats_tensor +from ...test_pipeline_mixin import PipelineTesterMixin + +if is_torch_available(): + import torch + from torch import nn + from transformers import SamHQModel, SamHQProcessor + +if is_vision_available(): + from PIL import Image + +class SamHQPromptEncoderTester: + def __init__( + self, + hidden_size=32, + input_image_size=24, + patch_size=2, + mask_input_channels=4, + num_point_embeddings=4, + hidden_act="gelu", + ): + self.hidden_size = hidden_size + self.input_image_size = input_image_size + self.patch_size = patch_size + self.mask_input_channels = mask_input_channels + self.num_point_embeddings = num_point_embeddings + self.hidden_act = hidden_act + + def get_config(self): + return SamHQPromptEncoderConfig( + image_size=self.input_image_size, + patch_size=self.patch_size, + mask_input_channels=self.mask_input_channels, + hidden_size=self.hidden_size, + num_point_embeddings=self.num_point_embeddings, + hidden_act=self.hidden_act, + ) + + def prepare_config_and_inputs(self): + dummy_points = floats_tensor([self.batch_size, 3, 2]) + config = self.get_config() + return config, dummy_points + +class SamHQMaskDecoderTester: + def __init__( + self, + hidden_size=32, + hidden_act="relu", + mlp_dim=64, + num_hidden_layers=2, + num_attention_heads=4, + attention_downsample_rate=2, + num_multimask_outputs=3, + iou_head_depth=3, + iou_head_hidden_dim=32, + layer_norm_eps=1e-6, + vit_dim=384, # Added for SAM HQ + ): + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.mlp_dim = mlp_dim + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.attention_downsample_rate = attention_downsample_rate + self.num_multimask_outputs = num_multimask_outputs + self.iou_head_depth = iou_head_depth + self.iou_head_hidden_dim = iou_head_hidden_dim + self.layer_norm_eps = layer_norm_eps + self.vit_dim = vit_dim + + def get_config(self): + return SamHQMaskDecoderConfig( + hidden_size=self.hidden_size, + hidden_act=self.hidden_act, + mlp_dim=self.mlp_dim, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + attention_downsample_rate=self.attention_downsample_rate, + num_multimask_outputs=self.num_multimask_outputs, + iou_head_depth=self.iou_head_depth, + iou_head_hidden_dim=self.iou_head_hidden_dim, + layer_norm_eps=self.layer_norm_eps, + vit_dim=self.vit_dim, + ) + + def prepare_config_and_inputs(self): + config = self.get_config() + dummy_inputs = { + "image_embedding": floats_tensor([self.batch_size, self.hidden_size]), + } + return config, dummy_inputs + +class SamHQModelTester: + def __init__( + self, + parent, + hidden_size=36, + intermediate_size=72, + projection_dim=62, + output_channels=32, + num_hidden_layers=2, + num_attention_heads=4, + num_channels=3, + image_size=24, + patch_size=2, + hidden_act="gelu", + layer_norm_eps=1e-06, + dropout=0.0, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + qkv_bias=True, + mlp_ratio=4.0, + use_abs_pos=True, + use_rel_pos=True, + rel_pos_zero_init=False, + window_size=14, + global_attn_indexes=[2, 5, 8, 11], + num_pos_feats=16, + mlp_dim=None, + batch_size=2, + vit_dim=384, # Added for SAM HQ + ): + self.parent = parent + self.image_size = image_size + self.patch_size = patch_size + self.output_channels = output_channels + self.num_channels = num_channels + self.hidden_size = hidden_size + self.projection_dim = projection_dim + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.dropout = dropout + self.attention_dropout = attention_dropout + self.initializer_range = initializer_range + self.initializer_factor = initializer_factor + self.hidden_act = hidden_act + self.layer_norm_eps = layer_norm_eps + self.qkv_bias = qkv_bias + self.mlp_ratio = mlp_ratio + self.use_abs_pos = use_abs_pos + self.use_rel_pos = use_rel_pos + self.rel_pos_zero_init = rel_pos_zero_init + self.window_size = window_size + self.global_attn_indexes = global_attn_indexes + self.num_pos_feats = num_pos_feats + self.mlp_dim = mlp_dim + self.batch_size = batch_size + self.vit_dim = vit_dim + + # in ViT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token) + num_patches = (image_size // patch_size) ** 2 + self.seq_length = num_patches + 1 + + self.prompt_encoder_tester = SamHQPromptEncoderTester() + self.mask_decoder_tester = SamHQMaskDecoderTester() + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + config = self.get_config() + return config, pixel_values + + def get_config(self): + vision_config = SamHQVisionConfig( + image_size=self.image_size, + patch_size=self.patch_size, + num_channels=self.num_channels, + hidden_size=self.hidden_size, + projection_dim=self.projection_dim, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + dropout=self.dropout, + attention_dropout=self.attention_dropout, + initializer_range=self.initializer_range, + initializer_factor=self.initializer_factor, + output_channels=self.output_channels, + qkv_bias=self.qkv_bias, + mlp_ratio=self.mlp_ratio, + use_abs_pos=self.use_abs_pos, + use_rel_pos=self.use_rel_pos, + rel_pos_zero_init=self.rel_pos_zero_init, + window_size=self.window_size, + global_attn_indexes=self.global_attn_indexes, + num_pos_feats=self.num_pos_feats, + mlp_dim=self.mlp_dim, + ) + + prompt_encoder_config = self.prompt_encoder_tester.get_config() + mask_decoder_config = self.mask_decoder_tester.get_config() + + return SamHQConfig( + vision_config=vision_config, + prompt_encoder_config=prompt_encoder_config, + mask_decoder_config=mask_decoder_config, + ) + + def create_and_check_model(self, config, pixel_values): + model = SamHQModel(config=config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + result = model(pixel_values) + self.parent.assertEqual(result.iou_scores.shape, (self.batch_size, 1, 3)) + self.parent.assertEqual(result.pred_masks.shape[:3], (self.batch_size, 1, 3)) + + def create_and_check_get_image_features(self, config, pixel_values): + model = SamHQModel(config=config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + image_embeddings, interm_embeddings = model.get_image_embeddings(pixel_values) + self.parent.assertEqual(image_embeddings.shape[-2:], (12, 12)) + self.parent.assertIsNotNone(interm_embeddings) + + def create_and_check_get_image_hidden_states(self, config, pixel_values): + model = SamHQModel(config=config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model.vision_encoder( + pixel_values, + output_hidden_states=True, + return_dict=True, + ) + + # after computing the convolutional features + expected_hidden_states_shape = (self.batch_size, 12, 12, 36) + self.parent.assertEqual(len(outputs[1]), self.num_hidden_layers + 1) + self.parent.assertEqual(outputs[1][0].shape, expected_hidden_states_shape) + + with torch.no_grad(): + outputs = model.vision_encoder( + pixel_values, + output_hidden_states=True, + return_dict=False, + ) + + # after computing the convolutional features + self.parent.assertEqual(len(outputs[1]), self.num_hidden_layers + 1) + self.parent.assertEqual(outputs[1][0].shape, expected_hidden_states_shape) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values = config_and_inputs + inputs_dict = {"pixel_values": pixel_values} + return config, inputs_dict + +@require_torch +class SamHQModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): + """ + Here we also overwrite some of the tests of test_modeling_common.py, as SAM-HQ's vision encoder does not use input_ids, inputs_embeds, + attention_mask and seq_length. + """ + + all_model_classes = (SamHQModel,) if is_torch_available() else () + pipeline_model_mapping = ( + {"feature-extraction": SamHQModel, "mask-generation": SamHQModel} if is_torch_available() else {} + ) + fx_compatible = False + test_pruning = False + test_resize_embeddings = False + test_head_masking = False + test_torchscript = False + + def is_pipeline_test_to_skip( + self, + pipeline_test_case_name, + config_class, + model_architecture, + tokenizer_name, + image_processor_name, + feature_extractor_name, + processor_name, + ): + return True + + def setUp(self): + self.model_tester = SamHQModelTester(self) + self.vision_config_tester = ConfigTester(self, config_class=SamHQVisionConfig, has_text_modality=False) + self.prompt_encoder_config_tester = ConfigTester( + self, + config_class=SamHQPromptEncoderConfig, + has_text_modality=False, + num_attention_heads=12, + num_hidden_layers=2, + ) + self.mask_decoder_config_tester = ConfigTester( + self, config_class=SamHQMaskDecoderConfig, has_text_modality=False + ) + + def test_config(self): + self.vision_config_tester.run_common_tests() + self.prompt_encoder_config_tester.run_common_tests() + self.mask_decoder_config_tester.run_common_tests() + + @unittest.skip(reason="SAM-HQ's vision encoder does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + def test_model_common_attributes(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + self.assertIsInstance(model.get_input_embeddings(), (nn.Module)) + x = model.get_output_embeddings() + self.assertTrue(x is None or isinstance(x, nn.Linear)) + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_get_image_features(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_get_image_features(*config_and_inputs) + + def test_image_hidden_states(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_get_image_hidden_states(*config_and_inputs) + + def test_attention_outputs(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + expected_vision_attention_shape = ( + self.model_tester.batch_size * self.model_tester.num_attention_heads, + 196, + 196, + ) + expected_mask_decoder_attention_shape = (self.model_tester.batch_size, 1, 144, 32) + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False + config.return_dict = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + vision_attentions = outputs.vision_attentions + self.assertEqual(len(vision_attentions), self.model_tester.num_hidden_layers) + + mask_decoder_attentions = outputs.mask_decoder_attentions + self.assertEqual(len(mask_decoder_attentions), self.model_tester.mask_decoder_tester.num_hidden_layers) + + # check that output_attentions also work using config + del inputs_dict["output_attentions"] + config.output_attentions = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + vision_attentions = outputs.vision_attentions + self.assertEqual(len(vision_attentions), self.model_tester.num_hidden_layers) + + mask_decoder_attentions = outputs.mask_decoder_attentions + self.assertEqual(len(mask_decoder_attentions), self.model_tester.mask_decoder_tester.num_hidden_layers) + + self.assertListEqual( + list(vision_attentions[0].shape[-4:]), + list(expected_vision_attention_shape), + ) + + self.assertListEqual( + list(mask_decoder_attentions[0].shape[-4:]), + list(expected_mask_decoder_attention_shape), + ) + + @unittest.skip(reason="SamHQModel does not support training") + def test_training(self): + pass + + @unittest.skip(reason="SamHQModel does not support training") + def test_training_gradient_checkpointing(self): + pass + + @unittest.skip( + reason="This architecture seems to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" + ) + def test_training_gradient_checkpointing_use_reentrant(self): + pass + + @unittest.skip( + reason="This architecture seems to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" + ) + def test_training_gradient_checkpointing_use_reentrant_false(self): + pass + + @unittest.skip(reason="SamHQModel has no base class and is not available in MODEL_MAPPING") + def test_save_load_fast_init_from_base(self): + pass + + @unittest.skip(reason="SamHQModel has no base class and is not available in MODEL_MAPPING") + def test_save_load_fast_init_to_base(self): + pass + + @unittest.skip(reason="SamHQModel does not support training") + def test_retain_grad_hidden_states_attentions(self): + pass + + @unittest.skip(reason="Hidden_states is tested in create_and_check_model tests") + def test_hidden_states_output(self): + pass + + def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=5e-5, name="outputs", attributes=None): + # Use a slightly higher default tol to make the tests non-flaky + super().check_pt_tf_outputs(tf_outputs, pt_outputs, model_class, tol=tol, name=name, attributes=attributes) + + @slow + def test_model_from_pretrained(self): + model = SamHQModel.from_pretrained("Uminosachi/sam-hq-vit-huge") + self.assertIsNotNone(model) + +def prepare_image(): + img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png" + raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") + return raw_image + +def prepare_dog_img(): + img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/dog-sam.png" + raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") + return raw_image + +@slow +class SamHQModelIntegrationTest(unittest.TestCase): + def test_inference_mask_generation_one_point_one_bb_zero(self): + model = SamHQModel.from_pretrained("sushmanth/sam_hq_vit_b") + processor = SamHQProcessor.from_pretrained("sushmanth/sam_hq_vit_b") + + model.to(torch_device) + model.eval() + + raw_image = prepare_image() + input_boxes = [[[620, 900, 1000, 1255]]] + input_points = [[[820, 1080]]] + labels = [[0]] + + inputs = processor( + images=raw_image, + input_boxes=input_boxes, + input_points=input_points, + input_labels=labels, + return_tensors="pt", + ).to(torch_device) + + # Test standard path + with torch.no_grad(): + outputs = model(**inputs) + outputs_hq = model(**inputs, hq_token_only=True) + + scores = outputs.iou_scores.squeeze() + self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.7894), atol=1e-4)) + self.assertIsNotNone(outputs_hq.pred_masks) + + def test_inference_mask_generation_one_point(self): + model = SamHQModel.from_pretrained("sushmanth/sam_hq_vit_b") + processor = SamHQProcessor.from_pretrained("sushmanth/sam_hq_vit_b") + + model.to(torch_device) + model.eval() + + raw_image = prepare_image() + + input_points = [[[400, 650]]] + input_labels = [[1]] + + inputs = processor( + images=raw_image, input_points=input_points, input_labels=input_labels, return_tensors="pt" + ).to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs) + outputs_hq = model(**inputs, hq_token_only=True) + + scores = outputs.iou_scores.squeeze() + self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9675), atol=1e-4)) + self.assertIsNotNone(outputs_hq.pred_masks) + + # With no label + input_points = [[[400, 650]]] + + inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt").to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs) + outputs_hq = model(**inputs, hq_token_only=True) + + scores = outputs.iou_scores.squeeze() + self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9675), atol=1e-4)) + self.assertIsNotNone(outputs_hq.pred_masks) + + def test_inference_mask_generation_two_points(self): + model = SamHQModel.from_pretrained("sushmanth/sam_hq_vit_b") + processor = SamHQProcessor.from_pretrained("sushmanth/sam_hq_vit_b") + + model.to(torch_device) + model.eval() + + raw_image = prepare_image() + + input_points = [[[400, 650], [800, 650]]] + input_labels = [[1, 1]] + + inputs = processor( + images=raw_image, input_points=input_points, input_labels=input_labels, return_tensors="pt" + ).to(torch_device) + + # Test both paths + with torch.no_grad(): + outputs = model(**inputs) + outputs_hq = model(**inputs, hq_token_only=True) + + scores = outputs.iou_scores.squeeze() + self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9762), atol=1e-4)) + self.assertIsNotNone(outputs_hq.pred_masks) + + # no labels + inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt").to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs) + outputs_hq = model(**inputs, hq_token_only=True) + + scores = outputs.iou_scores.squeeze() + self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9762), atol=1e-4)) + self.assertIsNotNone(outputs_hq.pred_masks) + + def test_inference_mask_generation_two_points_batched(self): + model = SamHQModel.from_pretrained("sushmanth/sam_hq_vit_b") + processor = SamHQProcessor.from_pretrained("sushmanth/sam_hq_vit_b") + + model.to(torch_device) + model.eval() + + raw_image = prepare_image() + + input_points = [[[400, 650], [800, 650]], [[400, 650]]] + input_labels = [[1, 1], [1]] + + inputs = processor( + images=[raw_image, raw_image], input_points=input_points, input_labels=input_labels, return_tensors="pt" + ).to(torch_device) + + # Test both paths + with torch.no_grad(): + outputs = model(**inputs) + outputs_hq = model(**inputs, hq_token_only=True) + + scores = outputs.iou_scores.squeeze() + self.assertTrue(torch.allclose(scores[0][-1], torch.tensor(0.9762), atol=1e-4)) + self.assertTrue(torch.allclose(scores[1][-1], torch.tensor(0.9637), atol=1e-4)) + self.assertIsNotNone(outputs_hq.pred_masks) + + def test_inference_mask_generation_one_box(self): + model = SamHQModel.from_pretrained("sushmanth/sam_hq_vit_b") + processor = SamHQProcessor.from_pretrained("sushmanth/sam_hq_vit_b") + + model.to(torch_device) + model.eval() + + raw_image = prepare_image() + + input_boxes = [[[75, 275, 1725, 850]]] + + inputs = processor(images=raw_image, input_boxes=input_boxes, return_tensors="pt").to(torch_device) + + # Test both paths + with torch.no_grad(): + outputs = model(**inputs) + outputs_hq = model(**inputs, hq_token_only=True) + + scores = outputs.iou_scores.squeeze() + self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.7937), atol=1e-4)) + self.assertIsNotNone(outputs_hq.pred_masks) + + def test_inference_mask_generation_batched_image_one_point(self): + model = SamHQModel.from_pretrained("sushmanth/sam_hq_vit_b") + processor = SamHQProcessor.from_pretrained("sushmanth/sam_hq_vit_b") + + model.to(torch_device) + model.eval() + + raw_image = prepare_image() + raw_dog_image = prepare_dog_img() + + input_points = [[[820, 1080]], [[220, 470]]] + + inputs = processor(images=[raw_image, raw_dog_image], input_points=input_points, return_tensors="pt").to( + torch_device + ) + + # Test both paths + with torch.no_grad(): + outputs = model(**inputs) + outputs_hq = model(**inputs, hq_token_only=True) + + scores_batched = outputs.iou_scores.squeeze() + self.assertIsNotNone(outputs_hq.pred_masks) + + input_points = [[[220, 470]]] + inputs = processor(images=raw_dog_image, input_points=input_points, return_tensors="pt").to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs) + outputs_hq = model(**inputs, hq_token_only=True) + + scores_single = outputs.iou_scores.squeeze() + self.assertTrue(torch.allclose(scores_batched[1, :], scores_single, atol=1e-4)) + self.assertIsNotNone(outputs_hq.pred_masks) + + def test_inference_mask_generation_two_points_point_batch(self): + model = SamHQModel.from_pretrained("sushmanth/sam_hq_vit_b") + processor = SamHQProcessor.from_pretrained("sushmanth/sam_hq_vit_b") + + model.to(torch_device) + model.eval() + + raw_image = prepare_image() + + input_points = torch.Tensor([[[400, 650]], [[220, 470]]]).cpu() # fmt: skip + + input_points = input_points.unsqueeze(0) + + inputs = processor(raw_image, input_points=input_points, return_tensors="pt").to(torch_device) + + # Test both paths + with torch.no_grad(): + outputs = model(**inputs) + outputs_hq = model(**inputs, hq_token_only=True) + + iou_scores = outputs.iou_scores.cpu() + self.assertTrue(iou_scores.shape == (1, 2, 3)) + torch.testing.assert_close( + iou_scores, torch.tensor([[[0.9105, 0.9825, 0.9675], [0.7646, 0.7943, 0.7774]]]), atol=1e-4, rtol=1e-4 + ) + self.assertIsNotNone(outputs_hq.pred_masks) + + def test_inference_mask_generation_three_boxes_point_batch(self): + model = SamHQModel.from_pretrained("sushmanth/sam_hq_vit_b") + processor = SamHQProcessor.from_pretrained("sushmanth/sam_hq_vit_b") + + model.to(torch_device) + model.eval() + + raw_image = prepare_image() + + # fmt: off + input_boxes = torch.Tensor([[[620, 900, 1000, 1255]], [[75, 275, 1725, 850]], [[75, 275, 1725, 850]]]).cpu() + EXPECTED_IOU = torch.tensor([[[0.9773, 0.9881, 0.9522], + [0.5996, 0.7661, 0.7937], + [0.5996, 0.7661, 0.7937]]]) + # fmt: on + input_boxes = input_boxes.unsqueeze(0) + + inputs = processor(raw_image, input_boxes=input_boxes, return_tensors="pt").to(torch_device) + + # Test both paths + with torch.no_grad(): + outputs = model(**inputs) + outputs_hq = model(**inputs, hq_token_only=True) + + iou_scores = outputs.iou_scores.cpu() + self.assertTrue(iou_scores.shape == (1, 3, 3)) + torch.testing.assert_close(iou_scores, EXPECTED_IOU, atol=1e-4, rtol=1e-4) + self.assertIsNotNone(outputs_hq.pred_masks) + + def test_dummy_pipeline_generation(self): + generator = pipeline("mask-generation", model="sushmanth/sam_hq_vit_b", device=torch_device) + raw_image = prepare_image() + + # Test standard pipeline + _ = generator(raw_image, points_per_batch=64) + + # Test HQ-only pipeline + _ = generator(raw_image, points_per_batch=64, hq_token_only=True) \ No newline at end of file diff --git a/tests/pipelines/test_pipelines_image_to_image.py b/tests/pipelines/test_pipelines_image_to_image.py index 29d590a8e34c4d..2764c9276cd0b3 100644 --- a/tests/pipelines/test_pipelines_image_to_image.py +++ b/tests/pipelines/test_pipelines_image_to_image.py @@ -47,7 +47,7 @@ def open(*args, **kwargs): class ImageToImagePipelineTests(unittest.TestCase): model_mapping = MODEL_FOR_IMAGE_TO_IMAGE_MAPPING examples = [ - Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png"), + Image.open("./fixtures/test_samples/COCO/000000039769.png"), "http://images.cocodataset.org/val2017/000000039769.jpg", ] diff --git a/utils/check_copies.py b/utils/check_copies.py index ed6a4f68b577f0..762a8d440562a4 100644 --- a/utils/check_copies.py +++ b/utils/check_copies.py @@ -34,7 +34,7 @@ ``` for a check that will fix all inconsistencies automatically (used by `make fix-copies`). -""" +""" import argparse import glob @@ -1072,6 +1072,7 @@ def _find_text_in_file(filename: str, start_prompt: str, end_prompt: str) -> Tup "OpenAI GPT": "GPT", "Perceiver": "Perceiver IO", "SAM": "Segment Anything", + "SAM_HQ": "Segment Anything High Quality", "ViT": "Vision Transformer (ViT)", } diff --git a/utils/check_repo.py b/utils/check_repo.py index 10be5cdcd26230..bd0b6942195f71 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -212,6 +212,7 @@ "JukeboxVQVAE", "JukeboxPrior", "SamModel", + "SamHQModel" "DPTForDepthEstimation", "DecisionTransformerGPT2Model", "GLPNForDepthEstimation",