Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Samhq model addition #35147

Draft
wants to merge 27 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
a284572
added the configuartion for sam_hq
sushmanthreddy Dec 8, 2024
2a3caa2
added the modeelling for sam_hq
sushmanthreddy Dec 8, 2024
92f291a
added the sam hq mask decoder with hq features
sushmanthreddy Dec 8, 2024
395a5a5
added the code for the samhq
sushmanthreddy Dec 8, 2024
091da86
added the code for the samhq
sushmanthreddy Dec 8, 2024
2fa1ac4
added the code for the samhq
sushmanthreddy Dec 8, 2024
7453aad
Delete src/transformers/models/sam_hq/modelling_sam_hq.py
sushmanthreddy Dec 8, 2024
419575a
added the code for the samhq
sushmanthreddy Dec 8, 2024
138c1f3
added the code for the samhq
sushmanthreddy Dec 8, 2024
0579d8a
added the chnages for the modeelling
sushmanthreddy Dec 12, 2024
c7c0e81
added the code for sam hq for image processing
sushmanthreddy Dec 12, 2024
2c8e5d9
added code for the sam hq model
sushmanthreddy Dec 14, 2024
2d836c4
added the required changes
sushmanthreddy Dec 17, 2024
dc8b7e3
added the changes
sushmanthreddy Dec 17, 2024
6c454bc
added the key mappings for the sam hq
sushmanthreddy Dec 18, 2024
d109255
adding the working code of samhq
sushmanthreddy Dec 19, 2024
682cf0b
added the required files
sushmanthreddy Dec 19, 2024
f2564ef
adding the pt object
sushmanthreddy Dec 19, 2024
896eb7c
added the push to hub account
sushmanthreddy Dec 19, 2024
0869e0a
added the args for the sam maks decoder
sushmanthreddy Dec 20, 2024
f8b8c30
added the args for the sam hq vision config
sushmanthreddy Dec 20, 2024
7edc5a5
aded the some more documentation
sushmanthreddy Dec 20, 2024
4304188
removed the unecessary spaces
sushmanthreddy Dec 20, 2024
395956b
all required chnages
sushmanthreddy Dec 20, 2024
f3d63ba
removed the image processor
sushmanthreddy Dec 20, 2024
89aaeae
added the required file
sushmanthreddy Dec 20, 2024
c27c3b7
added the changes for the checkcopies
sushmanthreddy Dec 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -902,6 +902,8 @@
title: Qwen2VL
- local: model_doc/sam
title: Segment Anything
- local: model_doc/sam_hq
title: Segment Anything High Quality
- local: model_doc/siglip
title: SigLIP
- local: model_doc/speech-encoder-decoder
Expand Down
1 change: 1 addition & 0 deletions docs/source/en/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,7 @@ Flax), PyTorch, and/or TensorFlow.
| [RT-DETR-ResNet](model_doc/rt_detr_resnet) | ✅ | ❌ | ❌ |
| [RWKV](model_doc/rwkv) | ✅ | ❌ | ❌ |
| [SAM](model_doc/sam) | ✅ | ✅ | ❌ |
| [SAM_HQ](model_doc/sam_hq) | ✅ | ❌ | ❌ |
| [SeamlessM4T](model_doc/seamless_m4t) | ✅ | ❌ | ❌ |
| [SeamlessM4Tv2](model_doc/seamless_m4t_v2) | ✅ | ❌ | ❌ |
| [SegFormer](model_doc/segformer) | ✅ | ✅ | ❌ |
Expand Down
122 changes: 122 additions & 0 deletions docs/source/en/model_doc/sam_hq.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# SAM-HQ

## Overview

