diff --git a/tests/test_utils.py b/tests/test_utils.py index e79edc755f..226861d96f 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -15,12 +15,19 @@ import unittest import torch +from datasets import load_dataset from transformers import AutoTokenizer from transformers.testing_utils import require_peft from transformers.utils import is_peft_available from trl.trainer.model_config import ModelConfig -from trl.trainer.utils import decode_and_strip_padding, generate_model_card, get_peft_config, pad +from trl.trainer.utils import ( + DataCollatorForChatML, + decode_and_strip_padding, + generate_model_card, + get_peft_config, + pad, +) if is_peft_available(): @@ -169,3 +176,77 @@ def test_val_none(self): assert "my_model" in card_text assert 'pipeline("text-generation", model="username/my_hub_model", device="cuda")' in card_text assert "My Trainer" in card_text + + +class TestDataCollatorForChatML(unittest.TestCase): + def setUp(self): + # Initialize the tokenizer + self.tokenizer = AutoTokenizer.from_pretrained("codellama/CodeLlama-7b-Instruct-hf") + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + # Define token IDs + self.bos_token_id = self.tokenizer.bos_token_id if self.tokenizer.bos_token_id is not None else 1 + self.eos_token_id = self.tokenizer.eos_token_id if self.tokenizer.eos_token_id is not None else 2 + # Token ID for "true", the last assistant's response in the example: + self.ignore_index = -100 + self.max_length = 1024 + self.messages_key = "messages" + + # Example input + dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling", split="train") + self.examples = dataset.to_list() + + # Initialize the data collator + self.collator = DataCollatorForChatML( + tokenizer=self.tokenizer, + max_length=self.max_length, + ignore_index=self.ignore_index, + ) + + def test_data_collator_for_chatml(self): + # Process the data + data = self.collator(self.examples) + + # Decode input_ids and labels for verification + input_ids = data["input_ids"][0].tolist() + labels = data["labels"][0].tolist() + prompt_only = data["prompts"][0].tolist() + + # Verify that input_ids start with optional padding tokens and a single BOS token and there are no extra ones + first_non_pad = next(token for token in input_ids if token != self.tokenizer.pad_token_id) + self.assertEqual( + first_non_pad, self.bos_token_id, "The first non-padding token of input_ids should be BOS token." + ) + self.assertEqual(input_ids.count(self.bos_token_id), 1, "There should be exactly one BOS token in input_ids.") + + # Verify that the assistant's response token is present in input_ids and not in the prompt_only + last_assistant_response = self.examples[0][self.messages_key][-1]["content"] + last_assistant_response_tokens = self.tokenizer.encode(last_assistant_response, add_special_tokens=False) + response_in_input_ids = all(token in input_ids for token in last_assistant_response_tokens) + self.assertTrue(response_in_input_ids, "The assistant's response should be present in input_ids.") + + # Check if the last assistant's response tokens are not in prompt_only + response_in_prompt = all(token in prompt_only for token in last_assistant_response_tokens) + self.assertFalse(response_in_prompt, "The assistant's response should not be present in prompt_only.") + + # Verify that EOS token is at the end of input_ids + self.assertEqual(input_ids[-1], self.eos_token_id, "The last token of input_ids should be EOS token.") + + # Verify that the labels preserved the target string (last_assistant_response) + last_assistant_response = self.examples[0][self.messages_key][-1]["content"] + last_assistant_response_tokens = self.tokenizer.encode(last_assistant_response, add_special_tokens=False) + + # Find the start and end of the last assistant's response in the labels + response_start = next(i for i, label in enumerate(labels) if label != self.ignore_index) + response_end = next(i for i in range(len(labels) - 1, -1, -1) if labels[i] != self.ignore_index) + + actual_response = labels[response_start : response_end - 1] + self.assertEqual( + actual_response, + last_assistant_response_tokens, + "The labels should preserve the last assistant's response tokens.", + ) + + # Verify that EOS token is at the end of labels + self.assertEqual(labels[-1], self.eos_token_id, "The last token of labels should be EOS token.") diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 1591b75ada..81e826f971 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -244,6 +244,7 @@ class DataCollatorForChatML: tokenizer: PreTrainedTokenizerBase ignore_index: int = -100 max_length: int = None + prompt_key: str = "prompt" messages_key: str = "messages" def __post_init__(self): @@ -254,67 +255,69 @@ def __post_init__(self): self.max_length = min(self.tokenizer.model_max_length, 1024) def __call__(self, examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: - prompts = [] - completions = [] - - for example in examples: - messages = example[self.messages_key] - formatted_chat = self.tokenizer.apply_chat_template(messages, tokenize=False) - - # Split the formatted chat into prompt and completion - assistant_messages = [msg for msg in messages if msg["role"] == "assistant"] - last_assistant_message = assistant_messages[-1]["content"] - prompt = formatted_chat.rsplit(last_assistant_message, 1)[0] - completion = last_assistant_message - - prompts.append(prompt) - completions.append(completion) - - # Tokenize prompts and completions - tokenized_prompts = self.tokenizer( - prompts, truncation=True, max_length=self.max_length, padding=False, return_tensors=None - ) - tokenized_completions = self.tokenizer( - completions, truncation=True, max_length=self.max_length, padding=False, return_tensors=None - ) - - # Combine prompts and completions input_ids = [] attention_mask = [] + prompts_input_ids = [] + prompt_attention_mask = [] labels = [] - for prompt, completion in zip(tokenized_prompts["input_ids"], tokenized_completions["input_ids"]): - combined_input_ids = prompt + completion - combined_attention_mask = [1] * len(combined_input_ids) + for example in examples: + formatted_prompt = example.get(self.prompt_key, None) + if formatted_prompt is None: + prompt = example[self.messages_key][:-1] + formatted_prompt = self.tokenizer.apply_chat_template( + prompt, tokenize=False, add_generation_prompt=True + ) - # Create labels for one-token ahead task, masking the prompt - combined_labels = [self.ignore_index] * len(prompt) + completion[:-1] - combined_labels.append(self.tokenizer.eos_token_id) # Add EOS token as final target + if "input_ids" not in example: + message = example[self.messages_key] + formatted_message = self.tokenizer.apply_chat_template( + message, tokenize=False, add_generation_prompt=True + ) + tokenized_message = self.tokenizer( + formatted_message, + truncation=True, + max_length=self.max_length, + padding=False, + return_tensors=None, + add_special_tokens=False, + ) + input_ids.append(tokenized_message["input_ids"]) + attention_mask.append(tokenized_message["attention_mask"]) + else: + input_ids.append(example["input_ids"]) + attention_mask.append(example["attention_mask"]) + + tokenized_prompt = self.tokenizer( + formatted_prompt, + truncation=True, + max_length=len(input_ids[-1]), + padding=False, + return_tensors=None, + add_special_tokens=False, + ) - input_ids.append(combined_input_ids) - attention_mask.append(combined_attention_mask) - labels.append(combined_labels) + prompts_input_ids.append(tokenized_prompt["input_ids"]) + prompt_attention_mask.append(tokenized_prompt["attention_mask"]) - # first convert to list of tensors - input_ids = [torch.tensor(ids) for ids in input_ids] - attention_mask = [torch.tensor(mask) for mask in attention_mask] - labels = [torch.tensor(label) for label in labels] + # Create the labels that will have all but the completion tokens of the example["input_ids"] set to ignore_index + label = [self.ignore_index] * len(input_ids[-1]) + completion_start_idx = len(tokenized_prompt["input_ids"]) + label[completion_start_idx:] = input_ids[-1][completion_start_idx:] + labels.append(label) - # pad the input_ids, attention_mask and labels to the same length across the batch + # convert to list of tensors and pad + input_ids = [torch.tensor(ids, dtype=torch.long) for ids in input_ids] + attention_mask = [torch.tensor(mask, dtype=torch.long) for mask in attention_mask] + labels = [torch.tensor(label, dtype=torch.long) for label in labels] input_ids = pad(input_ids, padding_side="left", padding_value=self.tokenizer.pad_token_id) attention_mask = pad(attention_mask, padding_side="left", padding_value=0) labels = pad(labels, padding_side="left", padding_value=self.ignore_index) - # pad the tokenized_prompts on the left to the same length convert to tensor first - prompts_input_ids = [torch.tensor(ids) for ids in tokenized_prompts["input_ids"]] + prompts_input_ids = [torch.tensor(ids, dtype=torch.long) for ids in prompts_input_ids] + prompt_attention_mask = [torch.tensor(mask, dtype=torch.long) for mask in prompt_attention_mask] prompts_input_ids = pad(prompts_input_ids, padding_side="left", padding_value=self.tokenizer.pad_token_id) - - # prompt attention mask - prompt_attention_mask = pad( - [torch.tensor([1] * len(ids)) for ids in tokenized_prompts["input_ids"]], - padding_side="left", - padding_value=0, - ) + prompt_attention_mask = pad(prompt_attention_mask, padding_side="left", padding_value=0) return { "input_ids": input_ids,