From 7deb60f167bbef448a1cbf53cb2e62e1c77b4ede Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Tue, 16 Apr 2024 17:53:30 +0000 Subject: [PATCH] Fix padding_index in test and typo --- tests/trainer/test_trainer_utils.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/trainer/test_trainer_utils.py b/tests/trainer/test_trainer_utils.py index 0a72be47091144..f14f1093c04490 100644 --- a/tests/trainer/test_trainer_utils.py +++ b/tests/trainer/test_trainer_utils.py @@ -510,7 +510,7 @@ def test_eval_loop_container(self): (torch.ones([4, 2, 3]), torch.ones([4, 6])), ] - concat_container = EvalLoopContainer(do_nested_concat=True, padding_index=False) + concat_container = EvalLoopContainer(do_nested_concat=True, padding_index=-100) concat_container.add(batch_1) concat_container.add(batch_2) concat_container.to_cpu_and_numpy() @@ -529,8 +529,10 @@ def test_eval_loop_container(self): self.assertEqual(len(arrays[2]), 2) self.assertEqual(arrays[2][0].shape, (12, 2, 3)) self.assertEqual(arrays[2][1].shape, (12, 6)) + # check that first batch padded with padding index -100 after concatenation + self.assertEqual(arrays[2][1][0][2], -100) - # Test tow batches with no concatenation + # Test two batches with no concatenation list_container = EvalLoopContainer(do_nested_concat=False) list_container.add(batch_1) list_container.add(batch_2) @@ -562,14 +564,14 @@ def test_eval_loop_container(self): self.assertEqual(np_batch_2[2][1].shape, (4, 6)) # Test no batches - none_arr = EvalLoopContainer(do_nested_concat=True, padding_index=False).get_arrays() + none_arr = EvalLoopContainer(do_nested_concat=True, padding_index=-100).get_arrays() self.assertIsNone(none_arr) none_arr = EvalLoopContainer(do_nested_concat=False).get_arrays() self.assertIsNone(none_arr) # Test one batch - concat_container = EvalLoopContainer(do_nested_concat=True, padding_index=False) + concat_container = EvalLoopContainer(do_nested_concat=True, padding_index=-100) concat_container.add(batch_1) arrays = concat_container.get_arrays() self.assertIsInstance(arrays, list)