Skip to content

Commit

Permalink
Fix padding_index in test and typo
Browse files Browse the repository at this point in the history
  • Loading branch information
qubvel committed Apr 16, 2024
1 parent bcbe83f commit 7deb60f
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions tests/trainer/test_trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@ def test_eval_loop_container(self):
(torch.ones([4, 2, 3]), torch.ones([4, 6])),
]

concat_container = EvalLoopContainer(do_nested_concat=True, padding_index=False)
concat_container = EvalLoopContainer(do_nested_concat=True, padding_index=-100)
concat_container.add(batch_1)
concat_container.add(batch_2)
concat_container.to_cpu_and_numpy()
Expand All @@ -529,8 +529,10 @@ def test_eval_loop_container(self):
self.assertEqual(len(arrays[2]), 2)
self.assertEqual(arrays[2][0].shape, (12, 2, 3))
self.assertEqual(arrays[2][1].shape, (12, 6))
# check that first batch padded with padding index -100 after concatenation
self.assertEqual(arrays[2][1][0][2], -100)

# Test tow batches with no concatenation
# Test two batches with no concatenation
list_container = EvalLoopContainer(do_nested_concat=False)
list_container.add(batch_1)
list_container.add(batch_2)
Expand Down Expand Up @@ -562,14 +564,14 @@ def test_eval_loop_container(self):
self.assertEqual(np_batch_2[2][1].shape, (4, 6))

# Test no batches
none_arr = EvalLoopContainer(do_nested_concat=True, padding_index=False).get_arrays()
none_arr = EvalLoopContainer(do_nested_concat=True, padding_index=-100).get_arrays()
self.assertIsNone(none_arr)

none_arr = EvalLoopContainer(do_nested_concat=False).get_arrays()
self.assertIsNone(none_arr)

# Test one batch
concat_container = EvalLoopContainer(do_nested_concat=True, padding_index=False)
concat_container = EvalLoopContainer(do_nested_concat=True, padding_index=-100)
concat_container.add(batch_1)
arrays = concat_container.get_arrays()
self.assertIsInstance(arrays, list)
Expand Down

0 comments on commit 7deb60f

Please sign in to comment.