Skip to content

Commit

Permalink
Support for SDPA for SAM models (#34110)
Browse files Browse the repository at this point in the history
* feat: add support for sdpa and gradient checkpointing

* fix: ruff format

* fix: config sdpa

* fix: sdpa layer naming convention

* fix: update test_eager_matches_sdpa_inference to handle vision_hidden_states

* test: skip incompatible tests and fix loading issue with sdpa

- Updated tests to skip cases flash and dynamic compile.
- Minor adjustment to ensure correct loading of model with sdpa for dispatch test.

* style: apply Ruff formatting

* ruff fix again after rebase

* [run-slow] sam

* [run-slow] sam

* refactor: Address review comments and improve sub-config handling in SAM model tests

- Added attributes for sub_configs as per PR #34410.
- Enabled tests for configs, ensuring the composite model (SAM) has several sub-configs in the main config.
- Added class attribute _is_composite=True to the tester class
- test_sdpa_can_dispatch_composite_models added

* [run-slow] sam

* style: ruff

* [run-slow] sam

* style: ruff again ...

* [run-slow] sam
  • Loading branch information
MagnusS0 authored Dec 17, 2024
1 parent 747f361 commit 6eb00dd
Show file tree
Hide file tree
Showing 4 changed files with 256 additions and 31 deletions.
11 changes: 11 additions & 0 deletions src/transformers/models/sam/configuration_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ class SamPromptEncoderConfig(PretrainedConfig):
The non-linear activation function in the encoder and pooler.
"""

base_config_key = "prompt_encoder_config"

def __init__(
self,
hidden_size=256,
Expand Down Expand Up @@ -102,6 +104,8 @@ class SamMaskDecoderConfig(PretrainedConfig):
"""

base_config_key = "mask_decoder_config"

def __init__(
self,
hidden_size=256,
Expand Down Expand Up @@ -181,6 +185,8 @@ class SamVisionConfig(PretrainedConfig):
hidden_size`.
"""

base_config_key = "vision_config"

def __init__(
self,
hidden_size=768,
Expand Down Expand Up @@ -278,6 +284,11 @@ class SamConfig(PretrainedConfig):
```"""

model_type = "sam"
sub_configs = {
"prompt_encoder_config": SamPromptEncoderConfig,
"mask_decoder_config": SamMaskDecoderConfig,
"vision_config": SamVisionConfig,
}

def __init__(
self,
Expand Down
167 changes: 160 additions & 7 deletions src/transformers/models/sam/modeling_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,47 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, attention_similarit
return out


class SamSdpaAttention(SamAttention):
"""
SAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
values. Using SDPA instead of the default attention.
"""

def __init__(self, config, downsample_rate=None):
super().__init__(config, downsample_rate)

def forward(self, query: Tensor, key: Tensor, value: Tensor, attention_similarity: Tensor = None) -> Tensor:
# Input projections
query = self.q_proj(query)
key = self.k_proj(key)
value = self.v_proj(value)

point_batch_size = query.shape[1]
# Separate into heads
query = self._separate_heads(query, self.num_attention_heads)
key = self._separate_heads(key, self.num_attention_heads)
value = self._separate_heads(value, self.num_attention_heads)

# Scaled dot product attention
attn_mask = None
if attention_similarity is not None:
attn_mask = attention_similarity.unsqueeze(1).expand(-1, self.num_attention_heads, -1, -1)

out = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask)

# Get output
out = self._recombine_heads(out, point_batch_size)
out = self.out_proj(out)

return out


SAM_ATTENTION_CLASSES = {
"eager": SamAttention,
"sdpa": SamSdpaAttention,
}


class SamTwoWayAttentionBlock(nn.Module):
def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False):
"""
Expand All @@ -266,18 +307,21 @@ def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_
self.hidden_size = config.hidden_size
self.layer_norm_eps = config.layer_norm_eps

self.self_attn = SamAttention(config, downsample_rate=1)
self.self_attn = SAM_ATTENTION_CLASSES[config._attn_implementation](config, downsample_rate=1)
self.layer_norm1 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)

self.cross_attn_token_to_image = SamAttention(config, downsample_rate=attention_downsample_rate)
self.cross_attn_token_to_image = SAM_ATTENTION_CLASSES[config._attn_implementation](
config, downsample_rate=attention_downsample_rate
)
self.layer_norm2 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)

self.mlp = SamMLPBlock(config)
self.layer_norm3 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)

self.layer_norm4 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
self.cross_attn_image_to_token = SamAttention(config, downsample_rate=attention_downsample_rate)

self.cross_attn_image_to_token = SAM_ATTENTION_CLASSES[config._attn_implementation](
config, downsample_rate=attention_downsample_rate
)
self.skip_first_layer_pe = skip_first_layer_pe

