diff --git a/src/transformers/models/mask2former/modeling_mask2former.py b/src/transformers/models/mask2former/modeling_mask2former.py index 5a5dacbcf4c120..eeee25967e4f4d 100644 --- a/src/transformers/models/mask2former/modeling_mask2former.py +++ b/src/transformers/models/mask2former/modeling_mask2former.py @@ -787,7 +787,7 @@ def get_num_masks(self, class_labels: torch.Tensor, device: torch.device) -> tor Computes the average number of target masks across the batch, for normalization purposes. """ num_masks = sum([len(classes) for classes in class_labels]) - num_masks_pt = torch.as_tensor([num_masks], dtype=torch.float, device=device) + num_masks_pt = torch.as_tensor(num_masks, dtype=torch.float, device=device) return num_masks_pt diff --git a/src/transformers/models/maskformer/modeling_maskformer.py b/src/transformers/models/maskformer/modeling_maskformer.py index 6572f5cdde864a..dc46a6e8798893 100644 --- a/src/transformers/models/maskformer/modeling_maskformer.py +++ b/src/transformers/models/maskformer/modeling_maskformer.py @@ -1193,7 +1193,7 @@ def get_num_masks(self, class_labels: torch.Tensor, device: torch.device) -> tor Computes the average number of target masks across the batch, for normalization purposes. """ num_masks = sum([len(classes) for classes in class_labels]) - num_masks_pt = torch.as_tensor([num_masks], dtype=torch.float, device=device) + num_masks_pt = torch.as_tensor(num_masks, dtype=torch.float, device=device) return num_masks_pt diff --git a/tests/models/mask2former/test_modeling_mask2former.py b/tests/models/mask2former/test_modeling_mask2former.py index 8dc70d7c750d1b..1c4846947988ec 100644 --- a/tests/models/mask2former/test_modeling_mask2former.py +++ b/tests/models/mask2former/test_modeling_mask2former.py @@ -190,7 +190,7 @@ def comm_check_on_output(result): comm_check_on_output(result) self.parent.assertTrue(result.loss is not None) - self.parent.assertEqual(result.loss.shape, torch.Size([1])) + self.parent.assertEqual(result.loss.shape, torch.Size([])) @require_torch diff --git a/tests/models/maskformer/test_modeling_maskformer.py b/tests/models/maskformer/test_modeling_maskformer.py index ffa77a051259e0..c4f014c5bbb5bb 100644 --- a/tests/models/maskformer/test_modeling_maskformer.py +++ b/tests/models/maskformer/test_modeling_maskformer.py @@ -189,7 +189,7 @@ def comm_check_on_output(result): comm_check_on_output(result) self.parent.assertTrue(result.loss is not None) - self.parent.assertEqual(result.loss.shape, torch.Size([1])) + self.parent.assertEqual(result.loss.shape, torch.Size([])) @require_torch