Skip to content

Commit

Permalink
Update concat seq test for corrected padding
Browse files Browse the repository at this point in the history
Signed-off-by: Alex-Brooks <[email protected]>
  • Loading branch information
alex-jw-brooks committed Sep 29, 2023
1 parent 73e804e commit 349cbdc
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions tests/resources/test_pretrained_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,16 +242,17 @@ def test_causal_lm_as_a_sequence_problem_no_truncation(models_cache_dir, padding
sample = GenerationTrainRecord(
input="Hello world", output="How are you doing today?!"
)
max_lengths = 20 # Must be longer than the concatenated sequence
max_lengths = 20
# First, build the output we expect for left / right respectively...
input_tok = causal_lm.tokenizer.encode(sample.input)
output_tok = causal_lm.tokenizer.encode(sample.output) + [causal_lm.tokenizer.eos_token_id]
concat_res = input_tok + output_tok
masked_res = ([-100] * len(input_tok)) + output_tok

# This must true because otherwise no padding was needed, e.g., truncation
assert len(concat_res) < max_lengths
pads_needed = max_lengths - len(concat_res)
assert len(input_tok) < max_lengths
assert len(output_tok) < (max_lengths + 1)
pads_needed = (1 + 2 * max_lengths) - len(concat_res)
if causal_lm.tokenizer.padding_side.lower() == "left":
expected_input_ids = torch.tensor([causal_lm.tokenizer.pad_token_id] * pads_needed + concat_res)
expected_attn_mask = torch.tensor([0] * pads_needed + [1] * len(concat_res))
Expand All @@ -262,8 +263,6 @@ def test_causal_lm_as_a_sequence_problem_no_truncation(models_cache_dir, padding
expected_labels = torch.tensor(masked_res + [-100] * pads_needed)

# Now build the analogous tokenizer closure and compare the tensors

# Concatenated sequence has length 18; we don't truncate anything here
(tok_func, _) = causal_lm.build_task_tokenize_closure(
tokenizer=causal_lm.tokenizer,
max_source_length=max_lengths,
Expand Down

0 comments on commit 349cbdc

Please sign in to comment.