Skip to content

Commit

Permalink
🍬 Use any reward model for online methods (#2276)
Browse files Browse the repository at this point in the history
* Refactor reward processing in OnlineDPOTrainer

* Refactor completion decoding and reward processing

* remove strip

* remove warning

* Add reward_tokenizer to training script

* Add reward_tokenizer and reward_processing_class to OnlineDPOTrainer test

* propagate to xpo and nash

* style

* reduce memory requirement with inference_mode

* fix tests

* pairrm judge llmblender

* setUpClass(cls)

* Add setUpClass method to TestJudges class

* truncation left for reward tokenizer

* don't logcompletion without eval dataset

* only eval when possible
  • Loading branch information
qgallouedec authored Oct 28, 2024
1 parent 0ce3b65 commit b269657
Show file tree
Hide file tree
Showing 14 changed files with 138 additions and 65 deletions.
9 changes: 3 additions & 6 deletions docs/source/online_dpo_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,20 +79,17 @@ Instead of a judge, you can chose to use a reward model -- see [Reward Bench](ht

- judge = PairRMJudge()
+ reward_model = AutoModelForSequenceClassification.from_pretrained("trl-lib/Qwen2-0.5B-Reward", num_labels=1)
+ reward_tokenizer = AutoTokenizer.from_pretrained("trl-lib/Qwen2-0.5B-Reward")

trainer = OnlineDPOTrainer(
...
- judge=judge,
+ reward_model=reward_model,
+ reward_processing_class=reward_tokenizer,
...
)
```

<Tip warning={true}>

Make sure that the SFT model and reward model use the _same_ chat template and the same tokenizer. Otherwise, you may find the model completions are scored incorrectly during training.

</Tip>

### Encourage EOS token generation

When using a reward model, we may want the model to generate completions within a given length. During training, the model will generate completions up to the maximum length specified in the `max_new_tokens` argument of [`OnlineDPOConfig`]. If you want to penalize the model for not generating an EOS token before reaching the maximum length, you can use the `missing_eos_penalty` argument of [`OnlineDPOConfig`]:
Expand Down
8 changes: 5 additions & 3 deletions examples/scripts/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,11 @@
)

trainer.train()
metrics = trainer.evaluate()
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

if training_args.eval_strategy != "no":
metrics = trainer.evaluate()
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

# Save and push to hub
trainer.save_model(training_args.output_dir)
Expand Down
21 changes: 16 additions & 5 deletions examples/scripts/dpo_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,15 @@
trust_remote_code=model_config.trust_remote_code,
**model_kwargs,
)
reward_tokenizer = AutoTokenizer.from_pretrained(
training_args.reward_model_path,
trust_remote_code=model_config.trust_remote_code,
truncation=True,
truncation_side="left", # since we judge the completion, truncating left is more appropriate
)
else:
reward_model = None
reward_tokenizer = None

if training_args.judge is not None:
judge_cls = JUDGES[training_args.judge]
Expand Down Expand Up @@ -123,13 +130,17 @@
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
processing_class=tokenizer,
reward_processing_class=reward_tokenizer,
peft_config=get_peft_config(model_config),
)
generation_config = GenerationConfig(
max_new_tokens=training_args.max_new_tokens, do_sample=True, temperature=training_args.temperature
)
completions_callback = LogCompletionsCallback(trainer, generation_config, num_prompts=8)
trainer.add_callback(completions_callback)

if training_args.eval_strategy != "no":
generation_config = GenerationConfig(
max_new_tokens=training_args.max_new_tokens, do_sample=True, temperature=training_args.temperature
)
completions_callback = LogCompletionsCallback(trainer, generation_config, num_prompts=8)
trainer.add_callback(completions_callback)

trainer.train()

# Save and push to hub
Expand Down
12 changes: 9 additions & 3 deletions examples/scripts/gkd.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@

from accelerate import PartialState
from datasets import load_dataset
from transformers import AutoTokenizer
from transformers import AutoTokenizer, GenerationConfig

from trl import (
GKDConfig,
Expand Down Expand Up @@ -125,8 +125,14 @@
processing_class=tokenizer,
peft_config=get_peft_config(model_config),
)
completions_callback = LogCompletionsCallback(trainer, trainer.generation_config, num_prompts=8)
trainer.add_callback(completions_callback)

if training_args.eval_strategy != "no":
generation_config = GenerationConfig(
max_new_tokens=training_args.max_new_tokens, do_sample=True, temperature=training_args.temperature
)
completions_callback = LogCompletionsCallback(trainer, generation_config, num_prompts=8)
trainer.add_callback(completions_callback)

trainer.train()

# Save and push to hub
Expand Down
14 changes: 8 additions & 6 deletions examples/scripts/nash_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,12 +132,14 @@
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
processing_class=tokenizer,
)
generation_config = GenerationConfig(
max_new_tokens=training_args.max_new_tokens, do_sample=True, temperature=training_args.temperature
)
completions_callback = LogCompletionsCallback(trainer, generation_config, num_prompts=8)
trainer.add_callback(completions_callback)
# train the model

if training_args.eval_strategy != "no":
generation_config = GenerationConfig(
max_new_tokens=training_args.max_new_tokens, do_sample=True, temperature=training_args.temperature
)
completions_callback = LogCompletionsCallback(trainer, generation_config, num_prompts=8)
trainer.add_callback(completions_callback)

trainer.train()

# Save and push to hub
Expand Down
8 changes: 5 additions & 3 deletions examples/scripts/reward_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,11 @@
# Save model and push to Hub
############################
trainer.save_model(training_args.output_dir)
metrics = trainer.evaluate()
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

if training_args.eval_strategy != "no":
metrics = trainer.evaluate()
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

# Save and push to hub
trainer.save_model(training_args.output_dir)
Expand Down
14 changes: 8 additions & 6 deletions examples/scripts/xpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,14 @@
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
processing_class=tokenizer,
)
generation_config = GenerationConfig(
max_new_tokens=training_args.max_new_tokens, do_sample=True, temperature=training_args.temperature
)
completions_callback = LogCompletionsCallback(trainer, generation_config, num_prompts=8)
trainer.add_callback(completions_callback)
# train the model

if training_args.eval_strategy != "no":
generation_config = GenerationConfig(
max_new_tokens=training_args.max_new_tokens, do_sample=True, temperature=training_args.temperature
)
completions_callback = LogCompletionsCallback(trainer, generation_config, num_prompts=8)
trainer.add_callback(completions_callback)

trainer.train()

# Save and push to hub
Expand Down
6 changes: 6 additions & 0 deletions tests/test_judges.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@


class TestJudges(unittest.TestCase):
@classmethod
def setUpClass(cls):
# Initialize once to download the model. This ensures it’s downloaded before running tests, preventing issues
# where concurrent tests attempt to load the model while it’s still downloading.
PairRMJudge()

def _get_prompts_and_completions(self):
prompts = ["The capital of France is", "The biggest planet in the solar system is"]
completions = [["Paris", "Marseille"], ["Saturn", "Jupiter"]]
Expand Down
34 changes: 23 additions & 11 deletions tests/test_online_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
from transformers.testing_utils import require_peft
from transformers.utils import is_peft_available

from trl import OnlineDPOConfig, OnlineDPOTrainer, PairRMJudge, is_llmblender_available
from trl import OnlineDPOConfig, OnlineDPOTrainer, RandomPairwiseJudge, is_llmblender_available
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE


if is_peft_available():
Expand All @@ -33,6 +34,9 @@ def setUp(self):
self.model = AutoModelForCausalLM.from_pretrained(self.model_id)
self.ref_model = AutoModelForCausalLM.from_pretrained(self.model_id)
self.reward_model = AutoModelForSequenceClassification.from_pretrained("EleutherAI/pythia-14m", num_labels=1)
self.reward_tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-14m")
self.reward_tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE
self.reward_tokenizer.pad_token = self.reward_tokenizer.eos_token
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
self.tokenizer.pad_token = self.tokenizer.eos_token

Expand All @@ -53,9 +57,10 @@ def test_training(self, config_name):
model=self.model,
reward_model=self.reward_model,
args=training_args,
processing_class=self.tokenizer,
train_dataset=dummy_dataset["train"],
eval_dataset=dummy_dataset["test"],
processing_class=self.tokenizer,
reward_processing_class=self.reward_tokenizer,
)
trainer.train()

Expand All @@ -79,9 +84,10 @@ def test_training_with_ref_model(self):
ref_model=self.ref_model,
reward_model=self.reward_model,
args=training_args,
processing_class=self.tokenizer,
train_dataset=dummy_dataset["train"],
eval_dataset=dummy_dataset["test"],
processing_class=self.tokenizer,
reward_processing_class=self.reward_tokenizer,
)
trainer.train()

Expand All @@ -103,9 +109,11 @@ def test_ref_model_is_model(self):
OnlineDPOTrainer(
model=self.model,
ref_model=self.model, # ref_model can't be the same as model
reward_model=self.reward_model,
args=training_args,
processing_class=self.tokenizer,
train_dataset=dummy_dataset["train"],
processing_class=self.tokenizer,
reward_processing_class=self.reward_tokenizer,
)

@require_peft
Expand All @@ -126,9 +134,10 @@ def test_training_with_peft(self):
model=self.model,
reward_model=self.reward_model,
args=training_args,
processing_class=self.tokenizer,
train_dataset=dummy_dataset["train"],
eval_dataset=dummy_dataset["test"],
processing_class=self.tokenizer,
reward_processing_class=self.reward_tokenizer,
peft_config=lora_config,
)

Expand Down Expand Up @@ -156,9 +165,10 @@ def test_training_with_peft_and_ref_model(self):
ref_model=self.ref_model,
reward_model=self.reward_model,
args=training_args,
processing_class=self.tokenizer,
train_dataset=dummy_dataset["train"],
eval_dataset=dummy_dataset["test"],
processing_class=self.tokenizer,
reward_processing_class=self.reward_tokenizer,
peft_config=lora_config,
)

Expand Down Expand Up @@ -188,9 +198,10 @@ def test_training_with_peft_model_and_peft_config(self):
model=model,
reward_model=self.reward_model,
args=training_args,
processing_class=self.tokenizer,
train_dataset=dummy_dataset["train"],
eval_dataset=dummy_dataset["test"],
processing_class=self.tokenizer,
reward_processing_class=self.reward_tokenizer,
peft_config=lora_train_config,
)

Expand All @@ -200,7 +211,8 @@ def test_training_with_peft_model_and_peft_config(self):
self.assertIn("train_loss", trainer.state.log_history[-1])

@unittest.skipIf(not is_llmblender_available(), "llm-blender is not available")
def test_training_with_judge(self):
@parameterized.expand([("standard_prompt_only",), ("conversational_prompt_only",)])
def test_training_with_judge(self, config_name):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = OnlineDPOConfig(
output_dir=tmp_dir,
Expand All @@ -210,15 +222,15 @@ def test_training_with_judge(self):
eval_strategy="steps",
report_to="none",
)
dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only")
dummy_dataset = load_dataset("trl-internal-testing/zen", config_name)

trainer = OnlineDPOTrainer(
model=self.model,
judge=PairRMJudge(),
judge=RandomPairwiseJudge(),
args=training_args,
processing_class=self.tokenizer,
train_dataset=dummy_dataset["train"],
eval_dataset=dummy_dataset["test"],
processing_class=self.tokenizer,
)
trainer.train()

Expand Down
5 changes: 3 additions & 2 deletions tests/test_trainers_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,12 +272,13 @@ def test_online_dpo(self, beta_list):
reward_model = AutoModelForSequenceClassification.from_pretrained("EleutherAI/pythia-14m", num_labels=1)
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-14m")
trainer = OnlineDPOTrainer(
args=training_args,
processing_class=tokenizer,
model=model,
ref_model=ref_model,
reward_model=reward_model,
args=training_args,
train_dataset=dataset,
processing_class=tokenizer,
reward_processing_class=tokenizer,
)
self.assertEqual(trainer.args.max_new_tokens, 42)
self.assertEqual(trainer.args.temperature, 0.5)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_xpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from transformers.testing_utils import require_peft
from transformers.utils import is_peft_available

from trl import PairRMJudge, XPOConfig, XPOTrainer, is_llmblender_available
from trl import RandomPairwiseJudge, XPOConfig, XPOTrainer, is_llmblender_available


if is_peft_available():
Expand Down Expand Up @@ -171,7 +171,7 @@ def test_xpo_trainer_judge_training(self, config_name):
report_to="none",
)
dummy_dataset = load_dataset("trl-internal-testing/zen", config_name)
judge = PairRMJudge()
judge = RandomPairwiseJudge()

trainer = XPOTrainer(
model=self.model,
Expand Down
1 change: 1 addition & 0 deletions trl/trainer/nash_md_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def __init__(
train_dataset=train_dataset,
eval_dataset=eval_dataset,
processing_class=processing_class,
reward_processing_class=processing_class, # for now, NashMDTrainer can't use any reward model
peft_config=peft_config,
compute_metrics=compute_metrics,
callbacks=callbacks,
Expand Down
Loading

0 comments on commit b269657

Please sign in to comment.