diff --git a/src/transformers/models/clipseg/modeling_clipseg.py b/src/transformers/models/clipseg/modeling_clipseg.py index 8ff7f1cd96a0d2..cf83f4dbce8967 100644 --- a/src/transformers/models/clipseg/modeling_clipseg.py +++ b/src/transformers/models/clipseg/modeling_clipseg.py @@ -33,7 +33,6 @@ add_start_docstrings_to_model_forward, logging, replace_return_docstrings, - torch_int, ) from .configuration_clipseg import CLIPSegConfig, CLIPSegTextConfig, CLIPSegVisionConfig @@ -164,62 +163,40 @@ def __init__(self, config: CLIPSegVisionConfig): self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) - def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: - """ - This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution - images. This method is also adapted to support torch.jit tracing. - - Adapted from: - - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and - - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 - """ - - num_patches = embeddings.shape[1] - 1 - position_embedding = self.position_embedding.weight.unsqueeze(0) - num_positions = position_embedding.shape[1] - 1 - - # always interpolate when tracing to ensure the exported model works for dynamic input shapes - if not torch.jit.is_tracing() and num_patches == num_positions and height == width: - return self.position_embedding(self.position_ids) - - class_pos_embed = position_embedding[:, :1] - patch_pos_embed = position_embedding[:, 1:] - - dim = embeddings.shape[-1] - - new_height = height // self.patch_size - new_width = width // self.patch_size + def interpolate_position_embeddings(self, new_size): + if len(new_size) != 2: + raise ValueError("new_size should consist of 2 values") - sqrt_num_positions = torch_int(num_positions**0.5) - patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) - patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) - - patch_pos_embed = nn.functional.interpolate( - patch_pos_embed, - size=(new_height, new_width), - mode="bicubic", - align_corners=False, + num_patches_one_direction = int(self.num_patches**0.5) + # we interpolate the position embeddings in 2D + a = self.position_embedding.weight[1:].T.view( + 1, self.config.hidden_size, num_patches_one_direction, num_patches_one_direction ) + b = ( + nn.functional.interpolate(a, new_size, mode="bicubic", align_corners=False) + .squeeze(0) + .view(self.config.hidden_size, new_size[0] * new_size[1]) + .T + ) + result = torch.cat([self.position_embedding.weight[:1], b]) - patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) - - return torch.cat((class_pos_embed, patch_pos_embed), dim=1) + return result - def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor: - batch_size, _, height, width = pixel_values.shape - if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size): - raise ValueError( - f"Input image size ({height}*{width}) doesn't match model" f" ({self.image_size}*{self.image_size})." - ) + def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: + batch_size = pixel_values.shape[0] patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid] patch_embeds = patch_embeds.flatten(2).transpose(1, 2) class_embeds = self.class_embedding.expand(batch_size, 1, -1) embeddings = torch.cat([class_embeds, patch_embeds], dim=1) - if interpolate_pos_encoding: - embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + + if embeddings.shape[1] != self.num_positions: + new_shape = int(math.sqrt(embeddings.shape[1] - 1)) + embeddings = embeddings + self.interpolate_position_embeddings((new_shape, new_shape)) + embeddings = embeddings.to(embeddings.dtype) else: embeddings = embeddings + self.position_embedding(self.position_ids) + return embeddings @@ -535,8 +512,6 @@ def _init_weights(self, module): output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. - interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): - Whether to interpolate the pre-trained position encodings. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @@ -574,8 +549,6 @@ def _init_weights(self, module): output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. - interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): - Whether to interpolate the pre-trained position encodings. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @@ -845,14 +818,12 @@ def __init__(self, config: CLIPSegVisionConfig): @add_start_docstrings_to_model_forward(CLIPSEG_VISION_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPSegVisionConfig) - # Copied from transformers.models.clip.modeling_clip.CLIPVisionTransformer.forward def forward( self, pixel_values: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - interpolate_pos_encoding: Optional[bool] = False, ) -> Union[Tuple, BaseModelOutputWithPooling]: r""" Returns: @@ -867,7 +838,7 @@ def forward( if pixel_values is None: raise ValueError("You have to specify pixel_values") - hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) + hidden_states = self.embeddings(pixel_values) hidden_states = self.pre_layrnorm(hidden_states) encoder_outputs = self.encoder( @@ -912,7 +883,6 @@ def forward( pixel_values: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - interpolate_pos_encoding: Optional[bool] = False, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPooling]: r""" @@ -941,7 +911,6 @@ def forward( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, ) @@ -1035,7 +1004,6 @@ def get_image_features( pixel_values: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> torch.FloatTensor: r""" @@ -1071,7 +1039,6 @@ def get_image_features( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, ) @@ -1091,7 +1058,6 @@ def forward( return_loss: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[Tuple, CLIPSegOutput]: r""" @@ -1129,7 +1095,6 @@ def forward( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, ) @@ -1397,7 +1362,6 @@ def forward( labels: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[Tuple, CLIPSegOutput]: r""" @@ -1437,7 +1401,6 @@ def forward( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=True, # we need the intermediate hidden states - interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, ) pooled_output = self.clip.visual_projection(vision_outputs[1]) diff --git a/tests/models/clipseg/test_modeling_clipseg.py b/tests/models/clipseg/test_modeling_clipseg.py index c5edf7cb757b30..9c8da9311da9f2 100644 --- a/tests/models/clipseg/test_modeling_clipseg.py +++ b/tests/models/clipseg/test_modeling_clipseg.py @@ -796,7 +796,7 @@ def test_inference_image_segmentation(self): # forward pass with torch.no_grad(): - outputs = model(**inputs, interpolate_pos_encoding=True) + outputs = model(**inputs) # verify the predicted masks self.assertEqual( @@ -804,7 +804,7 @@ def test_inference_image_segmentation(self): torch.Size((3, 352, 352)), ) expected_masks_slice = torch.tensor( - [[-7.4613, -7.4785, -7.3627], [-7.3268, -7.0898, -7.1333], [-6.9838, -6.7900, -6.8913]] + [[-7.4613, -7.4785, -7.3628], [-7.3268, -7.0899, -7.1333], [-6.9838, -6.7900, -6.8913]] ).to(torch_device) self.assertTrue(torch.allclose(outputs.logits[0, :3, :3], expected_masks_slice, atol=1e-3)) @@ -813,41 +813,4 @@ def test_inference_image_segmentation(self): expected_conditional = torch.tensor([0.5601, -0.0314, 0.1980]).to(torch_device) expected_pooled_output = torch.tensor([0.5036, -0.2681, -0.2644]).to(torch_device) self.assertTrue(torch.allclose(outputs.conditional_embeddings[0, :3], expected_conditional, atol=1e-3)) - self.assertTrue(torch.allclose(outputs.pooled_output[0, :3], expected_pooled_output, atol=1e-3)) - - @slow - def test_inference_interpolate_pos_encoding(self): - # ViT models have an `interpolate_pos_encoding` argument in their forward method, - # allowing to interpolate the pre-trained position embeddings in order to use - # the model on higher resolutions. The DINO model by Facebook AI leverages this - # to visualize self-attention on higher resolution images. - model = CLIPSegModel.from_pretrained("openai/clip-vit-base-patch32").to(torch_device) - - processor = CLIPSegProcessor.from_pretrained( - "openai/clip-vit-base-patch32", size={"height": 180, "width": 180}, crop_size={"height": 180, "width": 180} - ) - - image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") - inputs = processor(text="what's in the image", images=image, return_tensors="pt").to(torch_device) - - # interpolate_pos_encodiung false should return value error - with self.assertRaises(ValueError, msg="doesn't match model"): - with torch.no_grad(): - model(**inputs, interpolate_pos_encoding=False) - - # forward pass - with torch.no_grad(): - outputs = model(**inputs, interpolate_pos_encoding=True) - - # verify the logits - expected_shape = torch.Size((1, 26, 768)) - - self.assertEqual(outputs.vision_model_output.last_hidden_state.shape, expected_shape) - - expected_slice = torch.tensor( - [[-0.1538, 0.0322, -0.3235], [0.2893, 0.1135, -0.5708], [0.0461, 0.1540, -0.6018]] - ).to(torch_device) - - self.assertTrue( - torch.allclose(outputs.vision_model_output.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4) - ) + self.assertTrue(torch.allclose(outputs.pooled_output[0, :3], expected_pooled_output, atol=1e-3)) \ No newline at end of file