Skip to content

Commit

Permalink
⚖️ Add use_soft_judge option to WinRateCallback (#2347)
Browse files Browse the repository at this point in the history
* add `use_soft_judge` option to WinRateCallback

* formatting

* Update trl/trainer/callbacks.py

Co-authored-by: Quentin Gallouédec <[email protected]>

* renamed soft_win_rate to avg_win_prob

* Update trl/trainer/callbacks.py

Co-authored-by: Quentin Gallouédec <[email protected]>

* fix tests

* keep orignal

* formatting

* Update tests/test_callbacks.py

Co-authored-by: Quentin Gallouédec <[email protected]>

* Update trl/trainer/callbacks.py

Co-authored-by: Quentin Gallouédec <[email protected]>

* Update tests/test_callbacks.py

Co-authored-by: Quentin Gallouédec <[email protected]>

* Update tests/test_callbacks.py

* fix test

---------

Co-authored-by: Quentin Gallouédec <[email protected]>
  • Loading branch information
kashif and qgallouedec authored Nov 15, 2024
1 parent 6239631 commit b8c9d9c
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 6 deletions.
49 changes: 47 additions & 2 deletions tests/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,13 @@


class HalfPairwiseJudge(BasePairwiseJudge):
"""Naive pairwise judge that always returns [1, 0]"""
"""Naive pairwise judge that always returns [1, 0] for two prompts"""

def judge(self, prompts, completions, shuffle_order=True):
def judge(self, prompts, completions, shuffle_order=True, return_scores=False):
# just check that the batch size is 2
assert len(prompts) == 2
if return_scores:
return [0.3, 0.9]
return [1, 0]


Expand Down Expand Up @@ -132,6 +134,49 @@ def test_without_ref_model(self):
winrate_history = [h for h in trainer.state.log_history if "eval_win_rate" in h]
self.assertListEqual(winrate_history, self.expected_winrates)

def test_soft_judge(self):
"""Test that the soft judge functionality works correctly"""
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
output_dir=tmp_dir,
eval_strategy="steps",
eval_steps=2, # evaluate every 2 steps
per_device_train_batch_size=2, # 8 samples in total so 4 batches of 2 per epoch
per_device_eval_batch_size=2,
report_to="none",
)
trainer = TrainerWithRefModel(
model=self.model,
ref_model=self.ref_model,
args=training_args,
train_dataset=self.dataset["train"],
eval_dataset=self.dataset["test"],
processing_class=self.tokenizer,
)
win_rate_callback = WinRateCallback(
judge=self.judge, trainer=trainer, generation_config=self.generation_config, use_soft_judge=True
)
trainer.add_callback(win_rate_callback)
trainer.train()

# Expected values based on judge returning [0.3, 0.9] for each pair
expected_soft_winrates = [
{"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 0.0, "step": 0},
{"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 0.5, "step": 2},
{"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 1.0, "step": 4},
{"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 1.5, "step": 6},
{"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 2.0, "step": 8},
{"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 2.5, "step": 10},
{"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 3.0, "step": 12},
]

winrate_history = [
{k: h[k] for k in ["eval_avg_win_prob", "eval_win_rate", "epoch", "step"]}
for h in trainer.state.log_history
if "eval_avg_win_prob" in h
]
self.assertListEqual(winrate_history, expected_soft_winrates)

@require_peft
def test_lora(self):
with tempfile.TemporaryDirectory() as tmp_dir:
Expand Down
32 changes: 28 additions & 4 deletions trl/trainer/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,9 @@ class WinRateCallback(TrainerCallback):
in the evaluation dataset.
shuffle_order (`bool`, *optional*, defaults to `True`):
Whether to shuffle the order of the completions before judging.
use_soft_judge (`bool`, *optional*, defaults to `False`):
Whether to use a soft judge that returns a win probability between 0 and 1 for the first completion vs the
second.
"""

def __init__(
Expand All @@ -239,12 +242,14 @@ def __init__(
generation_config: Optional[GenerationConfig] = None,
num_prompts: Optional[int] = None,
shuffle_order: bool = True,
use_soft_judge: bool = False,
):
self.judge = judge
self.trainer = trainer
self.shuffle_order = shuffle_order
self.generation_config = generation_config
self.ref_completions = []
self.use_soft_judge = use_soft_judge

if self.trainer.eval_dataset is None:
raise ValueError("Trainer must have an evaluation dataset to use the WinRateCallback.")
Expand Down Expand Up @@ -281,15 +286,24 @@ def on_train_begin(self, args: TrainingArguments, state: TrainerState, control:
)
# Compute initial win rate as a reference point
completions = list(zip(self.ref_completions, self.ref_completions))
winner_indices = self.judge.judge(prompts, completions, self.shuffle_order)
if self.use_soft_judge:
ref_win_probs = self.judge.judge(prompts, completions, self.shuffle_order, return_scores=True)
winner_indices = [0 if score > 0.5 else 1 for score in ref_win_probs]
ref_win_probs = gather_object(ref_win_probs)
else:
winner_indices = self.judge.judge(prompts, completions, self.shuffle_order)
prompts = gather_object(prompts)
completions = gather_object(completions)
winner_indices = gather_object(winner_indices)

# Logging
if self.trainer.accelerator.is_main_process:
win_rate = sum(winner_idx == 1 for winner_idx in winner_indices) / len(winner_indices)
self.trainer.log({"eval_win_rate": win_rate})
if self.use_soft_judge:
avg_win_prob = 1.0 - sum(ref_win_probs) / len(ref_win_probs)
self.trainer.log({"eval_avg_win_prob": avg_win_prob, "eval_win_rate": win_rate})
else:
self.trainer.log({"eval_win_rate": win_rate})

if "wandb" in args.report_to:
import wandb
Expand Down Expand Up @@ -323,15 +337,25 @@ def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: Tra
)

completions = list(zip(self.ref_completions, completions))
winner_indices = self.judge.judge(prompts, completions, self.shuffle_order)

if self.use_soft_judge:
ref_win_probs = self.judge.judge(prompts, completions, self.shuffle_order, return_scores=True)
winner_indices = [0 if score > 0.5 else 1 for score in ref_win_probs]
ref_win_probs = gather_object(ref_win_probs)
else:
winner_indices = self.judge.judge(prompts, completions, self.shuffle_order)
prompts = gather_object(prompts)
completions = gather_object(completions)
winner_indices = gather_object(winner_indices)

# Logging
if self.trainer.accelerator.is_main_process:
win_rate = sum(winner_idx == 1 for winner_idx in winner_indices) / len(winner_indices)
self.trainer.log({"eval_win_rate": win_rate})
if self.use_soft_judge:
avg_win_prob = 1.0 - sum(ref_win_probs) / len(ref_win_probs)
self.trainer.log({"eval_avg_win_prob": avg_win_prob, "eval_win_rate": win_rate})
else:
self.trainer.log({"eval_win_rate": win_rate})

if "wandb" in args.report_to:
import wandb
Expand Down

0 comments on commit b8c9d9c

Please sign in to comment.