Skip to content

Commit

Permalink
Add OpenAI Summarize RLHF with trlX (#175)
Browse files Browse the repository at this point in the history
* init add summarize openai tl;dr

* fix style

* add read me

* add precommit

* update readme

* reformat code precommit

* update config

* change shift leff PPO (make it worked for summarize openai), add rollouts with best-n-sampling options, fix typo reward model training

* revert old ppo value state (t5) - rewards scores, update data part train sft

* Small fixes

* change config to model_path rw model

* Cleanup naming

* update newest result and new examples link

* old sentiment eval prompts

* Replace links with CarperAI HuggingFace refs

* Remove unused `gen_kwargs`

* Add `requirements.txt` and update `README`

* Add accelerate config

* change accelerate file name

* refactor trlx ppo train file

* fix reward model download link

* fix reward model download link

* add blog post to readme

* Remove `best_of_n` sampling for now

* update accelerate command to run ppo

* fix reference link in readme

* Update README.md

* Cleanup example `README.md`

* Format `README.md`

* Update dataset comment

* Revert return from examples to avoid RayTune errors

* Final `README.md` format edit

* Fix sft summarize data path

Co-authored-by: Duy Phung <[email protected]>
Co-authored-by: Duy Phung <[email protected]>
Co-authored-by: Duy Phung <[email protected]>
Co-authored-by: jon-tow <[email protected]>
Co-authored-by: Duy Phung <[email protected]>
Co-authored-by: Louis Castricato <[email protected]>
  • Loading branch information
7 people authored Jan 12, 2023
1 parent 0cb8438 commit 7ed923c
Show file tree
Hide file tree
Showing 19 changed files with 1,231 additions and 26 deletions.
8 changes: 3 additions & 5 deletions examples/ppo_sentiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def main(hparams={}):
"lvwerra/distilbert-imdb",
top_k=2,
truncation=True,
batch_size=128,
batch_size=256,
device=device,
)

Expand All @@ -43,15 +43,13 @@ def reward_fn(samples: List[str]) -> List[float]:
return sentiments

# Take few words off of movies reviews as prompts
imdb = load_dataset("imdb", split="train")
imdb = load_dataset("imdb", split="train+test")
prompts = [" ".join(review.split()[:4]) for review in imdb["text"]]
imdb = load_dataset("imdb", split="test")
val_prompts = [" ".join(review.split()[:4]) for review in imdb["text"]]

trlx.train(
reward_fn=reward_fn,
prompts=prompts,
eval_prompts=val_prompts[0:1000],
eval_prompts=["I don't know much about Hungarian underground"] * 64,
config=config,
)

Expand Down
6 changes: 6 additions & 0 deletions examples/summarize_daily_cnn/configs/ppo_config_cnn_daily.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,9 @@ method:
cliprange_reward: 10
gen_kwargs:
max_new_tokens: 100
gen_experience_kwargs:
max_new_tokens: 100
do_sample: True
temperature: 1.0
top_k: 50
top_p: 0.95
68 changes: 68 additions & 0 deletions examples/summarize_rlhf/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
## Learning to summarize from Human Feedback using `trlx`

This example shows how to use `trlx` to train a summarization model using human feedback
following the fine-tuning procedures described in Stiennon et al.'s, "[Learning to Summarize from human feedback](https://arxiv.org/abs/2106.00987)".


