From 78b2929c0554b79e0489b451ce4ece14d265ead2 Mon Sep 17 00:00:00 2001 From: Avishai Elmakies <36810152+avishaiElmakies@users.noreply.github.com> Date: Sat, 21 Sep 2024 03:58:00 +0300 Subject: [PATCH] Sdpa dino v2 (#33403) * add sdpa to dinov2 * fixup * add dinov2 to sdpa doc * update doc order * [run-slow] dinov2 * common to eager * [run-slow] dinov2 * update attn implementation in common * update test_modeling_dinov2 to have mask_ration, num_masks and mask_length similar to vit * [run-slow] dinov2 --------- Co-authored-by: Avishai Elmakies --- docs/source/en/perf_infer_gpu_one.md | 2 +- .../models/dinov2/modeling_dinov2.py | 48 ++++++++++++++++++- tests/models/dinov2/test_modeling_dinov2.py | 7 +++ 3 files changed, 55 insertions(+), 2 deletions(-) diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index 7ecfe46dd22b01..193af845da659d 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -217,6 +217,7 @@ For now, Transformers supports SDPA inference and training for the following arc * [data2vec_audio](https://huggingface.co/docs/transformers/main/en/model_doc/data2vec#transformers.Data2VecAudioModel) * [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel) * [DeiT](https://huggingface.co/docs/transformers/model_doc/deit#transformers.DeiTModel) +* [Dinov2](https://huggingface.co/docs/transformers/en/model_doc/dinov2) * [Dpr](https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DprReader) * [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel) * [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel) @@ -275,7 +276,6 @@ For now, Transformers supports SDPA inference and training for the following arc * [XLM-RoBERTa-XL](https://huggingface.co/docs/transformers/model_doc/xlm-roberta-xl#transformers.XLMRobertaXLModel) * [YOLOS](https://huggingface.co/docs/transformers/model_doc/yolos#transformers.YolosModel) - FlashAttention can only be used for models with the `fp16` or `bf16` torch type, so make sure to cast your model to the appropriate type first. The memory-efficient attention backend is able to handle `fp32` models. diff --git a/src/transformers/models/dinov2/modeling_dinov2.py b/src/transformers/models/dinov2/modeling_dinov2.py index 160c5ae69f394a..ebe322618a2d8f 100644 --- a/src/transformers/models/dinov2/modeling_dinov2.py +++ b/src/transformers/models/dinov2/modeling_dinov2.py @@ -231,6 +231,38 @@ def forward( return outputs +# Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention with ViT->Dinov2 +class Dinov2SdpaSelfAttention(Dinov2SelfAttention): + def __init__(self, config: Dinov2Config) -> None: + super().__init__(config) + self.attention_probs_dropout_prob = config.attention_probs_dropout_prob + + def forward( + self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + mixed_query_layer = self.query(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + head_mask, + self.attention_probs_dropout_prob if self.training else 0.0, + is_causal=False, + scale=None, + ) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + return context_layer, None + + # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Dinov2 class Dinov2SelfOutput(nn.Module): """ @@ -290,6 +322,13 @@ def forward( return outputs +# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->Dinov2 +class Dinov2SdpaAttention(Dinov2Attention): + def __init__(self, config: Dinov2Config) -> None: + super().__init__(config) + self.attention = Dinov2SdpaSelfAttention(config) + + class Dinov2LayerScale(nn.Module): def __init__(self, config) -> None: super().__init__() @@ -371,6 +410,12 @@ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: return self.weights_out(hidden) +DINOV2_ATTENTION_CLASSES = { + "eager": Dinov2Attention, + "sdpa": Dinov2SdpaAttention, +} + + class Dinov2Layer(nn.Module): """This corresponds to the Block class in the original implementation.""" @@ -378,7 +423,7 @@ def __init__(self, config: Dinov2Config) -> None: super().__init__() self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.attention = Dinov2Attention(config) + self.attention = DINOV2_ATTENTION_CLASSES[config._attn_implementation](config) self.layer_scale1 = Dinov2LayerScale(config) self.drop_path = Dinov2DropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity() @@ -485,6 +530,7 @@ class Dinov2PreTrainedModel(PreTrainedModel): main_input_name = "pixel_values" supports_gradient_checkpointing = True _no_split_modules = ["Dinov2SwiGLUFFN"] + _supports_sdpa = True def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" diff --git a/tests/models/dinov2/test_modeling_dinov2.py b/tests/models/dinov2/test_modeling_dinov2.py index 88be7061cf45bb..5caa3baec1a2ca 100644 --- a/tests/models/dinov2/test_modeling_dinov2.py +++ b/tests/models/dinov2/test_modeling_dinov2.py @@ -65,6 +65,8 @@ def __init__( type_sequence_label_size=10, initializer_range=0.02, scope=None, + attn_implementation="eager", + mask_ratio=0.5, ): self.parent = parent self.batch_size = batch_size @@ -83,10 +85,14 @@ def __init__( self.type_sequence_label_size = type_sequence_label_size self.initializer_range = initializer_range self.scope = scope + self.attn_implementation = attn_implementation + self.mask_ratio = mask_ratio # in Dinov2, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token) num_patches = (image_size // patch_size) ** 2 self.seq_length = num_patches + 1 + self.num_masks = int(self.mask_ratio * self.seq_length) + self.mask_length = num_patches def prepare_config_and_inputs(self): pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) @@ -113,6 +119,7 @@ def get_config(self): attention_probs_dropout_prob=self.attention_probs_dropout_prob, is_decoder=False, initializer_range=self.initializer_range, + attn_implementation=self.attn_implementation, ) def create_and_check_model(self, config, pixel_values, labels):