Skip to content

Commit

Permalink
Allow WinRateCallback to be used without reference model (#2013)
Browse files Browse the repository at this point in the history
* tests

* make ref model optional

* style

* remove attribute error
  • Loading branch information
qgallouedec authored Sep 3, 2024
1 parent 1f6a1d2 commit 6840380
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 87 deletions.
190 changes: 114 additions & 76 deletions tests/test_callback.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import tempfile
import unittest

from datasets import Dataset, DatasetDict
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, Trainer, TrainingArguments
Expand All @@ -18,87 +19,124 @@ def judge(self, prompts, completions, shuffle_order=True):
class TrainerWithRefModel(Trainer):
# This is a dummy class to test the callback. Compared to the Trainer class, it only has an additional
# ref_model attribute
def __init__(self, model, ref_model, args, trainer_dataset, eval_dataset, tokenizer):
def __init__(self, model, ref_model, args, train_dataset, eval_dataset, tokenizer):
super().__init__(
model=model, args=args, train_dataset=trainer_dataset, eval_dataset=eval_dataset, tokenizer=tokenizer
model=model, args=args, train_dataset=train_dataset, eval_dataset=eval_dataset, tokenizer=tokenizer
)
self.ref_model = ref_model


def test_trainer_callback():
model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/dummy-GPT2-correct-vocab")
ref_model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/dummy-GPT2-correct-vocab")
tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/dummy-GPT2-correct-vocab")
tokenizer.pad_token = tokenizer.eos_token
dataset = DatasetDict(
{
"train": Dataset.from_dict(
{
"prompt": [
"Hello world!",
"This is a test.",
"We are creating a dataset.",
"It has eight lines.",
"Each line is a sentence.",
"The sentences are simple.",
"This is just for testing.",
"Goodbye!",
]
}
),
"test": Dataset.from_dict(
{
"prompt": [
"The sun sets in the west.",
"Mountains are majestic.",
"Rivers flow endlessly.",
"Forests are full of life.",
"Birds sing in the morning.",
"Waves crash on the shore.",
"The moon glows at night.",
"Stars twinkle in the sky.",
]
}
),
}
)
class WinrateCallbackTester(unittest.TestCase):
def setUp(self):
self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/dummy-GPT2-correct-vocab")
self.ref_model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/dummy-GPT2-correct-vocab")
self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/dummy-GPT2-correct-vocab")
self.tokenizer.pad_token = self.tokenizer.eos_token
dataset = DatasetDict(
{
"train": Dataset.from_dict(
{
"prompt": [
"Hello world!",
"This is a test.",
"We are creating a dataset.",
"It has eight lines.",
"Each line is a sentence.",
"The sentences are simple.",
"This is just for testing.",
"Goodbye!",
]
}
),
"test": Dataset.from_dict(
{
"prompt": [
"The sun sets in the west.",
"Mountains are majestic.",
"Rivers flow endlessly.",
"Forests are full of life.",
"Birds sing in the morning.",
"Waves crash on the shore.",
"The moon glows at night.",
"Stars twinkle in the sky.",
]
}
),
}
)

def tokenize_function(examples):
out = tokenizer(examples["prompt"], padding="max_length", max_length=16, truncation=True)
out["labels"] = out["input_ids"].copy()
return out
def tokenize_function(examples):
out = self.tokenizer(examples["prompt"], padding="max_length", max_length=16, truncation=True)
out["labels"] = out["input_ids"].copy()
return out

dataset = dataset.map(tokenize_function, batched=True)
self.dataset = dataset.map(tokenize_function, batched=True)

with tempfile.TemporaryDirectory() as tmp_dir:
args = TrainingArguments(
output_dir=tmp_dir,
eval_strategy="steps",
eval_steps=2, # evaluate every 2 steps
per_device_train_batch_size=2, # 8 samples in total so 4 batches of 2 per epoch
per_device_eval_batch_size=2,
report_to="none",
)
trainer = TrainerWithRefModel(
model=model,
ref_model=ref_model,
args=args,
trainer_dataset=dataset["train"],
eval_dataset=dataset["test"],
tokenizer=tokenizer,
)
generation_config = GenerationConfig(max_length=32)
win_rate_callback = WinRateCallback(
judge=ThreeQuatersPairwiseJudge(), trainer=trainer, generation_config=generation_config
)
trainer.add_callback(win_rate_callback)
trainer.train()
winrate_history = [h for h in trainer.state.log_history if "eval_win_rate" in h]
assert winrate_history == [
{"eval_win_rate": 0.75, "epoch": 0.5, "step": 2},
{"eval_win_rate": 0.75, "epoch": 1.0, "step": 4},
{"eval_win_rate": 0.75, "epoch": 1.5, "step": 6},
{"eval_win_rate": 0.75, "epoch": 2.0, "step": 8},
{"eval_win_rate": 0.75, "epoch": 2.5, "step": 10},
{"eval_win_rate": 0.75, "epoch": 3.0, "step": 12},
]
self.generation_config = GenerationConfig(max_length=32)
self.judge = ThreeQuatersPairwiseJudge()

def test_basic(self):
with tempfile.TemporaryDirectory() as tmp_dir:
args = TrainingArguments(
output_dir=tmp_dir,
eval_strategy="steps",
eval_steps=2, # evaluate every 2 steps
per_device_train_batch_size=2, # 8 samples in total so 4 batches of 2 per epoch
per_device_eval_batch_size=2,
report_to="none",
)
trainer = TrainerWithRefModel(
model=self.model,
ref_model=self.ref_model,
args=args,
train_dataset=self.dataset["train"],
eval_dataset=self.dataset["test"],
tokenizer=self.tokenizer,
)
win_rate_callback = WinRateCallback(
judge=self.judge, trainer=trainer, generation_config=self.generation_config
)
trainer.add_callback(win_rate_callback)
trainer.train()
winrate_history = [h for h in trainer.state.log_history if "eval_win_rate" in h]
assert winrate_history == [
{"eval_win_rate": 0.75, "epoch": 0.5, "step": 2},
{"eval_win_rate": 0.75, "epoch": 1.0, "step": 4},
{"eval_win_rate": 0.75, "epoch": 1.5, "step": 6},
{"eval_win_rate": 0.75, "epoch": 2.0, "step": 8},
{"eval_win_rate": 0.75, "epoch": 2.5, "step": 10},
{"eval_win_rate": 0.75, "epoch": 3.0, "step": 12},
]

def test_without_ref_model(self):
# Same as before, but without the ref_model attribute. It should use the model attribute instead
with tempfile.TemporaryDirectory() as tmp_dir:
args = TrainingArguments(
output_dir=tmp_dir,
eval_strategy="steps",
eval_steps=2, # evaluate every 2 steps
per_device_train_batch_size=2, # 8 samples in total so 4 batches of 2 per epoch
per_device_eval_batch_size=2,
report_to="none",
)
trainer = Trainer(
model=self.model,
args=args,
train_dataset=self.dataset["train"],
eval_dataset=self.dataset["test"],
tokenizer=self.tokenizer,
)
win_rate_callback = WinRateCallback(
judge=self.judge, trainer=trainer, generation_config=self.generation_config
)
trainer.add_callback(win_rate_callback)
trainer.train()
winrate_history = [h for h in trainer.state.log_history if "eval_win_rate" in h]
assert winrate_history == [
{"eval_win_rate": 0.75, "epoch": 0.5, "step": 2},
{"eval_win_rate": 0.75, "epoch": 1.0, "step": 4},
{"eval_win_rate": 0.75, "epoch": 1.5, "step": 6},
{"eval_win_rate": 0.75, "epoch": 2.0, "step": 8},
{"eval_win_rate": 0.75, "epoch": 2.5, "step": 10},
{"eval_win_rate": 0.75, "epoch": 3.0, "step": 12},
]
22 changes: 11 additions & 11 deletions trl/trainer/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,11 @@ class WinRateCallback(TrainerCallback):
"""
A [`~transformers.TrainerCallback`] that computes the win rate of a model based on a reference.
It uses prompts from the evaluation dataset to generate completions for both the model and the reference model.
At every evaluation step, it compares the completions generated by the model and the reference model using a judge
and computes the win rate. This win rate is then logged to the trainer under the key `"eval_win_rate"`.
It generates completions using prompts from the evaluation dataset and compares the trained model's outputs against
a reference. The reference is either the initial version of the model (before training) or the reference model, if
available in the trainer. During each evaluation step, a judge determines how often the trained model's completions
win against the reference using a judge. The win rate is then logged in the trainer's logs under the key
`"eval_win_rate"`.
Usage:
```python
Expand All @@ -173,11 +175,10 @@ class WinRateCallback(TrainerCallback):
judge (`BasePairwiseJudge`):
The judge to use for comparing completions.
trainer (`Trainer`):
The trainer. The trainer must comply with the following requirements:
- its evaluation dataset must have a column `"prompt"` that contains the prompts to generate completions for.
- it must have an attribute `ref_model` that contains the reference model.
Trainer to which the callback will be attached. The trainer's evaluation dataset must include a `"prompt"`
column containing the prompts for generating completions. If the `Trainer` has a reference model (via the
`ref_model` attribute), it will use this reference model for generating the reference completions;
otherwise, it defaults to using the initial model.
generation_config (`GenerationConfig`, *optional*):
The generation config to use for generating completions.
batch_size (`int`, *optional*):
Expand All @@ -196,8 +197,6 @@ def __init__(
self.ref_completions = []
self.trainer = trainer
self.eval_dataset = self.trainer.eval_dataset
if not hasattr(trainer, "ref_model"):
raise AttributeError("Trainer must have a `ref_model` attribute.")
self.batch_size = batch_size

def generate_completions_for_model(self, model, tokenizer, prompts):
Expand Down Expand Up @@ -225,8 +224,9 @@ def on_train_begin(self, args: TrainingArguments, state: TrainerState, control:
tokenizer = kwargs["tokenizer"]
tokenizer.padding_side = "left"
accelerator = self.trainer.accelerator
model = getattr(self.trainer, "ref_model", kwargs["model"]) # get the ref model if any, else use the model
with accelerator.split_between_processes(self.eval_dataset["prompt"], apply_padding=True) as prompts:
self.ref_completions = self.generate_completions_for_model(self.trainer.ref_model, tokenizer, prompts)
self.ref_completions = self.generate_completions_for_model(model, tokenizer, prompts)

def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
# At every evaluation step, we generate completions for the model and compare them with the reference
Expand Down

0 comments on commit 6840380

Please sign in to comment.