-
Notifications
You must be signed in to change notification settings - Fork 471
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add OpenAI Summarize RLHF with trlX (#175)
* init add summarize openai tl;dr * fix style * add read me * add precommit * update readme * reformat code precommit * update config * change shift leff PPO (make it worked for summarize openai), add rollouts with best-n-sampling options, fix typo reward model training * revert old ppo value state (t5) - rewards scores, update data part train sft * Small fixes * change config to model_path rw model * Cleanup naming * update newest result and new examples link * old sentiment eval prompts * Replace links with CarperAI HuggingFace refs * Remove unused `gen_kwargs` * Add `requirements.txt` and update `README` * Add accelerate config * change accelerate file name * refactor trlx ppo train file * fix reward model download link * fix reward model download link * add blog post to readme * Remove `best_of_n` sampling for now * update accelerate command to run ppo * fix reference link in readme * Update README.md * Cleanup example `README.md` * Format `README.md` * Update dataset comment * Revert return from examples to avoid RayTune errors * Final `README.md` format edit * Fix sft summarize data path Co-authored-by: Duy Phung <[email protected]> Co-authored-by: Duy Phung <[email protected]> Co-authored-by: Duy Phung <[email protected]> Co-authored-by: jon-tow <[email protected]> Co-authored-by: Duy Phung <[email protected]> Co-authored-by: Louis Castricato <[email protected]>
- Loading branch information
1 parent
0cb8438
commit 7ed923c
Showing
19 changed files
with
1,231 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
## Learning to summarize from Human Feedback using `trlx` | ||
|
||
This example shows how to use `trlx` to train a summarization model using human feedback | ||
following the fine-tuning procedures described in Stiennon et al.'s, "[Learning to Summarize from human feedback](https://arxiv.org/abs/2106.00987)". | ||
|
||
|
||
Before running everything, we need some extra packages not included in the `trlx` dependency list. Specifically, we need HuggingFace's [`evaluate`](https://huggingface.co/docs/evaluate/index) package and Google's re-implementation of ROUGE, [`rouge-score`](https://github.com/google-research/google-research/tree/master/rouge). To install them, run `requirements.txt` in this example's root directory: | ||
|
||
```bash | ||
pip install -r requirements.txt | ||
``` | ||
|
||
### Training Process | ||
|
||
For an in-depth description of the example, please refer to our [blog post](http://wandb.me/summarize-rlhf-trlx). We leave the following for a quick overview of the fine-tuning process and what scripts to run. | ||
|
||
|
||
1. Train SFT: | ||
```bash | ||
cd sft/ && deepspeed train_gptj_summarize.py | ||
``` | ||
Checkpoint: [SFT](https://huggingface.co/CarperAI/openai_summarize_tldr_sft) | ||
|
||
2. Train Reward Model: | ||
```bash | ||
cd reward_model/ && deepspeed train_reward_model_gptj.py | ||
``` | ||
Download reward model checkpoint: | ||
```bash | ||
mkdir reward_model/rm_checkpoint | ||
wget https://huggingface.co/CarperAI/openai_summarize_tldr_rm_checkpoint/resolve/main/pytorch_model.bin -O reward_model/rm_checkpoint/pytorch_model.bin | ||
``` | ||
|
||
3. PPO training: | ||
```bash | ||
accelerate launch --config_file configs/default_accelerate_config.yaml trlx_gptj_text_summarization.py | ||
``` | ||
Checkpoint: [PPO](https://huggingface.co/CarperAI/openai_summarize_tldr_ppo) | ||
|
||
|
||
### Results | ||
|
||
On 1,000 samples from CNN/DailyMail test dataset: | ||
|
||
1. SFT vs PPO | ||
|
||
__ROUGE scores__ | ||
|
||
| Model | Rouge-1 | Rouge-2 | Rouge-L | Average | | ||
| --- | --- | --- | --- | --- | | ||
| SFT | 0.334 | 0.125 | 0.261 | 0.240 | | ||
| PPO | 0.323 | 0.109 | 0.238 | 0.223 | | ||
|
||
__Reward scores__ | ||
|
||
| Model | Average Reward | Reward $\Delta$ | | ||
| --- | --- | --- | | ||
| SFT | 2.729 | -0.181 | | ||
| PPO | 3.291 | +0.411 | | ||
|
||
|
||
2. Examples of generated summaries can be found [here](https://wandb.ai/carperai/summarize_RLHF/runs/2uirt89a). | ||
|
||
3. Check our blog post for metric logs and other results [here](http://wandb.me/summarize-rlhf-trlx). | ||
|
||
## References | ||
|
||
1. Nisan Stiennon, Long Ouyang, Jeff Wu, Daniel M. Ziegler, Ryan Lowe, Chelsea Voss, Alec Radford, Dario Amodei, Paul Christiano, "[Learning to Summarize from human feedback](https://arxiv.org/abs/2009.01325)", Neural Information Processing Systems, 2020. |
24 changes: 24 additions & 0 deletions
24
examples/summarize_rlhf/configs/default_accelerate_config.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
command_file: null | ||
commands: null | ||
compute_environment: LOCAL_MACHINE | ||
deepspeed_config: | ||
deepspeed_config_file: ds_config_trlx_gptj_summarize.json | ||
zero3_init_flag: false | ||
distributed_type: DEEPSPEED | ||
downcast_bf16: 'no' | ||
dynamo_backend: 'NO' | ||
fsdp_config: {} | ||
gpu_ids: null | ||
machine_rank: 0 | ||
main_process_ip: null | ||
main_process_port: null | ||
main_training_function: main | ||
megatron_lm_config: {} | ||
mixed_precision: 'no' | ||
num_machines: 1 | ||
num_processes: 1 | ||
rdzv_backend: static | ||
same_network: true | ||
tpu_name: null | ||
tpu_zone: null | ||
use_cpu: false |
22 changes: 22 additions & 0 deletions
22
examples/summarize_rlhf/configs/ds_config_trlx_gptj_summarize.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
{ | ||
"train_micro_batch_size_per_gpu": 2, | ||
"gradient_accumulation_steps": 4, | ||
"fp16": { | ||
"enabled": true, | ||
"min_loss_scale": 0.5, | ||
"fp16_scale_tolerance": 0.25, | ||
"opt_level": "O2" | ||
}, | ||
"zero_optimization": { | ||
"stage": 2, | ||
"offload_param": { | ||
"device": "cpu" | ||
}, | ||
"offload_optimizer": { | ||
"device": "cpu" | ||
}, | ||
"allgather_partitions": true, | ||
"allgather_bucket_size": 5e8, | ||
"contiguous_gradients": true | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
train: | ||
seq_length: 550 | ||
epochs: 50 | ||
total_steps: 100000 | ||
batch_size: 4 | ||
|
||
checkpoint_interval: 10000 | ||
eval_interval: 200 | ||
|
||
pipeline: "PromptPipeline" | ||
orchestrator: "PPOOrchestrator" | ||
trainer: "AcceleratePPOTrainer" | ||
|
||
model: | ||
model_path: "CarperAI/openai_summarize_tldr_sft" | ||
tokenizer_path: "EleutherAI/gpt-j-6B" | ||
num_layers_unfrozen: 8 | ||
|
||
optimizer: | ||
name: "adamw" | ||
kwargs: | ||
lr: 5.0e-6 | ||
betas: [0.9, 0.999] | ||
eps: 1.0e-8 | ||
weight_decay: 0.01 | ||
|
||
scheduler: | ||
name: "cosine_annealing" | ||
kwargs: | ||
T_max: 100000 | ||
eta_min: 5.0e-6 | ||
|
||
method: | ||
name: "ppoconfig" | ||
num_rollouts: 128 | ||
chunk_size: 16 | ||
ppo_epochs: 4 | ||
init_kl_coef: 0.1 | ||
target: 6 | ||
horizon: 10000 | ||
gamma: 1 | ||
lam: 0.95 | ||
cliprange: 0.2 | ||
cliprange_value: 0.2 | ||
vf_coef: 0.2 | ||
scale_reward: False | ||
ref_mean: null | ||
ref_std: null | ||
cliprange_reward: 10 | ||
gen_kwargs: | ||
max_new_tokens: 50 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
evaluate>=0.4.0 | ||
nltk>=3.8.1 | ||
rouge-score>=0.1.2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
{ | ||
"train_batch_size": 32, | ||
"fp16": { | ||
"enabled": true, | ||
"min_loss_scale": 1, | ||
"opt_level": "O2" | ||
}, | ||
"zero_optimization": { | ||
"stage": 2, | ||
"offload_param": { | ||
"device": "cpu" | ||
}, | ||
"offload_optimizer": { | ||
"device": "cpu" | ||
}, | ||
"allgather_partitions": true, | ||
"allgather_bucket_size": 5e8, | ||
"contiguous_gradients": true | ||
}, | ||
"optimizer": { | ||
"type": "AdamW", | ||
"params": { | ||
"lr": 1e-5, | ||
"betas": [ | ||
0.9, | ||
0.999 | ||
], | ||
"eps": 1e-08 | ||
} | ||
}, | ||
"scheduler": { | ||
"type": "WarmupLR", | ||
"params": { | ||
"warmup_min_lr": 0, | ||
"warmup_max_lr": "auto", | ||
"warmup_num_steps": 100 | ||
} | ||
} | ||
} |
124 changes: 124 additions & 0 deletions
124
examples/summarize_rlhf/reward_model/gptj_reward_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
import random | ||
|
||
import numpy as np | ||
import torch | ||
from datasets import load_dataset | ||
from reward_model import GPTRewardModel | ||
from torch.utils.data import Dataset | ||
from tqdm import tqdm | ||
from transformers import AutoTokenizer | ||
|
||
|
||
def set_seed(seed_val=42): | ||
random.seed(seed_val) | ||
np.random.seed(seed_val) | ||
torch.manual_seed(seed_val) | ||
torch.cuda.manual_seed_all(seed_val) | ||
|
||
|
||
def create_comparison_dataset( | ||
path="CarperAI/openai_summarize_comparisons", split="train" | ||
): | ||
dataset = load_dataset(path, split=split) | ||
if split == "test": | ||
dataset = dataset.select(range(5000)) | ||
|
||
pairs = [] | ||
for sample in tqdm(dataset): | ||
pair = {} | ||
prompt = sample["prompt"] | ||
chosen_summary = sample["chosen"] | ||
rejected_summary = sample["rejected"] | ||
if chosen_summary == rejected_summary: | ||
continue | ||
if len(chosen_summary.split()) < 5 or len(rejected_summary.split()) < 5: | ||
continue | ||
pair["chosen"] = prompt + "\n" + chosen_summary | ||
pair["rejected"] = prompt + "\n" + rejected_summary | ||
pairs.append(pair) | ||
return pairs | ||
|
||
|
||
class PairwiseDataset(Dataset): | ||
def __init__(self, pairs, tokenizer, max_length): | ||
self.chosen_input_ids = [] | ||
self.chosen_attn_masks = [] | ||
self.rejected_input_ids = [] | ||
self.rejected_attn_masks = [] | ||
for pair in pairs: | ||
chosen, rejected = pair["chosen"], pair["rejected"] | ||
chosen_encodings_dict = tokenizer( | ||
"<|startoftext|>" + chosen + "<|endoftext|>", | ||
truncation=True, | ||
max_length=max_length, | ||
padding="max_length", | ||
return_tensors="pt", | ||
) | ||
rejected_encodings_dict = tokenizer( | ||
"<|startoftext|>" + rejected + "<|endoftext|>", | ||
truncation=True, | ||
max_length=max_length, | ||
padding="max_length", | ||
return_tensors="pt", | ||
) | ||
self.chosen_input_ids.append(chosen_encodings_dict["input_ids"]) | ||
self.chosen_attn_masks.append(chosen_encodings_dict["attention_mask"]) | ||
self.rejected_input_ids.append(rejected_encodings_dict["input_ids"]) | ||
self.rejected_attn_masks.append(rejected_encodings_dict["attention_mask"]) | ||
|
||
def __len__(self): | ||
return len(self.chosen_input_ids) | ||
|
||
def __getitem__(self, idx): | ||
return ( | ||
self.chosen_input_ids[idx], | ||
self.chosen_attn_masks[idx], | ||
self.rejected_input_ids[idx], | ||
self.rejected_attn_masks[idx], | ||
) | ||
|
||
|
||
class DataCollatorReward: | ||
def __call__(self, data): | ||
batch = {} | ||
batch["input_ids"] = torch.cat([f[0] for f in data] + [f[2] for f in data]) | ||
batch["attention_mask"] = torch.cat([f[1] for f in data] + [f[3] for f in data]) | ||
batch["labels"] = torch.tensor([0] * len(data) + [1] * len(data)) | ||
return batch | ||
|
||
|
||
if __name__ == "__main__": | ||
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") | ||
tokenizer.pad_token = tokenizer.eos_token | ||
PAD_ID = tokenizer(tokenizer.pad_token)["input_ids"][0] | ||
|
||
model = GPTRewardModel("CarperAI/openai_summarize_tldr_sft") | ||
model.load_state_dict(torch.load("rm_checkpoint/pytorch_model.bin")) | ||
max_length = 550 | ||
val_pairs = create_comparison_dataset( | ||
"CarperAI/openai_summarize_comparisons", "test" | ||
) | ||
dev_dataset = PairwiseDataset(val_pairs, tokenizer, max_length=max_length) | ||
|
||
from torch.utils.data import DataLoader | ||
|
||
dev_dataloader = DataLoader( | ||
dev_dataset, shuffle=False, batch_size=6, collate_fn=DataCollatorReward() | ||
) | ||
model.cuda() | ||
model.eval() | ||
model.half() | ||
correct = 0 | ||
chosen_list = [] | ||
reject_list = [] | ||
with torch.no_grad(): | ||
for step, batch in tqdm(enumerate(dev_dataloader), total=len(dev_dataloader)): | ||
for x in batch: | ||
batch[x] = batch[x].cuda() | ||
outputs = model(**batch) | ||
correct += sum( | ||
outputs["chosen_end_scores"] > outputs["rejected_end_scores"] | ||
) | ||
chosen_list.append(outputs["chosen_end_scores"].cpu()) | ||
reject_list.append(outputs["rejected_end_scores"].cpu()) | ||
print("Total accuracy: ", correct / len(dev_dataset)) |
Oops, something went wrong.