def forward(
Expand Down Expand Up @@ -344,7 +388,7 @@ def __init__(self, config: SamMaskDecoderConfig):
for i in range(self.num_hidden_layers):
self.layers.append(SamTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0)))

self.final_attn_token_to_image = SamAttention(config)
self.final_attn_token_to_image = SAM_ATTENTION_CLASSES[config._attn_implementation](config)
self.layer_norm_final_attn = nn.LayerNorm(config.hidden_size)

def forward(
Expand Down Expand Up @@ -431,7 +475,7 @@ def forward(self, hidden_states):
class SamMaskDecoder(nn.Module):
def __init__(self, config: SamMaskDecoderConfig):
super().__init__()

self.config = config
self.hidden_size = config.hidden_size

self.num_multimask_outputs = config.num_multimask_outputs
Expand Down Expand Up @@ -856,11 +900,118 @@ def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch
return outputs


class SamVisionSdpaAttention(SamVisionAttention):
"""
Multi-head Attention block with relative position embeddings.
Using SDPA instead of the default attention.
"""

def __init__(self, config, window_size):
super().__init__(config, window_size)

def add_decomposed_rel_pos(
self,
query: torch.Tensor,
rel_pos_h: torch.Tensor,
rel_pos_w: torch.Tensor,
q_size: Tuple[int, int],
k_size: Tuple[int, int],
) -> torch.Tensor:
"""
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
This method is reimplemented to follow the implementation in:
https://github.com/pytorch-labs/segment-anything-fast/blob/main/segment_anything_fast/modeling/image_encoder.py # noqa B950
This implementation is more memory efficient when using SDPA in the forward method.
Args:
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
Returns:
attn (Tensor): attention map with added relative positional embeddings.
"""
query_height, query_width = q_size
key_height, key_width = k_size
relative_position_height = self.get_rel_pos(query_height, key_height, rel_pos_h)
relative_position_width = self.get_rel_pos(query_width, key_width, rel_pos_w)

batch_size, _, dim = query.shape
reshaped_query = query.reshape(batch_size, query_height, query_width, dim)
rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height)
rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width)
rel_h = rel_h.unsqueeze(-1)
rel_w = rel_w.unsqueeze(-2)
rel_h = rel_h.reshape(batch_size, query_height * query_width, key_height, 1)
rel_w = rel_w.reshape(batch_size, query_height * query_width, 1, key_width)

return rel_h, rel_w

def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
batch_size, height, width, _ = hidden_states.shape
# qkv with shape (3, B, nHead, H * W, C)
qkv = (
self.qkv(hidden_states)
.reshape(batch_size, height * width, 3, self.num_attention_heads, -1)
.permute(2, 0, 3, 1, 4)
)
# q, k, v with shape (B * nHead, H * W, C)
query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0)

rel_h, rel_w = None, None
if self.use_rel_pos:
rel_h, rel_w = self.add_decomposed_rel_pos(
query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
)

query = query.view(batch_size, self.num_attention_heads, height * width, -1)
key = key.view(batch_size, self.num_attention_heads, height * width, -1)
value = value.view(batch_size, self.num_attention_heads, height * width, -1)

if self.use_rel_pos:
rel_h = rel_h.view(batch_size, self.num_attention_heads, rel_h.size(1), rel_h.size(2), rel_h.size(3))
rel_w = rel_w.view(batch_size, self.num_attention_heads, rel_w.size(1), rel_w.size(2), rel_w.size(3))
attn_bias = (rel_h + rel_w).view(
batch_size, self.num_attention_heads, rel_h.size(2), rel_h.size(3) * rel_w.size(4)
)
attn_output = torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attn_bias)
else:
attn_output = torch.nn.functional.scaled_dot_product_attention(query, key, value)

attn_output = (
attn_output.view(batch_size, self.num_attention_heads, height, width, -1)
.permute(0, 2, 3, 1, 4)
.reshape(batch_size, height, width, -1)
)

attn_output = self.proj(attn_output)

if output_attentions:
# For output_attentions, calculate the attention weights
attn_weights = (query @ key.transpose(-2, -1)) * self.scale
if attn_bias is not None:
attn_weights = attn_weights + attn_bias
attn_weights = F.softmax(attn_weights, dim=-1)
outputs = (attn_output, attn_weights)
else:
outputs = (attn_output, None)

return outputs


SAM_VISION_ATTENTION_CLASSES = {
"eager": SamVisionAttention,
"sdpa": SamVisionSdpaAttention,
}


