From 179ba5367181d9bd4bdaec70d50789b09754d04a Mon Sep 17 00:00:00 2001 From: Gaetan LOPEZ LATOUCHE <66413927+gaetanlop@users.noreply.github.com> Date: Fri, 13 Dec 2024 09:56:10 -0500 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=BE=20Process-supervised=20RM=20Traine?= =?UTF-8?q?r=20(#2127)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * initial skeleton * tokenize fn * adding bos and eos to tokenization fn * prmtrainer * fixing small typo in tokenize * typo in input_ids and labels construction * numpy dimension * introduce the stepwise reward trainer * update markdown files * let user decide post step separator in config * doc post_step_separator * do not add post step_tokens to last step of the reasoning process * renaming prm to stepwisereward * formatting * fix tokenize kwargs * adapt test to the new post_token args * adding example script * fix small typo * add create_model_card and renaming * fixing booleans * Adding the new stepwise_preference instead of placeholders for datasets * formatting * Update docs/source/_toctree.yml Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> * Update examples/scripts/stepwise_reward_modeling.py Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> * Update trl/trainer/stepwise_reward_trainer.py Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> * Update trl/trainer/stepwise_reward_trainer.py Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> * update push to hub Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> * step_separator can't be None Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> * fix suggested typos * add citation * reformat doc * reordering init * push to hub prm800k * changing dataset in example * change dataset format to align with the sky is blue example * fix tokenization column names * fix num labels in openai example * add support for conversational dataset * remove training whitespace * replace tokenizer with processing class * Update docs/source/dataset_formats.mdx Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> * remove openai_prm800k * Update trl/trainer/stepwise_reward_trainer.py Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> * Update trl/trainer/stepwise_reward_trainer.py Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> * Update docs/source/stepwise_reward_trainer.mdx Co-authored-by: lewtun * Update docs/source/stepwise_reward_trainer.mdx Co-authored-by: lewtun * renaming Co-authored-by: lewtun * renaming Co-authored-by: lewtun * minor renamings in docs * using prm800k instead of openai_prm800k * update num labels to 2 following the new format * changing doc examples to math examples * change reference to dataset_formats.mdx * changing dataset config in test * remove conversational dataset support * remove conv dataset support * fix bos token * fix scriptarguments in example * completion to completions * remove valuerror for step_separator inside steps * run precommit * remove conv dataset support Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> * renaming zen dataset * remove unused printing * unknown label column * introduce the train on last step arg * _tokenize support train_on_last_step * incorporate train_on_last_step to tests * formatting * remove comments in trainer * Refactor `tokenize_row` * Update max_completion_length parameter in StepwiseRewardConfig * Collator * Update comment * Update type hint * fix table * Remove collator * don't need pad token id * add error back * max length args * use tokenizer arg * Update doc * label -> labels * fixing tokenization issues in tokenize row * correct labels for token classification * adding max_length to tokenize_row * reformat tests * adding tests for tokenize row * fixing typos in comments * update doc Co-authored-by: Kashif Rasul * Add math_shepherd.py script for dataset processing * split the dataset * formatting * same evaluation method for the two training methods * adding filtering to example script * formatting * Add features to avoid casting labels to bool in dataset tokenization * Update docs/source/stepwise_reward_trainer.mdx [ci skip] * Add learning_rate parameter to StepwiseRewardConfig class * update doc * Remove unused setup_chat_format function * Fix warning message in stepwise_reward_modeling.py * Update logging steps in stepwise_reward_trainer.mdx * little doc change [ci skip] * Fix copyrights * fix space after copyrights * Update dataset loading in stepwise_reward_modeling.py * refine compute_accuracy and proper test * fix tests * style * renamings * renaming in init * doc renaming * fix sorting and tag * experiemental [ci skip] * trigger CI * other doc fix --------- Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Kashif Rasul Co-authored-by: lewtun Co-authored-by: Quentin Gallouédec --- docs/source/_toctree.yml | 2 + docs/source/dataset_formats.mdx | 1 + docs/source/prm_trainer.mdx | 123 +++++++++++ examples/datasets/math_shepherd.py | 131 ++++++++++++ examples/scripts/prm.py | 130 ++++++++++++ tests/test_judges.py | 1 - tests/test_prm_trainer.py | 329 ++++++++++++++++++++++++++++ tests/test_reward_trainer.py | 8 +- tests/test_utils.py | 71 ++++++- trl/__init__.py | 4 + trl/trainer/__init__.py | 4 + trl/trainer/prm_config.py | 51 +++++ trl/trainer/prm_trainer.py | 330 +++++++++++++++++++++++++++++ trl/trainer/utils.py | 41 +++- 14 files changed, 1207 insertions(+), 19 deletions(-) create mode 100644 docs/source/prm_trainer.mdx create mode 100644 examples/datasets/math_shepherd.py create mode 100644 examples/scripts/prm.py create mode 100644 tests/test_prm_trainer.py create mode 100644 trl/trainer/prm_config.py create mode 100644 trl/trainer/prm_trainer.py diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 3c5e1efd86..66fdc6b57b 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -42,6 +42,8 @@ title: ORPO - local: ppo_trainer title: PPO + - local: prm_trainer + title: PRM - local: reward_trainer title: Reward - local: rloo_trainer diff --git a/docs/source/dataset_formats.mdx b/docs/source/dataset_formats.mdx index a6dbf11d04..c8ab321506 100644 --- a/docs/source/dataset_formats.mdx +++ b/docs/source/dataset_formats.mdx @@ -266,6 +266,7 @@ Choosing the right dataset type depends on the task you are working on and the s | [`OnlineDPOTrainer`] | [Prompt-only](#prompt-only) | | [`ORPOTrainer`] | [Preference (explicit prompt recommended)](#preference) | | [`PPOTrainer`] | Tokenized language modeling | +| [`PRMTrainer`] | [Stepwise supervision](#stepwise-supervision) | | [`RewardTrainer`] | [Preference (implicit prompt recommended)](#preference) | | [`SFTTrainer`] | [Language modeling](#language-modeling) | | [`XPOTrainer`] | [Prompt-only](#prompt-only) | diff --git a/docs/source/prm_trainer.mdx b/docs/source/prm_trainer.mdx new file mode 100644 index 0000000000..012b8ec071 --- /dev/null +++ b/docs/source/prm_trainer.mdx @@ -0,0 +1,123 @@ +# PRM Trainer + + + +PRM Trainer is an experimental API which is subject to change at any time. + + + +## Overview + +Process-supervised Reward Models (PRM) were proposed in [Solving math word problems with process- and outcome-based feedback](https://huggingface.co/papers/2211.14275) by Jonathan Uesato, Nate Kushman, Ramana Kumar, Francis Song, Noah Siegel, Lisa Wang, Antonia Creswell, Geoffrey Irving, and Irina Higgins. + +The abstract from the paper is the following: + +> Recent work has shown that asking language models to generate reasoning steps improves performance on many reasoning tasks. When moving beyond prompting, this raises the question of how we should supervise such models: outcome-based approaches which supervise the final result, or process-based approaches which supervise the reasoning process itself? Differences between these approaches might naturally be expected not just in final-answer errors but also in reasoning errors, which can be difficult to detect and are problematic in many real-world domains such as education. We run the first comprehensive comparison between process- and outcome-based approaches trained on a natural language task, GSM8K. We find that pure outcome-based supervision produces similar final-answer error rates with less label supervision. However, for correct reasoning steps we find it necessary to use processbased supervision or supervision from learned reward models that emulate process-based feedback. In total, we improve the previous best results from 16.8% → 12.7% final-answer error and 14.0% → 3.4% reasoning error among final-answer-correct solutions. + +This post-training method was contributed by [Gaetan Lopez](https://github.com/gaetanlop), [Lewis Tunstall](https://huggingface.co/lewtun), [Quentin Gallouédec](https://huggingface.co/qgallouedec) and [Agustín Piqueres](https://huggingface.co/plaguss). + + +## Quick start + +This example demonstrates how to train a model using the PRM method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B) as the base model. We use the stepwise supervision data from the [Math Shepherd dataset](https://huggingface.co/datasets/trl-lib/math_shepherd). You can view the data in the dataset here: + + + +Below is the script to train the model: + +```python +# train_prm.py +from datasets import load_dataset +from trl import PRMConfig, PRMTrainer +from transformers import AutoModelForTokenClassification, AutoTokenizer + +model = AutoModelForTokenClassification.from_pretrained("Qwen/Qwen2-0.5B", num_labels=2) +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B") +train_dataset = load_dataset("trl-lib/math_shepherd", split="train[:10%]") + +training_args = PRMConfig(output_dir="Qwen2-0.5B-Reward-Math-Sheperd", logging_steps=10) +trainer = PRMTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset) +trainer.train() +``` + +Execute the script using the following command: + +```bash +accelerate launch train_prm.py +``` + +Distributed across 8 GPUs, the training takes approximately 1 hour. + +To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-Reward-Math-Sheperd) performs, you can use the following script. + + +```python +from datasets import load_dataset +from transformers import pipeline + +pipe = pipeline("token-classification", model="trl-lib/Qwen2-0.5B-Reward-Math-Sheperd") +dataset = load_dataset("trl-lib/math_shepherd") +example = { + "prompt": "Musa is the class teacher of a class of 45 students. He wants to split them into three groups by age. If a third of the class is under 11 years, and two-fifths are above 11 but under 13, how many students will be in the third group (13 years and above)?", + "completions": [ + "Step 1: A third of the class is under 11 years because 11 - 1/3 = <<11-1/3=7>>7.", + "Step 2: Two-fifths of the class are above 11 but under 13 because 2/5 * 11 = <<2/5*11=8>>8.", + "Step 3: There are 45 students, so the third group will have 45 - 7 - 8 = <<45-7-8=20>>20 students. The answer is: 20", + ], + "labels": [True, False, False], +} + + +separator = "\n" # It's important to use the same separator as the one used during training + +for idx in range(1, len(example["completions"]) + 1): + steps = example["completions"][0:idx] + text = separator.join((example["prompt"], *steps)) + separator # Add a separator between the prompt and each steps + pred_entity = pipe(text)[-1]["entity"] + pred = {"LABEL_0": False, "LABEL_1": True}[pred_entity] + label = example["labels"][idx - 1] + print(f"Step {idx}\tPredicted: {pred} \tLabel: {label}") +``` + +```text +Step 1 Predicted: True Label: True +Step 2 Predicted: False Label: False +Step 3 Predicted: False Label: False +``` + +It's a win! + +## Expected dataset type + +PRM requires a [stepwise supervision](dataset_formats#stepwise-supervision). +The dataset should contain the following columns: `prompt`, `completions` and `labels`, where `completions` contains a list of reasoning steps and `labels` a list of booleans or floats indicating the correctness of each step. + +The [`PRMTrainer`] only supports [standard](dataset_formats#standard) dataset format. + +## Example script + +We provide an example script to train a model using the PRM method. The script is available in [`examples/scripts/prm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/prm.py) + +To use the PRM script with the [Qwen2 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B) on the [Math Shepherd dataset](https://huggingface.co/datasets/trl-lib/math_shepherd), run the following command: + +```bash +accelerate launch examples/scripts/prm.py \ + --model_name_or_path Qwen/Qwen2-0.5B \ + --dataset_name trl-lib/math_shepherd \ + --num_train_epochs 1 \ + --logging_steps 25 \ + --output_dir Qwen2-0.5B-Reward-Math-Sheperd +``` + +## PRMTrainer + +[[autodoc]] PRMTrainer + +## PRMConfig + +[[autodoc]] PRMConfig diff --git a/examples/datasets/math_shepherd.py b/examples/datasets/math_shepherd.py new file mode 100644 index 0000000000..c09e745ad5 --- /dev/null +++ b/examples/datasets/math_shepherd.py @@ -0,0 +1,131 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from dataclasses import dataclass +from itertools import chain +from typing import Optional + +from datasets import load_dataset +from transformers import HfArgumentParser + + +@dataclass +class ScriptArguments: + r""" + Arguments for the script. + + Args: + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether to push the dataset to the Hugging Face Hub. + repo_id (`str`, *optional*, defaults to `"trl-lib/math_shepherd"`): + Hugging Face repository ID to push the dataset to. + dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`): + Number of workers to use for dataset processing. + """ + + push_to_hub: bool = False + repo_id: str = "trl-lib/math_shepherd" + dataset_num_proc: Optional[int] = None + + +def process_example(example): + # Replace "ки" with "ⶻ" so that the size of the "input" matches the size of the "label" + inputs = example["input"].replace("ки", "ⶻ") + + # Find the indices of the "ⶻ" characters (that should match with the indexes of the "+" or "-" in the label) + indexes = [m.start() for m in re.finditer("ⶻ", inputs)] + + # Sanity that all indexes are either "+" or "-" + assert all(example["label"][idx] in ["+", "-"] for idx in indexes) + + # Get the labels + labels = [example["label"][idx] == "+" for idx in indexes] + + # Split the inputs into steps (caution, the first step is missing here, it is the prompt) + steps = [inputs[i:j] for i, j in zip(chain([0], indexes), chain(indexes, [None]))] + + # Remove the last step (single ⶻ) + steps = steps[:-1] + + # Get the prompt (first part) and completions (rest) + prompt = steps[0] + completions = steps[1:] + + # Remove the heading "ⶻ" and the final whitespace from the completions + assert all(completion.startswith("ⶻ") for completion in completions) + completions = [completion[1:].strip() for completion in completions] + + # At this point, we need to retrieve the first step from the prompt. + # First, we handle particular cases (annotation error) where we have a first label before the end of the prompt. + if prompt.startswith( + ( + "Mr. Rocky", + "Parker", + "What is the smallest positive", + " The Myth", + "Let $\\mathbf{a}$", + "Find the arithmetic", + "Determine an ordered pair", + "Determine the ordered pair", + "At the Quill and Scroll stationery", + "Round to the nearest", + r"Calculate $\sqrt{10p}", + r"Simplify $\sqrt{28x}", + ) + ): + # Some spotted datasets errors where there is an annotation in the prompt: we remove it + labels = labels[1:] + + # Then we handle the general case: we get the first step from the prompt by looking for "Step 1:" or "step 1:" or + # (less common) "?". + elif "Step 1:" in prompt: + prompt, first_step = prompt.split("Step 1:") + first_step = "Step 1:" + first_step + completions = [first_step.strip()] + completions + elif "step 1:" in prompt: + prompt, first_step = prompt.split("step 1:") + first_step = "step 1:" + first_step + completions = [first_step.strip()] + completions + elif "?" in prompt: + prompt, first_step = prompt.split("?") + prompt = prompt + "?" + completions = [first_step.strip()] + completions + else: + raise ValueError(f"Prompt can't be processed: {prompt}") + + # Strip the prompt + prompt = prompt.strip() + + # Sanity check that the length of the completions is the same as the length of the labels + assert len(completions) == len(labels) + + return {"prompt": prompt, "completions": completions, "labels": labels} + + +if __name__ == "__main__": + parser = HfArgumentParser(ScriptArguments) + script_args = parser.parse_args_into_dataclasses()[0] + + dataset = load_dataset("peiyi9979/Math-Shepherd", split="train") + + dataset = dataset.map( + process_example, + remove_columns=["input", "label", "task"], + num_proc=script_args.dataset_num_proc, + ) + dataset = dataset.train_test_split(test_size=0.05, seed=42) + + if script_args.push_to_hub: + dataset.push_to_hub(script_args.repo_id) diff --git a/examples/scripts/prm.py b/examples/scripts/prm.py new file mode 100644 index 0000000000..ba7f9ce415 --- /dev/null +++ b/examples/scripts/prm.py @@ -0,0 +1,130 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Full training: +python examples/scripts/prm.py \ + --model_name_or_path Qwen/Qwen2-0.5B-Instruct \ + --dataset_name trl-lib/prm800k \ + --output_dir Qwen2-0.5B-Reward \ + --per_device_train_batch_size 8 \ + --num_train_epochs 1 \ + --gradient_checkpointing True \ + --learning_rate 1.0e-5 \ + --logging_steps 25 \ + --eval_strategy steps \ + --eval_steps 50 + +LoRA: +python examples/scripts/prm.py \ + --model_name_or_path Qwen/Qwen2-0.5B-Instruct \ + --dataset_name trl-lib/prm800k \ + --output_dir Qwen2-0.5B-Reward-LoRA \ + --per_device_train_batch_size 8 \ + --num_train_epochs 1 \ + --gradient_checkpointing True \ + --learning_rate 1.0e-4 \ + --logging_steps 25 \ + --eval_strategy steps \ + --eval_steps 50 + --use_peft \ + --lora_r 32 \ + --lora_alpha 16 +""" + +import warnings + +import torch +from datasets import load_dataset +from transformers import AutoModelForTokenClassification, AutoTokenizer, HfArgumentParser + +from trl import ( + ModelConfig, + PRMConfig, + PRMTrainer, + ScriptArguments, + get_kbit_device_map, + get_peft_config, + get_quantization_config, +) + + +if __name__ == "__main__": + parser = HfArgumentParser((ScriptArguments, PRMConfig, ModelConfig)) + script_args, training_args, model_config = parser.parse_args_into_dataclasses() + training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False) + + ################ + # Model & Tokenizer + ################ + torch_dtype = ( + model_config.torch_dtype + if model_config.torch_dtype in ["auto", None] + else getattr(torch, model_config.torch_dtype) + ) + quantization_config = get_quantization_config(model_config) + model_kwargs = dict( + revision=model_config.model_revision, + device_map=get_kbit_device_map() if quantization_config is not None else None, + quantization_config=quantization_config, + use_cache=False if training_args.gradient_checkpointing else True, + ) + tokenizer = AutoTokenizer.from_pretrained( + model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, use_fast=True + ) + model = AutoModelForTokenClassification.from_pretrained( + model_config.model_name_or_path, num_labels=2, trust_remote_code=model_config.trust_remote_code, **model_kwargs + ) + # Align padding tokens between tokenizer and model + model.config.pad_token_id = tokenizer.pad_token_id + + if model_config.use_peft and model_config.lora_task_type != "TOKEN_CLS": + warnings.warn( + "You are using a `task_type` that is different than `TOKEN_CLS` for PEFT. This will lead to silent bugs" + " Make sure to pass --lora_task_type TOKEN_CLS when using this script with PEFT.", + UserWarning, + ) + + ############## + # Load dataset + ############## + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) + + dataset = dataset.filter(lambda x: len(x["completions"]) > 0) + + ########## + # Training + ########## + trainer = PRMTrainer( + model=model, + processing_class=tokenizer, + args=training_args, + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=dataset[script_args.dataset_test_split], + peft_config=get_peft_config(model_config), + ) + trainer.train() + + ############################ + # Save model and push to Hub + ############################ + trainer.save_model(training_args.output_dir) + metrics = trainer.evaluate() + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + + # Save and push to hub + trainer.save_model(training_args.output_dir) + if training_args.push_to_hub: + trainer.push_to_hub(dataset_name=script_args.dataset_name) diff --git a/tests/test_judges.py b/tests/test_judges.py index 4789d3bb3a..0f8b83d881 100644 --- a/tests/test_judges.py +++ b/tests/test_judges.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - import time import unittest diff --git a/tests/test_prm_trainer.py b/tests/test_prm_trainer.py new file mode 100644 index 0000000000..4f2c1c21c1 --- /dev/null +++ b/tests/test_prm_trainer.py @@ -0,0 +1,329 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +import unittest +from unittest.mock import MagicMock + +import torch +from datasets import Dataset, load_dataset +from parameterized import parameterized +from transformers import AutoModelForTokenClassification, AutoTokenizer, PreTrainedTokenizerBase +from transformers.testing_utils import require_peft +from transformers.utils import is_peft_available + +from trl import PRMConfig, PRMTrainer + + +if is_peft_available(): + from peft import LoraConfig, TaskType + + +class TestTokenizeRow(unittest.TestCase): + def setUp(self): + # Set up the mock tokenizer with specific behaviors + self.tokenizer = MagicMock(spec=PreTrainedTokenizerBase) + self.tokenizer.bos_token_id = 0 + self.tokenizer.eos_token_id = 2 + + def mock_encode(text, add_special_tokens): + token_map = { + "Which number is larger, 9.8 or 9.11?": [465, 6766, 318, 298], + "11 is greater than 8.": [4, 322, 12], + "Hence, 9.11 > 9.8.": [4995, 11, 22], + "\n": [1030], + "\n\n": [1030, 1030], + } + + return token_map[text] + + def mock_tokenizer_call(text, add_special_tokens): + return {"input_ids": mock_encode(text, add_special_tokens)} + + self.tokenizer.encode.side_effect = mock_encode + self.tokenizer.side_effect = mock_tokenizer_call + + def test_tokenize_row_no_truncation(self): + # Define the input features + features = { + "prompt": "Which number is larger, 9.8 or 9.11?", + "completions": ["11 is greater than 8.", "Hence, 9.11 > 9.8."], + "labels": [True, False], + } + + # Call the method with no truncation + result = PRMTrainer.tokenize_row( + features=features, + tokenizer=self.tokenizer, + step_separator="\n", + max_length=None, + max_completion_length=None, + train_on_last_step_only=False, + is_eval=False, + ) + + self.assertEqual( + result, + { + "input_ids": [0, 465, 6766, 318, 298, 4, 322, 12, 1030, 4995, 11, 22, 1030], + "labels": [-100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100, -100, 0], + }, + ) + + def test_tokenize_row_train_on_last_step_only(self): + # Define the input features + features = { + "prompt": "Which number is larger, 9.8 or 9.11?", + "completions": ["11 is greater than 8.", "Hence, 9.11 > 9.8."], + "labels": [True, False], + } + + result = PRMTrainer.tokenize_row( + features=features, + tokenizer=self.tokenizer, + step_separator="\n", + max_length=None, + max_completion_length=None, + train_on_last_step_only=True, + is_eval=False, + ) + + self.assertEqual( + result, + { + "input_ids": [0, 465, 6766, 318, 298, 4, 322, 12, 1030, 4995, 11, 22, 1030], + "labels": [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 0], + }, + ) + + def test_tokenize_row_completion_truncation(self): + # Define the input features + features = { + "prompt": "Which number is larger, 9.8 or 9.11?", + "completions": ["11 is greater than 8.", "Hence, 9.11 > 9.8."], + "labels": [True, False], + } + + # Call the method with truncation on the completion + result = PRMTrainer.tokenize_row( + features=features, + tokenizer=self.tokenizer, + step_separator="\n", + max_length=None, + max_completion_length=6, + train_on_last_step_only=False, + is_eval=False, + ) + + self.assertEqual( + result, + { + "input_ids": [0, 465, 6766, 318, 298, 4, 322, 12, 1030, 4995, 11], + "labels": [-100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100], + }, + ) + + def test_tokenize_row_prompt_completion_truncation(self): + # Define the input features + features = { + "prompt": "Which number is larger, 9.8 or 9.11?", + "completions": ["11 is greater than 8.", "Hence, 9.11 > 9.8."], + "labels": [True, False], + } + + # Call the method with truncation on the prompt and completion + result = PRMTrainer.tokenize_row( + features=features, + tokenizer=self.tokenizer, + step_separator="\n", + max_length=9, + max_completion_length=None, + train_on_last_step_only=False, + is_eval=False, + ) + + self.assertEqual( + result, + { + "input_ids": [0, 465, 6766, 318, 298, 4, 322, 12, 1030], + "labels": [-100, -100, -100, -100, -100, -100, -100, -100, 1], + }, + ) + + def test_tokenize_row_multi_token_separator(self): + # Define the input features + features = { + "prompt": "Which number is larger, 9.8 or 9.11?", + "completions": ["11 is greater than 8.", "Hence, 9.11 > 9.8."], + "labels": [True, False], + } + + # Call the method using multiple tokens as step_separator + result = PRMTrainer.tokenize_row( + features=features, + tokenizer=self.tokenizer, + step_separator="\n\n", + max_length=None, + max_completion_length=None, + train_on_last_step_only=False, + is_eval=False, + ) + + self.assertEqual( + result, + { + "input_ids": [0, 465, 6766, 318, 298, 4, 322, 12, 1030, 1030, 4995, 11, 22, 1030, 1030], + "labels": [-100, -100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100, -100, -100, 0], + }, + ) + + +class PRMTrainerTester(unittest.TestCase): + def setUp(self): + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + self.model = AutoModelForTokenClassification.from_pretrained(model_id) + self.tokenizer = AutoTokenizer.from_pretrained(model_id) + + @parameterized.expand([True, False]) + def test_train_full(self, train_on_last_step_only): + with tempfile.TemporaryDirectory() as tmp_dir: + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_stepwise_supervision", split="train") + training_args = PRMConfig( + output_dir=tmp_dir, + report_to="none", + train_on_last_step_only=train_on_last_step_only, + ) + trainer = PRMTrainer( + model=self.model, args=training_args, processing_class=self.tokenizer, train_dataset=dummy_dataset + ) + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + # check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + # check the params have changed - ignore 0 biases + if param.sum() != 0: + self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)) + + def test_train_full_pretokenized(self): + with tempfile.TemporaryDirectory() as tmp_dir: + dummy_dataset = Dataset.from_dict( + { + "labels": [ + [-100, -100, -100, -100, -100, -100, -100, -100, -100, 0, -100, -100, 1], + [-100, -100, -100, -100, -100, -100, -100, -100, 0, -100, -100, 1, -100, -100, -100, -100, 0], + [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 0, -100, -100, 1], + [-100, -100, -100, -100, -100, -100, -100, 1, -100, -100, 1], + [-100, -100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100, 0], + [-100, -100, -100, -100, -100, -100, -100, -100, -100, 1], + [-100, -100, -100, -100, -100, -100, -100, -100, -100, 0], + [-100, -100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100, -100, -100, -100, 0], + [-100, -100, -100, -100, -100, -100, -100, -100, 0, -100, -100, 0], + [-100, -100, -100, -100, -100, -100, 0, -100, -100, -100, -100, 0], + [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1], + [-100, -100, -100, -100, -100, -100, 0], + [-100, -100, -100, -100, -100, -100, -100, -100, 1], + [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 0], + ], + "input_ids": [ + [46518, 374, 2664, 1091, 11, 1077, 752, 1744, 1112, 198, 27261, 13, 198], + [98923, 374, 2664, 1091, 11, 315, 3308, 11, 198, 17995, 13, 198, 1576, 31273, 12850, 13, 198], + [16374, 374, 2664, 1091, 1112, 1077, 594, 2506, 432, 6770, 11, 198, 6351, 13, 198], + [31137, 374, 2664, 1091, 979, 4362, 11, 198, 16965, 13, 198], + [31019, 374, 2664, 1091, 304, 3793, 315, 5944, 11, 198, 24034, 13, 198], + [98491, 374, 2664, 1091, 1112, 5310, 369, 91494, 13, 198], + [4418, 2897, 14579, 5310, 979, 3800, 1349, 432, 13, 198], + [20366, 5048, 7629, 944, 3281, 3322, 11, 7241, 1112, 198, 807, 1795, 279, 5601, 13, 198], + [15802, 14976, 487, 33327, 1045, 31787, 63443, 11, 198, 52400, 13, 198], + [13877, 1265, 2581, 1494, 49394, 11, 198, 7241, 20975, 91681, 13, 198], + [641, 279, 3579, 315, 71768, 11, 25066, 279, 61361, 311, 7942, 13, 198], + [7039, 374, 2664, 1091, 2937, 13, 198], + [26155, 374, 3545, 2664, 1091, 34933, 26537, 13, 198], + [2679, 279, 8129, 374, 4135, 311, 10339, 11, 432, 2578, 387, 264, 1661, 2884, 13, 198], + ], + } + ) + + training_args = PRMConfig(output_dir=tmp_dir, report_to="none") + trainer = PRMTrainer( + model=self.model, args=training_args, processing_class=self.tokenizer, train_dataset=dummy_dataset + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + # check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + # check the params have changed - ignore 0 biases + if param.sum() != 0: + self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)) + + @require_peft + def test_train_lora(self): + peft_config = LoraConfig( + task_type=TaskType.TOKEN_CLS, + inference_mode=False, + r=8, + lora_alpha=32, + lora_dropout=0.1, + ) + with tempfile.TemporaryDirectory() as tmp_dir: + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_stepwise_supervision", split="train") + training_args = PRMConfig(output_dir=tmp_dir, max_steps=3, report_to="none") + trainer = PRMTrainer( + model=self.model, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset, + peft_config=peft_config, + ) + previous_trainable_params = {} + previous_non_trainable_params = {} + + # due to a change in the way the modules to save are dealt in PEFT. + trainable_params_name = ["lora", "modules_to_save"] + + # check gradients are not None + for n, param in trainer.model.named_parameters(): + if any(t in n for t in trainable_params_name): + previous_trainable_params[n] = param.clone() + else: + previous_non_trainable_params[n] = param.clone() + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[(-1)]["train_loss"]) + + # check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.allclose(param, new_param, atol=1e-12, rtol=1e-12)) + + # check the non trainable params have not changed + for n, param in previous_non_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertTrue(torch.allclose(param, new_param, atol=1e-12, rtol=1e-12)) + + def test_tags(self): + with tempfile.TemporaryDirectory() as tmp_dir: + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_stepwise_supervision", split="train") + training_args = PRMConfig(output_dir=tmp_dir, report_to="none") + trainer = PRMTrainer( + model=self.model, args=training_args, processing_class=self.tokenizer, train_dataset=dummy_dataset + ) + self.assertEqual(trainer.model.model_tags, trainer._tag_names) diff --git a/tests/test_reward_trainer.py b/tests/test_reward_trainer.py index d4466a0404..d977719765 100644 --- a/tests/test_reward_trainer.py +++ b/tests/test_reward_trainer.py @@ -17,12 +17,11 @@ import torch from datasets import Dataset, load_dataset -from transformers import AutoModelForSequenceClassification, AutoTokenizer, EvalPrediction +from transformers import AutoModelForSequenceClassification, AutoTokenizer from transformers.testing_utils import require_peft from transformers.utils import is_peft_available from trl import RewardConfig, RewardTrainer, maybe_apply_chat_template -from trl.trainer import compute_accuracy from trl.trainer.reward_trainer import _tokenize @@ -37,11 +36,6 @@ def setUp(self): self.model = AutoModelForSequenceClassification.from_pretrained(self.model_id) self.model.config.pad_token_id = self.tokenizer.pad_token_id - def test_accuracy_metrics(self): - dummy_eval_predictions = EvalPrediction(torch.FloatTensor([[0.1, 0.9], [0.9, 0.1]]), torch.LongTensor([0, 0])) - accuracy = compute_accuracy(dummy_eval_predictions) - self.assertEqual(accuracy["accuracy"], 0.5) - def test_preprocessing_conversational(self): with tempfile.TemporaryDirectory() as tmp_dir: dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train") diff --git a/tests/test_utils.py b/tests/test_utils.py index 210eae4306..a1cabcfc19 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -14,13 +14,15 @@ import unittest +import numpy as np import torch from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig from transformers.testing_utils import require_peft from transformers.utils import is_peft_available -from trl.trainer.model_config import ModelConfig +from trl import ModelConfig +from trl.trainer import compute_accuracy from trl.trainer.utils import ( DataCollatorForChatML, batch_generation, @@ -332,3 +334,70 @@ def test_single_batch_generation(self): self.assertGreater(max_length_query, context_length) self.assertEqual(query_responses.shape, (bs, max_length_query)) self.assertEqual(logits.shape, (bs, max_length_logits, self.model.config.vocab_size)) + + +class TestComputeAccuracy(unittest.TestCase): + def test_token_classification_task(self): + eval_pred = ( + np.array( + [ + [[0.1, 0.9], [0.8, 0.2]], # Batch 1 + [[0.3, 0.7], [0.6, 0.4]], # Batch 2 + ] + ), + np.array([[0, 1], [1, 0]]), + ) + expected_accuracy = 0.5 # 2 matches, 2 mismatches + result = compute_accuracy(eval_pred) + self.assertAlmostEqual(result["accuracy"], expected_accuracy) + + def test_token_classification_task_with_ignored_tokens_0(self): + eval_pred = ( + np.array( + [ + [[0.1, 0.9], [0.8, 0.2]], # Batch 1 + [[0.3, 0.7], [0.6, 0.4]], # Batch 2 + ] + ), + np.array([[1, 0], [1, -100]]), + ) + expected_accuracy = 1.0 # All non-ignored tokens match + result = compute_accuracy(eval_pred) + self.assertAlmostEqual(result["accuracy"], expected_accuracy) + + def test_token_classification_task_with_ignored_tokens_1(self): + eval_pred = ( + np.array( + [ + [[0.1, 0.9], [0.8, 0.2]], # Batch 1 + [[0.3, 0.7], [0.6, 0.4]], # Batch 2 + ] + ), + np.array([[1, 1], [0, -100]]), + ) + expected_accuracy = 1 / 3 # 1 match, 2 mismatch, 1 ignored + result = compute_accuracy(eval_pred) + self.assertAlmostEqual(result["accuracy"], expected_accuracy) + + def test_rewards_comparison_task(self): + eval_pred = ( + np.array( + [ + [0.9, 0.1], # Batch 1 + [0.6, 0.4], # Batch 2 + [0.5, 0.5], # Batch 3 (equal) + ] + ), + np.array([0, 1, 1]), + ) + expected_accuracy = 0.5 # 1 match, 1 mismatch, 1 equal (ignored) + + with self.assertWarns(UserWarning) as cm: + result = compute_accuracy(eval_pred) + + self.assertAlmostEqual(result["accuracy"], expected_accuracy) + expected_warning = ( + "There are 1 out of 3 instances where the predictions for both options are equal. " + "These instances are ignored in the accuracy computation." + ) + self.assertEqual(str(cm.warning), expected_warning) diff --git a/trl/__init__.py b/trl/__init__.py index 8976eb2603..b05f75cd8b 100644 --- a/trl/__init__.py +++ b/trl/__init__.py @@ -82,6 +82,8 @@ "PairRMJudge", "PPOConfig", "PPOTrainer", + "PRMConfig", + "PRMTrainer", "RewardConfig", "RewardTrainer", "RLOOConfig", @@ -172,6 +174,8 @@ PairRMJudge, PPOConfig, PPOTrainer, + PRMConfig, + PRMTrainer, RewardConfig, RewardTrainer, RLOOConfig, diff --git a/trl/trainer/__init__.py b/trl/trainer/__init__.py index e5599756f7..85a2e4d57c 100644 --- a/trl/trainer/__init__.py +++ b/trl/trainer/__init__.py @@ -62,6 +62,8 @@ "ppo_trainer": ["PPOTrainer"], "ppov2_config": ["PPOv2Config"], "ppov2_trainer": ["PPOv2Trainer"], + "prm_config": ["PRMConfig"], + "prm_trainer": ["PRMTrainer"], "reward_config": ["RewardConfig"], "reward_trainer": ["RewardTrainer", "compute_accuracy"], "rloo_config": ["RLOOConfig"], @@ -130,6 +132,8 @@ from .orpo_trainer import ORPOTrainer from .ppo_config import PPOConfig from .ppo_trainer import PPOTrainer + from .prm_config import PRMConfig + from .prm_trainer import PRMTrainer from .reward_config import RewardConfig from .reward_trainer import RewardTrainer, compute_accuracy from .rloo_config import RLOOConfig diff --git a/trl/trainer/prm_config.py b/trl/trainer/prm_config.py new file mode 100644 index 0000000000..4558084572 --- /dev/null +++ b/trl/trainer/prm_config.py @@ -0,0 +1,51 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Optional + +from transformers import TrainingArguments + + +@dataclass +class PRMConfig(TrainingArguments): + r""" + Configuration class for the [`PRMTrainer`]. + + Using [`~transformers.HfArgumentParser`] we can turn this class into + [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the + command line. + + Parameters: + learning_rate (`float`, *optional*, defaults to `1e-5`): + Initial learning rate for [`AdamW`] optimizer. The default value replaces that of + [`~transformers.TrainingArguments`]. + max_length (`Optional[int]`, *optional*, defaults to `None`): + Maximum length of the sequences (prompt + completion) used for truncation. + max_completion_length (`Optional[int]`, *optional*, defaults to `None`): + Maximum length of the completion used for truncation. The completion is the concatenation of the steps. + step_separator (`str`, *optional*, defaults to `"\n"`): + Separator used to separate each step of the reasoning process. + train_on_last_step_only (`bool`, *optional*, defaults to `False`): + Whether to train only on the last step. + dataset_num_proc (`int`, *optional*, defaults to `None`): + Number of processes to use for processing the dataset. + """ + + learning_rate: float = 1e-5 + max_length: Optional[int] = None + max_completion_length: Optional[int] = None + step_separator: str = "\n" + train_on_last_step_only: bool = False + dataset_num_proc: Optional[int] = None diff --git a/trl/trainer/prm_trainer.py b/trl/trainer/prm_trainer.py new file mode 100644 index 0000000000..dbb3558d57 --- /dev/null +++ b/trl/trainer/prm_trainer.py @@ -0,0 +1,330 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import os +import textwrap +import warnings +from itertools import chain +from typing import Callable, Optional, Union + +import torch +import torch.nn as nn +from accelerate import PartialState +from datasets import Dataset, features +from transformers import ( + BaseImageProcessor, + DataCollator, + DataCollatorForTokenClassification, + FeatureExtractionMixin, + PreTrainedModel, + PreTrainedTokenizerBase, + ProcessorMixin, + Trainer, + is_wandb_available, +) +from transformers.trainer_callback import TrainerCallback +from transformers.trainer_utils import EvalPrediction +from transformers.utils import is_peft_available + +from .prm_config import PRMConfig +from .utils import compute_accuracy, generate_model_card + + +if is_peft_available(): + from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training + +if is_wandb_available(): + import wandb + + +class PRMTrainer(Trainer): + """ + Initialize PRMTrainer. + + Args: + model (`transformers.PreTrainedModel`): + The model to train, preferably an `AutoModelForTokenClassification`. + args (`PRMConfig`): + The arguments to use for training. + data_collator (`transformers.DataCollator`): + The data collator to use for training. If None is specified, the default data collator (`DataCollatorForTokenClassification`) will be used + which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences. + train_dataset (`datasets.Dataset`): + The dataset to use for training. + eval_dataset (`datasets.Dataset`): + The dataset to use for evaluation. + processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*): + Processing class used to process the data. If provided, will be used to automatically process the inputs + for the model, and it will be saved along the model to make it easier to rerun an interrupted training or + reuse the fine-tuned model. + model_init (`Callable[[], transformers.PreTrainedModel]`): + The model initializer to use for training. If None is specified, the default model initializer will be used. + compute_metrics (`Callable[[transformers.EvalPrediction], dict]`, *optional* defaults to `compute_accuracy`): + The metrics to use for evaluation. If no metrics are specified, the default metric (`compute_accuracy`) will be used. + callbacks (`list[transformers.TrainerCallback]`): + The callbacks to use for training. + optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): + The optimizer and scheduler to use for training. + preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): + The function to use to preprocess the logits before computing the metrics. + peft_config (`dict`, defaults to `None`): + The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model. + """ + + _tag_names = ["trl", "prm"] + + def __init__( + self, + model: Optional[Union[PreTrainedModel, nn.Module]] = None, + args: Optional[PRMConfig] = None, + data_collator: Optional[DataCollator] = None, + train_dataset: Optional[Dataset] = None, + eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None, + processing_class: Optional[ + Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] + ] = None, + model_init: Optional[Callable[[], PreTrainedModel]] = None, + compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None, + callbacks: Optional[list[TrainerCallback]] = None, + optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = ( + None, + None, + ), + preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, + peft_config: Optional[dict] = None, + ): + if not is_peft_available() and peft_config is not None: + raise ValueError( + "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models" + ) + elif is_peft_available() and peft_config is not None: + if not isinstance(model, PeftModel): + if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_quantized", False): + _supports_gc_kwargs = "gradient_checkpointing_kwargs" in list( + inspect.signature(prepare_model_for_kbit_training).parameters + ) + + prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing} + + if not _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None: + warnings.warn( + "You passed `gradient_checkpointing_kwargs` in the trainer's kwargs, but your peft version does not support it. " + "please update to the latest version of peft to use `gradient_checkpointing_kwargs`." + ) + elif _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None: + prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs + + model = prepare_model_for_kbit_training(model, **prepare_model_kwargs) + + model = get_peft_model(model, peft_config) + + if compute_metrics is None: + compute_metrics = compute_accuracy + + if data_collator is None: + if processing_class is None: + raise ValueError( + "A processing_class must be specified when using the default DataCollatorForTokenClassification" + ) + data_collator = DataCollatorForTokenClassification(processing_class, max_length=args.max_length) + + if "input_ids" not in train_dataset.column_names: + with PartialState().local_main_process_first(): + fn_kwargs = { + "tokenizer": processing_class, + "step_separator": args.step_separator, + "max_length": args.max_length, + "max_completion_length": args.max_completion_length, + "train_on_last_step_only": args.train_on_last_step_only, + } + train_fn_kwargs = {**fn_kwargs, "is_eval": False} + train_dataset = train_dataset.map( + self.tokenize_row, + fn_kwargs=train_fn_kwargs, + num_proc=args.dataset_num_proc, + remove_columns=train_dataset.features, + desc="Tokenizing train dataset", + features=features.Features( # needed to avoid map to cast labels to bool + { + "labels": features.Sequence(features.Value("int64")), + "input_ids": features.Sequence(features.Value("int64")), + } + ), + ) + + eval_fn_kwargs = {**fn_kwargs, "is_eval": True} + if eval_dataset is not None: + eval_dataset = eval_dataset.map( + self.tokenize_row, + fn_kwargs=eval_fn_kwargs, + num_proc=args.dataset_num_proc, + remove_columns=eval_dataset.features, + desc="Tokenizing eval dataset", + features=features.Features( # needed to avoid map to cast labels to bool + { + "labels": features.Sequence(features.Value("int64")), + "input_ids": features.Sequence(features.Value("int64")), + } + ), + ) + + super().__init__( + model=model, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, + model_init=model_init, + compute_metrics=compute_metrics, + callbacks=callbacks, + optimizers=optimizers, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + ) + + # Add tags for models that have been loaded with the correct transformers version + if hasattr(self.model, "add_model_tags"): + self.model.add_model_tags(self._tag_names) + + @staticmethod + def tokenize_row( + features, tokenizer, step_separator, max_length, max_completion_length, train_on_last_step_only, is_eval + ): + r""" + Tokenize a row of the dataset. + + Args: + features (`dict[str, str]`): + Row of the dataset, should contain the keys `"prompt"`, `"completions"`, and `"labels"`. + tokenizer (`PreTrainedTokenizerBase`): + Tokenizer used to process the data. + step_separator (`str`): + Separator between steps in the completion. + max_length (`int` or `None`): + Maximum length of the sequences (prompt + completion). If `None`, the sequences are not truncated. + max_completion_length (`int` or `None`): + Maximum length of the completion sequences. If `None`, the completion sequences are not truncated. + train_on_last_step_only (`bool`): + Whether to train only on the last step. If `True`, the labels are `-100` for all tokens except the last + token of the completion. + is_eval (`bool`): + Whether the function is used to tokenize samples from a training or an evaluation dataset. Used only if `train_on_last_step_only` is set to `True`. + + Returns: + `dict[str, list[int]]`: + Tokenized sequences with the keys `"input_ids"`, and `"labels". + + Example: + ```python + >>> from transformers import AutoTokenizer + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B") + >>> features = {"prompt": "Which number is larger, 9.8 or 9.11?", + ... "completions": ["11 is greater than 8.", + ... "Hence, 9.11 > 9.8."], + ... "labels": [True, False]} + >>> PRMTrainer.tokenize_row(features, tokenizer, "\n", max_completion_length=None, train_on_last_step_only=False, is_eval=False) + {'input_ids': [23085, 1372, 374, 8131, 11, 220, 24, 13, 23, 476, 220, 24, 13, 16, 16, 30, 16, 16, 374, 7046, 1091, 220, 23, 13, 198, 39, 763, 11, 220, 24, 13, 16, 16, 861, 220, 24, 13, 23, 13, 198], + 'labels': [-100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 0]} + ``` + """ + # Tokenize the prompt and completions + prompt_ids = tokenizer(features["prompt"], add_special_tokens=False)["input_ids"] + completions_ids = [ + tokenizer(completion, add_special_tokens=False)["input_ids"] for completion in features["completions"] + ] + if train_on_last_step_only and not is_eval: + labels = [-100] * (len(features["labels"]) - 1) + [int(features["labels"][-1])] + else: + labels = [int(label) for label in features["labels"]] + + # Get the ID of the separator token and add it to the completions + separator_ids = tokenizer.encode(step_separator, add_special_tokens=False) + completions_ids = [completion + separator_ids for completion in completions_ids] + + # Create the label + labels = [[-100] * (len(completion) - 1) + [label] for completion, label in zip(completions_ids, labels)] + + # Join the completions and labels steps + completion_ids = list(chain(*completions_ids)) + labels = list(chain(*labels)) + + if max_completion_length is not None: + completion_ids = completion_ids[:max_completion_length] + labels = labels[:max_completion_length] + + if tokenizer.bos_token_id is not None: + prompt_ids = [tokenizer.bos_token_id] + prompt_ids + + input_ids = prompt_ids + completion_ids + labels = [-100] * len(prompt_ids) + labels + + if max_length is not None: + input_ids = input_ids[:max_length] + labels = labels[:max_length] + + return {"input_ids": input_ids, "labels": labels} + + def create_model_card( + self, + model_name: Optional[str] = None, + dataset_name: Optional[str] = None, + tags: Union[str, list[str], None] = None, + ): + """ + Creates a draft of a model card using the information available to the `Trainer`. + Args: + model_name (`str`, *optional*, defaults to `None`): + The name of the model. + dataset_name (`str`, *optional*, defaults to `None`): + The name of the dataset used for training. + tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`): + Tags to be associated with the model card. + """ + if not self.is_world_process_zero(): + return + + if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path): + base_model = self.model.config._name_or_path + else: + base_model = None + + tags = tags or [] + if isinstance(tags, str): + tags = [tags] + + if hasattr(self.model.config, "unsloth_version"): + tags.append("unsloth") + + citation = textwrap.dedent("""\ + @article{uesato2022solving, + title = {Solving Math Word Problems With Process- and Outcome-Based Feedback}, + author = {Uesato, Jonathan and Kushman, Nate and Kumar, Ramana and Song, Francis and Siegel, Noah and Wang, Lisa and Creswell, Antonia and Irving, Geoffrey and Higgins, Irina}, + year = 2022, + journal = {arXiv preprint arXiv:2211.14275} + }""") + + model_card = generate_model_card( + base_model=base_model, + model_name=model_name, + hub_model_id=self.hub_model_id, + dataset_name=dataset_name, + tags=tags, + wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, + trainer_name="PRM", + trainer_citation=citation, + paper_title="Solving math word problems with process-and outcome-based feedback", + ) + + model_card.save(os.path.join(self.args.output_dir, "README.md")) diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 699a447bb8..fd30bea929 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -37,6 +37,7 @@ from transformers import ( BitsAndBytesConfig, DataCollatorForLanguageModeling, + EvalPrediction, GenerationConfig, PreTrainedTokenizerBase, TrainerState, @@ -757,18 +758,38 @@ def get_global_statistics( return global_mean.to(device), global_var.to(device), count.item() -def compute_accuracy(eval_pred) -> dict[str, float]: +def compute_accuracy(eval_pred: EvalPrediction) -> dict[str, float]: predictions, labels = eval_pred - # Here, predictions is rewards_chosen and rewards_rejected. - # We want to see how much of the time rewards_chosen > rewards_rejected. - equal_predictions_count = np.array(predictions[:, 0] == predictions[:, 1], dtype=float).sum() - if equal_predictions_count > 0: - warnings.warn( - f"There are {equal_predictions_count} out of {len(predictions[:, 0])} instances where the predictions for " - "both options are equal. As a consequence the accuracy can be misleading.", - UserWarning, + if predictions.ndim == 3: + # Token classification task. Shapes are (batch_size, seq_len, num_labels) and (batch_size, seq_len) + # Used to compute the accuracy in the prm_trainer. + predictions = np.argmax(predictions, axis=2) + + # Flatten the predictions and labels to remove the ignored tokens. + predictions = np.array( + [p for prediction, label in zip(predictions, labels) for (p, lbl) in zip(prediction, label) if lbl != -100] ) - predictions = np.argmax(predictions, axis=1) + labels = np.array([lbl for label in labels for lbl in label if lbl != -100]) + + else: + # Here, predictions is rewards_chosen and rewards_rejected. Shapes are (batch_size, 2) and (batch_size,) + # We want to see how much of the time rewards_chosen > rewards_rejected. + equal_mask = predictions[:, 0] == predictions[:, 1] + equal_predictions_count = int(equal_mask.sum()) + + if equal_predictions_count > 0: + warnings.warn( + f"There are {equal_predictions_count} out of {len(predictions[:, 0])} instances where the predictions " + "for both options are equal. These instances are ignored in the accuracy computation.", + UserWarning, + ) + + # Filter out equal predictions + predictions = predictions[~equal_mask] + labels = labels[~equal_mask] + + # Use the remaining predictions for accuracy calculation + predictions = np.argmax(predictions, axis=1) accuracy = np.array(predictions == labels, dtype=float).mean().item() return {"accuracy": accuracy}