Skip to content

Commit

Permalink
fix(Mask2Former): torch export
Browse files Browse the repository at this point in the history
Signed-off-by: Phillip Kuznetsov <[email protected]>
  • Loading branch information
philkuz committed Oct 24, 2024
1 parent f51ac9e commit c188015
Showing 1 changed file with 24 additions and 9 deletions.
33 changes: 24 additions & 9 deletions src/transformers/models/mask2former/modeling_mask2former.py
Original file line number Diff line number Diff line change
Expand Up @@ -927,6 +927,7 @@ def forward(
position_embeddings: Optional[torch.Tensor] = None,
reference_points=None,
spatial_shapes=None,
spatial_shapes_list=None,
level_start_index=None,
output_attentions: bool = False,
):
Expand All @@ -936,7 +937,8 @@ def forward(

batch_size, num_queries, _ = hidden_states.shape
batch_size, sequence_length, _ = encoder_hidden_states.shape
if (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length:
total_elements = sum(height * width for height, width in spatial_shapes_list)
if total_elements != sequence_length:
raise ValueError(
"Make sure to align the spatial shapes with the sequence length of the encoder hidden states"
)
Expand Down Expand Up @@ -970,7 +972,7 @@ def forward(
else:
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")

output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights)
output = multi_scale_deformable_attention(value, spatial_shapes_list, sampling_locations, attention_weights)
output = self.output_proj(output)

return output, attention_weights
Expand Down Expand Up @@ -1002,6 +1004,7 @@ def forward(
position_embeddings: torch.Tensor = None,
reference_points=None,
spatial_shapes=None,
spatial_shapes_list=None,
level_start_index=None,
output_attentions: bool = False,
):
Expand All @@ -1017,6 +1020,8 @@ def forward(
Reference points.
spatial_shapes (`torch.LongTensor`, *optional*):
Spatial shapes of the backbone feature maps.
spatial_shapes_list (`list` of `tuple`):
Spatial shapes of the backbone feature maps as a list of tuples.
level_start_index (`torch.LongTensor`, *optional*):
Level start index.
output_attentions (`bool`, *optional*):
Expand All @@ -1034,6 +1039,7 @@ def forward(
position_embeddings=position_embeddings,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
spatial_shapes_list=spatial_shapes_list,
level_start_index=level_start_index,
output_attentions=output_attentions,
)
Expand Down Expand Up @@ -1123,6 +1129,7 @@ def forward(
attention_mask=None,
position_embeddings=None,
spatial_shapes=None,
spatial_shapes_list=None,
level_start_index=None,
valid_ratios=None,
output_attentions=None,
Expand All @@ -1142,6 +1149,8 @@ def forward(
Position embeddings that are added to the queries and keys in each self-attention layer.
spatial_shapes (`torch.LongTensor` of shape `(num_feature_levels, 2)`):
Spatial shapes of each feature map.
spatial_shapes_list (`list` of `tuple`):
Spatial shapes of each feature map as a list of tuples.
level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`):
Starting index of each feature map.
valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`):
Expand All @@ -1162,7 +1171,7 @@ def forward(
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

hidden_states = inputs_embeds
reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=inputs_embeds.device)
reference_points = self.get_reference_points(spatial_shapes_list, valid_ratios, device=inputs_embeds.device)

all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
Expand All @@ -1177,6 +1186,7 @@ def forward(
position_embeddings=position_embeddings,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
spatial_shapes_list=spatial_shapes_list,
level_start_index=level_start_index,
output_attentions=output_attentions,
)
Expand Down Expand Up @@ -1302,16 +1312,18 @@ def forward(
]

# Prepare encoder inputs (by flattening)
spatial_shapes = [(embed.shape[2], embed.shape[3]) for embed in input_embeds]
spatial_shapes_list = [(embed.shape[2], embed.shape[3]) for embed in input_embeds]
input_embeds_flat = torch.cat([embed.flatten(2).transpose(1, 2) for embed in input_embeds], 1)
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=input_embeds_flat.device)
spatial_shapes = torch.as_tensor(spatial_shapes_list, dtype=torch.long, device=input_embeds_flat.device)
masks_flat = torch.cat([mask.flatten(1) for mask in masks], 1)

position_embeddings = [embed.flatten(2).transpose(1, 2) for embed in position_embeddings]
level_pos_embed_flat = [x + self.level_embed[i].view(1, 1, -1) for i, x in enumerate(position_embeddings)]
level_pos_embed_flat = torch.cat(level_pos_embed_flat, 1)

level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
level_start_index = [0]
for h, w in spatial_shapes_list[:-1]:
level_start_index.append(level_start_index[-1] + h * w)
valid_ratios = torch.stack([self.get_valid_ratio(mask, dtype=input_embeds_flat.dtype) for mask in masks], 1)

# Send input_embeds_flat + masks_flat + level_pos_embed_flat (backbone + proj layer output) through encoder
Expand All @@ -1321,6 +1333,7 @@ def forward(
attention_mask=masks_flat,
position_embeddings=level_pos_embed_flat,
spatial_shapes=spatial_shapes,
spatial_shapes_list=spatial_shapes_list,
level_start_index=level_start_index,
valid_ratios=valid_ratios,
output_attentions=output_attentions,
Expand All @@ -1338,11 +1351,11 @@ def forward(
else:
split_sizes[i] = last_hidden_state.shape[1] - level_start_index[i]

encoder_output = torch.split(last_hidden_state, [size.item() for size in split_sizes], dim=1)
encoder_output = torch.split(last_hidden_state, split_sizes, dim=1)

# Compute final features
outputs = [
x.transpose(1, 2).view(batch_size, -1, spatial_shapes[i][0], spatial_shapes[i][1])
x.transpose(1, 2).view(batch_size, -1, spatial_shapes_list[i][0], spatial_shapes_list[i][1])
for i, x in enumerate(encoder_output)
]

Expand Down Expand Up @@ -1876,7 +1889,9 @@ def forward(
else:
level_index = idx % self.num_feature_levels

attention_mask[torch.where(attention_mask.sum(-1) == attention_mask.shape[-1])] = False
where = (attention_mask.sum(-1) != attention_mask.shape[-1]).to(attention_mask.dtype)
# Multiply the attention mask instead of indexing to avoid issue in torch.export.
attention_mask = attention_mask * where.unsqueeze(-1)

layer_outputs = decoder_layer(
hidden_states,
Expand Down

0 comments on commit c188015

Please sign in to comment.