class SamVisionLayer(nn.Module):
def __init__(self, config, window_size):
super().__init__()
self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.attn = SamVisionAttention(config, window_size)
self.attn = SAM_VISION_ATTENTION_CLASSES[config._attn_implementation](config, window_size)
self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.mlp = SamMLPBlock(config)
self.window_size = window_size
Expand Down Expand Up @@ -1071,6 +1222,8 @@ class SamPreTrainedModel(PreTrainedModel):
base_model_prefix = "sam"
main_input_name = "pixel_values"
_no_split_modules = ["SamVisionAttention"]
supports_gradient_checkpointing = True
_supports_sdpa = True

def _init_weights(self, module):
std = self.config.initializer_range
Expand Down
83 changes: 69 additions & 14 deletions tests/models/sam/test_modeling_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@
# limitations under the License.
"""Testing suite for the PyTorch SAM model."""

import tempfile
import unittest

import requests

from transformers import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig, pipeline
from transformers.testing_utils import cleanup, require_torch, slow, torch_device
from transformers.testing_utils import cleanup, require_torch, require_torch_sdpa, slow, torch_device
from transformers.utils import is_torch_available, is_vision_available

from ...test_configuration_common import ConfigTester
Expand Down Expand Up @@ -295,6 +296,7 @@ class SamModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
test_resize_embeddings = False
test_head_masking = False
test_torchscript = False
_is_composite = True

# TODO: Fix me @Arthur: `run_batch_test` in `tests/test_pipeline_mixin.py` not working
def is_pipeline_test_to_skip(
Expand All @@ -311,22 +313,13 @@ def is_pipeline_test_to_skip(

def setUp(self):
self.model_tester = SamModelTester(self)
self.vision_config_tester = ConfigTester(self, config_class=SamVisionConfig, has_text_modality=False)
self.prompt_encoder_config_tester = ConfigTester(
self,
config_class=SamPromptEncoderConfig,
has_text_modality=False,
num_attention_heads=12,
num_hidden_layers=2,
)
self.mask_decoder_config_tester = ConfigTester(
self, config_class=SamMaskDecoderConfig, has_text_modality=False
common_properties = ["initializer_range"]
self.config_tester = ConfigTester(
self, config_class=SamConfig, has_text_modality=False, common_properties=common_properties
)

def test_config(self):
self.vision_config_tester.run_common_tests()
self.prompt_encoder_config_tester.run_common_tests()
self.mask_decoder_config_tester.run_common_tests()
self.config_tester.run_common_tests()

@unittest.skip(reason="SAM's vision encoder does not use inputs_embeds")
def test_inputs_embeds(self):
Expand Down Expand Up @@ -450,6 +443,68 @@ def test_model_from_pretrained(self):
model = SamModel.from_pretrained(model_name)
self.assertIsNotNone(model)

@require_torch_sdpa
def test_sdpa_can_compile_dynamic(self):
self.skipTest(reason="SAM model can't be compiled dynamic yet")

@require_torch_sdpa
def test_sdpa_can_dispatch_composite_models(self):
"""
Tests if composite models dispatch correctly on SDPA/eager when requested so when loading the model.
This tests only by looking at layer names, as usually SDPA layers are calles "SDPAAttention".
In contrast to the above test, this one checks if the "config._attn_implamentation" is a dict after the model
is loaded, because we manually replicate requested attn implementation on each sub-config when loading.
See https://github.com/huggingface/transformers/pull/32238 for more info
The test tries to cover most general cases of composite models, VLMs with vision and text configs. Any model
that has a different set of sub-configs has to overwrite this test.
"""
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")

if not self._is_composite:
self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")

for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)

with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_sdpa = model_class.from_pretrained(tmpdirname, attn_implementation="sdpa")
model_sdpa = model_sdpa.eval().to(torch_device)

model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager")
model_eager = model_eager.eval().to(torch_device)

# Root model determines SDPA support
attn_impl = "sdpa" if model._supports_sdpa else "eager"

# Check config propagation to submodels that support it
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
self.assertTrue(model_sdpa.vision_encoder.config._attn_implementation == attn_impl)
self.assertTrue(model_sdpa.mask_decoder.config._attn_implementation == attn_impl)

self.assertTrue(model_eager.config._attn_implementation == "eager")
self.assertTrue(model_eager.vision_encoder.config._attn_implementation == "eager")
self.assertTrue(model_eager.mask_decoder.config._attn_implementation == "eager")

# Verify SDPA/eager layer presence
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
has_sdpa = True
break

if not has_sdpa and attn_impl == "sdpa":
raise ValueError("The SDPA model should have SDPA attention layers")

for name, submodule in model_eager.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
raise ValueError("The eager model should not have SDPA attention layers")


def prepare_image():
img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
Expand Down
Loading

0 comments on commit 6eb00dd

Please sign in to comment.