Skip to content

Commit

Permalink
Remove interpolate_pos_encoding
Browse files Browse the repository at this point in the history
  • Loading branch information
NielsRogge committed Oct 25, 2024
1 parent 1d06379 commit 91816dd
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 101 deletions.
85 changes: 24 additions & 61 deletions src/transformers/models/clipseg/modeling_clipseg.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
torch_int,
)
from .configuration_clipseg import CLIPSegConfig, CLIPSegTextConfig, CLIPSegVisionConfig

Expand Down Expand Up @@ -164,62 +163,40 @@ def __init__(self, config: CLIPSegVisionConfig):
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)

def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
images. This method is also adapted to support torch.jit tracing.
Adapted from:
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
"""

num_patches = embeddings.shape[1] - 1
position_embedding = self.position_embedding.weight.unsqueeze(0)
num_positions = position_embedding.shape[1] - 1

# always interpolate when tracing to ensure the exported model works for dynamic input shapes
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
return self.position_embedding(self.position_ids)

class_pos_embed = position_embedding[:, :1]
patch_pos_embed = position_embedding[:, 1:]

dim = embeddings.shape[-1]

new_height = height // self.patch_size
new_width = width // self.patch_size
def interpolate_position_embeddings(self, new_size):
if len(new_size) != 2:
raise ValueError("new_size should consist of 2 values")

sqrt_num_positions = torch_int(num_positions**0.5)
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)

patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
size=(new_height, new_width),
mode="bicubic",
align_corners=False,
num_patches_one_direction = int(self.num_patches**0.5)
# we interpolate the position embeddings in 2D
a = self.position_embedding.weight[1:].T.view(
1, self.config.hidden_size, num_patches_one_direction, num_patches_one_direction
)
b = (
nn.functional.interpolate(a, new_size, mode="bicubic", align_corners=False)
.squeeze(0)
.view(self.config.hidden_size, new_size[0] * new_size[1])
.T
)
result = torch.cat([self.position_embedding.weight[:1], b])

patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)

return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
return result

def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:
batch_size, _, height, width = pixel_values.shape
if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size):
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model" f" ({self.image_size}*{self.image_size})."
)
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0]
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)

class_embeds = self.class_embedding.expand(batch_size, 1, -1)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
if interpolate_pos_encoding:
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)

if embeddings.shape[1] != self.num_positions:
new_shape = int(math.sqrt(embeddings.shape[1] - 1))
embeddings = embeddings + self.interpolate_position_embeddings((new_shape, new_shape))
embeddings = embeddings.to(embeddings.dtype)
else:
embeddings = embeddings + self.position_embedding(self.position_ids)

return embeddings


Expand Down Expand Up @@ -535,8 +512,6 @@ def _init_weights(self, module):
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the pre-trained position encodings.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
Expand Down Expand Up @@ -574,8 +549,6 @@ def _init_weights(self, module):
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the pre-trained position encodings.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
Expand Down Expand Up @@ -845,14 +818,12 @@ def __init__(self, config: CLIPSegVisionConfig):

@add_start_docstrings_to_model_forward(CLIPSEG_VISION_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPSegVisionConfig)
# Copied from transformers.models.clip.modeling_clip.CLIPVisionTransformer.forward
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = False,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Returns:
Expand All @@ -867,7 +838,7 @@ def forward(
if pixel_values is None:
raise ValueError("You have to specify pixel_values")

hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
hidden_states = self.embeddings(pixel_values)
hidden_states = self.pre_layrnorm(hidden_states)

encoder_outputs = self.encoder(
Expand Down Expand Up @@ -912,7 +883,6 @@ def forward(
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = False,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Expand Down Expand Up @@ -941,7 +911,6 @@ def forward(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
)

Expand Down Expand Up @@ -1035,7 +1004,6 @@ def get_image_features(
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None,
) -> torch.FloatTensor:
r"""
Expand Down Expand Up @@ -1071,7 +1039,6 @@ def get_image_features(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
)

Expand All @@ -1091,7 +1058,6 @@ def forward(
return_loss: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CLIPSegOutput]:
r"""
Expand Down Expand Up @@ -1129,7 +1095,6 @@ def forward(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
)

Expand Down Expand Up @@ -1397,7 +1362,6 @@ def forward(
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CLIPSegOutput]:
r"""
Expand Down Expand Up @@ -1437,7 +1401,6 @@ def forward(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=True, # we need the intermediate hidden states
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
)
pooled_output = self.clip.visual_projection(vision_outputs[1])
Expand Down
43 changes: 3 additions & 40 deletions tests/models/clipseg/test_modeling_clipseg.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,15 +796,15 @@ def test_inference_image_segmentation(self):

# forward pass
with torch.no_grad():
outputs = model(**inputs, interpolate_pos_encoding=True)
outputs = model(**inputs)

# verify the predicted masks
self.assertEqual(
outputs.logits.shape,
torch.Size((3, 352, 352)),
)
expected_masks_slice = torch.tensor(
[[-7.4613, -7.4785, -7.3627], [-7.3268, -7.0898, -7.1333], [-6.9838, -6.7900, -6.8913]]
[[-7.4613, -7.4785, -7.3628], [-7.3268, -7.0899, -7.1333], [-6.9838, -6.7900, -6.8913]]
).to(torch_device)

self.assertTrue(torch.allclose(outputs.logits[0, :3, :3], expected_masks_slice, atol=1e-3))
Expand All @@ -813,41 +813,4 @@ def test_inference_image_segmentation(self):
expected_conditional = torch.tensor([0.5601, -0.0314, 0.1980]).to(torch_device)
expected_pooled_output = torch.tensor([0.5036, -0.2681, -0.2644]).to(torch_device)
self.assertTrue(torch.allclose(outputs.conditional_embeddings[0, :3], expected_conditional, atol=1e-3))
self.assertTrue(torch.allclose(outputs.pooled_output[0, :3], expected_pooled_output, atol=1e-3))

@slow
def test_inference_interpolate_pos_encoding(self):
# ViT models have an `interpolate_pos_encoding` argument in their forward method,
# allowing to interpolate the pre-trained position embeddings in order to use
# the model on higher resolutions. The DINO model by Facebook AI leverages this
# to visualize self-attention on higher resolution images.
model = CLIPSegModel.from_pretrained("openai/clip-vit-base-patch32").to(torch_device)

processor = CLIPSegProcessor.from_pretrained(
"openai/clip-vit-base-patch32", size={"height": 180, "width": 180}, crop_size={"height": 180, "width": 180}
)

image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
inputs = processor(text="what's in the image", images=image, return_tensors="pt").to(torch_device)

# interpolate_pos_encodiung false should return value error
with self.assertRaises(ValueError, msg="doesn't match model"):
with torch.no_grad():
model(**inputs, interpolate_pos_encoding=False)

# forward pass
with torch.no_grad():
outputs = model(**inputs, interpolate_pos_encoding=True)

# verify the logits
expected_shape = torch.Size((1, 26, 768))

self.assertEqual(outputs.vision_model_output.last_hidden_state.shape, expected_shape)

expected_slice = torch.tensor(
[[-0.1538, 0.0322, -0.3235], [0.2893, 0.1135, -0.5708], [0.0461, 0.1540, -0.6018]]
).to(torch_device)

self.assertTrue(
torch.allclose(outputs.vision_model_output.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4)
)
self.assertTrue(torch.allclose(outputs.pooled_output[0, :3], expected_pooled_output, atol=1e-3))

0 comments on commit 91816dd

Please sign in to comment.