From 349cbdcba86f154cbc1f0c347990f226d04f5975 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Fri, 29 Sep 2023 05:12:37 -0500 Subject: [PATCH] Update concat seq test for corrected padding Signed-off-by: Alex-Brooks --- tests/resources/test_pretrained_model.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/resources/test_pretrained_model.py b/tests/resources/test_pretrained_model.py index 1d60ea6e..c9251068 100644 --- a/tests/resources/test_pretrained_model.py +++ b/tests/resources/test_pretrained_model.py @@ -242,7 +242,7 @@ def test_causal_lm_as_a_sequence_problem_no_truncation(models_cache_dir, padding sample = GenerationTrainRecord( input="Hello world", output="How are you doing today?!" ) - max_lengths = 20 # Must be longer than the concatenated sequence + max_lengths = 20 # First, build the output we expect for left / right respectively... input_tok = causal_lm.tokenizer.encode(sample.input) output_tok = causal_lm.tokenizer.encode(sample.output) + [causal_lm.tokenizer.eos_token_id] @@ -250,8 +250,9 @@ def test_causal_lm_as_a_sequence_problem_no_truncation(models_cache_dir, padding masked_res = ([-100] * len(input_tok)) + output_tok # This must true because otherwise no padding was needed, e.g., truncation - assert len(concat_res) < max_lengths - pads_needed = max_lengths - len(concat_res) + assert len(input_tok) < max_lengths + assert len(output_tok) < (max_lengths + 1) + pads_needed = (1 + 2 * max_lengths) - len(concat_res) if causal_lm.tokenizer.padding_side.lower() == "left": expected_input_ids = torch.tensor([causal_lm.tokenizer.pad_token_id] * pads_needed + concat_res) expected_attn_mask = torch.tensor([0] * pads_needed + [1] * len(concat_res)) @@ -262,8 +263,6 @@ def test_causal_lm_as_a_sequence_problem_no_truncation(models_cache_dir, padding expected_labels = torch.tensor(masked_res + [-100] * pads_needed) # Now build the analogous tokenizer closure and compare the tensors - - # Concatenated sequence has length 18; we don't truncate anything here (tok_func, _) = causal_lm.build_task_tokenize_closure( tokenizer=causal_lm.tokenizer, max_source_length=max_lengths,