From 5367e8b0639ea25598db22b9e404336a60aa6b47 Mon Sep 17 00:00:00 2001 From: Phillip Kuznetsov Date: Tue, 22 Oct 2024 15:01:03 -0700 Subject: [PATCH 1/9] fix(Mask2Former): torch export Signed-off-by: Phillip Kuznetsov --- .../mask2former/modeling_mask2former.py | 33 ++++++++++++++----- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/mask2former/modeling_mask2former.py b/src/transformers/models/mask2former/modeling_mask2former.py index f4aea415adf5e6..fd097b2eca0a2e 100644 --- a/src/transformers/models/mask2former/modeling_mask2former.py +++ b/src/transformers/models/mask2former/modeling_mask2former.py @@ -927,6 +927,7 @@ def forward( position_embeddings: Optional[torch.Tensor] = None, reference_points=None, spatial_shapes=None, + spatial_shapes_list=None, level_start_index=None, output_attentions: bool = False, ): @@ -936,7 +937,8 @@ def forward( batch_size, num_queries, _ = hidden_states.shape batch_size, sequence_length, _ = encoder_hidden_states.shape - if (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length: + total_elements = sum(height * width for height, width in spatial_shapes_list) + if total_elements != sequence_length: raise ValueError( "Make sure to align the spatial shapes with the sequence length of the encoder hidden states" ) @@ -970,7 +972,7 @@ def forward( else: raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}") - output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights) + output = multi_scale_deformable_attention(value, spatial_shapes_list, sampling_locations, attention_weights) output = self.output_proj(output) return output, attention_weights @@ -1002,6 +1004,7 @@ def forward( position_embeddings: torch.Tensor = None, reference_points=None, spatial_shapes=None, + spatial_shapes_list=None, level_start_index=None, output_attentions: bool = False, ): @@ -1017,6 +1020,8 @@ def forward( Reference points. spatial_shapes (`torch.LongTensor`, *optional*): Spatial shapes of the backbone feature maps. + spatial_shapes_list (`list` of `tuple`): + Spatial shapes of the backbone feature maps as a list of tuples. level_start_index (`torch.LongTensor`, *optional*): Level start index. output_attentions (`bool`, *optional*): @@ -1034,6 +1039,7 @@ def forward( position_embeddings=position_embeddings, reference_points=reference_points, spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, level_start_index=level_start_index, output_attentions=output_attentions, ) @@ -1123,6 +1129,7 @@ def forward( attention_mask=None, position_embeddings=None, spatial_shapes=None, + spatial_shapes_list=None, level_start_index=None, valid_ratios=None, output_attentions=None, @@ -1142,6 +1149,8 @@ def forward( Position embeddings that are added to the queries and keys in each self-attention layer. spatial_shapes (`torch.LongTensor` of shape `(num_feature_levels, 2)`): Spatial shapes of each feature map. + spatial_shapes_list (`list` of `tuple`): + Spatial shapes of each feature map as a list of tuples. level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`): Starting index of each feature map. valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`): @@ -1162,7 +1171,7 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict hidden_states = inputs_embeds - reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=inputs_embeds.device) + reference_points = self.get_reference_points(spatial_shapes_list, valid_ratios, device=inputs_embeds.device) all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -1177,6 +1186,7 @@ def forward( position_embeddings=position_embeddings, reference_points=reference_points, spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, level_start_index=level_start_index, output_attentions=output_attentions, ) @@ -1302,16 +1312,18 @@ def forward( ] # Prepare encoder inputs (by flattening) - spatial_shapes = [(embed.shape[2], embed.shape[3]) for embed in input_embeds] + spatial_shapes_list = [(embed.shape[2], embed.shape[3]) for embed in input_embeds] input_embeds_flat = torch.cat([embed.flatten(2).transpose(1, 2) for embed in input_embeds], 1) - spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=input_embeds_flat.device) + spatial_shapes = torch.as_tensor(spatial_shapes_list, dtype=torch.long, device=input_embeds_flat.device) masks_flat = torch.cat([mask.flatten(1) for mask in masks], 1) position_embeddings = [embed.flatten(2).transpose(1, 2) for embed in position_embeddings] level_pos_embed_flat = [x + self.level_embed[i].view(1, 1, -1) for i, x in enumerate(position_embeddings)] level_pos_embed_flat = torch.cat(level_pos_embed_flat, 1) - level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) + level_start_index = [0] + for h, w in spatial_shapes_list[:-1]: + level_start_index.append(level_start_index[-1] + h * w) valid_ratios = torch.stack([self.get_valid_ratio(mask, dtype=input_embeds_flat.dtype) for mask in masks], 1) # Send input_embeds_flat + masks_flat + level_pos_embed_flat (backbone + proj layer output) through encoder @@ -1321,6 +1333,7 @@ def forward( attention_mask=masks_flat, position_embeddings=level_pos_embed_flat, spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, level_start_index=level_start_index, valid_ratios=valid_ratios, output_attentions=output_attentions, @@ -1338,11 +1351,11 @@ def forward( else: split_sizes[i] = last_hidden_state.shape[1] - level_start_index[i] - encoder_output = torch.split(last_hidden_state, [size.item() for size in split_sizes], dim=1) + encoder_output = torch.split(last_hidden_state, split_sizes, dim=1) # Compute final features outputs = [ - x.transpose(1, 2).view(batch_size, -1, spatial_shapes[i][0], spatial_shapes[i][1]) + x.transpose(1, 2).view(batch_size, -1, spatial_shapes_list[i][0], spatial_shapes_list[i][1]) for i, x in enumerate(encoder_output) ] @@ -1876,7 +1889,9 @@ def forward( else: level_index = idx % self.num_feature_levels - attention_mask[torch.where(attention_mask.sum(-1) == attention_mask.shape[-1])] = False + where = (attention_mask.sum(-1) != attention_mask.shape[-1]).to(attention_mask.dtype) + # Multiply the attention mask instead of indexing to avoid issue in torch.export. + attention_mask = attention_mask * where.unsqueeze(-1) layer_outputs = decoder_layer( hidden_states, From 01bfc4129fc411b3b8992e9692634569a6c40075 Mon Sep 17 00:00:00 2001 From: Phillip Kuznetsov Date: Fri, 25 Oct 2024 06:11:38 -0700 Subject: [PATCH 2/9] revert level_start_index and create a level_start_index_list Signed-off-by: Phillip Kuznetsov --- .../models/mask2former/modeling_mask2former.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/mask2former/modeling_mask2former.py b/src/transformers/models/mask2former/modeling_mask2former.py index fd097b2eca0a2e..e05201b30bdc2b 100644 --- a/src/transformers/models/mask2former/modeling_mask2former.py +++ b/src/transformers/models/mask2former/modeling_mask2former.py @@ -1321,9 +1321,7 @@ def forward( level_pos_embed_flat = [x + self.level_embed[i].view(1, 1, -1) for i, x in enumerate(position_embeddings)] level_pos_embed_flat = torch.cat(level_pos_embed_flat, 1) - level_start_index = [0] - for h, w in spatial_shapes_list[:-1]: - level_start_index.append(level_start_index[-1] + h * w) + level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) valid_ratios = torch.stack([self.get_valid_ratio(mask, dtype=input_embeds_flat.dtype) for mask in masks], 1) # Send input_embeds_flat + masks_flat + level_pos_embed_flat (backbone + proj layer output) through encoder @@ -1344,12 +1342,15 @@ def forward( last_hidden_state = encoder_outputs.last_hidden_state batch_size = last_hidden_state.shape[0] + level_start_index_list = [0] + for h, w in spatial_shapes_list[:-1]: + level_start_index_list.append(level_start_index_list[-1] + h * w) split_sizes = [None] * self.num_feature_levels for i in range(self.num_feature_levels): if i < self.num_feature_levels - 1: - split_sizes[i] = level_start_index[i + 1] - level_start_index[i] + split_sizes[i] = level_start_index_list[i + 1] - level_start_index_list[i] else: - split_sizes[i] = last_hidden_state.shape[1] - level_start_index[i] + split_sizes[i] = last_hidden_state.shape[1] - level_start_index_list[i] encoder_output = torch.split(last_hidden_state, split_sizes, dim=1) From a3f8377edaa6187a68fa5c95547fba6ab962aff3 Mon Sep 17 00:00:00 2001 From: Phillip Kuznetsov Date: Mon, 28 Oct 2024 11:52:00 -0700 Subject: [PATCH 3/9] Add a comment to explain the level_start_index_list Signed-off-by: Phillip Kuznetsov --- src/transformers/models/mask2former/modeling_mask2former.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/models/mask2former/modeling_mask2former.py b/src/transformers/models/mask2former/modeling_mask2former.py index e05201b30bdc2b..384ce3a369c494 100644 --- a/src/transformers/models/mask2former/modeling_mask2former.py +++ b/src/transformers/models/mask2former/modeling_mask2former.py @@ -1342,6 +1342,8 @@ def forward( last_hidden_state = encoder_outputs.last_hidden_state batch_size = last_hidden_state.shape[0] + # We compute level_start_index_list separately from the tensor version level_start_index + # to avoid iterating over a tensor which breaks torch.compile/export. level_start_index_list = [0] for h, w in spatial_shapes_list[:-1]: level_start_index_list.append(level_start_index_list[-1] + h * w) From 183d1fc4ad43491d66fd0f9e4df87e7f14d17177 Mon Sep 17 00:00:00 2001 From: Phillip Kuznetsov Date: Tue, 29 Oct 2024 10:37:16 -0700 Subject: [PATCH 4/9] Address comment Signed-off-by: Phillip Kuznetsov --- src/transformers/models/mask2former/modeling_mask2former.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/mask2former/modeling_mask2former.py b/src/transformers/models/mask2former/modeling_mask2former.py index 384ce3a369c494..a205712960dd3d 100644 --- a/src/transformers/models/mask2former/modeling_mask2former.py +++ b/src/transformers/models/mask2former/modeling_mask2former.py @@ -1345,8 +1345,8 @@ def forward( # We compute level_start_index_list separately from the tensor version level_start_index # to avoid iterating over a tensor which breaks torch.compile/export. level_start_index_list = [0] - for h, w in spatial_shapes_list[:-1]: - level_start_index_list.append(level_start_index_list[-1] + h * w) + for height, width in spatial_shapes_list[:-1]: + level_start_index_list.append(level_start_index_list[-1] + height * width) split_sizes = [None] * self.num_feature_levels for i in range(self.num_feature_levels): if i < self.num_feature_levels - 1: From 609de2609a47a6b8eaaff823e1b90c641297a59d Mon Sep 17 00:00:00 2001 From: Phillip Kuznetsov Date: Tue, 29 Oct 2024 10:55:33 -0700 Subject: [PATCH 5/9] add torch.export.export test Signed-off-by: Phillip Kuznetsov --- .../mask2former/test_modeling_mask2former.py | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/tests/models/mask2former/test_modeling_mask2former.py b/tests/models/mask2former/test_modeling_mask2former.py index ba78cf9ce3f7d6..c91fb9d07f57f8 100644 --- a/tests/models/mask2former/test_modeling_mask2former.py +++ b/tests/models/mask2former/test_modeling_mask2former.py @@ -17,6 +17,7 @@ import unittest import numpy as np +from packaging import version from tests.test_modeling_common import floats_tensor from transformers import Mask2FormerConfig, is_torch_available, is_vision_available @@ -481,3 +482,29 @@ def test_with_segmentation_maps_and_loss(self): outputs = model(**inputs) self.assertTrue(outputs.loss is not None) + + @slow + def test_export(self): + if version.parse(torch.__version__) < version.parse("2.4.0"): + self.skipTest(reason="This test requires torch >= 2.4 to run.") + model = Mask2FormerForUniversalSegmentation.from_pretrained(self.model_checkpoints).to(torch_device).eval() + image_processor = self.default_image_processor + image = prepare_img() + inputs = image_processor(image, return_tensors="pt").to(torch_device) + + exported_program = torch.export.export( + model, + args=(inputs["pixel_values"], inputs["pixel_mask"]), + strict=True, + ) + with torch.no_grad(): + eager_outputs = model(**inputs) + exported_outputs = exported_program.module().forward(inputs["pixel_values"], inputs["pixel_mask"]) + self.assertEqual(eager_outputs.masks_queries_logits.shape, exported_outputs.masks_queries_logits.shape) + self.assertTrue( + torch.allclose(eager_outputs.masks_queries_logits, exported_outputs.masks_queries_logits, atol=TOLERANCE) + ) + self.assertEqual(eager_outputs.class_queries_logits.shape, exported_outputs.class_queries_logits.shape) + self.assertTrue( + torch.allclose(eager_outputs.class_queries_logits, exported_outputs.class_queries_logits, atol=TOLERANCE) + ) From 99e90b2f36d9bcce225977273fd0a637659f96ae Mon Sep 17 00:00:00 2001 From: Phillip Kuznetsov Date: Tue, 29 Oct 2024 11:51:19 -0700 Subject: [PATCH 6/9] rename arg Signed-off-by: Phillip Kuznetsov --- .../models/mask2former/modeling_mask2former.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/mask2former/modeling_mask2former.py b/src/transformers/models/mask2former/modeling_mask2former.py index a205712960dd3d..5d1c3f5f142353 100644 --- a/src/transformers/models/mask2former/modeling_mask2former.py +++ b/src/transformers/models/mask2former/modeling_mask2former.py @@ -1092,13 +1092,13 @@ def __init__(self, config: Mask2FormerConfig): ) @staticmethod - def get_reference_points(spatial_shapes, valid_ratios, device): + def get_reference_points(spatial_shapes_list, valid_ratios, device): """ Get reference points for each feature map. Used in decoder. Args: - spatial_shapes (`torch.LongTensor`): - Spatial shapes of each feature map, has shape of `(num_feature_levels, 2)`. + spatial_shapes_list (`list` of `tuple`): + Spatial shapes of the backbone feature maps as a list of tuples. valid_ratios (`torch.FloatTensor`): Valid ratios of each feature map, has shape of `(batch_size, num_feature_levels, 2)`. device (`torch.device`): @@ -1107,7 +1107,7 @@ def get_reference_points(spatial_shapes, valid_ratios, device): `torch.FloatTensor` of shape `(batch_size, num_queries, num_feature_levels, 2)` """ reference_points_list = [] - for lvl, (height, width) in enumerate(spatial_shapes): + for lvl, (height, width) in enumerate(spatial_shapes_list): ref_y, ref_x = torch.meshgrid( torch.linspace(0.5, height - 0.5, height, dtype=valid_ratios.dtype, device=device), torch.linspace(0.5, width - 0.5, width, dtype=valid_ratios.dtype, device=device), From 104b16d0cf4848132fe903422c58ff4077bb06e0 Mon Sep 17 00:00:00 2001 From: Phillip Kuznetsov Date: Tue, 29 Oct 2024 13:15:19 -0700 Subject: [PATCH 7/9] remove spatial_shapes Signed-off-by: Phillip Kuznetsov --- .../models/mask2former/modeling_mask2former.py | 16 +++++----------- .../mask2former/test_modeling_mask2former.py | 1 - 2 files changed, 5 insertions(+), 12 deletions(-) diff --git a/src/transformers/models/mask2former/modeling_mask2former.py b/src/transformers/models/mask2former/modeling_mask2former.py index 5d1c3f5f142353..4cc96b1652dbf9 100644 --- a/src/transformers/models/mask2former/modeling_mask2former.py +++ b/src/transformers/models/mask2former/modeling_mask2former.py @@ -926,7 +926,6 @@ def forward( encoder_attention_mask=None, position_embeddings: Optional[torch.Tensor] = None, reference_points=None, - spatial_shapes=None, spatial_shapes_list=None, level_start_index=None, output_attentions: bool = False, @@ -959,7 +958,11 @@ def forward( ) # batch_size, num_queries, n_heads, n_levels, n_points, 2 if reference_points.shape[-1] == 2: - offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1) + offset_normalizer = torch.tensor( + [[shape[1], shape[0]] for shape in spatial_shapes_list], + dtype=torch.long, + device=reference_points.device, + ) sampling_locations = ( reference_points[:, :, None, :, None, :] + sampling_offsets / offset_normalizer[None, None, None, :, None, :] @@ -1003,7 +1006,6 @@ def forward( attention_mask: torch.Tensor, position_embeddings: torch.Tensor = None, reference_points=None, - spatial_shapes=None, spatial_shapes_list=None, level_start_index=None, output_attentions: bool = False, @@ -1018,8 +1020,6 @@ def forward( Position embeddings, to be added to `hidden_states`. reference_points (`torch.FloatTensor`, *optional*): Reference points. - spatial_shapes (`torch.LongTensor`, *optional*): - Spatial shapes of the backbone feature maps. spatial_shapes_list (`list` of `tuple`): Spatial shapes of the backbone feature maps as a list of tuples. level_start_index (`torch.LongTensor`, *optional*): @@ -1038,7 +1038,6 @@ def forward( encoder_attention_mask=attention_mask, position_embeddings=position_embeddings, reference_points=reference_points, - spatial_shapes=spatial_shapes, spatial_shapes_list=spatial_shapes_list, level_start_index=level_start_index, output_attentions=output_attentions, @@ -1128,7 +1127,6 @@ def forward( inputs_embeds=None, attention_mask=None, position_embeddings=None, - spatial_shapes=None, spatial_shapes_list=None, level_start_index=None, valid_ratios=None, @@ -1147,8 +1145,6 @@ def forward( [What are attention masks?](../glossary#attention-mask) position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): Position embeddings that are added to the queries and keys in each self-attention layer. - spatial_shapes (`torch.LongTensor` of shape `(num_feature_levels, 2)`): - Spatial shapes of each feature map. spatial_shapes_list (`list` of `tuple`): Spatial shapes of each feature map as a list of tuples. level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`): @@ -1185,7 +1181,6 @@ def forward( attention_mask, position_embeddings=position_embeddings, reference_points=reference_points, - spatial_shapes=spatial_shapes, spatial_shapes_list=spatial_shapes_list, level_start_index=level_start_index, output_attentions=output_attentions, @@ -1330,7 +1325,6 @@ def forward( inputs_embeds=input_embeds_flat, attention_mask=masks_flat, position_embeddings=level_pos_embed_flat, - spatial_shapes=spatial_shapes, spatial_shapes_list=spatial_shapes_list, level_start_index=level_start_index, valid_ratios=valid_ratios, diff --git a/tests/models/mask2former/test_modeling_mask2former.py b/tests/models/mask2former/test_modeling_mask2former.py index c91fb9d07f57f8..986c150b926a94 100644 --- a/tests/models/mask2former/test_modeling_mask2former.py +++ b/tests/models/mask2former/test_modeling_mask2former.py @@ -483,7 +483,6 @@ def test_with_segmentation_maps_and_loss(self): self.assertTrue(outputs.loss is not None) - @slow def test_export(self): if version.parse(torch.__version__) < version.parse("2.4.0"): self.skipTest(reason="This test requires torch >= 2.4 to run.") From a341d60ad0cbe1393646fa0a8de75a75edebae1a Mon Sep 17 00:00:00 2001 From: Phillip Kuznetsov Date: Tue, 29 Oct 2024 14:49:05 -0700 Subject: [PATCH 8/9] Use the version check from pytorch_utils Signed-off-by: Phillip Kuznetsov --- tests/models/mask2former/test_modeling_mask2former.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/mask2former/test_modeling_mask2former.py b/tests/models/mask2former/test_modeling_mask2former.py index 986c150b926a94..a3caefe14ab501 100644 --- a/tests/models/mask2former/test_modeling_mask2former.py +++ b/tests/models/mask2former/test_modeling_mask2former.py @@ -17,10 +17,10 @@ import unittest import numpy as np -from packaging import version from tests.test_modeling_common import floats_tensor from transformers import Mask2FormerConfig, is_torch_available, is_vision_available +from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4 from transformers.testing_utils import ( require_timm, require_torch, @@ -484,7 +484,7 @@ def test_with_segmentation_maps_and_loss(self): self.assertTrue(outputs.loss is not None) def test_export(self): - if version.parse(torch.__version__) < version.parse("2.4.0"): + if not is_torch_greater_or_equal_than_2_4: self.skipTest(reason="This test requires torch >= 2.4 to run.") model = Mask2FormerForUniversalSegmentation.from_pretrained(self.model_checkpoints).to(torch_device).eval() image_processor = self.default_image_processor From 704b79f9cd2414048fbd3bb27f1a3dcdc89f3090 Mon Sep 17 00:00:00 2001 From: Phillip Kuznetsov Date: Tue, 29 Oct 2024 14:54:50 -0700 Subject: [PATCH 9/9] [run_slow] mask2former Signed-off-by: Phillip Kuznetsov