SAM-HQ (High-Quality Segment Anything Model) was proposed in [Segment Anything in High Quality](https://arxiv.org/pdf/2306.01567.pdf) by Lei Ke, Mingqiao Ye, Martin Danelljan, Yifan Liu, Yu-Wing Tai, Chi-Keung Tang, Fisher Yu.

The model is an enhancement to the original SAM model that produces significantly higher quality segmentation masks while maintaining SAM's original promptable design, efficiency, and zero-shot generalizability.

![example image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/sam-output.png)

The abstract from the paper is the following:

*The recent Segment Anything Model (SAM) represents a big leap in scaling up segmentation models, allowing for powerful zero-shot capabilities and flexible prompting. Despite being trained with 1.1 billion masks, SAM's mask prediction quality falls short in many cases, particularly when dealing with objects that have intricate structures. We propose HQ-SAM, equipping SAM with the ability to accurately segment any object, while maintaining SAM's original promptable design, efficiency, and zero-shot generalizability. Our careful design reuses and preserves the pre-trained model weights of SAM, while only introducing minimal additional parameters and computation. We design a learnable High-Quality Output Token, which is injected into SAM's mask decoder and is responsible for predicting the high-quality mask. Instead of only applying it on mask-decoder features, we first fuse them with early and final ViT features for improved mask details. To train our introduced learnable parameters, we compose a dataset of 44K fine-grained masks from several sources. HQ-SAM is only trained on the introduced dataset of 44k masks, which takes only 4 hours on 8 GPUs.*

Tips:

- SAM-HQ produces higher quality masks than the original SAM model, particularly for objects with intricate structures and fine details
- The model predicts binary masks with more accurate boundaries and better handling of thin structures
- Like SAM, the model performs better with input 2D points and/or input bounding boxes
- You can prompt multiple points for the same image and predict a single high-quality mask
- The model maintains SAM's zero-shot generalization capabilities
- SAM-HQ only adds ~0.5% additional parameters compared to SAM
- Fine-tuning the model is not supported yet

This model was contributed by [sushmanth](https://huggingface.co/sushmanth).
The original code can be found [here](https://github.com/SysCV/SAM-HQ).

Below is an example on how to run mask generation given an image and a 2D point:

```python
import torch
from PIL import Image
import requests
from transformers import SamHQModel, SamHQProcessor

device = "cuda" if torch.cuda.is_available() else "cpu"
model = SamHQModel.from_pretrained("sushmanth/sam_hq_vit_b").to(device)
processor = SamHQProcessor.from_pretrained("sushmanth/sam_hq_vit_b")

img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
input_points = [[[450, 600]]] # 2D location of a window in the image

inputs = processor(raw_image, input_points=input_points, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)

masks = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
)
scores = outputs.iou_scores
```

You can also process your own masks alongside the input images in the processor to be passed to the model:

```python
import torch
from PIL import Image
import requests
from transformers import SamHQModel, SamHQProcessor

device = "cuda" if torch.cuda.is_available() else "cpu"
model = SamHQModel.from_pretrained("sushmanth/sam_hq_vit_b").to(device)
processor = SamHQProcessor.from_pretrained("sushmanth/sam_hq_vit_b")

img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
mask_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
segmentation_map = Image.open(requests.get(mask_url, stream=True).raw).convert("1")
input_points = [[[450, 600]]] # 2D location of a window in the image

inputs = processor(raw_image, input_points=input_points, segmentation_maps=segmentation_map, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)

masks = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
)
scores = outputs.iou_scores
```

## Key Differences from SAM

SAM-HQ introduces several key improvements over the original SAM model:

1. High-Quality Output Token: A learnable token injected into SAM's mask decoder for higher quality mask prediction
2. Global-local Feature Fusion: Combines features from different stages of the model for improved mask details
3. Training Data: Uses a carefully curated dataset of 44K high-quality masks instead of SA-1B
4. Efficiency: Adds only 0.5% additional parameters while significantly improving mask quality
5. Zero-shot Capability: Maintains SAM's strong zero-shot performance while improving accuracy

## Resources

A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with SAM-HQ:

- Demo notebook for using the model (coming soon)
- Paper implementation and code: [SAM-HQ GitHub Repository](https://github.com/SysCV/SAM-HQ)

## SamHQConfig

[[autodoc]] SamHQConfig

## SamHQVisionConfig

[[autodoc]] SamHQVisionConfig

## SamHQMaskDecoderConfig

[[autodoc]] SamHQMaskDecoderConfig

## SamHQPromptEncoderConfig

[[autodoc]] SamHQPromptEncoderConfig

## SamHQProcessor

[[autodoc]] SamHQProcessor

## SamHQModel

[[autodoc]] SamHQModel
- forward
26 changes: 26 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,6 +723,13 @@
"SamPromptEncoderConfig",
"SamVisionConfig",
],
"models.sam_hq": [
"SamHQConfig",
"SamHQMaskDecoderConfig",
"SamHQProcessor",
"SamHQVisionConfig",
"SamHQPromptEncoderConfig",
],
"models.seamless_m4t": [
"SeamlessM4TConfig",
"SeamlessM4TFeatureExtractor",
Expand Down Expand Up @@ -1233,6 +1240,7 @@
_import_structure["models.qwen2_vl"].extend(["Qwen2VLImageProcessor"])
_import_structure["models.rt_detr"].extend(["RTDetrImageProcessor"])
_import_structure["models.sam"].extend(["SamImageProcessor"])
_import_structure["models.sam"].extend(["SamImageProcessor"])
_import_structure["models.segformer"].extend(["SegformerFeatureExtractor", "SegformerImageProcessor"])
_import_structure["models.seggpt"].extend(["SegGptImageProcessor"])
_import_structure["models.siglip"].append("SiglipImageProcessor")
Expand Down Expand Up @@ -3295,6 +3303,12 @@
"SamPreTrainedModel",
]
)
_import_structure["models.sam_hq"].extend(
[
"SamHQModel",
"SamHQPreTrainedModel",
]
)
_import_structure["models.seamless_m4t"].extend(
[
"SeamlessM4TCodeHifiGan",
Expand Down Expand Up @@ -5638,6 +5652,13 @@
SamPromptEncoderConfig,
SamVisionConfig,
)
from .models.sam_hq import (
SamHQConfig,
SamHQProcessor,
SamHQMaskDecoderConfig,
SamHQPromptEncoderConfig,
SamHQVisionConfig,
)
from .models.seamless_m4t import (
SeamlessM4TConfig,
SeamlessM4TFeatureExtractor,
Expand Down Expand Up @@ -6167,6 +6188,7 @@
from .models.qwen2_vl import Qwen2VLImageProcessor
from .models.rt_detr import RTDetrImageProcessor
from .models.sam import SamImageProcessor
from .models.sam import SamImageProcessor
from .models.segformer import SegformerFeatureExtractor, SegformerImageProcessor
from .models.seggpt import SegGptImageProcessor
from .models.siglip import SiglipImageProcessor
Expand Down Expand Up @@ -7827,6 +7849,10 @@
SamModel,
SamPreTrainedModel,
)
from .models.sam_hq import (
SamHQModel,
SamHQPreTrainedModel,
)
from .models.seamless_m4t import (
SeamlessM4TCodeHifiGan,
SeamlessM4TForSpeechToSpeech,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@
rt_detr,
rwkv,
sam,
sam_hq,
seamless_m4t,
seamless_m4t_v2,
segformer,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@
("rt_detr_resnet", "RTDetrResNetConfig"),
("rwkv", "RwkvConfig"),
("sam", "SamConfig"),
("sam_hq", "SamHQConfig"),
("seamless_m4t", "SeamlessM4TConfig"),
("seamless_m4t_v2", "SeamlessM4Tv2Config"),
("segformer", "SegformerConfig"),
Expand Down Expand Up @@ -561,6 +562,7 @@
("rt_detr_resnet", "RT-DETR-ResNet"),
("rwkv", "RWKV"),
("sam", "SAM"),
("sam_hq", "SAM-HQ"),
("seamless_m4t", "SeamlessM4T"),
("seamless_m4t_v2", "SeamlessM4Tv2"),
("segformer", "SegFormer"),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/image_processing_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@
("resnet", ("ConvNextImageProcessor",)),
("rt_detr", ("RTDetrImageProcessor", "RTDetrImageProcessorFast")),
("sam", ("SamImageProcessor",)),
("sam_hq", ("SamImageProcessor",)),
("segformer", ("SegformerImageProcessor",)),
("seggpt", ("SegGptImageProcessor",)),
("siglip", ("SiglipImageProcessor",)),
Expand Down
7 changes: 7 additions & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@
("rt_detr", "RTDetrModel"),
("rwkv", "RwkvModel"),
("sam", "SamModel"),
("sam_hq", "SamHQModel"),
("seamless_m4t", "SeamlessM4TModel"),
("seamless_m4t_v2", "SeamlessM4Tv2Model"),
("segformer", "SegformerModel"),
Expand Down Expand Up @@ -1367,6 +1368,12 @@
]
)

MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict(
[
("sam_hq", "SamHQModel"),
]
)


MODEL_FOR_KEYPOINT_DETECTION_MAPPING_NAMES = OrderedDict(
[
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/processing_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
("qwen2_audio", "Qwen2AudioProcessor"),
("qwen2_vl", "Qwen2VLProcessor"),
("sam", "SamProcessor"),
("sam-hq", "SamHQProcessor"),
("seamless_m4t", "SeamlessM4TProcessor"),
("sew", "Wav2Vec2Processor"),
("sew-d", "Wav2Vec2Processor"),
Expand Down
74 changes: 74 additions & 0 deletions src/transformers/models/sam_hq/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from typing import TYPE_CHECKING

from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_torch_available,
is_vision_available,
)


__import_structure = {
"configuration_sam_hq": [
"SamHQConfig",
"SamHQMaskDecoderConfig",
"SamHQPromptEncoderConfig",
"SamHQVisionConfig",
],
"processing_samhq": ["SamHQProcessor"],
}



try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
__import_structure["modeling_sam_hq"] = [
"SamHQModel",
"SamHQPreTrainedModel",
]

try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
__import_structure["image_processing_sam"] = ["SamImageProcessor"]


if TYPE_CHECKING:
from .configuration_sam_hq import (
SamHQConfig,
SamHQMaskDecoderConfig,
SamHQPromptEncoderConfig,
SamHQVisionConfig,
)
from .processing_samhq import SamHQProcessor


try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_sam_hq import SamHQModel, SamHQPreTrainedModel


try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from transformers.models.sam.image_processing_sam import SamImageProcessor


else:
import sys

sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], __import_structure, module_spec=__spec__)
Loading