Skip to content

Commit

Permalink
🐾 Process-supervised RM Trainer (#2127)
Browse files Browse the repository at this point in the history
* initial skeleton

* tokenize fn

* adding bos and eos to tokenization fn

* prmtrainer

* fixing small typo in tokenize

* typo in input_ids and labels construction

* numpy dimension

* introduce the stepwise reward trainer

* update markdown files

* let user decide post step separator in config

* doc post_step_separator

* do not add post step_tokens to last step of the reasoning process

* renaming prm to stepwisereward

* formatting

* fix tokenize kwargs

* adapt test to the new post_token args

* adding example script

* fix small typo

* add create_model_card and renaming

* fixing booleans

* Adding the new stepwise_preference instead of placeholders for datasets

* formatting

* Update docs/source/_toctree.yml

Co-authored-by: Quentin Gallouédec <[email protected]>

* Update examples/scripts/stepwise_reward_modeling.py

Co-authored-by: Quentin Gallouédec <[email protected]>

* Update trl/trainer/stepwise_reward_trainer.py

Co-authored-by: Quentin Gallouédec <[email protected]>

* Update trl/trainer/stepwise_reward_trainer.py

Co-authored-by: Quentin Gallouédec <[email protected]>

* update push to hub

Co-authored-by: Quentin Gallouédec <[email protected]>

* step_separator can't be None

Co-authored-by: Quentin Gallouédec <[email protected]>

* fix suggested typos

* add citation

* reformat doc

* reordering init

* push to hub prm800k

* changing dataset in example

* change dataset format to align with the sky is blue example

* fix tokenization column names

* fix num labels in openai example

* add support for conversational dataset

* remove training whitespace

* replace tokenizer with processing class

* Update docs/source/dataset_formats.mdx

Co-authored-by: Quentin Gallouédec <[email protected]>

* remove openai_prm800k

* Update trl/trainer/stepwise_reward_trainer.py

Co-authored-by: Quentin Gallouédec <[email protected]>

* Update trl/trainer/stepwise_reward_trainer.py

Co-authored-by: Quentin Gallouédec <[email protected]>

* Update docs/source/stepwise_reward_trainer.mdx

Co-authored-by: lewtun <[email protected]>

* Update docs/source/stepwise_reward_trainer.mdx

Co-authored-by: lewtun <[email protected]>

* renaming

Co-authored-by: lewtun <[email protected]>

* renaming

Co-authored-by: lewtun <[email protected]>

* minor renamings in docs

* using prm800k instead of openai_prm800k

* update num labels to 2 following the new format

* changing doc examples to math examples

* change reference to dataset_formats.mdx

* changing dataset config in test

* remove conversational dataset support

* remove conv dataset support

* fix bos token

* fix scriptarguments in example

* completion to completions

* remove valuerror for step_separator inside steps

* run precommit

* remove conv dataset support

Co-authored-by: Quentin Gallouédec <[email protected]>

* renaming zen dataset

* remove unused printing

* unknown label column

* introduce the train on last step arg

* _tokenize support train_on_last_step

* incorporate train_on_last_step to tests

* formatting

* remove comments in trainer

* Refactor `tokenize_row`

* Update max_completion_length parameter in StepwiseRewardConfig

* Collator

* Update comment

* Update type hint

* fix table

* Remove collator

* don't need pad token id

* add error back

* max length args

* use tokenizer arg

* Update doc

* label -> labels

* fixing tokenization issues in tokenize row

* correct labels for token classification

* adding max_length to tokenize_row

* reformat tests

* adding tests for tokenize row

* fixing typos in comments

* update doc

Co-authored-by: Kashif Rasul <[email protected]>

* Add math_shepherd.py script for dataset processing

* split the dataset

* formatting

* same evaluation method for the two training methods

* adding filtering to example script

* formatting

* Add features to avoid casting labels to bool in dataset tokenization

* Update docs/source/stepwise_reward_trainer.mdx [ci skip]

* Add learning_rate parameter to StepwiseRewardConfig class

* update doc

* Remove unused setup_chat_format function

* Fix warning message in stepwise_reward_modeling.py

* Update logging steps in stepwise_reward_trainer.mdx

* little doc change [ci skip]

* Fix copyrights

* fix space after copyrights

* Update dataset loading in stepwise_reward_modeling.py

* refine compute_accuracy and proper test

* fix tests

* style

* renamings

* renaming in init

* doc renaming

* fix sorting and tag

* experiemental [ci skip]

* trigger CI

* other doc fix

---------

Co-authored-by: Quentin Gallouédec <[email protected]>
Co-authored-by: Kashif Rasul <[email protected]>
Co-authored-by: lewtun <[email protected]>
Co-authored-by: Quentin Gallouédec <[email protected]>
  • Loading branch information
5 people authored Dec 13, 2024
1 parent e3e171a commit 179ba53
Show file tree
Hide file tree
Showing 14 changed files with 1,207 additions and 19 deletions.
2 changes: 2 additions & 0 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
title: ORPO
- local: ppo_trainer
title: PPO
- local: prm_trainer
title: PRM
- local: reward_trainer
title: Reward
- local: rloo_trainer
Expand Down
1 change: 1 addition & 0 deletions docs/source/dataset_formats.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ Choosing the right dataset type depends on the task you are working on and the s
| [`OnlineDPOTrainer`] | [Prompt-only](#prompt-only) |
| [`ORPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
| [`PPOTrainer`] | Tokenized language modeling |
| [`PRMTrainer`] | [Stepwise supervision](#stepwise-supervision) |
| [`RewardTrainer`] | [Preference (implicit prompt recommended)](#preference) |
| [`SFTTrainer`] | [Language modeling](#language-modeling) |
| [`XPOTrainer`] | [Prompt-only](#prompt-only) |
Expand Down
123 changes: 123 additions & 0 deletions docs/source/prm_trainer.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# PRM Trainer

<Tip warning={true}>

PRM Trainer is an experimental API which is subject to change at any time.

</Tip>

## Overview

Process-supervised Reward Models (PRM) were proposed in [Solving math word problems with process- and outcome-based feedback](https://huggingface.co/papers/2211.14275) by Jonathan Uesato, Nate Kushman, Ramana Kumar, Francis Song, Noah Siegel, Lisa Wang, Antonia Creswell, Geoffrey Irving, and Irina Higgins.

The abstract from the paper is the following:

> Recent work has shown that asking language models to generate reasoning steps improves performance on many reasoning tasks. When moving beyond prompting, this raises the question of how we should supervise such models: outcome-based approaches which supervise the final result, or process-based approaches which supervise the reasoning process itself? Differences between these approaches might naturally be expected not just in final-answer errors but also in reasoning errors, which can be difficult to detect and are problematic in many real-world domains such as education. We run the first comprehensive comparison between process- and outcome-based approaches trained on a natural language task, GSM8K. We find that pure outcome-based supervision produces similar final-answer error rates with less label supervision. However, for correct reasoning steps we find it necessary to use processbased supervision or supervision from learned reward models that emulate process-based feedback. In total, we improve the previous best results from 16.8% → 12.7% final-answer error and 14.0% → 3.4% reasoning error among final-answer-correct solutions.

This post-training method was contributed by [Gaetan Lopez](https://github.com/gaetanlop), [Lewis Tunstall](https://huggingface.co/lewtun), [Quentin Gallouédec](https://huggingface.co/qgallouedec) and [Agustín Piqueres](https://huggingface.co/plaguss).


## Quick start

This example demonstrates how to train a model using the PRM method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B) as the base model. We use the stepwise supervision data from the [Math Shepherd dataset](https://huggingface.co/datasets/trl-lib/math_shepherd). You can view the data in the dataset here:

<iframe
src="https://huggingface.co/datasets/trl-lib/math_shepherd/embed/viewer/default/train?row=0"
frameborder="0"
width="100%"
height="560px"
></iframe>

Below is the script to train the model:

```python
# train_prm.py
from datasets import load_dataset
from trl import PRMConfig, PRMTrainer
from transformers import AutoModelForTokenClassification, AutoTokenizer

model = AutoModelForTokenClassification.from_pretrained("Qwen/Qwen2-0.5B", num_labels=2)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B")
train_dataset = load_dataset("trl-lib/math_shepherd", split="train[:10%]")

training_args = PRMConfig(output_dir="Qwen2-0.5B-Reward-Math-Sheperd", logging_steps=10)
trainer = PRMTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset)
trainer.train()
```

Execute the script using the following command:

```bash
accelerate launch train_prm.py
```

Distributed across 8 GPUs, the training takes approximately 1 hour.

To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-Reward-Math-Sheperd) performs, you can use the following script.


```python
from datasets import load_dataset
from transformers import pipeline

pipe = pipeline("token-classification", model="trl-lib/Qwen2-0.5B-Reward-Math-Sheperd")
dataset = load_dataset("trl-lib/math_shepherd")
example = {
"prompt": "Musa is the class teacher of a class of 45 students. He wants to split them into three groups by age. If a third of the class is under 11 years, and two-fifths are above 11 but under 13, how many students will be in the third group (13 years and above)?",
"completions": [
"Step 1: A third of the class is under 11 years because 11 - 1/3 = <<11-1/3=7>>7.",
"Step 2: Two-fifths of the class are above 11 but under 13 because 2/5 * 11 = <<2/5*11=8>>8.",
"Step 3: There are 45 students, so the third group will have 45 - 7 - 8 = <<45-7-8=20>>20 students. The answer is: 20",
],
"labels": [True, False, False],
}


separator = "\n" # It's important to use the same separator as the one used during training

for idx in range(1, len(example["completions"]) + 1):
steps = example["completions"][0:idx]
text = separator.join((example["prompt"], *steps)) + separator # Add a separator between the prompt and each steps
pred_entity = pipe(text)[-1]["entity"]
pred = {"LABEL_0": False, "LABEL_1": True}[pred_entity]
label = example["labels"][idx - 1]
print(f"Step {idx}\tPredicted: {pred} \tLabel: {label}")
```

```text
Step 1 Predicted: True Label: True
Step 2 Predicted: False Label: False
Step 3 Predicted: False Label: False
```

It's a win!

## Expected dataset type

PRM requires a [stepwise supervision](dataset_formats#stepwise-supervision).
The dataset should contain the following columns: `prompt`, `completions` and `labels`, where `completions` contains a list of reasoning steps and `labels` a list of booleans or floats indicating the correctness of each step.

The [`PRMTrainer`] only supports [standard](dataset_formats#standard) dataset format.

## Example script

We provide an example script to train a model using the PRM method. The script is available in [`examples/scripts/prm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/prm.py)

To use the PRM script with the [Qwen2 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B) on the [Math Shepherd dataset](https://huggingface.co/datasets/trl-lib/math_shepherd), run the following command:

```bash
accelerate launch examples/scripts/prm.py \
--model_name_or_path Qwen/Qwen2-0.5B \
--dataset_name trl-lib/math_shepherd \
--num_train_epochs 1 \
--logging_steps 25 \
--output_dir Qwen2-0.5B-Reward-Math-Sheperd
```

## PRMTrainer

[[autodoc]] PRMTrainer

## PRMConfig

[[autodoc]] PRMConfig
131 changes: 131 additions & 0 deletions examples/datasets/math_shepherd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re
from dataclasses import dataclass
from itertools import chain
from typing import Optional

from datasets import load_dataset
from transformers import HfArgumentParser


@dataclass
class ScriptArguments:
r"""
Arguments for the script.
Args:
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether to push the dataset to the Hugging Face Hub.
repo_id (`str`, *optional*, defaults to `"trl-lib/math_shepherd"`):
Hugging Face repository ID to push the dataset to.
dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`):
Number of workers to use for dataset processing.
"""

push_to_hub: bool = False
repo_id: str = "trl-lib/math_shepherd"
dataset_num_proc: Optional[int] = None


def process_example(example):
# Replace "ки" with "ⶻ" so that the size of the "input" matches the size of the "label"
inputs = example["input"].replace("ки", "ⶻ")

# Find the indices of the "ⶻ" characters (that should match with the indexes of the "+" or "-" in the label)
indexes = [m.start() for m in re.finditer("ⶻ", inputs)]

# Sanity that all indexes are either "+" or "-"
assert all(example["label"][idx] in ["+", "-"] for idx in indexes)

# Get the labels
labels = [example["label"][idx] == "+" for idx in indexes]

# Split the inputs into steps (caution, the first step is missing here, it is the prompt)
steps = [inputs[i:j] for i, j in zip(chain([0], indexes), chain(indexes, [None]))]

# Remove the last step (single ⶻ)
steps = steps[:-1]

# Get the prompt (first part) and completions (rest)
prompt = steps[0]
completions = steps[1:]

# Remove the heading "ⶻ" and the final whitespace from the completions
assert all(completion.startswith("ⶻ") for completion in completions)
completions = [completion[1:].strip() for completion in completions]

# At this point, we need to retrieve the first step from the prompt.
# First, we handle particular cases (annotation error) where we have a first label before the end of the prompt.
if prompt.startswith(
(
"Mr. Rocky",
"Parker",
"What is the smallest positive",
" The Myth",
"Let $\\mathbf{a}$",
"Find the arithmetic",
"Determine an ordered pair",
"Determine the ordered pair",
"At the Quill and Scroll stationery",
"Round to the nearest",
r"Calculate $\sqrt{10p}",
r"Simplify $\sqrt{28x}",
)
):
# Some spotted datasets errors where there is an annotation in the prompt: we remove it
labels = labels[1:]

# Then we handle the general case: we get the first step from the prompt by looking for "Step 1:" or "step 1:" or
# (less common) "?".
elif "Step 1:" in prompt:
prompt, first_step = prompt.split("Step 1:")
first_step = "Step 1:" + first_step
completions = [first_step.strip()] + completions
elif "step 1:" in prompt:
prompt, first_step = prompt.split("step 1:")
first_step = "step 1:" + first_step
completions = [first_step.strip()] + completions
elif "?" in prompt:
prompt, first_step = prompt.split("?")
prompt = prompt + "?"
completions = [first_step.strip()] + completions
else:
raise ValueError(f"Prompt can't be processed: {prompt}")

# Strip the prompt
prompt = prompt.strip()

# Sanity check that the length of the completions is the same as the length of the labels
assert len(completions) == len(labels)

return {"prompt": prompt, "completions": completions, "labels": labels}


if __name__ == "__main__":
parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]

dataset = load_dataset("peiyi9979/Math-Shepherd", split="train")

dataset = dataset.map(
process_example,
remove_columns=["input", "label", "task"],
num_proc=script_args.dataset_num_proc,
)
dataset = dataset.train_test_split(test_size=0.05, seed=42)

if script_args.push_to_hub:
dataset.push_to_hub(script_args.repo_id)
Loading

0 comments on commit 179ba53

Please sign in to comment.