From 481a95781404e48b1c80940be17e8279dec82fe8 Mon Sep 17 00:00:00 2001
From: Abhiroop Tejomay
<38148843+the-neural-networker@users.noreply.github.com>
Date: Fri, 17 May 2024 13:38:46 -0400
Subject: [PATCH] Enable dynamic resolution input for Swin Transformer and
variants (#30656)
* add interpolation of positional encoding support to swin
* add style changes
* use default image processor and make size a dictionary
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* remove logits testing
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* Refactor image size validation logic when interpolation is disabled
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* remove asserts in modeling
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* add dynamic resolution input support to swinv2
* change size to ensure interpolation encoding path is triggered
* set interpolate_pos_encoding default value to False
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* set interpolate_pos_encoding default value to False
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* set interpolate_pos_encoding default value to False
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* set interpolate_pos_encoding default value to False
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* set interpolate_pos_encoding default value to False
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* set interpolate_pos_encoding default value to False
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* set interpolate_pos_encoding default value to False
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* set interpolate_pos_encoding default value to False
* add dynamic resolution input to donut swin
* add dynamic resolution input to maskformer swin
---------
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
---
.../models/donut/modeling_donut_swin.py | 63 ++++++++++++++--
.../maskformer/modeling_maskformer_swin.py | 58 ++++++++++++--
src/transformers/models/swin/modeling_swin.py | 75 +++++++++++++++++--
.../models/swinv2/modeling_swinv2.py | 75 +++++++++++++++++--
tests/models/swin/test_modeling_swin.py | 20 +++++
tests/models/swinv2/test_modeling_swinv2.py | 20 +++++
6 files changed, 291 insertions(+), 20 deletions(-)
diff --git a/src/transformers/models/donut/modeling_donut_swin.py b/src/transformers/models/donut/modeling_donut_swin.py
index e48e1ddfe14cb2..c1b27cd1806cd0 100644
--- a/src/transformers/models/donut/modeling_donut_swin.py
+++ b/src/transformers/models/donut/modeling_donut_swin.py
@@ -166,10 +166,48 @@ def __init__(self, config, use_mask_token=False):
self.norm = nn.LayerNorm(config.embed_dim)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
+ """
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
+ resolution images.
+
+ Source:
+ https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
+ """
+
+ num_patches = embeddings.shape[1] - 1
+ num_positions = self.position_embeddings.shape[1] - 1
+ if num_patches == num_positions and height == width:
+ return self.position_embeddings
+ class_pos_embed = self.position_embeddings[:, 0]
+ patch_pos_embed = self.position_embeddings[:, 1:]
+ dim = embeddings.shape[-1]
+ h0 = height // self.config.patch_size
+ w0 = width // self.config.patch_size
+ # we add a small number to avoid floating point error in the interpolation
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
+ h0, w0 = h0 + 0.1, w0 + 0.1
+ patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed,
+ scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
+ mode="bicubic",
+ align_corners=False,
+ )
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
+
def forward(
- self, pixel_values: Optional[torch.FloatTensor], bool_masked_pos: Optional[torch.BoolTensor] = None
+ self,
+ pixel_values: Optional[torch.FloatTensor],
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
+ interpolate_pos_encoding: bool = False,
) -> Tuple[torch.Tensor]:
- embeddings, output_dimensions = self.patch_embeddings(pixel_values)
+ _, num_channels, height, width = pixel_values.shape
+ embeddings, output_dimensions = self.patch_embeddings(
+ pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
+ )
embeddings = self.norm(embeddings)
batch_size, seq_len, _ = embeddings.size()
@@ -180,7 +218,10 @@ def forward(
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
if self.position_embeddings is not None:
- embeddings = embeddings + self.position_embeddings
+ if interpolate_pos_encoding:
+ embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
+ else:
+ embeddings = embeddings + self.position_embeddings
embeddings = self.dropout(embeddings)
@@ -219,7 +260,9 @@ def maybe_pad(self, pixel_values, height, width):
pixel_values = nn.functional.pad(pixel_values, pad_values)
return pixel_values
- def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]:
+ def forward(
+ self, pixel_values: Optional[torch.FloatTensor], interpolate_pos_encoding: bool = False
+ ) -> Tuple[torch.Tensor, Tuple[int]]:
_, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels:
raise ValueError(
@@ -227,6 +270,11 @@ def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tens
)
# pad the input to be divisible by self.patch_size, if needed
pixel_values = self.maybe_pad(pixel_values, height, width)
+ if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
+ raise ValueError(
+ f"Input image size ({height}*{width}) doesn't match model"
+ f" ({self.image_size[0]}*{self.image_size[1]})."
+ )
embeddings = self.projection(pixel_values)
_, _, height, width = embeddings.shape
output_dimensions = (height, width)
@@ -849,6 +897,8 @@ def _init_weights(self, module):
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
+ interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
+ Whether to interpolate the pre-trained position encodings.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
@@ -899,6 +949,7 @@ def forward(
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
+ interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None,
) -> Union[Tuple, DonutSwinModelOutput]:
r"""
@@ -921,7 +972,9 @@ def forward(
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, len(self.config.depths))
- embedding_output, input_dimensions = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
+ embedding_output, input_dimensions = self.embeddings(
+ pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
+ )
encoder_outputs = self.encoder(
embedding_output,
diff --git a/src/transformers/models/maskformer/modeling_maskformer_swin.py b/src/transformers/models/maskformer/modeling_maskformer_swin.py
index 1c358c88de4e7f..fc9c642adc8124 100644
--- a/src/transformers/models/maskformer/modeling_maskformer_swin.py
+++ b/src/transformers/models/maskformer/modeling_maskformer_swin.py
@@ -163,12 +163,50 @@ def __init__(self, config):
self.norm = nn.LayerNorm(config.embed_dim)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
- def forward(self, pixel_values):
- embeddings, output_dimensions = self.patch_embeddings(pixel_values)
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
+ """
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
+ resolution images.
+
+ Source:
+ https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
+ """
+
+ num_patches = embeddings.shape[1] - 1
+ num_positions = self.position_embeddings.shape[1] - 1
+ if num_patches == num_positions and height == width:
+ return self.position_embeddings
+ class_pos_embed = self.position_embeddings[:, 0]
+ patch_pos_embed = self.position_embeddings[:, 1:]
+ dim = embeddings.shape[-1]
+ h0 = height // self.config.patch_size
+ w0 = width // self.config.patch_size
+ # we add a small number to avoid floating point error in the interpolation
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
+ h0, w0 = h0 + 0.1, w0 + 0.1
+ patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed,
+ scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
+ mode="bicubic",
+ align_corners=False,
+ )
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
+
+ def forward(self, pixel_values, interpolate_pos_encoding):
+ _, num_channels, height, width = pixel_values.shape
+ embeddings, output_dimensions = self.patch_embeddings(
+ pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
+ )
embeddings = self.norm(embeddings)
if self.position_embeddings is not None:
- embeddings = embeddings + self.position_embeddings
+ if interpolate_pos_encoding:
+ embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
+ else:
+ embeddings = embeddings + self.position_embeddings
embeddings = self.dropout(embeddings)
@@ -207,7 +245,9 @@ def maybe_pad(self, pixel_values, height, width):
pixel_values = nn.functional.pad(pixel_values, pad_values)
return pixel_values
- def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]:
+ def forward(
+ self, pixel_values: Optional[torch.FloatTensor], interpolate_pos_encoding: bool = False
+ ) -> Tuple[torch.Tensor, Tuple[int]]:
_, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels:
raise ValueError(
@@ -215,6 +255,11 @@ def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tens
)
# pad the input to be divisible by self.patch_size, if needed
pixel_values = self.maybe_pad(pixel_values, height, width)
+ if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
+ raise ValueError(
+ f"Input image size ({height}*{width}) doesn't match model"
+ f" ({self.image_size[0]}*{self.image_size[1]})."
+ )
embeddings = self.projection(pixel_values)
_, _, height, width = embeddings.shape
output_dimensions = (height, width)
@@ -780,6 +825,7 @@ def forward(
head_mask=None,
output_attentions=None,
output_hidden_states=None,
+ interpolate_pos_encoding=False,
return_dict=None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@@ -798,7 +844,9 @@ def forward(
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, len(self.config.depths))
- embedding_output, input_dimensions = self.embeddings(pixel_values)
+ embedding_output, input_dimensions = self.embeddings(
+ pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
+ )
encoder_outputs = self.encoder(
embedding_output,
diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py
index cb0eff88abc26f..2a6363c8e69b7f 100644
--- a/src/transformers/models/swin/modeling_swin.py
+++ b/src/transformers/models/swin/modeling_swin.py
@@ -252,10 +252,48 @@ def __init__(self, config, use_mask_token=False):
self.norm = nn.LayerNorm(config.embed_dim)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
+ """
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
+ resolution images.
+
+ Source:
+ https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
+ """
+
+ num_patches = embeddings.shape[1] - 1
+ num_positions = self.position_embeddings.shape[1] - 1
+ if num_patches == num_positions and height == width:
+ return self.position_embeddings
+ class_pos_embed = self.position_embeddings[:, 0]
+ patch_pos_embed = self.position_embeddings[:, 1:]
+ dim = embeddings.shape[-1]
+ h0 = height // self.config.patch_size
+ w0 = width // self.config.patch_size
+ # we add a small number to avoid floating point error in the interpolation
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
+ h0, w0 = h0 + 0.1, w0 + 0.1
+ patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed,
+ scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
+ mode="bicubic",
+ align_corners=False,
+ )
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
+
def forward(
- self, pixel_values: Optional[torch.FloatTensor], bool_masked_pos: Optional[torch.BoolTensor] = None
+ self,
+ pixel_values: Optional[torch.FloatTensor],
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
+ interpolate_pos_encoding: bool = False,
) -> Tuple[torch.Tensor]:
- embeddings, output_dimensions = self.patch_embeddings(pixel_values)
+ _, num_channels, height, width = pixel_values.shape
+ embeddings, output_dimensions = self.patch_embeddings(
+ pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
+ )
embeddings = self.norm(embeddings)
batch_size, seq_len, _ = embeddings.size()
@@ -266,7 +304,10 @@ def forward(
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
if self.position_embeddings is not None:
- embeddings = embeddings + self.position_embeddings
+ if interpolate_pos_encoding:
+ embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
+ else:
+ embeddings = embeddings + self.position_embeddings
embeddings = self.dropout(embeddings)
@@ -304,7 +345,9 @@ def maybe_pad(self, pixel_values, height, width):
pixel_values = nn.functional.pad(pixel_values, pad_values)
return pixel_values
- def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]:
+ def forward(
+ self, pixel_values: Optional[torch.FloatTensor], interpolate_pos_encoding: bool = False
+ ) -> Tuple[torch.Tensor, Tuple[int]]:
_, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels:
raise ValueError(
@@ -312,6 +355,11 @@ def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tens
)
# pad the input to be divisible by self.patch_size, if needed
pixel_values = self.maybe_pad(pixel_values, height, width)
+ if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
+ raise ValueError(
+ f"Input image size ({height}*{width}) doesn't match model"
+ f" ({self.image_size[0]}*{self.image_size[1]})."
+ )
embeddings = self.projection(pixel_values)
_, _, height, width = embeddings.shape
output_dimensions = (height, width)
@@ -924,6 +972,8 @@ def _init_weights(self, module):
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
+ interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
+ Whether to interpolate the pre-trained position encodings.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
@@ -981,6 +1031,7 @@ def forward(
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
+ interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SwinModelOutput]:
r"""
@@ -1003,7 +1054,9 @@ def forward(
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, len(self.config.depths))
- embedding_output, input_dimensions = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
+ embedding_output, input_dimensions = self.embeddings(
+ pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
+ )
encoder_outputs = self.encoder(
embedding_output,
@@ -1074,6 +1127,7 @@ def forward(
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
+ interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SwinMaskedImageModelingOutput]:
r"""
@@ -1113,6 +1167,7 @@ def forward(
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
+ interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
)
@@ -1156,6 +1211,14 @@ def forward(
"""
Swin Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
the [CLS] token) e.g. for ImageNet.
+
+
+
+ Note that it's possible to fine-tune Swin on higher resolution images than the ones it has been trained on, by
+ setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
+ position embeddings to the higher resolution.
+
+
""",
SWIN_START_DOCSTRING,
)
@@ -1188,6 +1251,7 @@ def forward(
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
+ interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SwinImageClassifierOutput]:
r"""
@@ -1203,6 +1267,7 @@ def forward(
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
+ interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
)
diff --git a/src/transformers/models/swinv2/modeling_swinv2.py b/src/transformers/models/swinv2/modeling_swinv2.py
index 213d60a386dcc8..ac8ec197e599d1 100644
--- a/src/transformers/models/swinv2/modeling_swinv2.py
+++ b/src/transformers/models/swinv2/modeling_swinv2.py
@@ -295,10 +295,48 @@ def __init__(self, config, use_mask_token=False):
self.norm = nn.LayerNorm(config.embed_dim)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
+ """
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
+ resolution images.
+
+ Source:
+ https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
+ """
+
+ num_patches = embeddings.shape[1] - 1
+ num_positions = self.position_embeddings.shape[1] - 1
+ if num_patches == num_positions and height == width:
+ return self.position_embeddings
+ class_pos_embed = self.position_embeddings[:, 0]
+ patch_pos_embed = self.position_embeddings[:, 1:]
+ dim = embeddings.shape[-1]
+ h0 = height // self.config.patch_size
+ w0 = width // self.config.patch_size
+ # we add a small number to avoid floating point error in the interpolation
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
+ h0, w0 = h0 + 0.1, w0 + 0.1
+ patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed,
+ scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
+ mode="bicubic",
+ align_corners=False,
+ )
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
+
def forward(
- self, pixel_values: Optional[torch.FloatTensor], bool_masked_pos: Optional[torch.BoolTensor] = None
+ self,
+ pixel_values: Optional[torch.FloatTensor],
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
+ interpolate_pos_encoding: bool = False,
) -> Tuple[torch.Tensor]:
- embeddings, output_dimensions = self.patch_embeddings(pixel_values)
+ _, num_channels, height, width = pixel_values.shape
+ embeddings, output_dimensions = self.patch_embeddings(
+ pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
+ )
embeddings = self.norm(embeddings)
batch_size, seq_len, _ = embeddings.size()
@@ -309,7 +347,10 @@ def forward(
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
if self.position_embeddings is not None:
- embeddings = embeddings + self.position_embeddings
+ if interpolate_pos_encoding:
+ embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
+ else:
+ embeddings = embeddings + self.position_embeddings
embeddings = self.dropout(embeddings)
@@ -348,7 +389,9 @@ def maybe_pad(self, pixel_values, height, width):
pixel_values = nn.functional.pad(pixel_values, pad_values)
return pixel_values
- def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]:
+ def forward(
+ self, pixel_values: Optional[torch.FloatTensor], interpolate_pos_encoding: bool = False
+ ) -> Tuple[torch.Tensor, Tuple[int]]:
_, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels:
raise ValueError(
@@ -356,6 +399,11 @@ def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tens
)
# pad the input to be divisible by self.patch_size, if needed
pixel_values = self.maybe_pad(pixel_values, height, width)
+ if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
+ raise ValueError(
+ f"Input image size ({height}*{width}) doesn't match model"
+ f" ({self.image_size[0]}*{self.image_size[1]})."
+ )
embeddings = self.projection(pixel_values)
_, _, height, width = embeddings.shape
output_dimensions = (height, width)
@@ -979,6 +1027,8 @@ def _init_weights(self, module):
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
+ interpolate_pos_encoding (`bool`, *optional*, default `False`):
+ Whether to interpolate the pre-trained position encodings.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
@@ -1031,6 +1081,7 @@ def forward(
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
+ interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None,
) -> Union[Tuple, Swinv2ModelOutput]:
r"""
@@ -1053,7 +1104,9 @@ def forward(
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, len(self.config.depths))
- embedding_output, input_dimensions = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
+ embedding_output, input_dimensions = self.embeddings(
+ pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
+ )
encoder_outputs = self.encoder(
embedding_output,
@@ -1126,6 +1179,7 @@ def forward(
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
+ interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None,
) -> Union[Tuple, Swinv2MaskedImageModelingOutput]:
r"""
@@ -1165,6 +1219,7 @@ def forward(
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
+ interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
)
@@ -1208,6 +1263,14 @@ def forward(
"""
Swinv2 Model transformer with an image classification head on top (a linear layer on top of the final hidden state
of the [CLS] token) e.g. for ImageNet.
+
+
+
+ Note that it's possible to fine-tune SwinV2 on higher resolution images than the ones it has been trained on, by
+ setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
+ position embeddings to the higher resolution.
+
+
""",
SWINV2_START_DOCSTRING,
)
@@ -1241,6 +1304,7 @@ def forward(
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
+ interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None,
) -> Union[Tuple, Swinv2ImageClassifierOutput]:
r"""
@@ -1256,6 +1320,7 @@ def forward(
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
+ interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
)
diff --git a/tests/models/swin/test_modeling_swin.py b/tests/models/swin/test_modeling_swin.py
index 9220784e23029a..699171722d0db9 100644
--- a/tests/models/swin/test_modeling_swin.py
+++ b/tests/models/swin/test_modeling_swin.py
@@ -493,6 +493,26 @@ def test_inference_image_classification_head(self):
expected_slice = torch.tensor([-0.0948, -0.6454, -0.0921]).to(torch_device)
self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
+ @slow
+ def test_inference_interpolate_pos_encoding(self):
+ # Swin models have an `interpolate_pos_encoding` argument in their forward method,
+ # allowing to interpolate the pre-trained position embeddings in order to use
+ # the model on higher resolutions.
+ model = SwinModel.from_pretrained("microsoft/swin-tiny-patch4-window7-224").to(torch_device)
+
+ image_processor = self.default_image_processor
+ image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
+ inputs = image_processor(images=image, size={"height": 481, "width": 481}, return_tensors="pt")
+ pixel_values = inputs.pixel_values.to(torch_device)
+
+ # forward pass
+ with torch.no_grad():
+ outputs = model(pixel_values, interpolate_pos_encoding=True)
+
+ # verify the logits
+ expected_shape = torch.Size((1, 256, 768))
+ self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
+
@require_torch
class SwinBackboneTest(unittest.TestCase, BackboneTesterMixin):
diff --git a/tests/models/swinv2/test_modeling_swinv2.py b/tests/models/swinv2/test_modeling_swinv2.py
index b8f97ee7c23bc6..7a948d1282c1b6 100644
--- a/tests/models/swinv2/test_modeling_swinv2.py
+++ b/tests/models/swinv2/test_modeling_swinv2.py
@@ -485,6 +485,26 @@ def test_inference_image_classification_head(self):
expected_slice = torch.tensor([-0.3947, -0.4306, 0.0026]).to(torch_device)
self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
+ @slow
+ def test_inference_interpolate_pos_encoding(self):
+ # Swinv2 models have an `interpolate_pos_encoding` argument in their forward method,
+ # allowing to interpolate the pre-trained position embeddings in order to use
+ # the model on higher resolutions.
+ model = Swinv2Model.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256").to(torch_device)
+
+ image_processor = self.default_image_processor
+ image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
+ inputs = image_processor(images=image, size={"height": 481, "width": 481}, return_tensors="pt")
+ pixel_values = inputs.pixel_values.to(torch_device)
+
+ # forward pass
+ with torch.no_grad():
+ outputs = model(pixel_values, interpolate_pos_encoding=True)
+
+ # verify the logits
+ expected_shape = torch.Size((1, 256, 768))
+ self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
+
@require_torch
class Swinv2BackboneTest(unittest.TestCase, BackboneTesterMixin):