From 481a95781404e48b1c80940be17e8279dec82fe8 Mon Sep 17 00:00:00 2001 From: Abhiroop Tejomay <38148843+the-neural-networker@users.noreply.github.com> Date: Fri, 17 May 2024 13:38:46 -0400 Subject: [PATCH] Enable dynamic resolution input for Swin Transformer and variants (#30656) * add interpolation of positional encoding support to swin * add style changes * use default image processor and make size a dictionary Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * remove logits testing Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Refactor image size validation logic when interpolation is disabled Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * remove asserts in modeling Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * add dynamic resolution input support to swinv2 * change size to ensure interpolation encoding path is triggered * set interpolate_pos_encoding default value to False Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * set interpolate_pos_encoding default value to False Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * set interpolate_pos_encoding default value to False Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * set interpolate_pos_encoding default value to False Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * set interpolate_pos_encoding default value to False Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * set interpolate_pos_encoding default value to False Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * set interpolate_pos_encoding default value to False Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * set interpolate_pos_encoding default value to False * add dynamic resolution input to donut swin * add dynamic resolution input to maskformer swin --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- .../models/donut/modeling_donut_swin.py | 63 ++++++++++++++-- .../maskformer/modeling_maskformer_swin.py | 58 ++++++++++++-- src/transformers/models/swin/modeling_swin.py | 75 +++++++++++++++++-- .../models/swinv2/modeling_swinv2.py | 75 +++++++++++++++++-- tests/models/swin/test_modeling_swin.py | 20 +++++ tests/models/swinv2/test_modeling_swinv2.py | 20 +++++ 6 files changed, 291 insertions(+), 20 deletions(-) diff --git a/src/transformers/models/donut/modeling_donut_swin.py b/src/transformers/models/donut/modeling_donut_swin.py index e48e1ddfe14cb2..c1b27cd1806cd0 100644 --- a/src/transformers/models/donut/modeling_donut_swin.py +++ b/src/transformers/models/donut/modeling_donut_swin.py @@ -166,10 +166,48 @@ def __init__(self, config, use_mask_token=False): self.norm = nn.LayerNorm(config.embed_dim) self.dropout = nn.Dropout(config.hidden_dropout_prob) + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 + if num_patches == num_positions and height == width: + return self.position_embeddings + class_pos_embed = self.position_embeddings[:, 0] + patch_pos_embed = self.position_embeddings[:, 1:] + dim = embeddings.shape[-1] + h0 = height // self.config.patch_size + w0 = width // self.config.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + h0, w0 = h0 + 0.1, w0 + 0.1 + patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)), + mode="bicubic", + align_corners=False, + ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + def forward( - self, pixel_values: Optional[torch.FloatTensor], bool_masked_pos: Optional[torch.BoolTensor] = None + self, + pixel_values: Optional[torch.FloatTensor], + bool_masked_pos: Optional[torch.BoolTensor] = None, + interpolate_pos_encoding: bool = False, ) -> Tuple[torch.Tensor]: - embeddings, output_dimensions = self.patch_embeddings(pixel_values) + _, num_channels, height, width = pixel_values.shape + embeddings, output_dimensions = self.patch_embeddings( + pixel_values, interpolate_pos_encoding=interpolate_pos_encoding + ) embeddings = self.norm(embeddings) batch_size, seq_len, _ = embeddings.size() @@ -180,7 +218,10 @@ def forward( embeddings = embeddings * (1.0 - mask) + mask_tokens * mask if self.position_embeddings is not None: - embeddings = embeddings + self.position_embeddings + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embeddings embeddings = self.dropout(embeddings) @@ -219,7 +260,9 @@ def maybe_pad(self, pixel_values, height, width): pixel_values = nn.functional.pad(pixel_values, pad_values) return pixel_values - def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]: + def forward( + self, pixel_values: Optional[torch.FloatTensor], interpolate_pos_encoding: bool = False + ) -> Tuple[torch.Tensor, Tuple[int]]: _, num_channels, height, width = pixel_values.shape if num_channels != self.num_channels: raise ValueError( @@ -227,6 +270,11 @@ def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tens ) # pad the input to be divisible by self.patch_size, if needed pixel_values = self.maybe_pad(pixel_values, height, width) + if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]): + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model" + f" ({self.image_size[0]}*{self.image_size[1]})." + ) embeddings = self.projection(pixel_values) _, _, height, width = embeddings.shape output_dimensions = (height, width) @@ -849,6 +897,8 @@ def _init_weights(self, module): output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. + interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): + Whether to interpolate the pre-trained position encodings. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @@ -899,6 +949,7 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[Tuple, DonutSwinModelOutput]: r""" @@ -921,7 +972,9 @@ def forward( # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, len(self.config.depths)) - embedding_output, input_dimensions = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos) + embedding_output, input_dimensions = self.embeddings( + pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding + ) encoder_outputs = self.encoder( embedding_output, diff --git a/src/transformers/models/maskformer/modeling_maskformer_swin.py b/src/transformers/models/maskformer/modeling_maskformer_swin.py index 1c358c88de4e7f..fc9c642adc8124 100644 --- a/src/transformers/models/maskformer/modeling_maskformer_swin.py +++ b/src/transformers/models/maskformer/modeling_maskformer_swin.py @@ -163,12 +163,50 @@ def __init__(self, config): self.norm = nn.LayerNorm(config.embed_dim) self.dropout = nn.Dropout(config.hidden_dropout_prob) - def forward(self, pixel_values): - embeddings, output_dimensions = self.patch_embeddings(pixel_values) + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 + if num_patches == num_positions and height == width: + return self.position_embeddings + class_pos_embed = self.position_embeddings[:, 0] + patch_pos_embed = self.position_embeddings[:, 1:] + dim = embeddings.shape[-1] + h0 = height // self.config.patch_size + w0 = width // self.config.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + h0, w0 = h0 + 0.1, w0 + 0.1 + patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)), + mode="bicubic", + align_corners=False, + ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + def forward(self, pixel_values, interpolate_pos_encoding): + _, num_channels, height, width = pixel_values.shape + embeddings, output_dimensions = self.patch_embeddings( + pixel_values, interpolate_pos_encoding=interpolate_pos_encoding + ) embeddings = self.norm(embeddings) if self.position_embeddings is not None: - embeddings = embeddings + self.position_embeddings + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embeddings embeddings = self.dropout(embeddings) @@ -207,7 +245,9 @@ def maybe_pad(self, pixel_values, height, width): pixel_values = nn.functional.pad(pixel_values, pad_values) return pixel_values - def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]: + def forward( + self, pixel_values: Optional[torch.FloatTensor], interpolate_pos_encoding: bool = False + ) -> Tuple[torch.Tensor, Tuple[int]]: _, num_channels, height, width = pixel_values.shape if num_channels != self.num_channels: raise ValueError( @@ -215,6 +255,11 @@ def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tens ) # pad the input to be divisible by self.patch_size, if needed pixel_values = self.maybe_pad(pixel_values, height, width) + if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]): + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model" + f" ({self.image_size[0]}*{self.image_size[1]})." + ) embeddings = self.projection(pixel_values) _, _, height, width = embeddings.shape output_dimensions = (height, width) @@ -780,6 +825,7 @@ def forward( head_mask=None, output_attentions=None, output_hidden_states=None, + interpolate_pos_encoding=False, return_dict=None, ): output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -798,7 +844,9 @@ def forward( # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, len(self.config.depths)) - embedding_output, input_dimensions = self.embeddings(pixel_values) + embedding_output, input_dimensions = self.embeddings( + pixel_values, interpolate_pos_encoding=interpolate_pos_encoding + ) encoder_outputs = self.encoder( embedding_output, diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index cb0eff88abc26f..2a6363c8e69b7f 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -252,10 +252,48 @@ def __init__(self, config, use_mask_token=False): self.norm = nn.LayerNorm(config.embed_dim) self.dropout = nn.Dropout(config.hidden_dropout_prob) + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 + if num_patches == num_positions and height == width: + return self.position_embeddings + class_pos_embed = self.position_embeddings[:, 0] + patch_pos_embed = self.position_embeddings[:, 1:] + dim = embeddings.shape[-1] + h0 = height // self.config.patch_size + w0 = width // self.config.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + h0, w0 = h0 + 0.1, w0 + 0.1 + patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)), + mode="bicubic", + align_corners=False, + ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + def forward( - self, pixel_values: Optional[torch.FloatTensor], bool_masked_pos: Optional[torch.BoolTensor] = None + self, + pixel_values: Optional[torch.FloatTensor], + bool_masked_pos: Optional[torch.BoolTensor] = None, + interpolate_pos_encoding: bool = False, ) -> Tuple[torch.Tensor]: - embeddings, output_dimensions = self.patch_embeddings(pixel_values) + _, num_channels, height, width = pixel_values.shape + embeddings, output_dimensions = self.patch_embeddings( + pixel_values, interpolate_pos_encoding=interpolate_pos_encoding + ) embeddings = self.norm(embeddings) batch_size, seq_len, _ = embeddings.size() @@ -266,7 +304,10 @@ def forward( embeddings = embeddings * (1.0 - mask) + mask_tokens * mask if self.position_embeddings is not None: - embeddings = embeddings + self.position_embeddings + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embeddings embeddings = self.dropout(embeddings) @@ -304,7 +345,9 @@ def maybe_pad(self, pixel_values, height, width): pixel_values = nn.functional.pad(pixel_values, pad_values) return pixel_values - def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]: + def forward( + self, pixel_values: Optional[torch.FloatTensor], interpolate_pos_encoding: bool = False + ) -> Tuple[torch.Tensor, Tuple[int]]: _, num_channels, height, width = pixel_values.shape if num_channels != self.num_channels: raise ValueError( @@ -312,6 +355,11 @@ def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tens ) # pad the input to be divisible by self.patch_size, if needed pixel_values = self.maybe_pad(pixel_values, height, width) + if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]): + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model" + f" ({self.image_size[0]}*{self.image_size[1]})." + ) embeddings = self.projection(pixel_values) _, _, height, width = embeddings.shape output_dimensions = (height, width) @@ -924,6 +972,8 @@ def _init_weights(self, module): output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. + interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): + Whether to interpolate the pre-trained position encodings. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @@ -981,6 +1031,7 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[Tuple, SwinModelOutput]: r""" @@ -1003,7 +1054,9 @@ def forward( # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, len(self.config.depths)) - embedding_output, input_dimensions = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos) + embedding_output, input_dimensions = self.embeddings( + pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding + ) encoder_outputs = self.encoder( embedding_output, @@ -1074,6 +1127,7 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[Tuple, SwinMaskedImageModelingOutput]: r""" @@ -1113,6 +1167,7 @@ def forward( head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, ) @@ -1156,6 +1211,14 @@ def forward( """ Swin Model transformer with an image classification head on top (a linear layer on top of the final hidden state of the [CLS] token) e.g. for ImageNet. + + + + Note that it's possible to fine-tune Swin on higher resolution images than the ones it has been trained on, by + setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained + position embeddings to the higher resolution. + + """, SWIN_START_DOCSTRING, ) @@ -1188,6 +1251,7 @@ def forward( labels: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[Tuple, SwinImageClassifierOutput]: r""" @@ -1203,6 +1267,7 @@ def forward( head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, ) diff --git a/src/transformers/models/swinv2/modeling_swinv2.py b/src/transformers/models/swinv2/modeling_swinv2.py index 213d60a386dcc8..ac8ec197e599d1 100644 --- a/src/transformers/models/swinv2/modeling_swinv2.py +++ b/src/transformers/models/swinv2/modeling_swinv2.py @@ -295,10 +295,48 @@ def __init__(self, config, use_mask_token=False): self.norm = nn.LayerNorm(config.embed_dim) self.dropout = nn.Dropout(config.hidden_dropout_prob) + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 + if num_patches == num_positions and height == width: + return self.position_embeddings + class_pos_embed = self.position_embeddings[:, 0] + patch_pos_embed = self.position_embeddings[:, 1:] + dim = embeddings.shape[-1] + h0 = height // self.config.patch_size + w0 = width // self.config.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + h0, w0 = h0 + 0.1, w0 + 0.1 + patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)), + mode="bicubic", + align_corners=False, + ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + def forward( - self, pixel_values: Optional[torch.FloatTensor], bool_masked_pos: Optional[torch.BoolTensor] = None + self, + pixel_values: Optional[torch.FloatTensor], + bool_masked_pos: Optional[torch.BoolTensor] = None, + interpolate_pos_encoding: bool = False, ) -> Tuple[torch.Tensor]: - embeddings, output_dimensions = self.patch_embeddings(pixel_values) + _, num_channels, height, width = pixel_values.shape + embeddings, output_dimensions = self.patch_embeddings( + pixel_values, interpolate_pos_encoding=interpolate_pos_encoding + ) embeddings = self.norm(embeddings) batch_size, seq_len, _ = embeddings.size() @@ -309,7 +347,10 @@ def forward( embeddings = embeddings * (1.0 - mask) + mask_tokens * mask if self.position_embeddings is not None: - embeddings = embeddings + self.position_embeddings + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embeddings embeddings = self.dropout(embeddings) @@ -348,7 +389,9 @@ def maybe_pad(self, pixel_values, height, width): pixel_values = nn.functional.pad(pixel_values, pad_values) return pixel_values - def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]: + def forward( + self, pixel_values: Optional[torch.FloatTensor], interpolate_pos_encoding: bool = False + ) -> Tuple[torch.Tensor, Tuple[int]]: _, num_channels, height, width = pixel_values.shape if num_channels != self.num_channels: raise ValueError( @@ -356,6 +399,11 @@ def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tens ) # pad the input to be divisible by self.patch_size, if needed pixel_values = self.maybe_pad(pixel_values, height, width) + if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]): + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model" + f" ({self.image_size[0]}*{self.image_size[1]})." + ) embeddings = self.projection(pixel_values) _, _, height, width = embeddings.shape output_dimensions = (height, width) @@ -979,6 +1027,8 @@ def _init_weights(self, module): output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. + interpolate_pos_encoding (`bool`, *optional*, default `False`): + Whether to interpolate the pre-trained position encodings. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @@ -1031,6 +1081,7 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[Tuple, Swinv2ModelOutput]: r""" @@ -1053,7 +1104,9 @@ def forward( # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, len(self.config.depths)) - embedding_output, input_dimensions = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos) + embedding_output, input_dimensions = self.embeddings( + pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding + ) encoder_outputs = self.encoder( embedding_output, @@ -1126,6 +1179,7 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[Tuple, Swinv2MaskedImageModelingOutput]: r""" @@ -1165,6 +1219,7 @@ def forward( head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, ) @@ -1208,6 +1263,14 @@ def forward( """ Swinv2 Model transformer with an image classification head on top (a linear layer on top of the final hidden state of the [CLS] token) e.g. for ImageNet. + + + + Note that it's possible to fine-tune SwinV2 on higher resolution images than the ones it has been trained on, by + setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained + position embeddings to the higher resolution. + + """, SWINV2_START_DOCSTRING, ) @@ -1241,6 +1304,7 @@ def forward( labels: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[Tuple, Swinv2ImageClassifierOutput]: r""" @@ -1256,6 +1320,7 @@ def forward( head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, ) diff --git a/tests/models/swin/test_modeling_swin.py b/tests/models/swin/test_modeling_swin.py index 9220784e23029a..699171722d0db9 100644 --- a/tests/models/swin/test_modeling_swin.py +++ b/tests/models/swin/test_modeling_swin.py @@ -493,6 +493,26 @@ def test_inference_image_classification_head(self): expected_slice = torch.tensor([-0.0948, -0.6454, -0.0921]).to(torch_device) self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4)) + @slow + def test_inference_interpolate_pos_encoding(self): + # Swin models have an `interpolate_pos_encoding` argument in their forward method, + # allowing to interpolate the pre-trained position embeddings in order to use + # the model on higher resolutions. + model = SwinModel.from_pretrained("microsoft/swin-tiny-patch4-window7-224").to(torch_device) + + image_processor = self.default_image_processor + image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") + inputs = image_processor(images=image, size={"height": 481, "width": 481}, return_tensors="pt") + pixel_values = inputs.pixel_values.to(torch_device) + + # forward pass + with torch.no_grad(): + outputs = model(pixel_values, interpolate_pos_encoding=True) + + # verify the logits + expected_shape = torch.Size((1, 256, 768)) + self.assertEqual(outputs.last_hidden_state.shape, expected_shape) + @require_torch class SwinBackboneTest(unittest.TestCase, BackboneTesterMixin): diff --git a/tests/models/swinv2/test_modeling_swinv2.py b/tests/models/swinv2/test_modeling_swinv2.py index b8f97ee7c23bc6..7a948d1282c1b6 100644 --- a/tests/models/swinv2/test_modeling_swinv2.py +++ b/tests/models/swinv2/test_modeling_swinv2.py @@ -485,6 +485,26 @@ def test_inference_image_classification_head(self): expected_slice = torch.tensor([-0.3947, -0.4306, 0.0026]).to(torch_device) self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4)) + @slow + def test_inference_interpolate_pos_encoding(self): + # Swinv2 models have an `interpolate_pos_encoding` argument in their forward method, + # allowing to interpolate the pre-trained position embeddings in order to use + # the model on higher resolutions. + model = Swinv2Model.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256").to(torch_device) + + image_processor = self.default_image_processor + image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") + inputs = image_processor(images=image, size={"height": 481, "width": 481}, return_tensors="pt") + pixel_values = inputs.pixel_values.to(torch_device) + + # forward pass + with torch.no_grad(): + outputs = model(pixel_values, interpolate_pos_encoding=True) + + # verify the logits + expected_shape = torch.Size((1, 256, 768)) + self.assertEqual(outputs.last_hidden_state.shape, expected_shape) + @require_torch class Swinv2BackboneTest(unittest.TestCase, BackboneTesterMixin):