Before running everything, we need some extra packages not included in the `trlx` dependency list. Specifically, we need HuggingFace's [`evaluate`](https://huggingface.co/docs/evaluate/index) package and Google's re-implementation of ROUGE, [`rouge-score`](https://github.com/google-research/google-research/tree/master/rouge). To install them, run `requirements.txt` in this example's root directory:

```bash
pip install -r requirements.txt
```

### Training Process

For an in-depth description of the example, please refer to our [blog post](http://wandb.me/summarize-rlhf-trlx). We leave the following for a quick overview of the fine-tuning process and what scripts to run.


1. Train SFT:
```bash
cd sft/ && deepspeed train_gptj_summarize.py
```
Checkpoint: [SFT](https://huggingface.co/CarperAI/openai_summarize_tldr_sft)

2. Train Reward Model:
```bash
cd reward_model/ && deepspeed train_reward_model_gptj.py
```
Download reward model checkpoint:
```bash
mkdir reward_model/rm_checkpoint
wget https://huggingface.co/CarperAI/openai_summarize_tldr_rm_checkpoint/resolve/main/pytorch_model.bin -O reward_model/rm_checkpoint/pytorch_model.bin
```

3. PPO training:
```bash
accelerate launch --config_file configs/default_accelerate_config.yaml trlx_gptj_text_summarization.py
```
Checkpoint: [PPO](https://huggingface.co/CarperAI/openai_summarize_tldr_ppo)


### Results

On 1,000 samples from CNN/DailyMail test dataset:

1. SFT vs PPO

__ROUGE scores__

| Model | Rouge-1 | Rouge-2 | Rouge-L | Average |
| --- | --- | --- | --- | --- |
| SFT | 0.334 | 0.125 | 0.261 | 0.240 |
| PPO | 0.323 | 0.109 | 0.238 | 0.223 |

__Reward scores__

| Model | Average Reward | Reward $\Delta$ |
| --- | --- | --- |
| SFT | 2.729 | -0.181 |
| PPO | 3.291 | +0.411 |


2. Examples of generated summaries can be found [here](https://wandb.ai/carperai/summarize_RLHF/runs/2uirt89a).

3. Check our blog post for metric logs and other results [here](http://wandb.me/summarize-rlhf-trlx).

## References

1. Nisan Stiennon, Long Ouyang, Jeff Wu, Daniel M. Ziegler, Ryan Lowe, Chelsea Voss, Alec Radford, Dario Amodei, Paul Christiano, "[Learning to Summarize from human feedback](https://arxiv.org/abs/2009.01325)", Neural Information Processing Systems, 2020.
24 changes: 24 additions & 0 deletions examples/summarize_rlhf/configs/default_accelerate_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
command_file: null
commands: null
compute_environment: LOCAL_MACHINE
deepspeed_config:
deepspeed_config_file: ds_config_trlx_gptj_summarize.json
zero3_init_flag: false
distributed_type: DEEPSPEED
downcast_bf16: 'no'
dynamo_backend: 'NO'
fsdp_config: {}
gpu_ids: null
machine_rank: 0
main_process_ip: null
main_process_port: null
main_training_function: main
megatron_lm_config: {}
mixed_precision: 'no'
num_machines: 1
num_processes: 1
rdzv_backend: static
same_network: true
tpu_name: null
tpu_zone: null
use_cpu: false
22 changes: 22 additions & 0 deletions examples/summarize_rlhf/configs/ds_config_trlx_gptj_summarize.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
{
"train_micro_batch_size_per_gpu": 2,
"gradient_accumulation_steps": 4,
"fp16": {
"enabled": true,
"min_loss_scale": 0.5,
"fp16_scale_tolerance": 0.25,
"opt_level": "O2"
},
"zero_optimization": {
"stage": 2,
"offload_param": {
"device": "cpu"
},
"offload_optimizer": {
"device": "cpu"
},
"allgather_partitions": true,
"allgather_bucket_size": 5e8,
"contiguous_gradients": true
}
}
51 changes: 51 additions & 0 deletions examples/summarize_rlhf/configs/ppo_config_summ_gptj.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
train:
seq_length: 550
epochs: 50
total_steps: 100000
batch_size: 4

checkpoint_interval: 10000
eval_interval: 200

pipeline: "PromptPipeline"
orchestrator: "PPOOrchestrator"
trainer: "AcceleratePPOTrainer"

model:
model_path: "CarperAI/openai_summarize_tldr_sft"
tokenizer_path: "EleutherAI/gpt-j-6B"
num_layers_unfrozen: 8

optimizer:
name: "adamw"
kwargs:
lr: 5.0e-6
betas: [0.9, 0.999]
eps: 1.0e-8
weight_decay: 0.01

scheduler:
name: "cosine_annealing"
kwargs:
T_max: 100000
eta_min: 5.0e-6

method:
name: "ppoconfig"
num_rollouts: 128
chunk_size: 16
ppo_epochs: 4
init_kl_coef: 0.1
target: 6
horizon: 10000
gamma: 1
lam: 0.95
cliprange: 0.2
cliprange_value: 0.2
vf_coef: 0.2
scale_reward: False
ref_mean: null
ref_std: null
cliprange_reward: 10
gen_kwargs:
max_new_tokens: 50
3 changes: 3 additions & 0 deletions examples/summarize_rlhf/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
evaluate>=0.4.0
nltk>=3.8.1
rouge-score>=0.1.2
39 changes: 39 additions & 0 deletions examples/summarize_rlhf/reward_model/ds_config_gpt_j.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
{
"train_batch_size": 32,
"fp16": {
"enabled": true,
"min_loss_scale": 1,
"opt_level": "O2"
},
"zero_optimization": {
"stage": 2,
"offload_param": {
"device": "cpu"
},
"offload_optimizer": {
"device": "cpu"
},
"allgather_partitions": true,
"allgather_bucket_size": 5e8,
"contiguous_gradients": true
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": 1e-5,
"betas": [
0.9,
0.999
],
"eps": 1e-08
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": 0,
"warmup_max_lr": "auto",
"warmup_num_steps": 100
}
}
}
124 changes: 124 additions & 0 deletions examples/summarize_rlhf/reward_model/gptj_reward_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import random

import numpy as np
import torch
from datasets import load_dataset
from reward_model import GPTRewardModel
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import AutoTokenizer


def set_seed(seed_val=42):
random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)


def create_comparison_dataset(
path="CarperAI/openai_summarize_comparisons", split="train"
):
dataset = load_dataset(path, split=split)
if split == "test":
dataset = dataset.select(range(5000))

pairs = []
for sample in tqdm(dataset):
pair = {}
prompt = sample["prompt"]
chosen_summary = sample["chosen"]
rejected_summary = sample["rejected"]
if chosen_summary == rejected_summary:
continue
if len(chosen_summary.split()) < 5 or len(rejected_summary.split()) < 5:
continue
pair["chosen"] = prompt + "\n" + chosen_summary
pair["rejected"] = prompt + "\n" + rejected_summary
pairs.append(pair)
return pairs


class PairwiseDataset(Dataset):
def __init__(self, pairs, tokenizer, max_length):
self.chosen_input_ids = []
self.chosen_attn_masks = []
self.rejected_input_ids = []
self.rejected_attn_masks = []
for pair in pairs:
chosen, rejected = pair["chosen"], pair["rejected"]
chosen_encodings_dict = tokenizer(
"<|startoftext|>" + chosen + "<|endoftext|>",
truncation=True,
max_length=max_length,
padding="max_length",
return_tensors="pt",
)
rejected_encodings_dict = tokenizer(
"<|startoftext|>" + rejected + "<|endoftext|>",
truncation=True,
max_length=max_length,
padding="max_length",
return_tensors="pt",
)
self.chosen_input_ids.append(chosen_encodings_dict["input_ids"])
self.chosen_attn_masks.append(chosen_encodings_dict["attention_mask"])
self.rejected_input_ids.append(rejected_encodings_dict["input_ids"])
self.rejected_attn_masks.append(rejected_encodings_dict["attention_mask"])

def __len__(self):
return len(self.chosen_input_ids)

def __getitem__(self, idx):
return (
self.chosen_input_ids[idx],
self.chosen_attn_masks[idx],
self.rejected_input_ids[idx],
self.rejected_attn_masks[idx],
)


class DataCollatorReward:
def __call__(self, data):
batch = {}
batch["input_ids"] = torch.cat([f[0] for f in data] + [f[2] for f in data])
batch["attention_mask"] = torch.cat([f[1] for f in data] + [f[3] for f in data])
batch["labels"] = torch.tensor([0] * len(data) + [1] * len(data))
return batch


if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
tokenizer.pad_token = tokenizer.eos_token
PAD_ID = tokenizer(tokenizer.pad_token)["input_ids"][0]

model = GPTRewardModel("CarperAI/openai_summarize_tldr_sft")
model.load_state_dict(torch.load("rm_checkpoint/pytorch_model.bin"))
max_length = 550
val_pairs = create_comparison_dataset(
"CarperAI/openai_summarize_comparisons", "test"
)
dev_dataset = PairwiseDataset(val_pairs, tokenizer, max_length=max_length)

from torch.utils.data import DataLoader

dev_dataloader = DataLoader(
dev_dataset, shuffle=False, batch_size=6, collate_fn=DataCollatorReward()
)
model.cuda()
model.eval()
model.half()
correct = 0
chosen_list = []
reject_list = []
with torch.no_grad():
for step, batch in tqdm(enumerate(dev_dataloader), total=len(dev_dataloader)):
for x in batch:
batch[x] = batch[x].cuda()
outputs = model(**batch)
correct += sum(
outputs["chosen_end_scores"] > outputs["rejected_end_scores"]
)
chosen_list.append(outputs["chosen_end_scores"].cpu())
reject_list.append(outputs["rejected_end_scores"].cpu())
print("Total accuracy: ", correct / len(dev_dataset))
Loading

0 comments on commit 7ed923c

Please sign in to comment.