From ed1f8e61857fb1e216ec4f3814f6c0f1649db1ac Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Fri, 29 Sep 2023 05:13:01 -0500 Subject: [PATCH] Fix legacy ported sequence length bug Signed-off-by: Alex-Brooks --- .../pretrained_model/hf_auto_causal_lm.py | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/caikit_nlp/resources/pretrained_model/hf_auto_causal_lm.py b/caikit_nlp/resources/pretrained_model/hf_auto_causal_lm.py index 1b517b07..6c138726 100644 --- a/caikit_nlp/resources/pretrained_model/hf_auto_causal_lm.py +++ b/caikit_nlp/resources/pretrained_model/hf_auto_causal_lm.py @@ -356,17 +356,17 @@ def _causal_lm_padding_as_seq2seq( shifted input sequence. For now, this is a logical port of the old tokenization logic. + NOTE: In this tokenization strategy, where we concat the texts, the concatenated sequence + length is max_source_length + max_target_length + 1. """ IGNORE_ID = -100 # ID of the token to append after our target string; this should generally be pad / EOS FINAL_TOK_ID = tokenizer.eos_token_id + max_concat_length = max_source_length + max_target_length + 1 - # TODO: Add a check to verify if source and example.output both are str or not - # max_length=None => use the model max length (it's actually the default) - # For mrpc [default example], we have 2 sentences + labels to see if they are - # semantically equivalent - model_inputs = tokenizer(source, truncation=True) - labels = tokenizer(target, truncation=True) + # Truncate based on max source or max target length before considering as a joined sequence + model_inputs = tokenizer(source, truncation=True, max_length=max_source_length) + labels = tokenizer(target, truncation=True, max_length=max_target_length + 1) # Combine the source + target strings into the source input IDs # This makes the source and target the same length, and then masks the source out of the @@ -387,26 +387,26 @@ def _causal_lm_padding_as_seq2seq( model_inputs = tokenizer.pad( model_inputs, padding="max_length", - max_length=max_source_length + max_length=max_concat_length ) if tokenizer.padding_side.lower() == "left": labels["input_ids"] = [IGNORE_ID] * ( - max_source_length - len(sample_input_ids) + max_concat_length - len(sample_input_ids) ) + label_input_ids else: labels["input_ids"] = label_input_ids + [IGNORE_ID] * ( - max_source_length - len(sample_input_ids) + max_concat_length - len(sample_input_ids) ) model_inputs["input_ids"] = torch.tensor( - model_inputs["input_ids"][:max_source_length] + model_inputs["input_ids"][:max_concat_length] ) model_inputs["attention_mask"] = torch.tensor( - model_inputs["attention_mask"][:max_source_length] + model_inputs["attention_mask"][:max_concat_length] ) - # TODO: This is a bug, but it was present in the thing Alex is WIP porting - labels["input_ids"] = torch.tensor(labels["input_ids"][:max_source_length]) + + labels["input_ids"] = torch.tensor(labels["input_ids"][:max_concat_length]) model_inputs["labels"] = labels["input_ids"] model_inputs["task_ids"] = task_ids return model_inputs