From 8f38f58f3de5a35f9b8505e9b48985dce5470985 Mon Sep 17 00:00:00 2001 From: bastrob <50299842+bastrob@users.noreply.github.com> Date: Sat, 21 Dec 2024 09:51:09 +0100 Subject: [PATCH] owlvit/2 dynamic input resolution (#34764) * owlvit/2 dynamic input resolution. * adapt box grid to patch_dim_h patch_dim_w * fix ci * clarify variable naming * clarify variable naming.. * compute box_bias dynamically inside box_predictor * change style part of code * [run-slow] owlvit, owlv2 --- .../models/owlv2/modeling_owlv2.py | 182 ++++++++++++++---- .../models/owlvit/modeling_owlvit.py | 180 +++++++++++++---- tests/models/owlv2/test_modeling_owlv2.py | 138 +++++++++++++ tests/models/owlvit/test_modeling_owlvit.py | 138 +++++++++++++ 4 files changed, 565 insertions(+), 73 deletions(-) diff --git a/src/transformers/models/owlv2/modeling_owlv2.py b/src/transformers/models/owlv2/modeling_owlv2.py index d773396010a3cb..7b631a77fcdda3 100644 --- a/src/transformers/models/owlv2/modeling_owlv2.py +++ b/src/transformers/models/owlv2/modeling_owlv2.py @@ -33,6 +33,7 @@ is_vision_available, logging, replace_return_docstrings, + torch_int, ) from .configuration_owlv2 import Owlv2Config, Owlv2TextConfig, Owlv2VisionConfig @@ -274,6 +275,7 @@ def to_tuple(self) -> Tuple[Any]: class Owlv2VisionEmbeddings(nn.Module): def __init__(self, config: Owlv2VisionConfig): super().__init__() + self.patch_size = config.patch_size self.config = config self.embed_dim = config.hidden_size self.class_embedding = nn.Parameter(torch.randn(config.hidden_size)) @@ -291,15 +293,59 @@ def __init__(self, config: Owlv2VisionConfig): self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) - def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: - batch_size = pixel_values.shape[0] + # Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings.interpolate_pos_encoding + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 + """ + + num_patches = embeddings.shape[1] - 1 + position_embedding = self.position_embedding.weight.unsqueeze(0) + num_positions = position_embedding.shape[1] - 1 + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: + return self.position_embedding(self.position_ids) + + class_pos_embed = position_embedding[:, :1] + patch_pos_embed = position_embedding[:, 1:] + + dim = embeddings.shape[-1] + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(new_height, new_width), + mode="bicubic", + align_corners=False, + ) + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) + + def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor: + batch_size, _, height, width = pixel_values.shape patch_embeds = self.patch_embedding(pixel_values) # shape = [batch_size, num_channels, height, width] patch_embeds = patch_embeds.flatten(2).transpose(1, 2) class_embeds = self.class_embedding.expand(batch_size, 1, -1) embeddings = torch.cat([class_embeds, patch_embeds], dim=1) - embeddings = embeddings + self.position_embedding(self.position_ids) - + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embedding(self.position_ids) return embeddings @@ -610,6 +656,8 @@ def _init_weights(self, module): output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. + interpolate_pos_encoding (`bool`, *optional*, defaults `False`): + Whether to interpolate the pre-trained position encodings. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @@ -635,6 +683,8 @@ def _init_weights(self, module): output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. + interpolate_pos_encoding (`bool`, *optional*, defaults `False`): + Whether to interpolate the pre-trained position encodings. return_base_image_embeds (`bool`, *optional*): Whether or not to return the base image embeddings. return_dict (`bool`, *optional*): @@ -657,6 +707,8 @@ def _init_weights(self, module): output_hidden_states (`bool`, *optional*): Whether or not to return the last hidden state. See `text_model_last_hidden_state` and `vision_model_last_hidden_state` under returned tensors for more detail. + interpolate_pos_encoding (`bool`, *optional*, defaults `False`): + Whether to interpolate the pre-trained position encodings. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @@ -673,6 +725,8 @@ def _init_weights(self, module): output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. + interpolate_pos_encoding (`bool`, *optional*, defaults `False`): + Whether to interpolate the pre-trained position encodings. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @@ -914,6 +968,7 @@ def forward( pixel_values: torch.FloatTensor, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = False, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPooling]: r""" @@ -929,7 +984,7 @@ def forward( expected_input_dtype = self.embeddings.patch_embedding.weight.dtype pixel_values = pixel_values.to(expected_input_dtype) - hidden_states = self.embeddings(pixel_values) + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) hidden_states = self.pre_layernorm(hidden_states) encoder_outputs = self.encoder( @@ -976,6 +1031,7 @@ def forward( pixel_values: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPooling]: r""" @@ -1002,6 +1058,7 @@ def forward( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, ) @@ -1084,6 +1141,7 @@ def get_image_features( pixel_values: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> torch.FloatTensor: r""" @@ -1115,6 +1173,7 @@ def get_image_features( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, ) @@ -1133,6 +1192,7 @@ def forward( return_loss: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_base_image_embeds: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, Owlv2Output]: @@ -1165,6 +1225,7 @@ def forward( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, ) @@ -1295,21 +1356,23 @@ def __init__(self, config: Owlv2Config): self.layer_norm = nn.LayerNorm(config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps) self.sigmoid = nn.Sigmoid() - - self.sqrt_num_patches = config.vision_config.image_size // config.vision_config.patch_size - self.box_bias = self.compute_box_bias(self.sqrt_num_patches) + self.config = config + self.num_patches_height = self.config.vision_config.image_size // self.config.vision_config.patch_size + self.num_patches_width = self.config.vision_config.image_size // self.config.vision_config.patch_size + self.box_bias = self.compute_box_bias(self.num_patches_height, self.num_patches_width) @staticmethod # Copied from transformers.models.owlvit.modeling_owlvit.OwlViTForObjectDetection.normalize_grid_corner_coordinates - def normalize_grid_corner_coordinates(num_patches: int) -> torch.Tensor: + def normalize_grid_corner_coordinates(num_patches_height: int, num_patches_width: int) -> torch.Tensor: # Create grid coordinates using torch - x_coordinates = torch.arange(1, num_patches + 1, dtype=torch.float32) - y_coordinates = torch.arange(1, num_patches + 1, dtype=torch.float32) + x_coordinates = torch.arange(1, num_patches_width + 1, dtype=torch.float32) + y_coordinates = torch.arange(1, num_patches_height + 1, dtype=torch.float32) xx, yy = torch.meshgrid(x_coordinates, y_coordinates, indexing="xy") - # Stack the coordinates and divide by num_patches + # Stack the coordinates and divide by their respective patch counts box_coordinates = torch.stack((xx, yy), dim=-1) - box_coordinates /= num_patches + box_coordinates[..., 0] /= num_patches_width + box_coordinates[..., 1] /= num_patches_height # Flatten (h, w, 2) -> (h*w, 2) box_coordinates = box_coordinates.view(-1, 2) @@ -1332,18 +1395,22 @@ def objectness_predictor(self, image_features: torch.FloatTensor) -> torch.Float @lru_cache(maxsize=2) # Copied from transformers.models.owlvit.modeling_owlvit.OwlViTForObjectDetection.compute_box_bias - def compute_box_bias(self, num_patches: int, feature_map: Optional[torch.FloatTensor] = None) -> torch.Tensor: + def compute_box_bias( + self, num_patches_height: int, num_patches_width: int, feature_map: Optional[torch.FloatTensor] = None + ) -> torch.Tensor: if feature_map is not None: raise ValueError("feature_map has been deprecated as an input. Please pass in num_patches instead") # The box center is biased to its position on the feature grid - box_coordinates = self.normalize_grid_corner_coordinates(num_patches) + box_coordinates = self.normalize_grid_corner_coordinates(num_patches_height, num_patches_width) box_coordinates = torch.clip(box_coordinates, 0.0, 1.0) # Unnormalize xy box_coord_bias = torch.log(box_coordinates + 1e-4) - torch.log1p(-box_coordinates + 1e-4) # The box size is biased to the patch size - box_size = torch.full_like(box_coord_bias, 1.0 / num_patches) + box_size = torch.full_like(box_coord_bias, 1.0) + box_size[..., 0] /= num_patches_width + box_size[..., 1] /= num_patches_height box_size_bias = torch.log(box_size + 1e-4) - torch.log1p(-box_size + 1e-4) # Compute box bias @@ -1355,6 +1422,7 @@ def box_predictor( self, image_feats: torch.FloatTensor, feature_map: torch.FloatTensor, + interpolate_pos_encoding: bool = False, ) -> torch.FloatTensor: """ Args: @@ -1362,6 +1430,8 @@ def box_predictor( Features extracted from the image, returned by the `image_text_embedder` method. feature_map: A spatial re-arrangement of image_features, also returned by the `image_text_embedder` method. + interpolate_pos_encoding: + Whether to interpolate the pre-trained position encodings. Returns: pred_boxes: List of predicted boxes (cxcywh normalized to 0, 1) nested within a dictionary. @@ -1370,7 +1440,13 @@ def box_predictor( pred_boxes = self.box_head(image_feats) # Compute the location of each token on the grid and use it to compute a bias for the bbox prediction - box_bias = self.box_bias.to(feature_map.device) + if interpolate_pos_encoding: + _, num_patches_height, num_patches_width, _ = feature_map.shape + box_bias = self.compute_box_bias(num_patches_height, num_patches_width) + else: + box_bias = self.box_bias + + box_bias = box_bias.to(feature_map.device) pred_boxes += box_bias pred_boxes = self.sigmoid(pred_boxes) return pred_boxes @@ -1403,6 +1479,7 @@ def image_text_embedder( attention_mask: torch.Tensor, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, ) -> Tuple[torch.FloatTensor]: # Encode text and image outputs = self.owlv2( @@ -1411,9 +1488,18 @@ def image_text_embedder( attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, return_dict=True, ) + if interpolate_pos_encoding: + _, _, height, width = pixel_values.shape + num_patches_height = height // self.config.vision_config.patch_size + num_patches_width = width // self.config.vision_config.patch_size + else: + num_patches_height = self.num_patches_height + num_patches_width = self.num_patches_width + # Get image embeddings last_hidden_state = outputs.vision_model_output[0] image_embeds = self.owlv2.vision_model.post_layernorm(last_hidden_state) @@ -1425,11 +1511,11 @@ def image_text_embedder( image_embeds = image_embeds[:, 1:, :] * class_token_out image_embeds = self.layer_norm(image_embeds) - # Resize to [batch_size, num_patches, num_patches, hidden_size] + # Resize to [batch_size, num_patches_height, num_patches_width, hidden_size] new_size = ( image_embeds.shape[0], - self.sqrt_num_patches, - self.sqrt_num_patches, + num_patches_height, + num_patches_width, image_embeds.shape[-1], ) image_embeds = image_embeds.reshape(new_size) @@ -1443,9 +1529,20 @@ def image_embedder( pixel_values: torch.FloatTensor, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, ) -> Tuple[torch.FloatTensor]: # Get Owlv2Model vision embeddings (same as CLIP) - vision_outputs = self.owlv2.vision_model(pixel_values=pixel_values, return_dict=True) + vision_outputs = self.owlv2.vision_model( + pixel_values=pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, return_dict=True + ) + + if interpolate_pos_encoding: + _, _, height, width = pixel_values.shape + num_patches_height = height // self.config.vision_config.patch_size + num_patches_width = width // self.config.vision_config.patch_size + else: + num_patches_height = self.num_patches_height + num_patches_width = self.num_patches_width # Apply post_layernorm to last_hidden_state, return non-projected output last_hidden_state = vision_outputs[0] @@ -1458,11 +1555,11 @@ def image_embedder( image_embeds = image_embeds[:, 1:, :] * class_token_out image_embeds = self.layer_norm(image_embeds) - # Resize to [batch_size, num_patches, num_patches, hidden_size] + # Resize to [batch_size, num_patches_height, num_patches_width, hidden_size] new_size = ( image_embeds.shape[0], - self.sqrt_num_patches, - self.sqrt_num_patches, + num_patches_height, + num_patches_width, image_embeds.shape[-1], ) image_embeds = image_embeds.reshape(new_size) @@ -1471,10 +1568,13 @@ def image_embedder( # Copied from transformers.models.owlvit.modeling_owlvit.OwlViTForObjectDetection.embed_image_query def embed_image_query( - self, query_image_features: torch.FloatTensor, query_feature_map: torch.FloatTensor + self, + query_image_features: torch.FloatTensor, + query_feature_map: torch.FloatTensor, + interpolate_pos_encoding: bool = False, ) -> torch.FloatTensor: _, class_embeds = self.class_predictor(query_image_features) - pred_boxes = self.box_predictor(query_image_features, query_feature_map) + pred_boxes = self.box_predictor(query_image_features, query_feature_map, interpolate_pos_encoding) pred_boxes_as_corners = center_to_corners_format(pred_boxes) # Loop over query images @@ -1519,6 +1619,7 @@ def image_guided_detection( query_pixel_values: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Owlv2ImageGuidedObjectDetectionOutput: r""" @@ -1576,26 +1677,33 @@ def image_guided_detection( return_dict = return_dict if return_dict is not None else self.config.return_dict # Compute feature maps for the input and query images - query_feature_map = self.image_embedder(pixel_values=query_pixel_values)[0] + query_feature_map = self.image_embedder( + pixel_values=query_pixel_values, interpolate_pos_encoding=interpolate_pos_encoding + )[0] feature_map, vision_outputs = self.image_embedder( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, ) - batch_size, num_patches, num_patches, hidden_dim = feature_map.shape - image_feats = torch.reshape(feature_map, (batch_size, num_patches * num_patches, hidden_dim)) + batch_size, num_patches_height, num_patches_width, hidden_dim = feature_map.shape + image_feats = torch.reshape(feature_map, (batch_size, num_patches_height * num_patches_width, hidden_dim)) - batch_size, num_patches, num_patches, hidden_dim = query_feature_map.shape - query_image_feats = torch.reshape(query_feature_map, (batch_size, num_patches * num_patches, hidden_dim)) + batch_size, num_patches_height, num_patches_width, hidden_dim = query_feature_map.shape + query_image_feats = torch.reshape( + query_feature_map, (batch_size, num_patches_height * num_patches_width, hidden_dim) + ) # Get top class embedding and best box index for each query image in batch - query_embeds, best_box_indices, query_pred_boxes = self.embed_image_query(query_image_feats, query_feature_map) + query_embeds, best_box_indices, query_pred_boxes = self.embed_image_query( + query_image_feats, query_feature_map, interpolate_pos_encoding + ) # Predict object classes [batch_size, num_patches, num_queries+1] (pred_logits, class_embeds) = self.class_predictor(image_feats=image_feats, query_embeds=query_embeds) # Predict object boxes - target_pred_boxes = self.box_predictor(image_feats, feature_map) + target_pred_boxes = self.box_predictor(image_feats, feature_map, interpolate_pos_encoding) if not return_dict: output = ( @@ -1630,6 +1738,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Owlv2ObjectDetectionOutput: r""" @@ -1683,14 +1792,15 @@ def forward( attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, ) # Text and vision model outputs text_outputs = outputs.text_model_output vision_outputs = outputs.vision_model_output - batch_size, num_patches, num_patches, hidden_dim = feature_map.shape - image_feats = torch.reshape(feature_map, (batch_size, num_patches * num_patches, hidden_dim)) + batch_size, num_patches_height, num_patches_width, hidden_dim = feature_map.shape + image_feats = torch.reshape(feature_map, (batch_size, num_patches_height * num_patches_width, hidden_dim)) # Reshape from [batch_size * max_text_queries, hidden_dim] -> [batch_size, max_text_queries, hidden_dim] max_text_queries = input_ids.shape[0] // batch_size @@ -1707,7 +1817,7 @@ def forward( objectness_logits = self.objectness_predictor(image_feats) # Predict object boxes - pred_boxes = self.box_predictor(image_feats, feature_map) + pred_boxes = self.box_predictor(image_feats, feature_map, interpolate_pos_encoding) if not return_dict: output = ( diff --git a/src/transformers/models/owlvit/modeling_owlvit.py b/src/transformers/models/owlvit/modeling_owlvit.py index 7c3e124a207ff7..570d154a554c03 100644 --- a/src/transformers/models/owlvit/modeling_owlvit.py +++ b/src/transformers/models/owlvit/modeling_owlvit.py @@ -33,6 +33,7 @@ is_vision_available, logging, replace_return_docstrings, + torch_int, ) from .configuration_owlvit import OwlViTConfig, OwlViTTextConfig, OwlViTVisionConfig @@ -268,6 +269,7 @@ def to_tuple(self) -> Tuple[Any]: class OwlViTVisionEmbeddings(nn.Module): def __init__(self, config: OwlViTVisionConfig): super().__init__() + self.patch_size = config.patch_size self.config = config self.embed_dim = config.hidden_size self.class_embedding = nn.Parameter(torch.randn(config.hidden_size)) @@ -285,15 +287,55 @@ def __init__(self, config: OwlViTVisionConfig): self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) - def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: - batch_size = pixel_values.shape[0] + # Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings.interpolate_pos_encoding + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 + """ + + num_patches = embeddings.shape[1] - 1 + position_embedding = self.position_embedding.weight.unsqueeze(0) + num_positions = position_embedding.shape[1] - 1 + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: + return self.position_embedding(self.position_ids) + + class_pos_embed = position_embedding[:, :1] + patch_pos_embed = position_embedding[:, 1:] + + dim = embeddings.shape[-1] + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(new_height, new_width), + mode="bicubic", + align_corners=False, + ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) + + def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor: + batch_size, _, height, width = pixel_values.shape patch_embeds = self.patch_embedding(pixel_values) # shape = [batch_size, num_channels, height, width] patch_embeds = patch_embeds.flatten(2).transpose(1, 2) - class_embeds = self.class_embedding.expand(batch_size, 1, -1) embeddings = torch.cat([class_embeds, patch_embeds], dim=1) - embeddings = embeddings + self.position_embedding(self.position_ids) - + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embedding(self.position_ids) return embeddings @@ -601,6 +643,8 @@ def _init_weights(self, module): output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. + interpolate_pos_encoding (`bool`, *optional*, defaults `False`): + Whether to interpolate the pre-trained position encodings. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @@ -626,6 +670,8 @@ def _init_weights(self, module): output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. + interpolate_pos_encoding (`bool`, *optional*, defaults `False`): + Whether to interpolate the pre-trained position encodings. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @@ -646,6 +692,8 @@ def _init_weights(self, module): output_hidden_states (`bool`, *optional*): Whether or not to return the last hidden state. See `text_model_last_hidden_state` and `vision_model_last_hidden_state` under returned tensors for more detail. + interpolate_pos_encoding (`bool`, *optional*, defaults `False`): + Whether to interpolate the pre-trained position encodings. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @@ -662,6 +710,8 @@ def _init_weights(self, module): output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. + interpolate_pos_encoding (`bool`, *optional*, defaults `False`): + Whether to interpolate the pre-trained position encodings. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @@ -899,6 +949,7 @@ def forward( pixel_values: torch.FloatTensor, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = False, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPooling]: r""" @@ -914,7 +965,7 @@ def forward( expected_input_dtype = self.embeddings.patch_embedding.weight.dtype pixel_values = pixel_values.to(expected_input_dtype) - hidden_states = self.embeddings(pixel_values) + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) hidden_states = self.pre_layernorm(hidden_states) encoder_outputs = self.encoder( @@ -960,6 +1011,7 @@ def forward( pixel_values: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPooling]: r""" @@ -986,6 +1038,7 @@ def forward( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, ) @@ -1067,6 +1120,7 @@ def get_image_features( pixel_values: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> torch.FloatTensor: r""" @@ -1098,6 +1152,7 @@ def get_image_features( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, ) @@ -1116,6 +1171,7 @@ def forward( return_loss: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_base_image_embeds: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, OwlViTOutput]: @@ -1148,6 +1204,7 @@ def forward( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, ) @@ -1275,20 +1332,22 @@ def __init__(self, config: OwlViTConfig): self.layer_norm = nn.LayerNorm(config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps) self.sigmoid = nn.Sigmoid() - - self.sqrt_num_patches = config.vision_config.image_size // config.vision_config.patch_size - self.box_bias = self.compute_box_bias(self.sqrt_num_patches) + self.config = config + self.num_patches_height = self.config.vision_config.image_size // self.config.vision_config.patch_size + self.num_patches_width = self.config.vision_config.image_size // self.config.vision_config.patch_size + self.box_bias = self.compute_box_bias(self.num_patches_height, self.num_patches_width) @staticmethod - def normalize_grid_corner_coordinates(num_patches: int) -> torch.Tensor: + def normalize_grid_corner_coordinates(num_patches_height: int, num_patches_width: int) -> torch.Tensor: # Create grid coordinates using torch - x_coordinates = torch.arange(1, num_patches + 1, dtype=torch.float32) - y_coordinates = torch.arange(1, num_patches + 1, dtype=torch.float32) + x_coordinates = torch.arange(1, num_patches_width + 1, dtype=torch.float32) + y_coordinates = torch.arange(1, num_patches_height + 1, dtype=torch.float32) xx, yy = torch.meshgrid(x_coordinates, y_coordinates, indexing="xy") - # Stack the coordinates and divide by num_patches + # Stack the coordinates and divide by their respective patch counts box_coordinates = torch.stack((xx, yy), dim=-1) - box_coordinates /= num_patches + box_coordinates[..., 0] /= num_patches_width + box_coordinates[..., 1] /= num_patches_height # Flatten (h, w, 2) -> (h*w, 2) box_coordinates = box_coordinates.view(-1, 2) @@ -1296,18 +1355,22 @@ def normalize_grid_corner_coordinates(num_patches: int) -> torch.Tensor: return box_coordinates @lru_cache(maxsize=2) - def compute_box_bias(self, num_patches: int, feature_map: Optional[torch.FloatTensor] = None) -> torch.Tensor: + def compute_box_bias( + self, num_patches_height: int, num_patches_width: int, feature_map: Optional[torch.FloatTensor] = None + ) -> torch.Tensor: if feature_map is not None: raise ValueError("feature_map has been deprecated as an input. Please pass in num_patches instead") # The box center is biased to its position on the feature grid - box_coordinates = self.normalize_grid_corner_coordinates(num_patches) + box_coordinates = self.normalize_grid_corner_coordinates(num_patches_height, num_patches_width) box_coordinates = torch.clip(box_coordinates, 0.0, 1.0) # Unnormalize xy box_coord_bias = torch.log(box_coordinates + 1e-4) - torch.log1p(-box_coordinates + 1e-4) # The box size is biased to the patch size - box_size = torch.full_like(box_coord_bias, 1.0 / num_patches) + box_size = torch.full_like(box_coord_bias, 1.0) + box_size[..., 0] /= num_patches_width + box_size[..., 1] /= num_patches_height box_size_bias = torch.log(box_size + 1e-4) - torch.log1p(-box_size + 1e-4) # Compute box bias @@ -1318,6 +1381,7 @@ def box_predictor( self, image_feats: torch.FloatTensor, feature_map: torch.FloatTensor, + interpolate_pos_encoding: bool = False, ) -> torch.FloatTensor: """ Args: @@ -1325,6 +1389,8 @@ def box_predictor( Features extracted from the image, returned by the `image_text_embedder` method. feature_map: A spatial re-arrangement of image_features, also returned by the `image_text_embedder` method. + interpolate_pos_encoding: + Whether to interpolate the pre-trained position encodings. Returns: pred_boxes: List of predicted boxes (cxcywh normalized to 0, 1) nested within a dictionary. @@ -1333,7 +1399,13 @@ def box_predictor( pred_boxes = self.box_head(image_feats) # Compute the location of each token on the grid and use it to compute a bias for the bbox prediction - box_bias = self.box_bias.to(feature_map.device) + if interpolate_pos_encoding: + _, num_patches_height, num_patches_width, _ = feature_map.shape + box_bias = self.compute_box_bias(num_patches_height, num_patches_width) + else: + box_bias = self.box_bias + + box_bias = box_bias.to(feature_map.device) pred_boxes += box_bias pred_boxes = self.sigmoid(pred_boxes) return pred_boxes @@ -1364,6 +1436,7 @@ def image_text_embedder( attention_mask: torch.Tensor, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, ) -> Tuple[torch.FloatTensor]: # Encode text and image outputs = self.owlvit( @@ -1372,9 +1445,18 @@ def image_text_embedder( attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, return_dict=True, ) + if interpolate_pos_encoding: + _, _, height, width = pixel_values.shape + num_patches_height = height // self.config.vision_config.patch_size + num_patches_width = width // self.config.vision_config.patch_size + else: + num_patches_height = self.num_patches_height + num_patches_width = self.num_patches_width + # Get image embeddings last_hidden_state = outputs.vision_model_output[0] image_embeds = self.owlvit.vision_model.post_layernorm(last_hidden_state) @@ -1386,11 +1468,11 @@ def image_text_embedder( image_embeds = image_embeds[:, 1:, :] * class_token_out image_embeds = self.layer_norm(image_embeds) - # Resize to [batch_size, num_patches, num_patches, hidden_size] + # Resize to [batch_size, num_patches_height, num_patches_width, hidden_size] new_size = ( image_embeds.shape[0], - self.sqrt_num_patches, - self.sqrt_num_patches, + num_patches_height, + num_patches_width, image_embeds.shape[-1], ) image_embeds = image_embeds.reshape(new_size) @@ -1403,9 +1485,20 @@ def image_embedder( pixel_values: torch.FloatTensor, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, ) -> Tuple[torch.FloatTensor]: # Get OwlViTModel vision embeddings (same as CLIP) - vision_outputs = self.owlvit.vision_model(pixel_values=pixel_values, return_dict=True) + vision_outputs = self.owlvit.vision_model( + pixel_values=pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, return_dict=True + ) + + if interpolate_pos_encoding: + _, _, height, width = pixel_values.shape + num_patches_height = height // self.config.vision_config.patch_size + num_patches_width = width // self.config.vision_config.patch_size + else: + num_patches_height = self.num_patches_height + num_patches_width = self.num_patches_width # Apply post_layernorm to last_hidden_state, return non-projected output last_hidden_state = vision_outputs[0] @@ -1418,11 +1511,11 @@ def image_embedder( image_embeds = image_embeds[:, 1:, :] * class_token_out image_embeds = self.layer_norm(image_embeds) - # Resize to [batch_size, num_patches, num_patches, hidden_size] + # Resize to [batch_size, num_patches_height, num_patches_width, hidden_size] new_size = ( image_embeds.shape[0], - self.sqrt_num_patches, - self.sqrt_num_patches, + num_patches_height, + num_patches_width, image_embeds.shape[-1], ) image_embeds = image_embeds.reshape(new_size) @@ -1430,10 +1523,13 @@ def image_embedder( return (image_embeds, vision_outputs) def embed_image_query( - self, query_image_features: torch.FloatTensor, query_feature_map: torch.FloatTensor + self, + query_image_features: torch.FloatTensor, + query_feature_map: torch.FloatTensor, + interpolate_pos_encoding: bool = False, ) -> torch.FloatTensor: _, class_embeds = self.class_predictor(query_image_features) - pred_boxes = self.box_predictor(query_image_features, query_feature_map) + pred_boxes = self.box_predictor(query_image_features, query_feature_map, interpolate_pos_encoding) pred_boxes_as_corners = center_to_corners_format(pred_boxes) # Loop over query images @@ -1478,6 +1574,7 @@ def image_guided_detection( query_pixel_values: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> OwlViTImageGuidedObjectDetectionOutput: r""" @@ -1520,26 +1617,33 @@ def image_guided_detection( return_dict = return_dict if return_dict is not None else self.config.return_dict # Compute feature maps for the input and query images - query_feature_map = self.image_embedder(pixel_values=query_pixel_values)[0] + query_feature_map = self.image_embedder( + pixel_values=query_pixel_values, interpolate_pos_encoding=interpolate_pos_encoding + )[0] feature_map, vision_outputs = self.image_embedder( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, ) - batch_size, num_patches, num_patches, hidden_dim = feature_map.shape - image_feats = torch.reshape(feature_map, (batch_size, num_patches * num_patches, hidden_dim)) + batch_size, num_patches_height, num_patches_width, hidden_dim = feature_map.shape + image_feats = torch.reshape(feature_map, (batch_size, num_patches_height * num_patches_width, hidden_dim)) - batch_size, num_patches, num_patches, hidden_dim = query_feature_map.shape - query_image_feats = torch.reshape(query_feature_map, (batch_size, num_patches * num_patches, hidden_dim)) + batch_size, num_patches_height, num_patches_width, hidden_dim = query_feature_map.shape + query_image_feats = torch.reshape( + query_feature_map, (batch_size, num_patches_height * num_patches_width, hidden_dim) + ) # Get top class embedding and best box index for each query image in batch - query_embeds, best_box_indices, query_pred_boxes = self.embed_image_query(query_image_feats, query_feature_map) + query_embeds, best_box_indices, query_pred_boxes = self.embed_image_query( + query_image_feats, query_feature_map, interpolate_pos_encoding + ) # Predict object classes [batch_size, num_patches, num_queries+1] (pred_logits, class_embeds) = self.class_predictor(image_feats=image_feats, query_embeds=query_embeds) # Predict object boxes - target_pred_boxes = self.box_predictor(image_feats, feature_map) + target_pred_boxes = self.box_predictor(image_feats, feature_map, interpolate_pos_encoding) if not return_dict: output = ( @@ -1574,6 +1678,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> OwlViTObjectDetectionOutput: r""" @@ -1625,14 +1730,15 @@ def forward( attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, ) # Text and vision model outputs text_outputs = outputs.text_model_output vision_outputs = outputs.vision_model_output - batch_size, num_patches, num_patches, hidden_dim = feature_map.shape - image_feats = torch.reshape(feature_map, (batch_size, num_patches * num_patches, hidden_dim)) + batch_size, num_patches_height, num_patches_width, hidden_dim = feature_map.shape + image_feats = torch.reshape(feature_map, (batch_size, num_patches_height * num_patches_width, hidden_dim)) # Reshape from [batch_size * max_text_queries, hidden_dim] -> [batch_size, max_text_queries, hidden_dim] max_text_queries = input_ids.shape[0] // batch_size @@ -1646,7 +1752,7 @@ def forward( (pred_logits, class_embeds) = self.class_predictor(image_feats, query_embeds, query_mask) # Predict object boxes - pred_boxes = self.box_predictor(image_feats, feature_map) + pred_boxes = self.box_predictor(image_feats, feature_map, interpolate_pos_encoding) if not return_dict: output = ( diff --git a/tests/models/owlv2/test_modeling_owlv2.py b/tests/models/owlv2/test_modeling_owlv2.py index df763aed48c749..b35f58e99a0402 100644 --- a/tests/models/owlv2/test_modeling_owlv2.py +++ b/tests/models/owlv2/test_modeling_owlv2.py @@ -828,6 +828,144 @@ def test_inference(self): expected_logits = torch.tensor([[-6.2229, -8.2601]], device=torch_device) self.assertTrue(torch.allclose(outputs.logits_per_image, expected_logits, atol=1e-3)) + @slow + def test_inference_interpolate_pos_encoding(self): + model_name = "google/owlv2-base-patch16" + model = Owlv2Model.from_pretrained(model_name).to(torch_device) + processor = OwlViTProcessor.from_pretrained(model_name) + processor.image_processor.size = {"height": 1024, "width": 1024} + + image = prepare_img() + inputs = processor( + text=[["a photo of a cat", "a photo of a dog"]], + images=image, + max_length=16, + padding="max_length", + return_tensors="pt", + ).to(torch_device) + + # forward pass + with torch.no_grad(): + outputs = model(**inputs, interpolate_pos_encoding=True) + + # verify the logits + self.assertEqual( + outputs.logits_per_image.shape, + torch.Size((inputs.pixel_values.shape[0], inputs.input_ids.shape[0])), + ) + self.assertEqual( + outputs.logits_per_text.shape, + torch.Size((inputs.input_ids.shape[0], inputs.pixel_values.shape[0])), + ) + expected_logits = torch.tensor([[-6.2520, -8.2970]], device=torch_device) + self.assertTrue(torch.allclose(outputs.logits_per_image, expected_logits, atol=1e-3)) + expected_shape = torch.Size((1, 4097, 768)) + self.assertEqual(outputs.vision_model_output.last_hidden_state.shape, expected_shape) + + # Owlv2ForObjectDetection part. + model = Owlv2ForObjectDetection.from_pretrained(model_name).to(torch_device) + processor.image_processor.size = {"height": 1024, "width": 1024} + + with torch.no_grad(): + outputs = model(**inputs, interpolate_pos_encoding=True) + + num_queries = int((inputs.pixel_values.shape[-1] / model.config.vision_config.patch_size) ** 2) + self.assertEqual(outputs.pred_boxes.shape, torch.Size((1, num_queries, 4))) + expected_slice_boxes = torch.tensor( + [[0.2407, 0.0553, 0.4636], [0.1082, 0.0494, 0.1861], [0.2459, 0.0527, 0.4398]] + ).to(torch_device) + self.assertTrue(torch.allclose(outputs.pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4)) + + model = Owlv2ForObjectDetection.from_pretrained(model_name).to(torch_device) + query_image = prepare_img() + inputs = processor( + images=image, + query_images=query_image, + max_length=16, + padding="max_length", + return_tensors="pt", + ).to(torch_device) + + with torch.no_grad(): + outputs = model.image_guided_detection(**inputs, interpolate_pos_encoding=True) + + # No need to check the logits, we just check inference runs fine. + num_queries = int((inputs.pixel_values.shape[-1] / model.config.vision_config.patch_size) ** 2) + self.assertEqual(outputs.target_pred_boxes.shape, torch.Size((1, num_queries, 4))) + + # Deactivate interpolate_pos_encoding on same model, and use default image size. + # Verify the dynamic change caused by the activation/deactivation of interpolate_pos_encoding of variables: self.sqrt_num_patches, self.box_bias from (OwlViTForObjectDetection). + processor = OwlViTProcessor.from_pretrained(model_name) + + image = prepare_img() + inputs = processor( + text=[["a photo of a cat", "a photo of a dog"]], + images=image, + max_length=16, + padding="max_length", + return_tensors="pt", + ).to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs, interpolate_pos_encoding=False) + + num_queries = int((inputs.pixel_values.shape[-1] // model.config.vision_config.patch_size) ** 2) + self.assertEqual(outputs.pred_boxes.shape, torch.Size((1, num_queries, 4))) + + expected_default_box_bias = torch.tensor( + [ + [-4.0717, -4.0717, -4.0717, -4.0717], + [-3.3644, -4.0717, -4.0717, -4.0717], + [-2.9425, -4.0717, -4.0717, -4.0717], + ] + ) + + self.assertTrue(torch.allclose(model.box_bias[:3, :4], expected_default_box_bias, atol=1e-4)) + + # Interpolate with any resolution size. + processor.image_processor.size = {"height": 1264, "width": 1024} + + image = prepare_img() + inputs = processor( + text=[["a photo of a cat", "a photo of a dog"]], + images=image, + max_length=16, + padding="max_length", + return_tensors="pt", + ).to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs, interpolate_pos_encoding=True) + + num_queries = int( + (inputs.pixel_values.shape[-2] // model.config.vision_config.patch_size) + * (inputs.pixel_values.shape[-1] // model.config.vision_config.patch_size) + ) + self.assertEqual(outputs.pred_boxes.shape, torch.Size((1, num_queries, 4))) + expected_slice_boxes = torch.tensor( + [[0.2438, 0.0945, 0.4675], [0.1361, 0.0431, 0.2406], [0.2465, 0.0428, 0.4429]] + ).to(torch_device) + self.assertTrue(torch.allclose(outputs.pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4)) + + query_image = prepare_img() + inputs = processor( + images=image, + query_images=query_image, + max_length=16, + padding="max_length", + return_tensors="pt", + ).to(torch_device) + + with torch.no_grad(): + outputs = model.image_guided_detection(**inputs, interpolate_pos_encoding=True) + + # No need to check the logits, we just check inference runs fine. + num_queries = int( + (inputs.pixel_values.shape[-2] // model.config.vision_config.patch_size) + * (inputs.pixel_values.shape[-1] // model.config.vision_config.patch_size) + ) + self.assertEqual(outputs.target_pred_boxes.shape, torch.Size((1, num_queries, 4))) + @slow def test_inference_object_detection(self): model_name = "google/owlv2-base-patch16" diff --git a/tests/models/owlvit/test_modeling_owlvit.py b/tests/models/owlvit/test_modeling_owlvit.py index e0599a50fb98b4..545fee0c4fe3af 100644 --- a/tests/models/owlvit/test_modeling_owlvit.py +++ b/tests/models/owlvit/test_modeling_owlvit.py @@ -821,6 +821,144 @@ def test_inference(self): expected_logits = torch.tensor([[3.4613, 0.9403]], device=torch_device) self.assertTrue(torch.allclose(outputs.logits_per_image, expected_logits, atol=1e-3)) + @slow + def test_inference_interpolate_pos_encoding(self): + model_name = "google/owlvit-base-patch32" + model = OwlViTModel.from_pretrained(model_name).to(torch_device) + processor = OwlViTProcessor.from_pretrained(model_name) + processor.image_processor.size = {"height": 800, "width": 800} + + image = prepare_img() + inputs = processor( + text=[["a photo of a cat", "a photo of a dog"]], + images=image, + max_length=16, + padding="max_length", + return_tensors="pt", + ).to(torch_device) + + # forward pass + with torch.no_grad(): + outputs = model(**inputs, interpolate_pos_encoding=True) + + # verify the logits + self.assertEqual( + outputs.logits_per_image.shape, + torch.Size((inputs.pixel_values.shape[0], inputs.input_ids.shape[0])), + ) + self.assertEqual( + outputs.logits_per_text.shape, + torch.Size((inputs.input_ids.shape[0], inputs.pixel_values.shape[0])), + ) + expected_logits = torch.tensor([[3.6278, 0.8861]], device=torch_device) + self.assertTrue(torch.allclose(outputs.logits_per_image, expected_logits, atol=1e-3)) + + expected_shape = torch.Size((1, 626, 768)) + self.assertEqual(outputs.vision_model_output.last_hidden_state.shape, expected_shape) + + # OwlViTForObjectDetection part. + model = OwlViTForObjectDetection.from_pretrained(model_name).to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs, interpolate_pos_encoding=True) + + num_queries = int((inputs.pixel_values.shape[-1] // model.config.vision_config.patch_size) ** 2) + self.assertEqual(outputs.pred_boxes.shape, torch.Size((1, num_queries, 4))) + + expected_slice_boxes = torch.tensor( + [[0.0680, 0.0422, 0.1347], [0.2071, 0.0450, 0.4146], [0.2000, 0.0418, 0.3476]] + ).to(torch_device) + self.assertTrue(torch.allclose(outputs.pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4)) + + model = OwlViTForObjectDetection.from_pretrained(model_name).to(torch_device) + query_image = prepare_img() + inputs = processor( + images=image, + query_images=query_image, + max_length=16, + padding="max_length", + return_tensors="pt", + ).to(torch_device) + + with torch.no_grad(): + outputs = model.image_guided_detection(**inputs, interpolate_pos_encoding=True) + + # No need to check the logits, we just check inference runs fine. + num_queries = int((inputs.pixel_values.shape[-1] / model.config.vision_config.patch_size) ** 2) + self.assertEqual(outputs.target_pred_boxes.shape, torch.Size((1, num_queries, 4))) + + # Deactivate interpolate_pos_encoding on same model, and use default image size. + # Verify the dynamic change caused by the activation/deactivation of interpolate_pos_encoding of variables: (self.sqrt_num_patch_h, self.sqrt_num_patch_w), self.box_bias from (OwlViTForObjectDetection). + processor = OwlViTProcessor.from_pretrained(model_name) + + image = prepare_img() + inputs = processor( + text=[["a photo of a cat", "a photo of a dog"]], + images=image, + max_length=16, + padding="max_length", + return_tensors="pt", + ).to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs, interpolate_pos_encoding=False) + + num_queries = int((inputs.pixel_values.shape[-1] // model.config.vision_config.patch_size) ** 2) + self.assertEqual(outputs.pred_boxes.shape, torch.Size((1, num_queries, 4))) + + expected_default_box_bias = torch.tensor( + [ + [-3.1332, -3.1332, -3.1332, -3.1332], + [-2.3968, -3.1332, -3.1332, -3.1332], + [-1.9452, -3.1332, -3.1332, -3.1332], + ] + ) + self.assertTrue(torch.allclose(model.box_bias[:3, :4], expected_default_box_bias, atol=1e-4)) + + # Interpolate with any resolution size. + processor.image_processor.size = {"height": 1264, "width": 1024} + + image = prepare_img() + inputs = processor( + text=[["a photo of a cat", "a photo of a dog"]], + images=image, + max_length=16, + padding="max_length", + return_tensors="pt", + ).to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs, interpolate_pos_encoding=True) + + num_queries = int( + (inputs.pixel_values.shape[-2] // model.config.vision_config.patch_size) + * (inputs.pixel_values.shape[-1] // model.config.vision_config.patch_size) + ) + self.assertEqual(outputs.pred_boxes.shape, torch.Size((1, num_queries, 4))) + expected_slice_boxes = torch.tensor( + [[0.0499, 0.0301, 0.0983], [0.2244, 0.0365, 0.4663], [0.1387, 0.0314, 0.1859]] + ).to(torch_device) + self.assertTrue(torch.allclose(outputs.pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4)) + + query_image = prepare_img() + inputs = processor( + images=image, + query_images=query_image, + max_length=16, + padding="max_length", + return_tensors="pt", + ).to(torch_device) + + with torch.no_grad(): + outputs = model.image_guided_detection(**inputs, interpolate_pos_encoding=True) + + # No need to check the logits, we just check inference runs fine. + num_queries = int( + (inputs.pixel_values.shape[-2] // model.config.vision_config.patch_size) + * (inputs.pixel_values.shape[-1] // model.config.vision_config.patch_size) + ) + self.assertEqual(outputs.target_pred_boxes.shape, torch.Size((1, num_queries, 4))) + @slow def test_inference_object_detection(self): model_name = "google/owlvit-base-patch32"