Skip to content

Commit

Permalink
revert level_start_index and create a level_start_index_list
Browse files Browse the repository at this point in the history
Signed-off-by: Phillip Kuznetsov <[email protected]>
  • Loading branch information
philkuz committed Oct 25, 2024
1 parent c188015 commit cfaef32
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions src/transformers/models/mask2former/modeling_mask2former.py
Original file line number Diff line number Diff line change
Expand Up @@ -1321,9 +1321,7 @@ def forward(
level_pos_embed_flat = [x + self.level_embed[i].view(1, 1, -1) for i, x in enumerate(position_embeddings)]
level_pos_embed_flat = torch.cat(level_pos_embed_flat, 1)

level_start_index = [0]
for h, w in spatial_shapes_list[:-1]:
level_start_index.append(level_start_index[-1] + h * w)
level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
valid_ratios = torch.stack([self.get_valid_ratio(mask, dtype=input_embeds_flat.dtype) for mask in masks], 1)

# Send input_embeds_flat + masks_flat + level_pos_embed_flat (backbone + proj layer output) through encoder
Expand All @@ -1344,12 +1342,15 @@ def forward(
last_hidden_state = encoder_outputs.last_hidden_state
batch_size = last_hidden_state.shape[0]

level_start_index_list = [0]
for h, w in spatial_shapes_list[:-1]:
level_start_index_list.append(level_start_index_list[-1] + h * w)
split_sizes = [None] * self.num_feature_levels
for i in range(self.num_feature_levels):
if i < self.num_feature_levels - 1:
split_sizes[i] = level_start_index[i + 1] - level_start_index[i]
split_sizes[i] = level_start_index_list[i + 1] - level_start_index_list[i]
else:
split_sizes[i] = last_hidden_state.shape[1] - level_start_index[i]
split_sizes[i] = last_hidden_state.shape[1] - level_start_index_list[i]

encoder_output = torch.split(last_hidden_state, split_sizes, dim=1)

Expand Down

0 comments on commit cfaef32

Please sign in to comment.