diff --git a/docs/source/online_dpo_trainer.md b/docs/source/online_dpo_trainer.md index 823f2e56cc..49e40957c1 100644 --- a/docs/source/online_dpo_trainer.md +++ b/docs/source/online_dpo_trainer.md @@ -79,20 +79,17 @@ Instead of a judge, you can chose to use a reward model -- see [Reward Bench](ht - judge = PairRMJudge() + reward_model = AutoModelForSequenceClassification.from_pretrained("trl-lib/Qwen2-0.5B-Reward", num_labels=1) ++ reward_tokenizer = AutoTokenizer.from_pretrained("trl-lib/Qwen2-0.5B-Reward") trainer = OnlineDPOTrainer( ... - judge=judge, + reward_model=reward_model, ++ reward_processing_class=reward_tokenizer, + ... ) ``` - - -Make sure that the SFT model and reward model use the _same_ chat template and the same tokenizer. Otherwise, you may find the model completions are scored incorrectly during training. - - - ### Encourage EOS token generation When using a reward model, we may want the model to generate completions within a given length. During training, the model will generate completions up to the maximum length specified in the `max_new_tokens` argument of [`OnlineDPOConfig`]. If you want to penalize the model for not generating an EOS token before reaching the maximum length, you can use the `missing_eos_penalty` argument of [`OnlineDPOConfig`]: diff --git a/examples/scripts/dpo.py b/examples/scripts/dpo.py index 7c1114eedd..ed14461725 100644 --- a/examples/scripts/dpo.py +++ b/examples/scripts/dpo.py @@ -126,9 +126,11 @@ ) trainer.train() - metrics = trainer.evaluate() - trainer.log_metrics("eval", metrics) - trainer.save_metrics("eval", metrics) + + if training_args.eval_strategy != "no": + metrics = trainer.evaluate() + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) # Save and push to hub trainer.save_model(training_args.output_dir) diff --git a/examples/scripts/dpo_online.py b/examples/scripts/dpo_online.py index 33dea12ff2..58d7c4a2c0 100644 --- a/examples/scripts/dpo_online.py +++ b/examples/scripts/dpo_online.py @@ -93,8 +93,15 @@ trust_remote_code=model_config.trust_remote_code, **model_kwargs, ) + reward_tokenizer = AutoTokenizer.from_pretrained( + training_args.reward_model_path, + trust_remote_code=model_config.trust_remote_code, + truncation=True, + truncation_side="left", # since we judge the completion, truncating left is more appropriate + ) else: reward_model = None + reward_tokenizer = None if training_args.judge is not None: judge_cls = JUDGES[training_args.judge] @@ -123,13 +130,17 @@ train_dataset=dataset[script_args.dataset_train_split], eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, processing_class=tokenizer, + reward_processing_class=reward_tokenizer, peft_config=get_peft_config(model_config), ) - generation_config = GenerationConfig( - max_new_tokens=training_args.max_new_tokens, do_sample=True, temperature=training_args.temperature - ) - completions_callback = LogCompletionsCallback(trainer, generation_config, num_prompts=8) - trainer.add_callback(completions_callback) + + if training_args.eval_strategy != "no": + generation_config = GenerationConfig( + max_new_tokens=training_args.max_new_tokens, do_sample=True, temperature=training_args.temperature + ) + completions_callback = LogCompletionsCallback(trainer, generation_config, num_prompts=8) + trainer.add_callback(completions_callback) + trainer.train() # Save and push to hub diff --git a/examples/scripts/gkd.py b/examples/scripts/gkd.py index 1ceaf5b085..ed5d07a47e 100644 --- a/examples/scripts/gkd.py +++ b/examples/scripts/gkd.py @@ -46,7 +46,7 @@ from accelerate import PartialState from datasets import load_dataset -from transformers import AutoTokenizer +from transformers import AutoTokenizer, GenerationConfig from trl import ( GKDConfig, @@ -125,8 +125,14 @@ processing_class=tokenizer, peft_config=get_peft_config(model_config), ) - completions_callback = LogCompletionsCallback(trainer, trainer.generation_config, num_prompts=8) - trainer.add_callback(completions_callback) + + if training_args.eval_strategy != "no": + generation_config = GenerationConfig( + max_new_tokens=training_args.max_new_tokens, do_sample=True, temperature=training_args.temperature + ) + completions_callback = LogCompletionsCallback(trainer, generation_config, num_prompts=8) + trainer.add_callback(completions_callback) + trainer.train() # Save and push to hub diff --git a/examples/scripts/nash_md.py b/examples/scripts/nash_md.py index c492faf76f..b9492e6042 100644 --- a/examples/scripts/nash_md.py +++ b/examples/scripts/nash_md.py @@ -132,12 +132,14 @@ eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, processing_class=tokenizer, ) - generation_config = GenerationConfig( - max_new_tokens=training_args.max_new_tokens, do_sample=True, temperature=training_args.temperature - ) - completions_callback = LogCompletionsCallback(trainer, generation_config, num_prompts=8) - trainer.add_callback(completions_callback) - # train the model + + if training_args.eval_strategy != "no": + generation_config = GenerationConfig( + max_new_tokens=training_args.max_new_tokens, do_sample=True, temperature=training_args.temperature + ) + completions_callback = LogCompletionsCallback(trainer, generation_config, num_prompts=8) + trainer.add_callback(completions_callback) + trainer.train() # Save and push to hub diff --git a/examples/scripts/reward_modeling.py b/examples/scripts/reward_modeling.py index 40924a3ad3..073016bc77 100644 --- a/examples/scripts/reward_modeling.py +++ b/examples/scripts/reward_modeling.py @@ -124,9 +124,11 @@ # Save model and push to Hub ############################ trainer.save_model(training_args.output_dir) - metrics = trainer.evaluate() - trainer.log_metrics("eval", metrics) - trainer.save_metrics("eval", metrics) + + if training_args.eval_strategy != "no": + metrics = trainer.evaluate() + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) # Save and push to hub trainer.save_model(training_args.output_dir) diff --git a/examples/scripts/xpo.py b/examples/scripts/xpo.py index 8a07e96b3d..1c465d90e6 100644 --- a/examples/scripts/xpo.py +++ b/examples/scripts/xpo.py @@ -117,12 +117,14 @@ eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, processing_class=tokenizer, ) - generation_config = GenerationConfig( - max_new_tokens=training_args.max_new_tokens, do_sample=True, temperature=training_args.temperature - ) - completions_callback = LogCompletionsCallback(trainer, generation_config, num_prompts=8) - trainer.add_callback(completions_callback) - # train the model + + if training_args.eval_strategy != "no": + generation_config = GenerationConfig( + max_new_tokens=training_args.max_new_tokens, do_sample=True, temperature=training_args.temperature + ) + completions_callback = LogCompletionsCallback(trainer, generation_config, num_prompts=8) + trainer.add_callback(completions_callback) + trainer.train() # Save and push to hub diff --git a/tests/test_judges.py b/tests/test_judges.py index def75066f1..0b393ff9fd 100644 --- a/tests/test_judges.py +++ b/tests/test_judges.py @@ -18,6 +18,12 @@ class TestJudges(unittest.TestCase): + @classmethod + def setUpClass(cls): + # Initialize once to download the model. This ensures it’s downloaded before running tests, preventing issues + # where concurrent tests attempt to load the model while it’s still downloading. + PairRMJudge() + def _get_prompts_and_completions(self): prompts = ["The capital of France is", "The biggest planet in the solar system is"] completions = [["Paris", "Marseille"], ["Saturn", "Jupiter"]] diff --git a/tests/test_online_dpo_trainer.py b/tests/test_online_dpo_trainer.py index 9a4e7680a2..9058462bed 100644 --- a/tests/test_online_dpo_trainer.py +++ b/tests/test_online_dpo_trainer.py @@ -20,7 +20,8 @@ from transformers.testing_utils import require_peft from transformers.utils import is_peft_available -from trl import OnlineDPOConfig, OnlineDPOTrainer, PairRMJudge, is_llmblender_available +from trl import OnlineDPOConfig, OnlineDPOTrainer, RandomPairwiseJudge, is_llmblender_available +from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE if is_peft_available(): @@ -33,6 +34,9 @@ def setUp(self): self.model = AutoModelForCausalLM.from_pretrained(self.model_id) self.ref_model = AutoModelForCausalLM.from_pretrained(self.model_id) self.reward_model = AutoModelForSequenceClassification.from_pretrained("EleutherAI/pythia-14m", num_labels=1) + self.reward_tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-14m") + self.reward_tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE + self.reward_tokenizer.pad_token = self.reward_tokenizer.eos_token self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) self.tokenizer.pad_token = self.tokenizer.eos_token @@ -53,9 +57,10 @@ def test_training(self, config_name): model=self.model, reward_model=self.reward_model, args=training_args, - processing_class=self.tokenizer, train_dataset=dummy_dataset["train"], eval_dataset=dummy_dataset["test"], + processing_class=self.tokenizer, + reward_processing_class=self.reward_tokenizer, ) trainer.train() @@ -79,9 +84,10 @@ def test_training_with_ref_model(self): ref_model=self.ref_model, reward_model=self.reward_model, args=training_args, - processing_class=self.tokenizer, train_dataset=dummy_dataset["train"], eval_dataset=dummy_dataset["test"], + processing_class=self.tokenizer, + reward_processing_class=self.reward_tokenizer, ) trainer.train() @@ -103,9 +109,11 @@ def test_ref_model_is_model(self): OnlineDPOTrainer( model=self.model, ref_model=self.model, # ref_model can't be the same as model + reward_model=self.reward_model, args=training_args, - processing_class=self.tokenizer, train_dataset=dummy_dataset["train"], + processing_class=self.tokenizer, + reward_processing_class=self.reward_tokenizer, ) @require_peft @@ -126,9 +134,10 @@ def test_training_with_peft(self): model=self.model, reward_model=self.reward_model, args=training_args, - processing_class=self.tokenizer, train_dataset=dummy_dataset["train"], eval_dataset=dummy_dataset["test"], + processing_class=self.tokenizer, + reward_processing_class=self.reward_tokenizer, peft_config=lora_config, ) @@ -156,9 +165,10 @@ def test_training_with_peft_and_ref_model(self): ref_model=self.ref_model, reward_model=self.reward_model, args=training_args, - processing_class=self.tokenizer, train_dataset=dummy_dataset["train"], eval_dataset=dummy_dataset["test"], + processing_class=self.tokenizer, + reward_processing_class=self.reward_tokenizer, peft_config=lora_config, ) @@ -188,9 +198,10 @@ def test_training_with_peft_model_and_peft_config(self): model=model, reward_model=self.reward_model, args=training_args, - processing_class=self.tokenizer, train_dataset=dummy_dataset["train"], eval_dataset=dummy_dataset["test"], + processing_class=self.tokenizer, + reward_processing_class=self.reward_tokenizer, peft_config=lora_train_config, ) @@ -200,7 +211,8 @@ def test_training_with_peft_model_and_peft_config(self): self.assertIn("train_loss", trainer.state.log_history[-1]) @unittest.skipIf(not is_llmblender_available(), "llm-blender is not available") - def test_training_with_judge(self): + @parameterized.expand([("standard_prompt_only",), ("conversational_prompt_only",)]) + def test_training_with_judge(self, config_name): with tempfile.TemporaryDirectory() as tmp_dir: training_args = OnlineDPOConfig( output_dir=tmp_dir, @@ -210,15 +222,15 @@ def test_training_with_judge(self): eval_strategy="steps", report_to="none", ) - dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only") + dummy_dataset = load_dataset("trl-internal-testing/zen", config_name) trainer = OnlineDPOTrainer( model=self.model, - judge=PairRMJudge(), + judge=RandomPairwiseJudge(), args=training_args, - processing_class=self.tokenizer, train_dataset=dummy_dataset["train"], eval_dataset=dummy_dataset["test"], + processing_class=self.tokenizer, ) trainer.train() diff --git a/tests/test_trainers_args.py b/tests/test_trainers_args.py index c7994b5564..0e49d8dd98 100644 --- a/tests/test_trainers_args.py +++ b/tests/test_trainers_args.py @@ -272,12 +272,13 @@ def test_online_dpo(self, beta_list): reward_model = AutoModelForSequenceClassification.from_pretrained("EleutherAI/pythia-14m", num_labels=1) tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-14m") trainer = OnlineDPOTrainer( - args=training_args, - processing_class=tokenizer, model=model, ref_model=ref_model, reward_model=reward_model, + args=training_args, train_dataset=dataset, + processing_class=tokenizer, + reward_processing_class=tokenizer, ) self.assertEqual(trainer.args.max_new_tokens, 42) self.assertEqual(trainer.args.temperature, 0.5) diff --git a/tests/test_xpo_trainer.py b/tests/test_xpo_trainer.py index e03b855d5a..ad8050ac00 100644 --- a/tests/test_xpo_trainer.py +++ b/tests/test_xpo_trainer.py @@ -20,7 +20,7 @@ from transformers.testing_utils import require_peft from transformers.utils import is_peft_available -from trl import PairRMJudge, XPOConfig, XPOTrainer, is_llmblender_available +from trl import RandomPairwiseJudge, XPOConfig, XPOTrainer, is_llmblender_available if is_peft_available(): @@ -171,7 +171,7 @@ def test_xpo_trainer_judge_training(self, config_name): report_to="none", ) dummy_dataset = load_dataset("trl-internal-testing/zen", config_name) - judge = PairRMJudge() + judge = RandomPairwiseJudge() trainer = XPOTrainer( model=self.model, diff --git a/trl/trainer/nash_md_trainer.py b/trl/trainer/nash_md_trainer.py index 68b40db4b9..c998174765 100644 --- a/trl/trainer/nash_md_trainer.py +++ b/trl/trainer/nash_md_trainer.py @@ -122,6 +122,7 @@ def __init__( train_dataset=train_dataset, eval_dataset=eval_dataset, processing_class=processing_class, + reward_processing_class=processing_class, # for now, NashMDTrainer can't use any reward model peft_config=peft_config, compute_metrics=compute_metrics, callbacks=callbacks, diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index 21ce78d8bc..2575c2f886 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -44,7 +44,7 @@ from transformers.training_args import OptimizerNames from transformers.utils import is_peft_available, is_sagemaker_mp_enabled, logging -from ..data_utils import is_conversational, maybe_apply_chat_template +from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template from ..models import create_reference_model from ..models.utils import unwrap_model_for_generation from .judges import BasePairwiseJudge @@ -137,6 +137,7 @@ def __init__( processing_class: Optional[ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] ] = None, + reward_processing_class: Optional[PreTrainedTokenizerBase] = None, peft_config: Optional[Dict] = None, compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, callbacks: Optional[List[TrainerCallback]] = None, @@ -161,6 +162,7 @@ def __init__( raise ValueError("Either `reward_model` or `judge` must be provided.") self.reward_model = reward_model + self.reward_processing_class = reward_processing_class self.judge = judge if args.missing_eos_penalty is not None and judge is not None: @@ -428,18 +430,23 @@ def training_step( ref_logprobs = torch.take_along_dim(ref_all_logprobs, completion_ids.unsqueeze(-1), dim=2).squeeze(-1) del ref_output, ref_logits, ref_all_logprobs # free memory - # Get the reward from the reward model or judge: - if self.judge is not None: - completions = self.processing_class.batch_decode( - prompt_completion_ids[:, context_length:], skip_special_tokens=True - ) - completions = [completion.strip() for completion in completions] # remove the leading space + # Decode the completions, and format them if the input is conversational + device = prompt_completion_ids.device + completions_ids = prompt_completion_ids[:, context_length:] + completions = self.processing_class.batch_decode(completions_ids, skip_special_tokens=True) + if is_conversational({"prompt": prompts[0]}): + completions = [[{"role": "assistant", "content": completion}] for completion in completions] + # Get the reward from the reward model or judge + if self.judge is not None: + # Once formatted, conversational data may contain special tokens (such as <|im_start|>) that are not + # directly understandable by the judge and could alter its judgment. To avoid this and make the judge + # independent of the model's chat template, we use the raw conversation data, and apply our own chat + # template to it. if is_conversational({"prompt": prompts[0]}): - completions = [[{"role": "assistant", "content": completion}] for completion in completions] environment = jinja2.Environment() template = environment.from_string(SIMPLE_CHAT_TEMPLATE) - prompts = [template.render(messages=message) for message in prompts] + prompts = [template.render(messages=prompt) for prompt in prompts] completions = [template.render(messages=completion) for completion in completions] ranks_of_first_completion = self.judge.judge( @@ -449,16 +456,39 @@ def training_step( # convert ranks to a True/False mask: # when rank == 0, it means the first completion is the best # when rank == 1, it means the second completion is the best - mask = torch.tensor([rank == 0 for rank in ranks_of_first_completion], device=prompt_completion_ids.device) + mask = torch.tensor([rank == 0 for rank in ranks_of_first_completion], device=device) else: - _, scores, _ = get_reward( - self.reward_model, prompt_completion_ids, self.processing_class.pad_token_id, context_length - ) + # The reward model may not have the same chat template or tokenizer as the model, so we need to use the + # raw data (string), apply the chat template (if needed), and tokenize it with the reward processing class. + prompts = 2 * prompts # repeat the prompt: [prompt0, prompt1] -> [prompt0, prompt1, prompt0, prompt1] + if is_conversational({"prompt": prompts[0]}): + examples = [{"prompt": p, "completion": c} for p, c in zip(prompts, completions)] + examples = [apply_chat_template(example, self.reward_processing_class) for example in examples] + prompts = [example["prompt"] for example in examples] + completions = [example["completion"] for example in examples] + + # Tokenize the prompts + prompts_ids = self.reward_processing_class( + prompts, padding=True, return_tensors="pt", padding_side="left" + )["input_ids"].to(device) + context_length = prompts_ids.shape[1] + + # Tokenize the completions + completions_ids = self.reward_processing_class( + completions, padding=True, return_tensors="pt", padding_side="right" + )["input_ids"].to(device) + + # Concatenate the prompts and completions and get the reward + prompt_completion_ids = torch.cat((prompts_ids, completions_ids), dim=1) + with torch.inference_mode(): + _, scores, _ = get_reward( + self.reward_model, prompt_completion_ids, self.reward_processing_class.pad_token_id, context_length + ) - # Filter completion. Ensure that the sample contains stop_token_id - # Completions not passing that filter will receive a lower score. - if self.args.missing_eos_penalty is not None: - scores[~contain_eos_token] -= self.args.missing_eos_penalty + # Filter completion. Ensure that the sample contains stop_token_id + # Completions not passing that filter will receive a lower score. + if self.args.missing_eos_penalty is not None: + scores[~contain_eos_token] -= self.args.missing_eos_penalty # Split the scores in 2 (the prompts of the first half are the same as the second half) first_half, second_half = scores.split(num_examples) @@ -466,7 +496,7 @@ def training_step( # Get the indices of the chosen and rejected examples mask = first_half >= second_half - num_examples_range = torch.arange(num_examples, device=prompt_completion_ids.device) + num_examples_range = torch.arange(num_examples, device=device) chosen_indices = num_examples_range + (~mask * num_examples) rejected_indices = num_examples_range + (mask * num_examples) diff --git a/trl/trainer/xpo_trainer.py b/trl/trainer/xpo_trainer.py index d643dab651..4ec501c7f0 100644 --- a/trl/trainer/xpo_trainer.py +++ b/trl/trainer/xpo_trainer.py @@ -121,6 +121,7 @@ def __init__( train_dataset=train_dataset, eval_dataset=eval_dataset, processing_class=processing_class, + reward_processing_class=processing_class, # for now, XPOTrainer can't use any reward model peft_config=peft_config, compute_metrics=compute_metrics, callbacks=callbacks,