diff --git a/docs/source/en/index.md b/docs/source/en/index.md index abbbcfe7414d12..912bbad1d2d5ea 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -275,7 +275,7 @@ Flax), PyTorch, and/or TensorFlow. | [StableLm](model_doc/stablelm) | ✅ | ❌ | ❌ | | [Starcoder2](model_doc/starcoder2) | ✅ | ❌ | ❌ | | [SuperPoint](model_doc/superpoint) | ✅ | ❌ | ❌ | -| [SwiftFormer](model_doc/swiftformer) | ✅ | ❌ | ❌ | +| [SwiftFormer](model_doc/swiftformer) | ✅ | ✅ | ❌ | | [Swin Transformer](model_doc/swin) | ✅ | ✅ | ❌ | | [Swin Transformer V2](model_doc/swinv2) | ✅ | ❌ | ❌ | | [Swin2SR](model_doc/swin2sr) | ✅ | ❌ | ❌ | diff --git a/docs/source/en/model_doc/swiftformer.md b/docs/source/en/model_doc/swiftformer.md index 30c6941f0f46da..319c79fce4fbec 100644 --- a/docs/source/en/model_doc/swiftformer.md +++ b/docs/source/en/model_doc/swiftformer.md @@ -26,7 +26,7 @@ The abstract from the paper is the following: *Self-attention has become a defacto choice for capturing global context in various vision applications. However, its quadratic computational complexity with respect to image resolution limits its use in real-time applications, especially for deployment on resource-constrained mobile devices. Although hybrid approaches have been proposed to combine the advantages of convolutions and self-attention for a better speed-accuracy trade-off, the expensive matrix multiplication operations in self-attention remain a bottleneck. In this work, we introduce a novel efficient additive attention mechanism that effectively replaces the quadratic matrix multiplication operations with linear element-wise multiplications. Our design shows that the key-value interaction can be replaced with a linear layer without sacrificing any accuracy. Unlike previous state-of-the-art methods, our efficient formulation of self-attention enables its usage at all stages of the network. Using our proposed efficient additive attention, we build a series of models called "SwiftFormer" which achieves state-of-the-art performance in terms of both accuracy and mobile inference speed. Our small variant achieves 78.5% top-1 ImageNet-1K accuracy with only 0.8 ms latency on iPhone 14, which is more accurate and 2x faster compared to MobileViT-v2.* -This model was contributed by [shehan97](https://huggingface.co/shehan97). +This model was contributed by [shehan97](https://huggingface.co/shehan97). The TensorFlow version was contributed by [joaocmd](https://huggingface.co/joaocmd). The original code can be found [here](https://github.com/Amshaker/SwiftFormer). ## SwiftFormerConfig @@ -42,3 +42,13 @@ The original code can be found [here](https://github.com/Amshaker/SwiftFormer). [[autodoc]] SwiftFormerForImageClassification - forward + +## TFSwiftFormerModel + +[[autodoc]] TFSwiftFormerModel + - call + +## TFSwiftFormerForImageClassification + +[[autodoc]] TFSwiftFormerForImageClassification + - call diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index e81e718b2b26d8..c07e3d8f1b7f8f 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -4517,6 +4517,14 @@ "TFSpeech2TextPreTrainedModel", ] ) + _import_structure["models.swiftformer"].extend( + [ + "TF_SWIFTFORMER_PRETRAINED_MODEL_ARCHIVE_LIST", + "TFSwiftFormerForImageClassification", + "TFSwiftFormerModel", + "TFSwiftFormerPreTrainedModel", + ] + ) _import_structure["models.swin"].extend( [ "TF_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -8901,6 +8909,12 @@ TFSpeech2TextModel, TFSpeech2TextPreTrainedModel, ) + from .models.swiftformer import ( + TF_SWIFTFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, + TFSwiftFormerForImageClassification, + TFSwiftFormerModel, + TFSwiftFormerPreTrainedModel, + ) from .models.swin import ( TF_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST, TFSwinForImageClassification, diff --git a/src/transformers/models/auto/modeling_tf_auto.py b/src/transformers/models/auto/modeling_tf_auto.py index deed743162e477..a3df614b9b7922 100644 --- a/src/transformers/models/auto/modeling_tf_auto.py +++ b/src/transformers/models/auto/modeling_tf_auto.py @@ -81,6 +81,7 @@ ("sam", "TFSamModel"), ("segformer", "TFSegformerModel"), ("speech_to_text", "TFSpeech2TextModel"), + ("swiftformer", "TFSwiftFormerModel"), ("swin", "TFSwinModel"), ("t5", "TFT5Model"), ("tapas", "TFTapasModel"), @@ -213,6 +214,7 @@ ("regnet", "TFRegNetForImageClassification"), ("resnet", "TFResNetForImageClassification"), ("segformer", "TFSegformerForImageClassification"), + ("swiftformer", "TFSwiftFormerForImageClassification"), ("swin", "TFSwinForImageClassification"), ("vit", "TFViTForImageClassification"), ] diff --git a/src/transformers/models/swiftformer/__init__.py b/src/transformers/models/swiftformer/__init__.py index ddba2b806fd168..b324ea174d551b 100644 --- a/src/transformers/models/swiftformer/__init__.py +++ b/src/transformers/models/swiftformer/__init__.py @@ -16,6 +16,7 @@ from ...utils import ( OptionalDependencyNotAvailable, _LazyModule, + is_tf_available, is_torch_available, ) @@ -41,6 +42,19 @@ "SwiftFormerPreTrainedModel", ] +try: + if not is_tf_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_tf_swiftformer"] = [ + "TF_SWIFTFORMER_PRETRAINED_MODEL_ARCHIVE_LIST", + "TFSwiftFormerForImageClassification", + "TFSwiftFormerModel", + "TFSwiftFormerPreTrainedModel", + ] + if TYPE_CHECKING: from .configuration_swiftformer import ( SWIFTFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, @@ -60,6 +74,18 @@ SwiftFormerModel, SwiftFormerPreTrainedModel, ) + try: + if not is_tf_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_tf_swiftformer import ( + TF_SWIFTFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, + TFSwiftFormerForImageClassification, + TFSwiftFormerModel, + TFSwiftFormerPreTrainedModel, + ) else: import sys diff --git a/src/transformers/models/swiftformer/configuration_swiftformer.py b/src/transformers/models/swiftformer/configuration_swiftformer.py index 3c7a9eebbd9101..3789c72d421fb3 100644 --- a/src/transformers/models/swiftformer/configuration_swiftformer.py +++ b/src/transformers/models/swiftformer/configuration_swiftformer.py @@ -42,6 +42,8 @@ class SwiftFormerConfig(PretrainedConfig): Args: + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image num_channels (`int`, *optional*, defaults to 3): The number of input channels depths (`List[int]`, *optional*, defaults to `[3, 3, 6, 4]`): @@ -62,6 +64,10 @@ class SwiftFormerConfig(PretrainedConfig): Padding in downsampling layers. drop_path_rate (`float`, *optional*, defaults to 0.0): Rate at which to increase dropout probability in DropPath. + drop_mlp_rate (`float`, *optional*, defaults to 0.0): + Dropout rate for the MLP component of SwiftFormer. + drop_conv_encoder_rate (`float`, *optional*, defaults to 0.0): + Dropout rate for the ConvEncoder component of SwiftFormer. use_layer_scale (`bool`, *optional*, defaults to `True`): Whether to scale outputs from token mixers. layer_scale_init_value (`float`, *optional*, defaults to 1e-05): @@ -89,6 +95,7 @@ class SwiftFormerConfig(PretrainedConfig): def __init__( self, + image_size=224, num_channels=3, depths=[3, 3, 6, 4], embed_dims=[48, 56, 112, 220], @@ -99,12 +106,15 @@ def __init__( down_stride=2, down_pad=1, drop_path_rate=0.0, + drop_mlp_rate=0.0, + drop_conv_encoder_rate=0.0, use_layer_scale=True, layer_scale_init_value=1e-5, batch_norm_eps=1e-5, **kwargs, ): super().__init__(**kwargs) + self.image_size = image_size self.num_channels = num_channels self.depths = depths self.embed_dims = embed_dims @@ -115,6 +125,8 @@ def __init__( self.down_stride = down_stride self.down_pad = down_pad self.drop_path_rate = drop_path_rate + self.drop_mlp_rate = drop_mlp_rate + self.drop_conv_encoder_rate = drop_conv_encoder_rate self.use_layer_scale = use_layer_scale self.layer_scale_init_value = layer_scale_init_value self.batch_norm_eps = batch_norm_eps diff --git a/src/transformers/models/swiftformer/modeling_swiftformer.py b/src/transformers/models/swiftformer/modeling_swiftformer.py index 0455a31641db37..970874423a3e3c 100644 --- a/src/transformers/models/swiftformer/modeling_swiftformer.py +++ b/src/transformers/models/swiftformer/modeling_swiftformer.py @@ -103,13 +103,12 @@ def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = Fals return output -# Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->Swiftformer class SwiftFormerDropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" - def __init__(self, drop_prob: Optional[float] = None) -> None: + def __init__(self, config: SwiftFormerConfig) -> None: super().__init__() - self.drop_prob = drop_prob + self.drop_prob = config.drop_path_rate def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return drop_path(hidden_states, self.drop_prob, self.training) @@ -169,7 +168,7 @@ def __init__(self, config: SwiftFormerConfig, dim: int): self.point_wise_conv1 = nn.Conv2d(dim, hidden_dim, kernel_size=1) self.act = nn.GELU() self.point_wise_conv2 = nn.Conv2d(hidden_dim, dim, kernel_size=1) - self.drop_path = nn.Identity() + self.drop_path = nn.Dropout(p=config.drop_conv_encoder_rate) self.layer_scale = nn.Parameter(torch.ones(dim).unsqueeze(-1).unsqueeze(-1), requires_grad=True) def forward(self, x): @@ -200,7 +199,7 @@ def __init__(self, config: SwiftFormerConfig, in_features: int): act_layer = ACT2CLS[config.hidden_act] self.act = act_layer() self.fc2 = nn.Conv2d(hidden_features, in_features, 1) - self.drop = nn.Dropout(p=0.0) + self.drop = nn.Dropout(p=config.drop_mlp_rate) def forward(self, x): x = self.norm1(x) @@ -302,7 +301,7 @@ def __init__(self, config: SwiftFormerConfig, dim: int, drop_path: float = 0.0) self.local_representation = SwiftFormerLocalRepresentation(config, dim=dim) self.attn = SwiftFormerEfficientAdditiveAttention(config, dim=dim) self.linear = SwiftFormerMlp(config, in_features=dim) - self.drop_path = SwiftFormerDropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.drop_path = SwiftFormerDropPath(config) if drop_path > 0.0 else nn.Identity() self.use_layer_scale = use_layer_scale if use_layer_scale: self.layer_scale_1 = nn.Parameter( @@ -315,21 +314,13 @@ def __init__(self, config: SwiftFormerConfig, dim: int, drop_path: float = 0.0) def forward(self, x): x = self.local_representation(x) batch_size, channels, height, width = x.shape + res = self.attn(x.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels)) + res = res.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2) if self.use_layer_scale: - x = x + self.drop_path( - self.layer_scale_1 - * self.attn(x.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels)) - .reshape(batch_size, height, width, channels) - .permute(0, 3, 1, 2) - ) + x = x + self.drop_path(self.layer_scale_1 * res) x = x + self.drop_path(self.layer_scale_2 * self.linear(x)) - else: - x = x + self.drop_path( - self.attn(x.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels)) - .reshape(batch_size, height, width, channels) - .permute(0, 3, 1, 2) - ) + x = x + self.drop_path(res) x = x + self.drop_path(self.linear(x)) return x diff --git a/src/transformers/models/swiftformer/modeling_tf_swiftformer.py b/src/transformers/models/swiftformer/modeling_tf_swiftformer.py new file mode 100644 index 00000000000000..ce8bf2452559c9 --- /dev/null +++ b/src/transformers/models/swiftformer/modeling_tf_swiftformer.py @@ -0,0 +1,870 @@ +# coding=utf-8 +# Copyright 2024 MBZUAI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" TensorFlow SwiftFormer model.""" + + +import collections.abc +from typing import Optional, Tuple, Union + +import tensorflow as tf + +from ...activations_tf import get_tf_activation +from ...modeling_tf_outputs import ( + TFBaseModelOutputWithNoAttention, + TFImageClassifierOutputWithNoAttention, +) +from ...modeling_tf_utils import TFPreTrainedModel, keras, keras_serializable, unpack_inputs +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, +) +from .configuration_swiftformer import SwiftFormerConfig + + +logger = logging.get_logger(__name__) + +# General docstring +_CONFIG_FOR_DOC = "SwiftFormerConfig" + +# Base docstring +_CHECKPOINT_FOR_DOC = "MBZUAI/swiftformer-xs" +_EXPECTED_OUTPUT_SHAPE = [1, 220, 7, 7] + +# Image classification docstring +_IMAGE_CLASS_CHECKPOINT = "MBZUAI/swiftformer-xs" +_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat" + + +TF_SWIFTFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "MBZUAI/swiftformer-xs", + # See all SwiftFormer models at https://huggingface.co/models?filter=swiftformer +] + + +class TFSwiftFormerPatchEmbeddingSequential(keras.layers.Layer): + """ + The sequential component of the patch embedding layer. + + Input: tensor of shape `[batch_size, in_channels, height, width]` + + Output: tensor of shape `[batch_size, out_channels, height/4, width/4]` + """ + + def __init__(self, config: SwiftFormerConfig, **kwargs): + super().__init__(**kwargs) + self.out_chs = config.embed_dims[0] + + self.zero_padding = keras.layers.ZeroPadding2D(padding=(1, 1)) + self.conv1 = keras.layers.Conv2D(self.out_chs // 2, kernel_size=3, strides=2, name="0") + self.batch_norm1 = keras.layers.BatchNormalization(epsilon=config.batch_norm_eps, momentum=0.9, name="1") + self.conv2 = keras.layers.Conv2D(self.out_chs, kernel_size=3, strides=2, name="3") + self.batch_norm2 = keras.layers.BatchNormalization(epsilon=config.batch_norm_eps, momentum=0.9, name="4") + self.config = config + + def call(self, x: tf.Tensor, training: bool = False) -> tf.Tensor: + x = self.zero_padding(x) + x = self.conv1(x) + x = self.batch_norm1(x, training=training) + x = get_tf_activation("relu")(x) + x = self.zero_padding(x) + x = self.conv2(x) + x = self.batch_norm2(x, training=training) + x = get_tf_activation("relu")(x) + return x + + def build(self, input_shape=None): + if self.built: + return + if getattr(self, "conv1", None) is not None: + with tf.name_scope(self.conv1.name): + self.conv1.build(self.config.num_channels) + if getattr(self, "batch_norm1", None) is not None: + with tf.name_scope(self.batch_norm1.name): + self.batch_norm1.build((None, None, None, self.out_chs // 2)) + if getattr(self, "conv2", None) is not None: + with tf.name_scope(self.conv2.name): + self.conv2.build((None, None, None, self.out_chs // 2)) + if getattr(self, "batch_norm2", None) is not None: + with tf.name_scope(self.batch_norm2.name): + self.batch_norm2.build((None, None, None, self.out_chs)) + self.built = True + + +class TFSwiftFormerPatchEmbedding(keras.layers.Layer): + """ + Patch Embedding Layer constructed of two 2D convolutional layers. + + Input: tensor of shape `[batch_size, in_channels, height, width]` + + Output: tensor of shape `[batch_size, out_channels, height/4, width/4]` + """ + + def __init__(self, config: SwiftFormerConfig, **kwargs): + super().__init__(**kwargs) + self.patch_embedding = TFSwiftFormerPatchEmbeddingSequential(config, name="patch_embedding") + + def call(self, x: tf.Tensor, training: bool = False) -> tf.Tensor: + return self.patch_embedding(x, training=training) + + def build(self, input_shape=None): + if self.built: + return + if getattr(self, "patch_embedding", None) is not None: + with tf.name_scope(self.patch_embedding.name): + self.patch_embedding.build(None) + self.built = True + + +class TFSwiftFormerDropPath(keras.layers.Layer): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, config: SwiftFormerConfig, **kwargs) -> None: + super().__init__(**kwargs) + raise NotImplementedError("Drop path is not implemented in TF port") + + def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor: + raise NotImplementedError("Drop path is not implemented in TF port") + + +class TFSwiftFormerEmbeddings(keras.layers.Layer): + """ + Embeddings layer consisting of a single 2D convolutional and batch normalization layer. + + Input: tensor of shape `[batch_size, channels, height, width]` + + Output: tensor of shape `[batch_size, channels, height/stride, width/stride]` + """ + + def __init__(self, config: SwiftFormerConfig, index: int, **kwargs): + super().__init__(**kwargs) + + patch_size = config.down_patch_size + stride = config.down_stride + padding = config.down_pad + embed_dims = config.embed_dims + + self.in_chans = embed_dims[index] + self.embed_dim = embed_dims[index + 1] + + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + stride = stride if isinstance(stride, collections.abc.Iterable) else (stride, stride) + padding = padding if isinstance(padding, collections.abc.Iterable) else (padding, padding) + + self.pad = keras.layers.ZeroPadding2D(padding=padding) + self.proj = keras.layers.Conv2D(self.embed_dim, kernel_size=patch_size, strides=stride, name="proj") + self.norm = keras.layers.BatchNormalization(epsilon=config.batch_norm_eps, momentum=0.9, name="norm") + + def call(self, x: tf.Tensor, training: bool = False) -> tf.Tensor: + x = self.pad(x) + x = self.proj(x) + x = self.norm(x, training=training) + return x + + def build(self, input_shape=None): + if self.built: + return + if getattr(self, "proj", None) is not None: + with tf.name_scope(self.proj.name): + self.proj.build(self.in_chans) + if getattr(self, "norm", None) is not None: + with tf.name_scope(self.norm.name): + self.norm.build((None, None, None, self.embed_dim)) + self.built = True + + +class TFSwiftFormerConvEncoder(keras.layers.Layer): + """ + `SwiftFormerConvEncoder` with 3*3 and 1*1 convolutions. + + Input: tensor of shape `[batch_size, channels, height, width]` + + Output: tensor of shape `[batch_size, channels, height, width]` + """ + + def __init__(self, config: SwiftFormerConfig, dim: int, **kwargs): + super().__init__(**kwargs) + hidden_dim = int(config.mlp_ratio * dim) + + self.dim = dim + self.pad = keras.layers.ZeroPadding2D(padding=(1, 1)) + self.depth_wise_conv = keras.layers.Conv2D(dim, kernel_size=3, groups=dim, name="depth_wise_conv") + self.norm = keras.layers.BatchNormalization(epsilon=config.batch_norm_eps, momentum=0.9, name="norm") + self.point_wise_conv1 = keras.layers.Conv2D(hidden_dim, kernel_size=1, name="point_wise_conv1") + self.act = get_tf_activation("gelu") + self.point_wise_conv2 = keras.layers.Conv2D(dim, kernel_size=1, name="point_wise_conv2") + self.drop_path = keras.layers.Dropout(name="drop_path", rate=config.drop_conv_encoder_rate) + self.hidden_dim = int(config.mlp_ratio * self.dim) + + def build(self, input_shape=None): + if self.built: + return + self.layer_scale = self.add_weight( + name="layer_scale", + shape=self.dim, + initializer="ones", + trainable=True, + ) + + if getattr(self, "depth_wise_conv", None) is not None: + with tf.name_scope(self.depth_wise_conv.name): + self.depth_wise_conv.build(self.dim) + if getattr(self, "norm", None) is not None: + with tf.name_scope(self.norm.name): + self.norm.build((None, None, None, self.dim)) + if getattr(self, "point_wise_conv1", None) is not None: + with tf.name_scope(self.point_wise_conv1.name): + self.point_wise_conv1.build(self.dim) + if getattr(self, "point_wise_conv2", None) is not None: + with tf.name_scope(self.point_wise_conv2.name): + self.point_wise_conv2.build(self.hidden_dim) + if getattr(self, "drop_path", None) is not None: + with tf.name_scope(self.drop_path.name): + self.drop_path.build(None) + self.built = True + + def call(self, x: tf.Tensor, training: bool = False) -> tf.Tensor: + input = x + x = self.pad(x) + x = self.depth_wise_conv(x) + x = self.norm(x, training=training) + x = self.point_wise_conv1(x) + x = self.act(x) + x = self.point_wise_conv2(x) + x = input + self.drop_path(self.layer_scale * x) + return x + + +class TFSwiftFormerMlp(keras.layers.Layer): + """ + MLP layer with 1*1 convolutions. + + Input: tensor of shape `[batch_size, channels, height, width]` + + Output: tensor of shape `[batch_size, channels, height, width]` + """ + + def __init__(self, config: SwiftFormerConfig, in_features: int, **kwargs): + super().__init__(**kwargs) + + hidden_features = int(in_features * config.mlp_ratio) + self.norm1 = keras.layers.BatchNormalization(epsilon=config.batch_norm_eps, momentum=0.9, name="norm1") + self.fc1 = keras.layers.Conv2D(hidden_features, 1, name="fc1") + act_layer = get_tf_activation(config.hidden_act) + self.act = act_layer + self.fc2 = keras.layers.Conv2D(in_features, 1, name="fc2") + self.drop = keras.layers.Dropout(rate=config.drop_mlp_rate) + self.hidden_features = hidden_features + self.in_features = in_features + + def call(self, x: tf.Tensor, training: bool = False) -> tf.Tensor: + x = self.norm1(x, training=training) + x = self.fc1(x) + x = self.act(x) + x = self.drop(x, training=training) + x = self.fc2(x) + x = self.drop(x, training=training) + return x + + def build(self, input_shape=None): + if self.built: + return + if getattr(self, "norm1", None) is not None: + with tf.name_scope(self.norm1.name): + self.norm1.build((None, None, None, self.in_features)) + if getattr(self, "fc1", None) is not None: + with tf.name_scope(self.fc1.name): + self.fc1.build((None, None, None, self.in_features)) + if getattr(self, "fc2", None) is not None: + with tf.name_scope(self.fc2.name): + self.fc2.build((None, None, None, self.hidden_features)) + self.built = True + + +class TFSwiftFormerEfficientAdditiveAttention(keras.layers.Layer): + """ + Efficient Additive Attention module for SwiftFormer. + + Input: tensor of shape `[batch_size, channels, height, width]` + + Output: tensor of shape `[batch_size, channels, height, width]` + """ + + def __init__(self, config: SwiftFormerConfig, dim: int = 512, **kwargs): + super().__init__(**kwargs) + + self.dim = dim + + self.to_query = keras.layers.Dense(dim, name="to_query") + self.to_key = keras.layers.Dense(dim, name="to_key") + + self.scale_factor = dim**-0.5 + self.proj = keras.layers.Dense(dim, name="proj") + self.final = keras.layers.Dense(dim, name="final") + + def build(self, input_shape=None): + if self.built: + return + self.w_g = self.add_weight( + name="w_g", + shape=(self.dim, 1), + initializer=keras.initializers.RandomNormal(mean=0, stddev=1), + trainable=True, + ) + + if getattr(self, "to_query", None) is not None: + with tf.name_scope(self.to_query.name): + self.to_query.build(self.dim) + if getattr(self, "to_key", None) is not None: + with tf.name_scope(self.to_key.name): + self.to_key.build(self.dim) + if getattr(self, "proj", None) is not None: + with tf.name_scope(self.proj.name): + self.proj.build(self.dim) + if getattr(self, "final", None) is not None: + with tf.name_scope(self.final.name): + self.final.build(self.dim) + self.built = True + + def call(self, x: tf.Tensor) -> tf.Tensor: + query = self.to_query(x) + key = self.to_key(x) + + query = tf.math.l2_normalize(query, dim=-1) + key = tf.math.l2_normalize(key, dim=-1) + + query_weight = query @ self.w_g + scaled_query_weight = query_weight * self.scale_factor + scaled_query_weight = tf.nn.softmax(scaled_query_weight, axis=-1) + + global_queries = tf.math.reduce_sum(scaled_query_weight * query, axis=1) + global_queries = tf.tile(tf.expand_dims(global_queries, 1), (1, key.shape[1], 1)) + + out = self.proj(global_queries * key) + query + out = self.final(out) + + return out + + +class TFSwiftFormerLocalRepresentation(keras.layers.Layer): + """ + Local Representation module for SwiftFormer that is implemented by 3*3 depth-wise and point-wise convolutions. + + Input: tensor of shape `[batch_size, channels, height, width]` + + Output: tensor of shape `[batch_size, channels, height, width]` + """ + + def __init__(self, config: SwiftFormerConfig, dim: int, **kwargs): + super().__init__(**kwargs) + + self.dim = dim + + self.pad = keras.layers.ZeroPadding2D(padding=(1, 1)) + self.depth_wise_conv = keras.layers.Conv2D(dim, kernel_size=3, groups=dim, name="depth_wise_conv") + self.norm = keras.layers.BatchNormalization(epsilon=config.batch_norm_eps, momentum=0.9, name="norm") + self.point_wise_conv1 = keras.layers.Conv2D(dim, kernel_size=1, name="point_wise_conv1") + self.act = get_tf_activation("gelu") + self.point_wise_conv2 = keras.layers.Conv2D(dim, kernel_size=1, name="point_wise_conv2") + self.drop_path = keras.layers.Identity(name="drop_path") + + def build(self, input_shape=None): + if self.built: + return + self.layer_scale = self.add_weight( + name="layer_scale", + shape=(self.dim), + initializer="ones", + trainable=True, + ) + if getattr(self, "depth_wise_conv", None) is not None: + with tf.name_scope(self.depth_wise_conv.name): + self.depth_wise_conv.build((None, None, None, self.dim)) + if getattr(self, "norm", None) is not None: + with tf.name_scope(self.norm.name): + self.norm.build((None, None, None, self.dim)) + if getattr(self, "point_wise_conv1", None) is not None: + with tf.name_scope(self.point_wise_conv1.name): + self.point_wise_conv1.build(self.dim) + if getattr(self, "point_wise_conv2", None) is not None: + with tf.name_scope(self.point_wise_conv2.name): + self.point_wise_conv2.build(self.dim) + if getattr(self, "drop_path", None) is not None: + with tf.name_scope(self.drop_path.name): + self.drop_path.build(None) + self.built = True + + def call(self, x: tf.Tensor, training: bool = False) -> tf.Tensor: + input = x + x = self.pad(x) + x = self.depth_wise_conv(x) + x = self.norm(x, training=training) + x = self.point_wise_conv1(x) + x = self.act(x) + x = self.point_wise_conv2(x) + x = input + self.drop_path(self.layer_scale * x, training=training) + return x + + +class TFSwiftFormerEncoderBlock(keras.layers.Layer): + """ + SwiftFormer Encoder Block for SwiftFormer. It consists of (1) Local representation module, (2) + SwiftFormerEfficientAdditiveAttention, and (3) MLP block. + + Input: tensor of shape `[batch_size, channels, height, width]` + + Output: tensor of shape `[batch_size, channels,height, width]` + """ + + def __init__(self, config: SwiftFormerConfig, dim: int, drop_path: float = 0.0, **kwargs): + super().__init__(**kwargs) + + layer_scale_init_value = config.layer_scale_init_value + use_layer_scale = config.use_layer_scale + + self.local_representation = TFSwiftFormerLocalRepresentation(config, dim=dim, name="local_representation") + self.attn = TFSwiftFormerEfficientAdditiveAttention(config, dim=dim, name="attn") + self.linear = TFSwiftFormerMlp(config, in_features=dim, name="linear") + self.drop_path = TFSwiftFormerDropPath(config) if drop_path > 0.0 else keras.layers.Identity() + self.use_layer_scale = use_layer_scale + if use_layer_scale: + self.dim = dim + self.layer_scale_init_value = layer_scale_init_value + + def build(self, input_shape=None): + if self.built: + return + self.layer_scale_1 = self.add_weight( + name="layer_scale_1", + shape=self.dim, + initializer=keras.initializers.constant(self.layer_scale_init_value), + trainable=True, + ) + self.layer_scale_2 = self.add_weight( + name="layer_scale_2", + shape=self.dim, + initializer=keras.initializers.constant(self.layer_scale_init_value), + trainable=True, + ) + + if getattr(self, "local_representation", None) is not None: + with tf.name_scope(self.local_representation.name): + self.local_representation.build(None) + if getattr(self, "attn", None) is not None: + with tf.name_scope(self.attn.name): + self.attn.build(None) + if getattr(self, "linear", None) is not None: + with tf.name_scope(self.linear.name): + self.linear.build(None) + self.built = True + + def call(self, x: tf.Tensor, training: bool = False): + x = self.local_representation(x, training=training) + batch_size, height, width, channels = x.shape + + res = tf.reshape(x, [-1, height * width, channels]) + res = self.attn(res) + res = tf.reshape(res, [-1, height, width, channels]) + if self.use_layer_scale: + x = x + self.drop_path(self.layer_scale_1 * res, training=training) + x = x + self.drop_path(self.layer_scale_2 * self.linear(x), training=training) + else: + x = x + self.drop_path(res, training=training) + x = x + self.drop_path(self.linear(x), training=training) + return x + + +class TFSwiftFormerStage(keras.layers.Layer): + """ + A Swiftformer stage consisting of a series of `SwiftFormerConvEncoder` blocks and a final + `SwiftFormerEncoderBlock`. + + Input: tensor in shape `[batch_size, channels, height, width]` + + Output: tensor in shape `[batch_size, channels, height, width]` + """ + + def __init__(self, config: SwiftFormerConfig, index: int, **kwargs) -> None: + super().__init__(**kwargs) + + layer_depths = config.depths + dim = config.embed_dims[index] + depth = layer_depths[index] + + self.blocks = [] + for block_idx in range(depth): + block_dpr = config.drop_path_rate * (block_idx + sum(layer_depths[:index])) / (sum(layer_depths) - 1) + + if depth - block_idx <= 1: + self.blocks.append( + TFSwiftFormerEncoderBlock(config, dim=dim, drop_path=block_dpr, name=f"blocks_._{block_idx}") + ) + else: + self.blocks.append(TFSwiftFormerConvEncoder(config, dim=dim, name=f"blocks_._{block_idx}")) + + def call(self, input: tf.Tensor, training: bool = False) -> tf.Tensor: + for i, block in enumerate(self.blocks): + input = block(input, training=training) + return input + + def build(self, input_shape=None): + for layer in self.blocks: + with tf.name_scope(layer.name): + layer.build(None) + + +class TFSwiftFormerEncoder(keras.layers.Layer): + def __init__(self, config: SwiftFormerConfig, **kwargs) -> None: + super().__init__(**kwargs) + self.config = config + + embed_dims = config.embed_dims + downsamples = config.downsamples + layer_depths = config.depths + + # Transformer model + self.network = [] + name_i = 0 + for i in range(len(layer_depths)): + stage = TFSwiftFormerStage(config, index=i, name=f"network_._{name_i}") + self.network.append(stage) + name_i += 1 + if i >= len(layer_depths) - 1: + break + if downsamples[i] or embed_dims[i] != embed_dims[i + 1]: + # downsampling between two stages + self.network.append(TFSwiftFormerEmbeddings(config, index=i, name=f"network_._{name_i}")) + name_i += 1 + + self.gradient_checkpointing = False + + def call( + self, + hidden_states: tf.Tensor, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: bool = False, + ) -> Union[tuple, TFBaseModelOutputWithNoAttention]: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + all_hidden_states = (hidden_states,) if output_hidden_states else None + + for i, block in enumerate(self.network): + hidden_states = block(hidden_states, training=training) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + hidden_states = tf.transpose(hidden_states, perm=[0, 3, 1, 2]) + if all_hidden_states: + all_hidden_states = tuple(tf.transpose(s, perm=[0, 3, 1, 2]) for s in all_hidden_states) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states] if v is not None) + + return TFBaseModelOutputWithNoAttention( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + ) + + def build(self, input_shape=None): + for layer in self.network: + with tf.name_scope(layer.name): + layer.build(None) + + +class TFSwiftFormerPreTrainedModel(TFPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = SwiftFormerConfig + base_model_prefix = "swiftformer" + main_input_name = "pixel_values" + + +TFSWIFTFORMER_START_DOCSTRING = r""" + This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it + as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and + behavior. + + + + TF 2.0 models accepts two formats as inputs: + - having all inputs as keyword arguments (like PyTorch models), or + - having all inputs as a list, tuple or dict in the first positional arguments. + This second option is useful when using [`keras.Model.fit`] method which currently requires having all the + tensors in the first argument of the model call function: `model(inputs)`. + If you choose this second option, there are three possibilities you can use to gather all the input Tensors in the + first positional argument : + - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` + - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: + `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` + - a dictionary with one or several input Tensors associated to the input names given in the docstring: + `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` + + + + Parameters: + config ([`SwiftFormerConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +TFSWIFTFORMER_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`] + for details. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + training (`bool`, *optional*, defaults to `False`): + Whether or not to run the model in training mode. +""" + + +@keras_serializable +class TFSwiftFormerMainLayer(keras.layers.Layer): + config_class = SwiftFormerConfig + + def __init__(self, config: SwiftFormerConfig, **kwargs): + super().__init__(**kwargs) + self.config = config + + self.patch_embed = TFSwiftFormerPatchEmbedding(config, name="patch_embed") + self.encoder = TFSwiftFormerEncoder(config, name="encoder") + + @unpack_inputs + def call( + self, + pixel_values: Optional[tf.Tensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: bool = False, + ) -> Union[Tuple, TFBaseModelOutputWithNoAttention]: + r""" """ + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # TF 2.0 image layers can't use NCHW format when running on CPU. + # We transpose to NHWC format and then transpose back after the full forward pass. + # (batch_size, num_channels, height, width) -> (batch_size, height, width, num_channels) + pixel_values = tf.transpose(pixel_values, perm=[0, 2, 3, 1]) + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + embedding_output = self.patch_embed(pixel_values, training=training) + encoder_outputs = self.encoder( + embedding_output, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + if not return_dict: + return tuple(v for v in encoder_outputs if v is not None) + + return TFBaseModelOutputWithNoAttention( + last_hidden_state=encoder_outputs.last_hidden_state, + hidden_states=encoder_outputs.hidden_states, + ) + + def build(self, input_shape=None): + if self.built: + return + if getattr(self, "patch_embed", None) is not None: + with tf.name_scope(self.patch_embed.name): + self.patch_embed.build(None) + if getattr(self, "encoder", None) is not None: + with tf.name_scope(self.encoder.name): + self.encoder.build(None) + self.built = True + + +@add_start_docstrings( + "The bare TFSwiftFormer Model transformer outputting raw hidden-states without any specific head on top.", + TFSWIFTFORMER_START_DOCSTRING, +) +class TFSwiftFormerModel(TFSwiftFormerPreTrainedModel): + def __init__(self, config: SwiftFormerConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.swiftformer = TFSwiftFormerMainLayer(config, name="swiftformer") + + @unpack_inputs + @add_start_docstrings_to_model_forward(TFSWIFTFORMER_INPUTS_DOCSTRING) + def call( + self, + pixel_values: Optional[tf.Tensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: bool = False, + ) -> Union[TFBaseModelOutputWithNoAttention, Tuple[tf.Tensor]]: + outputs = self.swiftformer( + pixel_values=pixel_values, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + return outputs + + def build(self, input_shape=None): + if self.built: + return + if getattr(self, "swiftformer", None) is not None: + with tf.name_scope(self.swiftformer.name): + self.swiftformer.build(None) + self.built = True + + +@add_start_docstrings( + """ + TFSwiftFormer Model transformer with an image classification head on top (e.g. for ImageNet). + """, + TFSWIFTFORMER_START_DOCSTRING, +) +class TFSwiftFormerForImageClassification(TFSwiftFormerPreTrainedModel): + def __init__(self, config: SwiftFormerConfig, **kwargs) -> None: + super().__init__(config, **kwargs) + + self.num_labels = config.num_labels + self.swiftformer = TFSwiftFormerMainLayer(config, name="swiftformer") + + # Classifier head + self.norm = keras.layers.BatchNormalization(epsilon=config.batch_norm_eps, momentum=0.9, name="norm") + self.head = ( + keras.layers.Dense(self.num_labels, name="head") + if self.num_labels > 0 + else keras.layers.Identity(name="head") + ) + self.dist_head = ( + keras.layers.Dense(self.num_labels, name="dist_head") + if self.num_labels > 0 + else keras.layers.Identity(name="dist_head") + ) + + def hf_compute_loss(self, labels, logits): + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == tf.int64 or labels.dtype == tf.int32): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = keras.losses.MSE + if self.num_labels == 1: + loss = loss_fct(labels.squeeze(), logits.squeeze()) + else: + loss = loss_fct(labels, logits) + elif self.config.problem_type == "single_label_classification": + loss_fct = keras.losses.SparseCategoricalCrossentropy( + from_logits=True, reduction=keras.losses.Reduction.NONE + ) + loss = loss_fct(labels, logits) + elif self.config.problem_type == "multi_label_classification": + loss_fct = keras.losses.SparseCategoricalCrossentropy( + from_logits=True, + reduction=keras.losses.Reduction.NONE, + ) + loss = loss_fct(labels, logits) + else: + loss = None + + return loss + + @unpack_inputs + @add_start_docstrings_to_model_forward(TFSWIFTFORMER_INPUTS_DOCSTRING) + def call( + self, + pixel_values: Optional[tf.Tensor] = None, + labels: Optional[tf.Tensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: bool = False, + ) -> Union[tuple, TFImageClassifierOutputWithNoAttention]: + r""" + labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # run base model + outputs = self.swiftformer( + pixel_values, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + sequence_output = outputs.last_hidden_state if return_dict else outputs[0] + sequence_output = tf.transpose(sequence_output, perm=[0, 2, 3, 1]) + + # run classification head + sequence_output = self.norm(sequence_output, training=training) + sequence_output = tf.transpose(sequence_output, perm=[0, 3, 1, 2]) + _, num_channels, height, width = sequence_output.shape + sequence_output = tf.reshape(sequence_output, [-1, num_channels, height * width]) + sequence_output = tf.reduce_mean(sequence_output, axis=-1) + cls_out = self.head(sequence_output) + distillation_out = self.dist_head(sequence_output) + logits = (cls_out + distillation_out) / 2 + + # calculate loss + loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return TFImageClassifierOutputWithNoAttention( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + ) + + def build(self, input_shape=None): + if self.built: + return + if getattr(self, "swiftformer", None) is not None: + with tf.name_scope(self.swiftformer.name): + self.swiftformer.build(None) + if getattr(self, "norm", None) is not None: + with tf.name_scope(self.norm.name): + self.norm.build((None, None, None, self.config.embed_dims[-1])) + if getattr(self, "head", None) is not None: + with tf.name_scope(self.head.name): + self.head.build(self.config.embed_dims[-1]) + if getattr(self, "dist_head", None) is not None: + with tf.name_scope(self.dist_head.name): + self.dist_head.build(self.config.embed_dims[-1]) + self.built = True diff --git a/src/transformers/utils/dummy_tf_objects.py b/src/transformers/utils/dummy_tf_objects.py index 5441883b85a463..e6f75d1f8f0e72 100644 --- a/src/transformers/utils/dummy_tf_objects.py +++ b/src/transformers/utils/dummy_tf_objects.py @@ -2554,6 +2554,30 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["tf"]) +TF_SWIFTFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None + + +class TFSwiftFormerForImageClassification(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + +class TFSwiftFormerModel(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + +class TFSwiftFormerPreTrainedModel(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + TF_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST = None diff --git a/tests/models/swiftformer/test_modeling_tf_swiftformer.py b/tests/models/swiftformer/test_modeling_tf_swiftformer.py new file mode 100644 index 00000000000000..1d30abed31fda4 --- /dev/null +++ b/tests/models/swiftformer/test_modeling_tf_swiftformer.py @@ -0,0 +1,273 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Testing suite for the TensorFlow SwiftFormer model. """ + + +import inspect +import unittest + +from transformers import SwiftFormerConfig +from transformers.testing_utils import ( + require_tf, + require_vision, + slow, +) +from transformers.utils import cached_property, is_tf_available, is_vision_available + +from ...test_configuration_common import ConfigTester +from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_tf_available(): + import tensorflow as tf + + from transformers import TFSwiftFormerForImageClassification, TFSwiftFormerModel + from transformers.modeling_tf_utils import keras + from transformers.models.swiftformer.modeling_tf_swiftformer import TF_SWIFTFORMER_PRETRAINED_MODEL_ARCHIVE_LIST + + +if is_vision_available(): + from PIL import Image + + from transformers import ViTImageProcessor + + +class TFSwiftFormerModelTester: + def __init__( + self, + parent, + batch_size=1, + num_channels=3, + is_training=True, + use_labels=True, + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + image_size=224, + num_labels=2, + layer_depths=[3, 3, 6, 4], + embed_dims=[48, 56, 112, 220], + ): + self.parent = parent + self.batch_size = batch_size + self.num_channels = num_channels + self.is_training = is_training + self.use_labels = use_labels + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.num_labels = num_labels + self.image_size = image_size + self.layer_depths = layer_depths + self.embed_dims = embed_dims + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + + labels = None + if self.use_labels: + labels = ids_tensor([self.batch_size], self.num_labels) + + config = self.get_config() + + return config, pixel_values, labels + + def get_config(self): + return SwiftFormerConfig( + depths=self.layer_depths, + embed_dims=self.embed_dims, + mlp_ratio=4, + downsamples=[True, True, True, True], + hidden_act="gelu", + num_labels=self.num_labels, + down_patch_size=3, + down_stride=2, + down_pad=1, + drop_rate=0.0, + drop_path_rate=0.0, + use_layer_scale=True, + layer_scale_init_value=1e-5, + ) + + def create_and_check_model(self, config, pixel_values, labels): + model = TFSwiftFormerModel(config=config) + result = model(pixel_values) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.embed_dims[-1], 7, 7)) + + def create_and_check_for_image_classification(self, config, pixel_values, labels): + config.num_labels = self.num_labels + model = TFSwiftFormerForImageClassification(config) + result = model(pixel_values, labels=labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels)) + + model = TFSwiftFormerForImageClassification(config) + + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + result = model(pixel_values) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels)) + + def prepare_config_and_inputs_for_common(self): + (config, pixel_values, labels) = self.prepare_config_and_inputs() + inputs_dict = {"pixel_values": pixel_values} + return config, inputs_dict + + +@require_tf +class TFSwiftFormerModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase): + """ + Here we also overwrite some of the tests of test_modeling_common.py, as SwiftFormer does not use input_ids, inputs_embeds, + attention_mask and seq_length. + """ + + all_model_classes = (TFSwiftFormerModel, TFSwiftFormerForImageClassification) if is_tf_available() else () + + pipeline_model_mapping = ( + {"feature-extraction": TFSwiftFormerModel, "image-classification": TFSwiftFormerForImageClassification} + if is_tf_available() + else {} + ) + + fx_compatible = False + test_pruning = False + test_resize_embeddings = False + test_head_masking = False + has_attentions = False + test_onnx = False + + def setUp(self): + self.model_tester = TFSwiftFormerModelTester(self) + self.config_tester = ConfigTester( + self, + config_class=SwiftFormerConfig, + has_text_modality=False, + hidden_size=37, + num_attention_heads=12, + num_hidden_layers=12, + ) + + def test_config(self): + self.config_tester.run_common_tests() + + @unittest.skip(reason="TFSwiftFormer does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + def test_model_common_attributes(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + x = model.get_output_embeddings() + self.assertTrue(x is None or isinstance(x, keras.layers.Dense)) + + # Copied from transformers.tests.models.deit.test_modeling_tf_deit.py + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.call) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = ["pixel_values"] + self.assertListEqual(arg_names[:1], expected_arg_names) + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_for_image_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_image_classification(*config_and_inputs) + + @slow + def test_model_from_pretrained(self): + for model_name in TF_SWIFTFORMER_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: + model = TFSwiftFormerModel.from_pretrained(model_name) + self.assertIsNotNone(model) + + @unittest.skip(reason="TFSwiftFormer does not output attentions") + def test_attention_outputs(self): + pass + + def test_hidden_states_output(self): + def check_hidden_states_output(inputs_dict, config, model_class): + model = model_class(config) + + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + hidden_states = outputs.hidden_states + + expected_num_stages = 8 + self.assertEqual(len(hidden_states), expected_num_stages) + + # SwiftFormer's feature maps are of shape (batch_size, embed_dims, height, width) + # with the width and height being successively divided by 2, after every 2 blocks + for i in range(len(hidden_states)): + self.assertEqual( + hidden_states[i].shape, + tf.TensorShape( + [ + self.model_tester.batch_size, + self.model_tester.embed_dims[i // 2], + (self.model_tester.image_size // 4) // 2 ** (i // 2), + (self.model_tester.image_size // 4) // 2 ** (i // 2), + ] + ), + ) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + inputs_dict["output_hidden_states"] = True + check_hidden_states_output(inputs_dict, config, model_class) + + # check that output_hidden_states also work using config + del inputs_dict["output_hidden_states"] + config.output_hidden_states = True + + check_hidden_states_output(inputs_dict, config, model_class) + + +# We will verify our results on an image of cute cats +def prepare_img(): + image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") + return image + + +@require_tf +@require_vision +class TFSwiftFormerModelIntegrationTest(unittest.TestCase): + @cached_property + def default_image_processor(self): + return ViTImageProcessor.from_pretrained("MBZUAI/swiftformer-xs") if is_vision_available() else None + + @slow + def test_inference_image_classification_head(self): + model = TFSwiftFormerForImageClassification.from_pretrained("MBZUAI/swiftformer-xs") + + feature_extractor = self.default_feature_extractor + image = prepare_img() + inputs = feature_extractor(images=image, return_tensors="tf") + + # forward pass + outputs = model(**inputs) + + # verify the logits + expected_shape = tf.TensorShape((1, 1000)) + self.assertEqual(outputs.logits.shape, expected_shape) + + expected_slice = tf.constant([[-2.1703e00, 2.1107e00, -2.0811e00]]) + tf.debugging.assert_near(outputs.logits[0, :3], expected_slice, atol=1e-4) diff --git a/utils/check_docstrings.py b/utils/check_docstrings.py index a58d08eccaf305..04572d132b9dd1 100644 --- a/utils/check_docstrings.py +++ b/utils/check_docstrings.py @@ -697,6 +697,8 @@ "TFSegformerModel", "TFSpeech2TextForConditionalGeneration", "TFSpeech2TextModel", + "TFSwiftFormerForImageClassification", + "TFSwiftFormerModel", "TFSwinForImageClassification", "TFSwinForMaskedImageModeling", "TFSwinModel",