Skip to content

Commit

Permalink
Sdpa dino v2 (#33403)
Browse files Browse the repository at this point in the history
* add sdpa to dinov2

* fixup

* add dinov2 to sdpa doc

* update doc order

* [run-slow] dinov2

* common to eager

* [run-slow] dinov2

* update attn implementation in common

* update test_modeling_dinov2 to have mask_ration, num_masks and mask_length similar to vit

* [run-slow] dinov2

---------

Co-authored-by: Avishai Elmakies <[email protected]>
  • Loading branch information
avishaiElmakies and Avishai Elmakies authored Sep 21, 2024
1 parent e71bf70 commit 78b2929
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 2 deletions.
2 changes: 1 addition & 1 deletion docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [data2vec_audio](https://huggingface.co/docs/transformers/main/en/model_doc/data2vec#transformers.Data2VecAudioModel)
* [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel)
* [DeiT](https://huggingface.co/docs/transformers/model_doc/deit#transformers.DeiTModel)
* [Dinov2](https://huggingface.co/docs/transformers/en/model_doc/dinov2)
* [Dpr](https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DprReader)
* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
* [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel)
Expand Down Expand Up @@ -275,7 +276,6 @@ For now, Transformers supports SDPA inference and training for the following arc
* [XLM-RoBERTa-XL](https://huggingface.co/docs/transformers/model_doc/xlm-roberta-xl#transformers.XLMRobertaXLModel)
* [YOLOS](https://huggingface.co/docs/transformers/model_doc/yolos#transformers.YolosModel)


<Tip>

FlashAttention can only be used for models with the `fp16` or `bf16` torch type, so make sure to cast your model to the appropriate type first. The memory-efficient attention backend is able to handle `fp32` models.
Expand Down
48 changes: 47 additions & 1 deletion src/transformers/models/dinov2/modeling_dinov2.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,38 @@ def forward(
return outputs


# Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention with ViT->Dinov2
class Dinov2SdpaSelfAttention(Dinov2SelfAttention):
def __init__(self, config: Dinov2Config) -> None:
super().__init__(config)
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob

def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)

key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)

context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
head_mask,
self.attention_probs_dropout_prob if self.training else 0.0,
is_causal=False,
scale=None,
)

context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)

return context_layer, None


# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Dinov2
class Dinov2SelfOutput(nn.Module):
"""
Expand Down Expand Up @@ -290,6 +322,13 @@ def forward(
return outputs


# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->Dinov2
class Dinov2SdpaAttention(Dinov2Attention):
def __init__(self, config: Dinov2Config) -> None:
super().__init__(config)
self.attention = Dinov2SdpaSelfAttention(config)


class Dinov2LayerScale(nn.Module):
def __init__(self, config) -> None:
super().__init__()
Expand Down Expand Up @@ -371,14 +410,20 @@ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
return self.weights_out(hidden)


DINOV2_ATTENTION_CLASSES = {
"eager": Dinov2Attention,
"sdpa": Dinov2SdpaAttention,
}


class Dinov2Layer(nn.Module):
"""This corresponds to the Block class in the original implementation."""

def __init__(self, config: Dinov2Config) -> None:
super().__init__()

self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.attention = Dinov2Attention(config)
self.attention = DINOV2_ATTENTION_CLASSES[config._attn_implementation](config)
self.layer_scale1 = Dinov2LayerScale(config)
self.drop_path = Dinov2DropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()

Expand Down Expand Up @@ -485,6 +530,7 @@ class Dinov2PreTrainedModel(PreTrainedModel):
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_no_split_modules = ["Dinov2SwiGLUFFN"]
_supports_sdpa = True

def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights"""
Expand Down
7 changes: 7 additions & 0 deletions tests/models/dinov2/test_modeling_dinov2.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ def __init__(
type_sequence_label_size=10,
initializer_range=0.02,
scope=None,
attn_implementation="eager",
mask_ratio=0.5,
):
self.parent = parent
self.batch_size = batch_size
Expand All @@ -83,10 +85,14 @@ def __init__(
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.scope = scope
self.attn_implementation = attn_implementation
self.mask_ratio = mask_ratio

# in Dinov2, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
num_patches = (image_size // patch_size) ** 2
self.seq_length = num_patches + 1
self.num_masks = int(self.mask_ratio * self.seq_length)
self.mask_length = num_patches

def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
Expand All @@ -113,6 +119,7 @@ def get_config(self):
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
is_decoder=False,
initializer_range=self.initializer_range,
attn_implementation=self.attn_implementation,
)

def create_and_check_model(self, config, pixel_values, labels):
Expand Down

0 comments on commit 78b2929

Please sign in to comment.