diff --git a/docs/source/en/index.md b/docs/source/en/index.md index 971ab3f15f74a5..c8c2f662eb42af 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -124,7 +124,7 @@ Flax), PyTorch, and/or TensorFlow. | [DialoGPT](model_doc/dialogpt) | ✅ | ✅ | ✅ | | [DiNAT](model_doc/dinat) | ✅ | ❌ | ❌ | | [DINOv2](model_doc/dinov2) | ✅ | ❌ | ✅ | -| [Dinov2WithRegisters](model_doc/dinov2_with_registers) | ✅ | ❌ | ❌ | +| [Dinov2WithRegisters](model_doc/dinov2_with_registers) | ✅ | ❌ | ❌ | | [DistilBERT](model_doc/distilbert) | ✅ | ✅ | ✅ | | [DiT](model_doc/dit) | ✅ | ❌ | ✅ | | [DonutSwin](model_doc/donut) | ✅ | ❌ | ❌ | diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index ab5e1c47a448f3..4096f1ada9d85b 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -231,6 +231,7 @@ For now, Transformers supports SDPA inference and training for the following arc * [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel) * [DeiT](https://huggingface.co/docs/transformers/model_doc/deit#transformers.DeiTModel) * [Dinov2](https://huggingface.co/docs/transformers/en/model_doc/dinov2) +* [Dinov2_with_registers](https://huggingface.co/docs/transformers/en/model_doc/dinov2) * [DistilBert](https://huggingface.co/docs/transformers/model_doc/distilbert#transformers.DistilBertModel) * [Dpr](https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DprReader) * [EncoderDecoder](https://huggingface.co/docs/transformers/model_doc/encoder_decoder#transformers.EncoderDecoderModel) diff --git a/src/transformers/commands/add_new_model_like.py b/src/transformers/commands/add_new_model_like.py index 97f76c07e968ee..85e1722aae324d 100644 --- a/src/transformers/commands/add_new_model_like.py +++ b/src/transformers/commands/add_new_model_like.py @@ -766,7 +766,7 @@ def retrieve_info_for_model(model_type, frameworks: Optional[List[str]] = None): image_processor_class = image_processor_classes[0] # we take the slow image processor class. else: image_processor_class = image_processor_classes - + feature_extractor_class = auto_module.feature_extraction_auto.FEATURE_EXTRACTOR_MAPPING_NAMES.get(model_type, None) processor_class = auto_module.processing_auto.PROCESSOR_MAPPING_NAMES.get(model_type, None) diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index f7248b0f1e0d62..c8459efedf10e8 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -653,6 +653,7 @@ ), ("dinat", "DinatForImageClassification"), ("dinov2", "Dinov2ForImageClassification"), + ("dinov2_with_registers", "Dinov2WithRegistersForImageClassification"), ( "efficientformer", ( @@ -660,7 +661,6 @@ "EfficientFormerForImageClassificationWithTeacher", ), ), - ("dinov2_with_registers", "Dinov2WithRegistersForImageClassification"), ("efficientnet", "EfficientNetForImageClassification"), ("focalnet", "FocalNetForImageClassification"), ("hiera", "HieraForImageClassification"), diff --git a/src/transformers/models/dinov2_with_registers/__init__.py b/src/transformers/models/dinov2_with_registers/__init__.py index 59d0e0109c9a7c..e2260dda9c6106 100644 --- a/src/transformers/models/dinov2_with_registers/__init__.py +++ b/src/transformers/models/dinov2_with_registers/__init__.py @@ -54,4 +54,4 @@ else: import sys - sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) \ No newline at end of file + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py b/src/transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py index 0a8cedd5c879e0..8c17f335c3bb47 100644 --- a/src/transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py +++ b/src/transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py @@ -1,9 +1,9 @@ -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# This file was automatically generated from . -# Do NOT edit this file manually as any edits will be overwritten by the generation of -# the file from the diff. If any change should be done, please apply the change to the -# diff.py file directly. -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_dinov2_with_registers.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2024 Meta Inc. and the HuggingFace Inc. team. All rights reserved. # @@ -19,19 +19,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - -from transformers import PretrainedConfig - -from ...utils import ( - logging, -) +from ...configuration_utils import PretrainedConfig from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices -logger = logging.get_logger(__name__) - - class Dinov2WithRegistersConfig(BackboneConfigMixin, PretrainedConfig): r""" This is the configuration class to store the configuration of a [`Dinov2WithRegistersModel`]. It is used to instantiate an @@ -169,4 +160,4 @@ def __init__( out_features=out_features, out_indices=out_indices, stage_names=self.stage_names ) self.apply_layernorm = apply_layernorm - self.reshape_hidden_states = reshape_hidden_states \ No newline at end of file + self.reshape_hidden_states = reshape_hidden_states diff --git a/src/transformers/models/dinov2_with_registers/convert_dinov2_with_registers_to_hf.py b/src/transformers/models/dinov2_with_registers/convert_dinov2_with_registers_to_hf.py index 96753b9fa1ce32..526aef1725473c 100644 --- a/src/transformers/models/dinov2_with_registers/convert_dinov2_with_registers_to_hf.py +++ b/src/transformers/models/dinov2_with_registers/convert_dinov2_with_registers_to_hf.py @@ -288,4 +288,4 @@ def convert_dinov2_with_registers_checkpoint(model_name, pytorch_dump_folder_pat ) args = parser.parse_args() - convert_dinov2_with_registers_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub) \ No newline at end of file + convert_dinov2_with_registers_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub) diff --git a/src/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py b/src/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py index bd9a4d76f71d5c..7578ce1490d3a1 100644 --- a/src/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py +++ b/src/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py @@ -1,9 +1,9 @@ -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# This file was automatically generated from . -# Do NOT edit this file manually as any edits will be overwritten by the generation of -# the file from the diff. If any change should be done, please apply the change to the -# diff.py file directly. -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_dinov2_with_registers.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2024 Meta Inc. and the HuggingFace Inc. team. All rights reserved. # @@ -19,23 +19,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import collections.abc import math from typing import Dict, List, Optional, Set, Tuple, Union import torch -import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...modeling_outputs import ( - BackboneOutput, - BaseModelOutput, - BaseModelOutputWithPooling, - ImageClassifierOutput, -) +from ...modeling_outputs import BackboneOutput, BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput from ...modeling_utils import PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( @@ -51,6 +44,12 @@ logger = logging.get_logger(__name__) +# Base docstring +_CHECKPOINT_FOR_DOC = "facebook/dinov2_with_registers-base" + +# General docstring +_CONFIG_FOR_DOC = "Dinov2WithRegistersConfig" + class Dinov2WithRegistersPatchEmbeddings(nn.Module): """ @@ -85,14 +84,6 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: return embeddings -# General docstring -_CONFIG_FOR_DOC = "Dinov2WithRegistersConfig" - -# Base docstring -_CHECKPOINT_FOR_DOC = "facebook/dinov2_with_registers-base" -_EXPECTED_OUTPUT_SHAPE = [1, 257, 768] - - class Dinov2WithRegistersEmbeddings(nn.Module): """ Construct the CLS token, mask token, register tokens, position and patch embeddings. @@ -103,11 +94,7 @@ def __init__(self, config: Dinov2WithRegistersConfig) -> None: self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size)) self.mask_token = nn.Parameter(torch.zeros(1, config.hidden_size)) - self.register_tokens = ( - nn.Parameter(torch.zeros(1, config.num_register_tokens, config.hidden_size)) - if config.num_register_tokens - else None - ) + self.register_tokens = nn.Parameter(torch.zeros(1, config.num_register_tokens, config.hidden_size)) self.patch_embeddings = Dinov2WithRegistersPatchEmbeddings(config) num_patches = self.patch_embeddings.num_patches self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size)) @@ -144,7 +131,8 @@ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: mode="bicubic", align_corners=False, antialias=self.config.interpolate_antialias, - ).to(dtype=target_dtype) + ) + patch_pos_embed = patch_pos_embed.to(dtype=target_dtype) if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]: raise ValueError("Width or height does not match with the interpolated position embeddings") patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) @@ -167,22 +155,15 @@ def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Te # add positional encoding to each token embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) - if self.register_tokens is not None: - embeddings = torch.cat( - ( - embeddings[:, :1], - self.register_tokens.expand(embeddings.shape[0], -1, -1), - embeddings[:, 1:], - ), - dim=1, - ) + # add register tokens + embeddings = torch.cat((embeddings[:, :1], self.register_tokens.expand(embeddings.shape[0], -1, -1)), dim=1) + embeddings = torch.cat((embeddings, embeddings[:, 1:]), dim=1) embeddings = self.dropout(embeddings) return embeddings -# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->Dinov2WithRegisters class Dinov2WithRegistersSelfAttention(nn.Module): def __init__(self, config: Dinov2WithRegistersConfig) -> None: super().__init__() @@ -243,7 +224,47 @@ def forward( return outputs -# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Dinov2WithRegisters +class Dinov2WithRegistersSdpaSelfAttention(Dinov2WithRegistersSelfAttention): + def __init__(self, config: Dinov2WithRegistersConfig) -> None: + super().__init__(config) + self.attention_probs_dropout_prob = config.attention_probs_dropout_prob + + def forward( + self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "Dinov2WithRegistersModel is using Dinov2WithRegistersSdpaSelfAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, head_mask=head_mask, output_attentions=output_attentions + ) + + mixed_query_layer = self.query(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + head_mask, + self.attention_probs_dropout_prob if self.training else 0.0, + is_causal=False, + scale=None, + ) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + return context_layer, None + + class Dinov2WithRegistersSelfOutput(nn.Module): """ The residual connection is defined in Dinov2WithRegistersLayer instead of here (as is the case with other models), due to the @@ -262,7 +283,6 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to return hidden_states -# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->Dinov2WithRegisters class Dinov2WithRegistersAttention(nn.Module): def __init__(self, config: Dinov2WithRegistersConfig) -> None: super().__init__() @@ -302,6 +322,12 @@ def forward( return outputs +class Dinov2WithRegistersSdpaAttention(Dinov2WithRegistersAttention): + def __init__(self, config: Dinov2WithRegistersConfig) -> None: + super().__init__(config) + self.attention = Dinov2WithRegistersSdpaSelfAttention(config) + + class Dinov2WithRegistersLayerScale(nn.Module): def __init__(self, config) -> None: super().__init__() @@ -321,7 +347,6 @@ def __init__(self, drop_prob: Optional[float] = None) -> None: def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the @@ -380,6 +405,12 @@ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: return self.weights_out(hidden) +DINOV2_WITH_REGISTERS_ATTENTION_CLASSES = { + "eager": Dinov2WithRegistersAttention, + "sdpa": Dinov2WithRegistersSdpaAttention, +} + + class Dinov2WithRegistersLayer(nn.Module): """This corresponds to the Block class in the original implementation.""" @@ -387,7 +418,7 @@ def __init__(self, config: Dinov2WithRegistersConfig) -> None: super().__init__() self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.attention = Dinov2WithRegistersAttention(config) + self.attention = DINOV2_WITH_REGISTERS_ATTENTION_CLASSES[config._attn_implementation](config) self.layer_scale1 = Dinov2WithRegistersLayerScale(config) self.drop_path = ( Dinov2WithRegistersDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity() @@ -433,7 +464,6 @@ def forward( return outputs -# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->Dinov2WithRegisters class Dinov2WithRegistersEncoder(nn.Module): def __init__(self, config: Dinov2WithRegistersConfig) -> None: super().__init__() @@ -485,10 +515,6 @@ def forward( ) -# Image classification docstring -_IMAGE_CLASS_CHECKPOINT = "facebook/dinov2_with_registers-small-imagenet1k-1-layer" - - class Dinov2WithRegistersPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained @@ -500,6 +526,7 @@ class Dinov2WithRegistersPreTrainedModel(PreTrainedModel): main_input_name = "pixel_values" supports_gradient_checkpointing = True _no_split_modules = ["Dinov2WithRegistersSwiGLUFFN"] + _supports_sdpa = True def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" @@ -528,7 +555,7 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No ).to(module.cls_token.dtype) -_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat" +_EXPECTED_OUTPUT_SHAPE = [1, 257, 768] DINOV2_WITH_REGISTERS_START_DOCSTRING = r""" @@ -542,28 +569,6 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ -DINOV2_WITH_REGISTERS_INPUTS_DOCSTRING = r""" - Args: - pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See - [`BitImageProcessor.preprocess`] for details. - - head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - DINOV2_WITH_REGISTERS_BASE_INPUTS_DOCSTRING = r""" Args: pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): @@ -677,6 +682,33 @@ def forward( ) +# Image classification docstring +_IMAGE_CLASS_CHECKPOINT = "facebook/dinov2_with_registers-small-imagenet1k-1-layer" +_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat" + +DINOV2_WITH_REGISTERS_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See + [`BitImageProcessor.preprocess`] for details. + + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + @add_start_docstrings( """ Dinov2WithRegisters Model transformer with an image classification head on top (a linear layer on top of the final hidden state @@ -787,14 +819,14 @@ class Dinov2WithRegistersBackbone(Dinov2WithRegistersPreTrainedModel, BackboneMi def __init__(self, config): super().__init__(config) super()._init_backbone(config) - - self.num_register_tokens = config.num_register_tokens self.num_features = [config.hidden_size for _ in range(config.num_hidden_layers + 1)] self.embeddings = Dinov2WithRegistersEmbeddings(config) self.encoder = Dinov2WithRegistersEncoder(config) self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.num_register_tokens = config.num_register_tokens + # Initialize weights and apply final processing self.post_init() @@ -814,6 +846,10 @@ def forward( Returns: Examples: + Returns: + + Examples: + ```python >>> from transformers import AutoImageProcessor, AutoBackbone @@ -876,4 +912,4 @@ def forward( feature_maps=feature_maps, hidden_states=outputs.hidden_states if output_hidden_states else None, attentions=outputs.attentions if output_attentions else None, - ) \ No newline at end of file + ) diff --git a/src/transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py b/src/transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py index be59410a886f7e..12450ea348cf7a 100644 --- a/src/transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py +++ b/src/transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py @@ -20,15 +20,14 @@ import torch.utils.checkpoint from torch import nn -from transformers import PretrainedConfig -from transformers.models.dinov2.modeling_dinov2 import ( +from ....transformers.models.dinov2.modeling_dinov2 import ( Dinov2Backbone, Dinov2Encoder, Dinov2ForImageClassification, Dinov2Model, Dinov2PatchEmbeddings, ) - +from ...configuration_utils import PretrainedConfig from ...modeling_outputs import BackboneOutput from ...utils import logging from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices @@ -364,4 +363,4 @@ def forward( feature_maps=feature_maps, hidden_states=outputs.hidden_states if output_hidden_states else None, attentions=outputs.attentions if output_attentions else None, - ) \ No newline at end of file + ) diff --git a/tests/models/dinov2_with_registers/test_modeling_dinov2_with_registers.py b/tests/models/dinov2_with_registers/test_modeling_dinov2_with_registers.py index bf27ceac5fcfc2..9c42555af68a2b 100644 --- a/tests/models/dinov2_with_registers/test_modeling_dinov2_with_registers.py +++ b/tests/models/dinov2_with_registers/test_modeling_dinov2_with_registers.py @@ -359,4 +359,4 @@ class Dinov2WithRegistersBackboneTest(unittest.TestCase, BackboneTesterMixin): has_attentions = False def setUp(self): - self.model_tester = Dinov2WithRegistersModelTester(self) \ No newline at end of file + self.model_tester = Dinov2WithRegistersModelTester(self)