Skip to content

Commit

Permalink
Use processing_class instead of tokenizer in `LogCompletionsCallb…
Browse files Browse the repository at this point in the history
…ack` (#2261)
  • Loading branch information
qgallouedec authored Oct 22, 2024
1 parent 84dab85 commit d843b3d
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 10 deletions.
76 changes: 69 additions & 7 deletions tests/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
import tempfile
import unittest

Expand All @@ -20,7 +22,7 @@
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, Trainer, TrainingArguments
from transformers.testing_utils import require_wandb

from trl import BasePairwiseJudge, WinRateCallback
from trl import BasePairwiseJudge, LogCompletionsCallback, WinRateCallback


class HalfPairwiseJudge(BasePairwiseJudge):
Expand All @@ -35,14 +37,17 @@ def judge(self, prompts, completions, shuffle_order=True):
class TrainerWithRefModel(Trainer):
# This is a dummy class to test the callback. Compared to the Trainer class, it only has an additional
# ref_model attribute
def __init__(self, model, ref_model, args, train_dataset, eval_dataset, tokenizer):
def __init__(self, model, ref_model, args, train_dataset, eval_dataset, processing_class):
super().__init__(
model=model, args=args, train_dataset=train_dataset, eval_dataset=eval_dataset, tokenizer=tokenizer
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
processing_class=processing_class,
)
self.ref_model = ref_model


@require_wandb
class WinRateCallbackTester(unittest.TestCase):
def setUp(self):
self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/dummy-GPT2-correct-vocab")
Expand All @@ -52,6 +57,7 @@ def setUp(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only")
dataset["train"] = dataset["train"].select(range(8))
self.expected_winrates = [
{"eval_win_rate": 0.5, "epoch": 0.0, "step": 0},
{"eval_win_rate": 0.5, "epoch": 0.5, "step": 2},
{"eval_win_rate": 0.5, "epoch": 1.0, "step": 4},
{"eval_win_rate": 0.5, "epoch": 1.5, "step": 6},
Expand Down Expand Up @@ -86,7 +92,7 @@ def test_basic(self):
args=training_args,
train_dataset=self.dataset["train"],
eval_dataset=self.dataset["test"],
tokenizer=self.tokenizer,
processing_class=self.tokenizer,
)
win_rate_callback = WinRateCallback(
judge=self.judge, trainer=trainer, generation_config=self.generation_config
Expand All @@ -112,7 +118,7 @@ def test_without_ref_model(self):
args=training_args,
train_dataset=self.dataset["train"],
eval_dataset=self.dataset["test"],
tokenizer=self.tokenizer,
processing_class=self.tokenizer,
)
win_rate_callback = WinRateCallback(
judge=self.judge, trainer=trainer, generation_config=self.generation_config
Expand Down Expand Up @@ -145,7 +151,7 @@ def test_lora(self):
args=training_args,
train_dataset=self.dataset["train"],
eval_dataset=self.dataset["test"],
tokenizer=self.tokenizer,
processing_class=self.tokenizer,
)
win_rate_callback = WinRateCallback(
judge=self.judge, trainer=trainer, generation_config=self.generation_config
Expand All @@ -154,3 +160,59 @@ def test_lora(self):
trainer.train()
winrate_history = [h for h in trainer.state.log_history if "eval_win_rate" in h]
self.assertListEqual(winrate_history, self.expected_winrates)


@require_wandb
class LogCompletionsCallbackTester(unittest.TestCase):
def setUp(self):
self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/dummy-GPT2-correct-vocab")
self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/dummy-GPT2-correct-vocab")
self.tokenizer.pad_token = self.tokenizer.eos_token
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only")
dataset["train"] = dataset["train"].select(range(8))

def tokenize_function(examples):
out = self.tokenizer(examples["prompt"], padding="max_length", max_length=16, truncation=True)
out["labels"] = out["input_ids"].copy()
return out

self.dataset = dataset.map(tokenize_function, batched=True)

self.generation_config = GenerationConfig(max_length=32)

def test_basic(self):
import wandb

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
output_dir=tmp_dir,
eval_strategy="steps",
eval_steps=2, # evaluate every 2 steps
per_device_train_batch_size=2, # 8 samples in total so 4 batches of 2 per epoch
per_device_eval_batch_size=2,
report_to="wandb",
)
trainer = Trainer(
model=self.model,
args=training_args,
train_dataset=self.dataset["train"],
eval_dataset=self.dataset["test"],
processing_class=self.tokenizer,
)
completions_callback = LogCompletionsCallback(trainer, self.generation_config, num_prompts=2)
trainer.add_callback(completions_callback)
trainer.train()

# Get the current run
completions_path = wandb.run.summary.completions["path"]
json_path = os.path.join(wandb.run.dir, completions_path)
with open(json_path) as f:
completions = json.load(f)

# Check that the columns are correct
self.assertIn("step", completions["columns"])
self.assertIn("prompt", completions["columns"])
self.assertIn("completion", completions["columns"])

# Check that the prompt is in the log
self.assertIn(self.dataset["test"][0]["prompt"], completions["data"][0])
6 changes: 3 additions & 3 deletions trl/trainer/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def __init__(

def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
# When the trainer is initialized, we generate completions for the reference model.
tokenizer = kwargs["tokenizer"]
tokenizer = kwargs["processing_class"]
tokenizer.padding_side = "left"
accelerator = self.trainer.accelerator
# Use the reference model if available, otherwise use the initial model
Expand Down Expand Up @@ -307,7 +307,7 @@ def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: Tra
# At every evaluation step, we generate completions for the model and compare them with the reference
# completions that have been generated at the beginning of training. We then compute the win rate and log it to
# the trainer.
tokenizer = kwargs["tokenizer"]
tokenizer = kwargs["processing_class"]
tokenizer.padding_side = "left"
accelerator = self.trainer.accelerator
model = self.trainer.model_wrapped
Expand Down Expand Up @@ -401,7 +401,7 @@ def on_step_end(self, args, state, control, **kwargs):
if state.global_step % freq != 0:
return

tokenizer = kwargs["tokenizer"]
tokenizer = kwargs["processing_class"]
tokenizer.padding_side = "left"
accelerator = self.trainer.accelerator
model = self.trainer.model_wrapped
Expand Down

0 comments on commit d843b3d

Please sign in to comment.