Skip to content

Commit

Permalink
Fix legacy ported sequence length bug
Browse files Browse the repository at this point in the history
Signed-off-by: Alex-Brooks <[email protected]>
  • Loading branch information
alex-jw-brooks committed Sep 29, 2023
1 parent 349cbdc commit ed1f8e6
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions caikit_nlp/resources/pretrained_model/hf_auto_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,17 +356,17 @@ def _causal_lm_padding_as_seq2seq(
shifted input sequence.
For now, this is a logical port of the old tokenization logic.
NOTE: In this tokenization strategy, where we concat the texts, the concatenated sequence
length is max_source_length + max_target_length + 1.
"""
IGNORE_ID = -100
# ID of the token to append after our target string; this should generally be pad / EOS
FINAL_TOK_ID = tokenizer.eos_token_id
max_concat_length = max_source_length + max_target_length + 1

# TODO: Add a check to verify if source and example.output both are str or not
# max_length=None => use the model max length (it's actually the default)
# For mrpc [default example], we have 2 sentences + labels to see if they are
# semantically equivalent
model_inputs = tokenizer(source, truncation=True)
labels = tokenizer(target, truncation=True)
# Truncate based on max source or max target length before considering as a joined sequence
model_inputs = tokenizer(source, truncation=True, max_length=max_source_length)
labels = tokenizer(target, truncation=True, max_length=max_target_length + 1)

# Combine the source + target strings into the source input IDs
# This makes the source and target the same length, and then masks the source out of the
Expand All @@ -387,26 +387,26 @@ def _causal_lm_padding_as_seq2seq(
model_inputs = tokenizer.pad(
model_inputs,
padding="max_length",
max_length=max_source_length
max_length=max_concat_length
)

if tokenizer.padding_side.lower() == "left":
labels["input_ids"] = [IGNORE_ID] * (
max_source_length - len(sample_input_ids)
max_concat_length - len(sample_input_ids)
) + label_input_ids
else:
labels["input_ids"] = label_input_ids + [IGNORE_ID] * (
max_source_length - len(sample_input_ids)
max_concat_length - len(sample_input_ids)
)

model_inputs["input_ids"] = torch.tensor(
model_inputs["input_ids"][:max_source_length]
model_inputs["input_ids"][:max_concat_length]
)
model_inputs["attention_mask"] = torch.tensor(
model_inputs["attention_mask"][:max_source_length]
model_inputs["attention_mask"][:max_concat_length]
)
# TODO: This is a bug, but it was present in the thing Alex is WIP porting
labels["input_ids"] = torch.tensor(labels["input_ids"][:max_source_length])

labels["input_ids"] = torch.tensor(labels["input_ids"][:max_concat_length])
model_inputs["labels"] = labels["input_ids"]
model_inputs["task_ids"] = task_ids
return model_inputs

0 comments on commit ed1f8e6

Please sign in to comment.