Skip to content

Commit

Permalink
enable training mask2former and maskformer for transformers trainer (#…
Browse files Browse the repository at this point in the history
…28277)

* fix get_num_masks output as [int] to int

* fix loss size from torch.Size([1]) to torch.Size([])
  • Loading branch information
SangbumChoi authored Jan 4, 2024
1 parent 6b8ec25 commit 4a66c0d
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -787,7 +787,7 @@ def get_num_masks(self, class_labels: torch.Tensor, device: torch.device) -> tor
Computes the average number of target masks across the batch, for normalization purposes.
"""
num_masks = sum([len(classes) for classes in class_labels])
num_masks_pt = torch.as_tensor([num_masks], dtype=torch.float, device=device)
num_masks_pt = torch.as_tensor(num_masks, dtype=torch.float, device=device)
return num_masks_pt


Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/maskformer/modeling_maskformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1193,7 +1193,7 @@ def get_num_masks(self, class_labels: torch.Tensor, device: torch.device) -> tor
Computes the average number of target masks across the batch, for normalization purposes.
"""
num_masks = sum([len(classes) for classes in class_labels])
num_masks_pt = torch.as_tensor([num_masks], dtype=torch.float, device=device)
num_masks_pt = torch.as_tensor(num_masks, dtype=torch.float, device=device)
return num_masks_pt


Expand Down
2 changes: 1 addition & 1 deletion tests/models/mask2former/test_modeling_mask2former.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def comm_check_on_output(result):
comm_check_on_output(result)

self.parent.assertTrue(result.loss is not None)
self.parent.assertEqual(result.loss.shape, torch.Size([1]))
self.parent.assertEqual(result.loss.shape, torch.Size([]))


@require_torch
Expand Down
2 changes: 1 addition & 1 deletion tests/models/maskformer/test_modeling_maskformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def comm_check_on_output(result):
comm_check_on_output(result)

self.parent.assertTrue(result.loss is not None)
self.parent.assertEqual(result.loss.shape, torch.Size([1]))
self.parent.assertEqual(result.loss.shape, torch.Size([]))


@require_torch
Expand Down

0 comments on commit 4a66c0d

Please sign in to comment.