From f34fa480b918e42cada2ae339d15733059f96478 Mon Sep 17 00:00:00 2001 From: Nick DeGroot Date: Mon, 4 Mar 2024 03:04:49 -0800 Subject: [PATCH] Fix OneFormer `post_process_instance_segmentation` for panoptic tasks (#29304) * :bug: Fix oneformer instance post processing when using panoptic task type * :white_check_mark: Add unit test for oneformer instance post processing panoptic bug --------- Co-authored-by: Nick DeGroot <1966472+nickthegroot@users.noreply.github.com> --- .../models/oneformer/image_processing_oneformer.py | 8 ++++---- .../oneformer/test_image_processing_oneformer.py | 13 +++++++++++++ 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/oneformer/image_processing_oneformer.py b/src/transformers/models/oneformer/image_processing_oneformer.py index d9b0c0168682ab..9f865f8efd9b94 100644 --- a/src/transformers/models/oneformer/image_processing_oneformer.py +++ b/src/transformers/models/oneformer/image_processing_oneformer.py @@ -1244,8 +1244,8 @@ def post_process_instance_segmentation( # if this is panoptic segmentation, we only keep the "thing" classes if task_type == "panoptic": keep = torch.zeros_like(scores_per_image).bool() - for i, lab in enumerate(labels_per_image): - keep[i] = lab in self.metadata["thing_ids"] + for j, lab in enumerate(labels_per_image): + keep[j] = lab in self.metadata["thing_ids"] scores_per_image = scores_per_image[keep] labels_per_image = labels_per_image[keep] @@ -1258,8 +1258,8 @@ def post_process_instance_segmentation( continue if "ade20k" in self.class_info_file and not is_demo and "instance" in task_type: - for i in range(labels_per_image.shape[0]): - labels_per_image[i] = self.metadata["thing_ids"].index(labels_per_image[i].item()) + for j in range(labels_per_image.shape[0]): + labels_per_image[j] = self.metadata["thing_ids"].index(labels_per_image[j].item()) # Get segmentation map and segment information of batch item target_size = target_sizes[i] if target_sizes is not None else None diff --git a/tests/models/oneformer/test_image_processing_oneformer.py b/tests/models/oneformer/test_image_processing_oneformer.py index 4a9e560463adf0..abec659a8bfc87 100644 --- a/tests/models/oneformer/test_image_processing_oneformer.py +++ b/tests/models/oneformer/test_image_processing_oneformer.py @@ -295,6 +295,19 @@ def test_post_process_instance_segmentation(self): el["segmentation"].shape, (self.image_processor_tester.height, self.image_processor_tester.width) ) + segmentation_with_opts = image_processor.post_process_instance_segmentation( + outputs, + threshold=0, + target_sizes=[(1, 4) for _ in range(self.image_processor_tester.batch_size)], + task_type="panoptic", + ) + self.assertTrue(len(segmentation_with_opts) == self.image_processor_tester.batch_size) + for el in segmentation_with_opts: + self.assertTrue("segmentation" in el) + self.assertTrue("segments_info" in el) + self.assertEqual(type(el["segments_info"]), list) + self.assertEqual(el["segmentation"].shape, (1, 4)) + def test_post_process_panoptic_segmentation(self): image_processor = self.image_processing_class( num_labels=self.image_processor_tester.num_classes,