Skip to content

Commit

Permalink
rebased files for modular dinov2
Browse files Browse the repository at this point in the history
  • Loading branch information
BernardZach committed Dec 5, 2024
1 parent b80e418 commit 7cdfdca
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 88 deletions.
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from <path_to_diff_file.py>.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the diff. If any change should be done, please apply the change to the
# diff.py file directly.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_dinov2_with_registers.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2024 Meta Inc. and the HuggingFace Inc. team. All rights reserved.
#
Expand All @@ -20,18 +20,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.


from transformers import PretrainedConfig

from ...utils import (
logging,
)
from ...configuration_utils import PretrainedConfig
from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices


logger = logging.get_logger(__name__)


class Dinov2WithRegistersConfig(BackboneConfigMixin, PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Dinov2WithRegistersModel`]. It is used to instantiate an
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from <path_to_diff_file.py>.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the diff. If any change should be done, please apply the change to the
# diff.py file directly.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_dinov2_with_registers.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2024 Meta Inc. and the HuggingFace Inc. team. All rights reserved.
#
Expand All @@ -19,23 +19,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import collections.abc
import math
from typing import Dict, List, Optional, Set, Tuple, Union

import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...modeling_outputs import (
BackboneOutput,
BaseModelOutput,
BaseModelOutputWithPooling,
ImageClassifierOutput,
)
from ...modeling_outputs import BackboneOutput, BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
Expand All @@ -51,6 +44,12 @@

logger = logging.get_logger(__name__)

# Base docstring
_CHECKPOINT_FOR_DOC = "facebook/dinov2_with_registers-base"

# General docstring
_CONFIG_FOR_DOC = "Dinov2WithRegistersConfig"


class Dinov2WithRegistersPatchEmbeddings(nn.Module):
"""
Expand Down Expand Up @@ -85,14 +84,6 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
return embeddings


# General docstring
_CONFIG_FOR_DOC = "Dinov2WithRegistersConfig"

# Base docstring
_CHECKPOINT_FOR_DOC = "facebook/dinov2_with_registers-base"
_EXPECTED_OUTPUT_SHAPE = [1, 257, 768]


class Dinov2WithRegistersEmbeddings(nn.Module):
"""
Construct the CLS token, mask token, register tokens, position and patch embeddings.
Expand All @@ -103,11 +94,7 @@ def __init__(self, config: Dinov2WithRegistersConfig) -> None:

self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
self.mask_token = nn.Parameter(torch.zeros(1, config.hidden_size))
self.register_tokens = (
nn.Parameter(torch.zeros(1, config.num_register_tokens, config.hidden_size))
if config.num_register_tokens
else None
)
self.register_tokens = nn.Parameter(torch.zeros(1, config.num_register_tokens, config.hidden_size))
self.patch_embeddings = Dinov2WithRegistersPatchEmbeddings(config)
num_patches = self.patch_embeddings.num_patches
self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size))
Expand Down Expand Up @@ -144,7 +131,8 @@ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width:
mode="bicubic",
align_corners=False,
antialias=self.config.interpolate_antialias,
).to(dtype=target_dtype)
)
patch_pos_embed = patch_pos_embed.to(dtype=target_dtype)
if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]:
raise ValueError("Width or height does not match with the interpolated position embeddings")
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
Expand All @@ -167,22 +155,15 @@ def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Te
# add positional encoding to each token
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)

if self.register_tokens is not None:
embeddings = torch.cat(
(
embeddings[:, :1],
self.register_tokens.expand(embeddings.shape[0], -1, -1),
embeddings[:, 1:],
),
dim=1,
)
# add register tokens
embeddings = torch.cat((embeddings[:, :1], self.register_tokens.expand(embeddings.shape[0], -1, -1)), dim=1)
embeddings = torch.cat((embeddings, embeddings[:, 1:]), dim=1)

embeddings = self.dropout(embeddings)

return embeddings


# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->Dinov2WithRegisters
class Dinov2WithRegistersSelfAttention(nn.Module):
def __init__(self, config: Dinov2WithRegistersConfig) -> None:
super().__init__()
Expand Down Expand Up @@ -243,7 +224,47 @@ def forward(
return outputs


# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Dinov2WithRegisters
class Dinov2WithRegistersSdpaSelfAttention(Dinov2WithRegistersSelfAttention):
def __init__(self, config: Dinov2WithRegistersConfig) -> None:
super().__init__(config)
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob

def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"Dinov2WithRegistersModel is using Dinov2WithRegistersSdpaSelfAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states, head_mask=head_mask, output_attentions=output_attentions
)

mixed_query_layer = self.query(hidden_states)

key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)

context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
head_mask,
self.attention_probs_dropout_prob if self.training else 0.0,
is_causal=False,
scale=None,
)

context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)

return context_layer, None


class Dinov2WithRegistersSelfOutput(nn.Module):
"""
The residual connection is defined in Dinov2WithRegistersLayer instead of here (as is the case with other models), due to the
Expand All @@ -262,7 +283,6 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to
return hidden_states


# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->Dinov2WithRegisters
class Dinov2WithRegistersAttention(nn.Module):
def __init__(self, config: Dinov2WithRegistersConfig) -> None:
super().__init__()
Expand Down Expand Up @@ -302,6 +322,12 @@ def forward(
return outputs


class Dinov2WithRegistersSdpaAttention(Dinov2WithRegistersAttention):
def __init__(self, config: Dinov2WithRegistersConfig) -> None:
super().__init__(config)
self.attention = Dinov2WithRegistersSdpaSelfAttention(config)


class Dinov2WithRegistersLayerScale(nn.Module):
def __init__(self, config) -> None:
super().__init__()
Expand Down Expand Up @@ -380,14 +406,20 @@ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
return self.weights_out(hidden)


DINOV2_WITH_REGISTERS_ATTENTION_CLASSES = {
"eager": Dinov2WithRegistersAttention,
"sdpa": Dinov2WithRegistersSdpaAttention,
}


class Dinov2WithRegistersLayer(nn.Module):
"""This corresponds to the Block class in the original implementation."""

def __init__(self, config: Dinov2WithRegistersConfig) -> None:
super().__init__()

self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.attention = Dinov2WithRegistersAttention(config)
self.attention = DINOV2_WITH_REGISTERS_ATTENTION_CLASSES[config._attn_implementation](config)
self.layer_scale1 = Dinov2WithRegistersLayerScale(config)
self.drop_path = (
Dinov2WithRegistersDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
Expand Down Expand Up @@ -433,7 +465,6 @@ def forward(
return outputs


# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->Dinov2WithRegisters
class Dinov2WithRegistersEncoder(nn.Module):
def __init__(self, config: Dinov2WithRegistersConfig) -> None:
super().__init__()
Expand Down Expand Up @@ -485,10 +516,6 @@ def forward(
)


# Image classification docstring
_IMAGE_CLASS_CHECKPOINT = "facebook/dinov2_with_registers-small-imagenet1k-1-layer"


class Dinov2WithRegistersPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
Expand All @@ -500,6 +527,7 @@ class Dinov2WithRegistersPreTrainedModel(PreTrainedModel):
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_no_split_modules = ["Dinov2WithRegistersSwiGLUFFN"]
_supports_sdpa = True

def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights"""
Expand Down Expand Up @@ -528,7 +556,7 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No
).to(module.cls_token.dtype)


_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
_EXPECTED_OUTPUT_SHAPE = [1, 257, 768]


DINOV2_WITH_REGISTERS_START_DOCSTRING = r"""
Expand All @@ -542,28 +570,6 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""

DINOV2_WITH_REGISTERS_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
[`BitImageProcessor.preprocess`] for details.
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""

DINOV2_WITH_REGISTERS_BASE_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Expand Down Expand Up @@ -677,6 +683,33 @@ def forward(
)


# Image classification docstring
_IMAGE_CLASS_CHECKPOINT = "facebook/dinov2_with_registers-small-imagenet1k-1-layer"
_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"

DINOV2_WITH_REGISTERS_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
[`BitImageProcessor.preprocess`] for details.
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""


@add_start_docstrings(
"""
Dinov2WithRegisters Model transformer with an image classification head on top (a linear layer on top of the final hidden state
Expand Down Expand Up @@ -787,14 +820,14 @@ class Dinov2WithRegistersBackbone(Dinov2WithRegistersPreTrainedModel, BackboneMi
def __init__(self, config):
super().__init__(config)
super()._init_backbone(config)

self.num_register_tokens = config.num_register_tokens
self.num_features = [config.hidden_size for _ in range(config.num_hidden_layers + 1)]
self.embeddings = Dinov2WithRegistersEmbeddings(config)
self.encoder = Dinov2WithRegistersEncoder(config)

self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

self.num_register_tokens = config.num_register_tokens

# Initialize weights and apply final processing
self.post_init()

Expand All @@ -814,6 +847,10 @@ def forward(
Returns:
Examples:
Returns:
Examples:
```python
>>> from transformers import AutoImageProcessor, AutoBackbone
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@
import torch.utils.checkpoint
from torch import nn

from transformers import PretrainedConfig
from transformers.models.dinov2.modeling_dinov2 import (
from ....transformers.models.dinov2.modeling_dinov2 import (
Dinov2Backbone,
Dinov2Encoder,
Dinov2ForImageClassification,
Dinov2Model,
Dinov2PatchEmbeddings,
)

from ...configuration_utils import PretrainedConfig
from ...modeling_outputs import BackboneOutput
from ...utils import logging
from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
Expand Down

0 comments on commit 7cdfdca

Please sign in to comment.