diff --git a/src/transformers/models/mask2former/modeling_mask2former.py b/src/transformers/models/mask2former/modeling_mask2former.py index fd097b2eca0a2e..e05201b30bdc2b 100644 --- a/src/transformers/models/mask2former/modeling_mask2former.py +++ b/src/transformers/models/mask2former/modeling_mask2former.py @@ -1321,9 +1321,7 @@ def forward( level_pos_embed_flat = [x + self.level_embed[i].view(1, 1, -1) for i, x in enumerate(position_embeddings)] level_pos_embed_flat = torch.cat(level_pos_embed_flat, 1) - level_start_index = [0] - for h, w in spatial_shapes_list[:-1]: - level_start_index.append(level_start_index[-1] + h * w) + level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) valid_ratios = torch.stack([self.get_valid_ratio(mask, dtype=input_embeds_flat.dtype) for mask in masks], 1) # Send input_embeds_flat + masks_flat + level_pos_embed_flat (backbone + proj layer output) through encoder @@ -1344,12 +1342,15 @@ def forward( last_hidden_state = encoder_outputs.last_hidden_state batch_size = last_hidden_state.shape[0] + level_start_index_list = [0] + for h, w in spatial_shapes_list[:-1]: + level_start_index_list.append(level_start_index_list[-1] + h * w) split_sizes = [None] * self.num_feature_levels for i in range(self.num_feature_levels): if i < self.num_feature_levels - 1: - split_sizes[i] = level_start_index[i + 1] - level_start_index[i] + split_sizes[i] = level_start_index_list[i + 1] - level_start_index_list[i] else: - split_sizes[i] = last_hidden_state.shape[1] - level_start_index[i] + split_sizes[i] = last_hidden_state.shape[1] - level_start_index_list[i] encoder_output = torch.split(last_hidden_state, split_sizes, dim=1)