Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🚨🚨🚨 fix(Mask2Former): torch export 🚨🚨🚨 #34393

Merged
merged 9 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 37 additions & 25 deletions src/transformers/models/mask2former/modeling_mask2former.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,7 +926,7 @@ def forward(
encoder_attention_mask=None,
position_embeddings: Optional[torch.Tensor] = None,
reference_points=None,
spatial_shapes=None,
spatial_shapes_list=None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changing the name is a breaking change, we should probably have a deprecation cycle no?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if not, we can add 🔴 as I think the motivation is strong enough.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I suppose it is an internal module of the model, not sure if it is intended to be used elsewhere, let me know if I'm wrong

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't changing the logic, but keeping both spatial_shapes and spatial_shapes_list already a breaking change?

Sure there's some question of whether users can rely on internal Modules of transformers models, but also if a user doesn't pass a value for spatial_shapes_list this code will fail as L939 will try to iterate over a None object.

BTW seems like the changes in #33600 already violate this contract(see this line)? I followed that PR as a guide on what I should change here.

It seems like the proper way forward is to add a 🔴 here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the path forward @ArthurZucker ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bumping this again after the weekend. I'm not familiar with Huggingface's policies of

  1. What qualifies as a breaking change?
  2. What the release process is?

Could you provide a recommendation on where we should take this PR?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually even if a module is Internal people still end up using it 😅
This IS a breaking change, but acceptable IMO. Let's use some 🚨 on the PR title to make sure we communicate about it on the release!

level_start_index=None,
output_attentions: bool = False,
):
Expand All @@ -936,7 +936,8 @@ def forward(

batch_size, num_queries, _ = hidden_states.shape
batch_size, sequence_length, _ = encoder_hidden_states.shape
if (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length:
total_elements = sum(height * width for height, width in spatial_shapes_list)
if total_elements != sequence_length:
raise ValueError(
"Make sure to align the spatial shapes with the sequence length of the encoder hidden states"
)
Expand All @@ -957,7 +958,11 @@ def forward(
)
# batch_size, num_queries, n_heads, n_levels, n_points, 2
if reference_points.shape[-1] == 2:
offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
offset_normalizer = torch.tensor(
[[shape[1], shape[0]] for shape in spatial_shapes_list],
dtype=torch.long,
device=reference_points.device,
)
sampling_locations = (
reference_points[:, :, None, :, None, :]
+ sampling_offsets / offset_normalizer[None, None, None, :, None, :]
Expand All @@ -970,7 +975,7 @@ def forward(
else:
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")

output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights)
output = multi_scale_deformable_attention(value, spatial_shapes_list, sampling_locations, attention_weights)
output = self.output_proj(output)

return output, attention_weights
Expand Down Expand Up @@ -1001,7 +1006,7 @@ def forward(
attention_mask: torch.Tensor,
position_embeddings: torch.Tensor = None,
reference_points=None,
spatial_shapes=None,
spatial_shapes_list=None,
level_start_index=None,
output_attentions: bool = False,
):
Expand All @@ -1015,8 +1020,8 @@ def forward(
Position embeddings, to be added to `hidden_states`.
reference_points (`torch.FloatTensor`, *optional*):
Reference points.
spatial_shapes (`torch.LongTensor`, *optional*):
Spatial shapes of the backbone feature maps.
spatial_shapes_list (`list` of `tuple`):
Spatial shapes of the backbone feature maps as a list of tuples.
level_start_index (`torch.LongTensor`, *optional*):
Level start index.
output_attentions (`bool`, *optional*):
Expand All @@ -1033,7 +1038,7 @@ def forward(
encoder_attention_mask=attention_mask,
position_embeddings=position_embeddings,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
spatial_shapes_list=spatial_shapes_list,
level_start_index=level_start_index,
output_attentions=output_attentions,
)
Expand Down Expand Up @@ -1086,13 +1091,13 @@ def __init__(self, config: Mask2FormerConfig):
)

@staticmethod
def get_reference_points(spatial_shapes, valid_ratios, device):
def get_reference_points(spatial_shapes_list, valid_ratios, device):
"""
Get reference points for each feature map. Used in decoder.

Args:
spatial_shapes (`torch.LongTensor`):
Spatial shapes of each feature map, has shape of `(num_feature_levels, 2)`.
spatial_shapes_list (`list` of `tuple`):
Spatial shapes of the backbone feature maps as a list of tuples.
valid_ratios (`torch.FloatTensor`):
Valid ratios of each feature map, has shape of `(batch_size, num_feature_levels, 2)`.
device (`torch.device`):
Expand All @@ -1101,7 +1106,7 @@ def get_reference_points(spatial_shapes, valid_ratios, device):
`torch.FloatTensor` of shape `(batch_size, num_queries, num_feature_levels, 2)`
"""
reference_points_list = []
for lvl, (height, width) in enumerate(spatial_shapes):
for lvl, (height, width) in enumerate(spatial_shapes_list):
ref_y, ref_x = torch.meshgrid(
torch.linspace(0.5, height - 0.5, height, dtype=valid_ratios.dtype, device=device),
torch.linspace(0.5, width - 0.5, width, dtype=valid_ratios.dtype, device=device),
Expand All @@ -1122,7 +1127,7 @@ def forward(
inputs_embeds=None,
attention_mask=None,
position_embeddings=None,
spatial_shapes=None,
spatial_shapes_list=None,
level_start_index=None,
valid_ratios=None,
output_attentions=None,
Expand All @@ -1140,8 +1145,8 @@ def forward(
[What are attention masks?](../glossary#attention-mask)
position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Position embeddings that are added to the queries and keys in each self-attention layer.
spatial_shapes (`torch.LongTensor` of shape `(num_feature_levels, 2)`):
Spatial shapes of each feature map.
spatial_shapes_list (`list` of `tuple`):
Spatial shapes of each feature map as a list of tuples.
level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`):
Starting index of each feature map.
valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`):
Expand All @@ -1162,7 +1167,7 @@ def forward(
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

hidden_states = inputs_embeds
reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=inputs_embeds.device)
reference_points = self.get_reference_points(spatial_shapes_list, valid_ratios, device=inputs_embeds.device)

all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
Expand All @@ -1176,7 +1181,7 @@ def forward(
attention_mask,
position_embeddings=position_embeddings,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
spatial_shapes_list=spatial_shapes_list,
level_start_index=level_start_index,
output_attentions=output_attentions,
)
Expand Down Expand Up @@ -1302,9 +1307,9 @@ def forward(
]

# Prepare encoder inputs (by flattening)
spatial_shapes = [(embed.shape[2], embed.shape[3]) for embed in input_embeds]
spatial_shapes_list = [(embed.shape[2], embed.shape[3]) for embed in input_embeds]
input_embeds_flat = torch.cat([embed.flatten(2).transpose(1, 2) for embed in input_embeds], 1)
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=input_embeds_flat.device)
spatial_shapes = torch.as_tensor(spatial_shapes_list, dtype=torch.long, device=input_embeds_flat.device)
masks_flat = torch.cat([mask.flatten(1) for mask in masks], 1)

position_embeddings = [embed.flatten(2).transpose(1, 2) for embed in position_embeddings]
Expand All @@ -1320,7 +1325,7 @@ def forward(
inputs_embeds=input_embeds_flat,
attention_mask=masks_flat,
position_embeddings=level_pos_embed_flat,
spatial_shapes=spatial_shapes,
spatial_shapes_list=spatial_shapes_list,
level_start_index=level_start_index,
valid_ratios=valid_ratios,
output_attentions=output_attentions,
Expand All @@ -1331,18 +1336,23 @@ def forward(
last_hidden_state = encoder_outputs.last_hidden_state
batch_size = last_hidden_state.shape[0]

# We compute level_start_index_list separately from the tensor version level_start_index
# to avoid iterating over a tensor which breaks torch.compile/export.
level_start_index_list = [0]
for height, width in spatial_shapes_list[:-1]:
level_start_index_list.append(level_start_index_list[-1] + height * width)
split_sizes = [None] * self.num_feature_levels
for i in range(self.num_feature_levels):
if i < self.num_feature_levels - 1:
split_sizes[i] = level_start_index[i + 1] - level_start_index[i]
split_sizes[i] = level_start_index_list[i + 1] - level_start_index_list[i]
else:
split_sizes[i] = last_hidden_state.shape[1] - level_start_index[i]
split_sizes[i] = last_hidden_state.shape[1] - level_start_index_list[i]

encoder_output = torch.split(last_hidden_state, [size.item() for size in split_sizes], dim=1)
encoder_output = torch.split(last_hidden_state, split_sizes, dim=1)

# Compute final features
outputs = [
x.transpose(1, 2).view(batch_size, -1, spatial_shapes[i][0], spatial_shapes[i][1])
x.transpose(1, 2).view(batch_size, -1, spatial_shapes_list[i][0], spatial_shapes_list[i][1])
for i, x in enumerate(encoder_output)
]

Expand Down Expand Up @@ -1876,7 +1886,9 @@ def forward(
else:
level_index = idx % self.num_feature_levels

attention_mask[torch.where(attention_mask.sum(-1) == attention_mask.shape[-1])] = False
where = (attention_mask.sum(-1) != attention_mask.shape[-1]).to(attention_mask.dtype)
# Multiply the attention mask instead of indexing to avoid issue in torch.export.
attention_mask = attention_mask * where.unsqueeze(-1)
philkuz marked this conversation as resolved.
Show resolved Hide resolved

layer_outputs = decoder_layer(
hidden_states,
Expand Down
26 changes: 26 additions & 0 deletions tests/models/mask2former/test_modeling_mask2former.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from tests.test_modeling_common import floats_tensor
from transformers import Mask2FormerConfig, is_torch_available, is_vision_available
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4
from transformers.testing_utils import (
require_timm,
require_torch,
Expand Down Expand Up @@ -481,3 +482,28 @@ def test_with_segmentation_maps_and_loss(self):
outputs = model(**inputs)

self.assertTrue(outputs.loss is not None)

def test_export(self):
qubvel marked this conversation as resolved.
Show resolved Hide resolved
if not is_torch_greater_or_equal_than_2_4:
self.skipTest(reason="This test requires torch >= 2.4 to run.")
model = Mask2FormerForUniversalSegmentation.from_pretrained(self.model_checkpoints).to(torch_device).eval()
image_processor = self.default_image_processor
image = prepare_img()
inputs = image_processor(image, return_tensors="pt").to(torch_device)

exported_program = torch.export.export(
model,
args=(inputs["pixel_values"], inputs["pixel_mask"]),
strict=True,
)
with torch.no_grad():
eager_outputs = model(**inputs)
exported_outputs = exported_program.module().forward(inputs["pixel_values"], inputs["pixel_mask"])
self.assertEqual(eager_outputs.masks_queries_logits.shape, exported_outputs.masks_queries_logits.shape)
self.assertTrue(
torch.allclose(eager_outputs.masks_queries_logits, exported_outputs.masks_queries_logits, atol=TOLERANCE)
)
self.assertEqual(eager_outputs.class_queries_logits.shape, exported_outputs.class_queries_logits.shape)
self.assertTrue(
torch.allclose(eager_outputs.class_queries_logits, exported_outputs.class_queries_logits, atol=TOLERANCE)
)
Loading