diff --git a/examples/ppo_sentiments.py b/examples/ppo_sentiments.py index 8b9abea25..32ee98ff8 100644 --- a/examples/ppo_sentiments.py +++ b/examples/ppo_sentiments.py @@ -34,7 +34,7 @@ def main(hparams={}): "lvwerra/distilbert-imdb", top_k=2, truncation=True, - batch_size=128, + batch_size=256, device=device, ) @@ -43,15 +43,13 @@ def reward_fn(samples: List[str]) -> List[float]: return sentiments # Take few words off of movies reviews as prompts - imdb = load_dataset("imdb", split="train") + imdb = load_dataset("imdb", split="train+test") prompts = [" ".join(review.split()[:4]) for review in imdb["text"]] - imdb = load_dataset("imdb", split="test") - val_prompts = [" ".join(review.split()[:4]) for review in imdb["text"]] trlx.train( reward_fn=reward_fn, prompts=prompts, - eval_prompts=val_prompts[0:1000], + eval_prompts=["I don't know much about Hungarian underground"] * 64, config=config, ) diff --git a/examples/summarize_daily_cnn/configs/ppo_config_cnn_daily.yml b/examples/summarize_daily_cnn/configs/ppo_config_cnn_daily.yml index 08191422c..87438ab12 100755 --- a/examples/summarize_daily_cnn/configs/ppo_config_cnn_daily.yml +++ b/examples/summarize_daily_cnn/configs/ppo_config_cnn_daily.yml @@ -51,3 +51,9 @@ method: cliprange_reward: 10 gen_kwargs: max_new_tokens: 100 + gen_experience_kwargs: + max_new_tokens: 100 + do_sample: True + temperature: 1.0 + top_k: 50 + top_p: 0.95 diff --git a/examples/summarize_rlhf/README.md b/examples/summarize_rlhf/README.md new file mode 100644 index 000000000..7f0dfb8f1 --- /dev/null +++ b/examples/summarize_rlhf/README.md @@ -0,0 +1,68 @@ +## Learning to summarize from Human Feedback using `trlx` + +This example shows how to use `trlx` to train a summarization model using human feedback +following the fine-tuning procedures described in Stiennon et al.'s, "[Learning to Summarize from human feedback](https://arxiv.org/abs/2106.00987)". + + +Before running everything, we need some extra packages not included in the `trlx` dependency list. Specifically, we need HuggingFace's [`evaluate`](https://huggingface.co/docs/evaluate/index) package and Google's re-implementation of ROUGE, [`rouge-score`](https://github.com/google-research/google-research/tree/master/rouge). To install them, run `requirements.txt` in this example's root directory: + +```bash +pip install -r requirements.txt +``` + +### Training Process + +For an in-depth description of the example, please refer to our [blog post](http://wandb.me/summarize-rlhf-trlx). We leave the following for a quick overview of the fine-tuning process and what scripts to run. + + +1. Train SFT: + ```bash + cd sft/ && deepspeed train_gptj_summarize.py + ``` + Checkpoint: [SFT](https://huggingface.co/CarperAI/openai_summarize_tldr_sft) + +2. Train Reward Model: + ```bash + cd reward_model/ && deepspeed train_reward_model_gptj.py + ``` + Download reward model checkpoint: + ```bash + mkdir reward_model/rm_checkpoint + wget https://huggingface.co/CarperAI/openai_summarize_tldr_rm_checkpoint/resolve/main/pytorch_model.bin -O reward_model/rm_checkpoint/pytorch_model.bin + ``` + +3. PPO training: + ```bash + accelerate launch --config_file configs/default_accelerate_config.yaml trlx_gptj_text_summarization.py + ``` + Checkpoint: [PPO](https://huggingface.co/CarperAI/openai_summarize_tldr_ppo) + + +### Results + +On 1,000 samples from CNN/DailyMail test dataset: + +1. SFT vs PPO + + __ROUGE scores__ + + | Model | Rouge-1 | Rouge-2 | Rouge-L | Average | + | --- | --- | --- | --- | --- | + | SFT | 0.334 | 0.125 | 0.261 | 0.240 | + | PPO | 0.323 | 0.109 | 0.238 | 0.223 | + + __Reward scores__ + + | Model | Average Reward | Reward $\Delta$ | + | --- | --- | --- | + | SFT | 2.729 | -0.181 | + | PPO | 3.291 | +0.411 | + + +2. Examples of generated summaries can be found [here](https://wandb.ai/carperai/summarize_RLHF/runs/2uirt89a). + +3. Check our blog post for metric logs and other results [here](http://wandb.me/summarize-rlhf-trlx). + +## References + +1. Nisan Stiennon, Long Ouyang, Jeff Wu, Daniel M. Ziegler, Ryan Lowe, Chelsea Voss, Alec Radford, Dario Amodei, Paul Christiano, "[Learning to Summarize from human feedback](https://arxiv.org/abs/2009.01325)", Neural Information Processing Systems, 2020. diff --git a/examples/summarize_rlhf/configs/default_accelerate_config.yaml b/examples/summarize_rlhf/configs/default_accelerate_config.yaml new file mode 100644 index 000000000..d9806ac48 --- /dev/null +++ b/examples/summarize_rlhf/configs/default_accelerate_config.yaml @@ -0,0 +1,24 @@ +command_file: null +commands: null +compute_environment: LOCAL_MACHINE +deepspeed_config: + deepspeed_config_file: ds_config_trlx_gptj_summarize.json + zero3_init_flag: false +distributed_type: DEEPSPEED +downcast_bf16: 'no' +dynamo_backend: 'NO' +fsdp_config: {} +gpu_ids: null +machine_rank: 0 +main_process_ip: null +main_process_port: null +main_training_function: main +megatron_lm_config: {} +mixed_precision: 'no' +num_machines: 1 +num_processes: 1 +rdzv_backend: static +same_network: true +tpu_name: null +tpu_zone: null +use_cpu: false diff --git a/examples/summarize_rlhf/configs/ds_config_trlx_gptj_summarize.json b/examples/summarize_rlhf/configs/ds_config_trlx_gptj_summarize.json new file mode 100644 index 000000000..f1ef786a9 --- /dev/null +++ b/examples/summarize_rlhf/configs/ds_config_trlx_gptj_summarize.json @@ -0,0 +1,22 @@ +{ + "train_micro_batch_size_per_gpu": 2, + "gradient_accumulation_steps": 4, + "fp16": { + "enabled": true, + "min_loss_scale": 0.5, + "fp16_scale_tolerance": 0.25, + "opt_level": "O2" + }, + "zero_optimization": { + "stage": 2, + "offload_param": { + "device": "cpu" + }, + "offload_optimizer": { + "device": "cpu" + }, + "allgather_partitions": true, + "allgather_bucket_size": 5e8, + "contiguous_gradients": true + } +} diff --git a/examples/summarize_rlhf/configs/ppo_config_summ_gptj.yml b/examples/summarize_rlhf/configs/ppo_config_summ_gptj.yml new file mode 100755 index 000000000..1f101cd81 --- /dev/null +++ b/examples/summarize_rlhf/configs/ppo_config_summ_gptj.yml @@ -0,0 +1,51 @@ +train: + seq_length: 550 + epochs: 50 + total_steps: 100000 + batch_size: 4 + + checkpoint_interval: 10000 + eval_interval: 200 + + pipeline: "PromptPipeline" + orchestrator: "PPOOrchestrator" + trainer: "AcceleratePPOTrainer" + +model: + model_path: "CarperAI/openai_summarize_tldr_sft" + tokenizer_path: "EleutherAI/gpt-j-6B" + num_layers_unfrozen: 8 + +optimizer: + name: "adamw" + kwargs: + lr: 5.0e-6 + betas: [0.9, 0.999] + eps: 1.0e-8 + weight_decay: 0.01 + +scheduler: + name: "cosine_annealing" + kwargs: + T_max: 100000 + eta_min: 5.0e-6 + +method: + name: "ppoconfig" + num_rollouts: 128 + chunk_size: 16 + ppo_epochs: 4 + init_kl_coef: 0.1 + target: 6 + horizon: 10000 + gamma: 1 + lam: 0.95 + cliprange: 0.2 + cliprange_value: 0.2 + vf_coef: 0.2 + scale_reward: False + ref_mean: null + ref_std: null + cliprange_reward: 10 + gen_kwargs: + max_new_tokens: 50 diff --git a/examples/summarize_rlhf/requirements.txt b/examples/summarize_rlhf/requirements.txt new file mode 100644 index 000000000..019c668dc --- /dev/null +++ b/examples/summarize_rlhf/requirements.txt @@ -0,0 +1,3 @@ +evaluate>=0.4.0 +nltk>=3.8.1 +rouge-score>=0.1.2 diff --git a/examples/summarize_rlhf/reward_model/ds_config_gpt_j.json b/examples/summarize_rlhf/reward_model/ds_config_gpt_j.json new file mode 100644 index 000000000..db6b601b2 --- /dev/null +++ b/examples/summarize_rlhf/reward_model/ds_config_gpt_j.json @@ -0,0 +1,39 @@ +{ + "train_batch_size": 32, + "fp16": { + "enabled": true, + "min_loss_scale": 1, + "opt_level": "O2" + }, + "zero_optimization": { + "stage": 2, + "offload_param": { + "device": "cpu" + }, + "offload_optimizer": { + "device": "cpu" + }, + "allgather_partitions": true, + "allgather_bucket_size": 5e8, + "contiguous_gradients": true + }, + "optimizer": { + "type": "AdamW", + "params": { + "lr": 1e-5, + "betas": [ + 0.9, + 0.999 + ], + "eps": 1e-08 + } + }, + "scheduler": { + "type": "WarmupLR", + "params": { + "warmup_min_lr": 0, + "warmup_max_lr": "auto", + "warmup_num_steps": 100 + } + } +} diff --git a/examples/summarize_rlhf/reward_model/gptj_reward_test.py b/examples/summarize_rlhf/reward_model/gptj_reward_test.py new file mode 100644 index 000000000..63a80da06 --- /dev/null +++ b/examples/summarize_rlhf/reward_model/gptj_reward_test.py @@ -0,0 +1,124 @@ +import random + +import numpy as np +import torch +from datasets import load_dataset +from reward_model import GPTRewardModel +from torch.utils.data import Dataset +from tqdm import tqdm +from transformers import AutoTokenizer + + +def set_seed(seed_val=42): + random.seed(seed_val) + np.random.seed(seed_val) + torch.manual_seed(seed_val) + torch.cuda.manual_seed_all(seed_val) + + +def create_comparison_dataset( + path="CarperAI/openai_summarize_comparisons", split="train" +): + dataset = load_dataset(path, split=split) + if split == "test": + dataset = dataset.select(range(5000)) + + pairs = [] + for sample in tqdm(dataset): + pair = {} + prompt = sample["prompt"] + chosen_summary = sample["chosen"] + rejected_summary = sample["rejected"] + if chosen_summary == rejected_summary: + continue + if len(chosen_summary.split()) < 5 or len(rejected_summary.split()) < 5: + continue + pair["chosen"] = prompt + "\n" + chosen_summary + pair["rejected"] = prompt + "\n" + rejected_summary + pairs.append(pair) + return pairs + + +class PairwiseDataset(Dataset): + def __init__(self, pairs, tokenizer, max_length): + self.chosen_input_ids = [] + self.chosen_attn_masks = [] + self.rejected_input_ids = [] + self.rejected_attn_masks = [] + for pair in pairs: + chosen, rejected = pair["chosen"], pair["rejected"] + chosen_encodings_dict = tokenizer( + "<|startoftext|>" + chosen + "<|endoftext|>", + truncation=True, + max_length=max_length, + padding="max_length", + return_tensors="pt", + ) + rejected_encodings_dict = tokenizer( + "<|startoftext|>" + rejected + "<|endoftext|>", + truncation=True, + max_length=max_length, + padding="max_length", + return_tensors="pt", + ) + self.chosen_input_ids.append(chosen_encodings_dict["input_ids"]) + self.chosen_attn_masks.append(chosen_encodings_dict["attention_mask"]) + self.rejected_input_ids.append(rejected_encodings_dict["input_ids"]) + self.rejected_attn_masks.append(rejected_encodings_dict["attention_mask"]) + + def __len__(self): + return len(self.chosen_input_ids) + + def __getitem__(self, idx): + return ( + self.chosen_input_ids[idx], + self.chosen_attn_masks[idx], + self.rejected_input_ids[idx], + self.rejected_attn_masks[idx], + ) + + +class DataCollatorReward: + def __call__(self, data): + batch = {} + batch["input_ids"] = torch.cat([f[0] for f in data] + [f[2] for f in data]) + batch["attention_mask"] = torch.cat([f[1] for f in data] + [f[3] for f in data]) + batch["labels"] = torch.tensor([0] * len(data) + [1] * len(data)) + return batch + + +if __name__ == "__main__": + tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") + tokenizer.pad_token = tokenizer.eos_token + PAD_ID = tokenizer(tokenizer.pad_token)["input_ids"][0] + + model = GPTRewardModel("CarperAI/openai_summarize_tldr_sft") + model.load_state_dict(torch.load("rm_checkpoint/pytorch_model.bin")) + max_length = 550 + val_pairs = create_comparison_dataset( + "CarperAI/openai_summarize_comparisons", "test" + ) + dev_dataset = PairwiseDataset(val_pairs, tokenizer, max_length=max_length) + + from torch.utils.data import DataLoader + + dev_dataloader = DataLoader( + dev_dataset, shuffle=False, batch_size=6, collate_fn=DataCollatorReward() + ) + model.cuda() + model.eval() + model.half() + correct = 0 + chosen_list = [] + reject_list = [] + with torch.no_grad(): + for step, batch in tqdm(enumerate(dev_dataloader), total=len(dev_dataloader)): + for x in batch: + batch[x] = batch[x].cuda() + outputs = model(**batch) + correct += sum( + outputs["chosen_end_scores"] > outputs["rejected_end_scores"] + ) + chosen_list.append(outputs["chosen_end_scores"].cpu()) + reject_list.append(outputs["rejected_end_scores"].cpu()) + print("Total accuracy: ", correct / len(dev_dataset)) diff --git a/examples/summarize_rlhf/reward_model/reward_model.py b/examples/summarize_rlhf/reward_model/reward_model.py new file mode 100644 index 000000000..e422fe9b6 --- /dev/null +++ b/examples/summarize_rlhf/reward_model/reward_model.py @@ -0,0 +1,111 @@ +import torch +from torch import nn +from transformers import AutoModelForCausalLM, AutoTokenizer + + +class GPTRewardModel(nn.Module): + def __init__(self, model_path): + super().__init__() + model = AutoModelForCausalLM.from_pretrained(model_path) + self.config = model.config + # `gpt-neo(x)` models use `hidden_size` attribute names instead of `n_embd`` + self.config.n_embd = ( + self.config.hidden_size + if hasattr(self.config, "hidden_size") + else self.config.n_embd + ) + self.transformer = model.transformer + self.v_head = nn.Linear(self.config.n_embd, 1, bias=False) + self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") + self.tokenizer.pad_token = self.tokenizer.eos_token + self.PAD_ID = self.tokenizer(self.tokenizer.pad_token)["input_ids"][0] + + def forward( + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + mc_token_ids=None, + labels=None, + return_dict=False, + output_attentions=False, + output_hidden_states=False, + ): + loss = None + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + ) + + hidden_states = transformer_outputs[0] + + rewards = self.v_head(hidden_states).squeeze(-1) + chosen_end_scores = [] + rejected_end_scores = [] + + # Split the inputs and rewards into two parts, chosen and rejected + assert len(input_ids.shape) == 2 + bs = input_ids.shape[0] // 2 + chosen = input_ids[:bs] + rejected = input_ids[bs:] + chosen_rewards = rewards[:bs] + rejected_rewards = rewards[bs:] + + # Compute pairwise loss. Only backprop on the last value before padding + loss = 0 + inference = False + for i in range(bs): + if torch.all(torch.eq(chosen[i], rejected[i])).item(): + c_inds = (chosen[i] == self.PAD_ID).nonzero() + c_ind = c_inds[0].item() if len(c_inds) > 0 else chosen.shape[1] + chosen_end_scores.append(chosen_rewards[i, c_ind - 1]) + inference = True + continue + + # Check if there is any padding otherwise take length of sequence + c_inds = (chosen[i] == self.PAD_ID).nonzero() + c_ind = c_inds[0].item() if len(c_inds) > 0 else chosen.shape[1] + r_inds = (rejected[i] == self.PAD_ID).nonzero() + r_ind = r_inds[0].item() if len(r_inds) > 0 else rejected.shape[1] + end_ind = max(c_ind, r_ind) + + # Retrieve first index where trajectories diverge + divergence_ind = (chosen[i] != rejected[i]).nonzero()[0] + assert divergence_ind > 0 + + # Index into the correct rewards + c_truncated_reward = chosen_rewards[i][divergence_ind:end_ind] + r_truncated_reward = rejected_rewards[i][divergence_ind:end_ind] + + # Append the last rewards to the list of end scores + chosen_end_scores.append(c_truncated_reward[-1]) + rejected_end_scores.append(r_truncated_reward[-1]) + + # Compute loss + loss += -torch.log( + torch.sigmoid(c_truncated_reward - r_truncated_reward) + ).mean() + loss = loss / bs + + if not inference: + chosen_end_scores = torch.stack(chosen_end_scores) + rejected_end_scores = torch.stack(rejected_end_scores) + + if inference: + chosen_end_scores = torch.stack(chosen_end_scores) + return {"chosen_end_scores": chosen_end_scores} + + return { + "loss": loss, + "chosen_end_scores": chosen_end_scores, + "rejected_end_scores": rejected_end_scores, + } diff --git a/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py b/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py new file mode 100644 index 000000000..2897ca48b --- /dev/null +++ b/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py @@ -0,0 +1,148 @@ +import os + +import torch +from datasets import load_dataset +from reward_model import GPTRewardModel +from torch.utils.data import Dataset +from tqdm import tqdm +from transformers import AutoTokenizer, Trainer, TrainingArguments + + +def create_comparison_dataset( + path="CarperAI/openai_summarize_comparisons", split="train" +): + dataset = load_dataset(path, split=split) + pairs = [] + for sample in tqdm(dataset): + pair = {} + prompt = sample["prompt"] + chosen_summary = sample["chosen"] + rejected_summary = sample["rejected"] + if chosen_summary == rejected_summary: + continue + if len(chosen_summary.split()) < 5 or len(rejected_summary.split()) < 5: + continue + pair["chosen"] = prompt + "\n" + chosen_summary + pair["rejected"] = prompt + "\n" + rejected_summary + pairs.append(pair) + return pairs + + +class PairwiseDataset(Dataset): + def __init__(self, pairs, tokenizer, max_length): + self.chosen_input_ids = [] + self.chosen_attn_masks = [] + self.rejected_input_ids = [] + self.rejected_attn_masks = [] + for pair in tqdm(pairs): + chosen, rejected = pair["chosen"], pair["rejected"] + chosen_encodings_dict = tokenizer( + "<|startoftext|>" + chosen + "<|endoftext|>", + truncation=True, + max_length=max_length, + padding="max_length", + return_tensors="pt", + ) + rejected_encodings_dict = tokenizer( + "<|startoftext|>" + rejected + "<|endoftext|>", + truncation=True, + max_length=max_length, + padding="max_length", + return_tensors="pt", + ) + self.chosen_input_ids.append(chosen_encodings_dict["input_ids"]) + self.chosen_attn_masks.append(chosen_encodings_dict["attention_mask"]) + self.rejected_input_ids.append(rejected_encodings_dict["input_ids"]) + self.rejected_attn_masks.append(rejected_encodings_dict["attention_mask"]) + + def __len__(self): + return len(self.chosen_input_ids) + + def __getitem__(self, idx): + return ( + self.chosen_input_ids[idx], + self.chosen_attn_masks[idx], + self.rejected_input_ids[idx], + self.rejected_attn_masks[idx], + ) + + +class DataCollatorReward: + def __call__(self, data): + batch = {} + batch["input_ids"] = torch.cat([f[0] for f in data] + [f[2] for f in data]) + batch["attention_mask"] = torch.cat([f[1] for f in data] + [f[3] for f in data]) + batch["labels"] = torch.tensor([0] * len(data) + [1] * len(data)) + return batch + + +def compute_metrics(eval_preds): + chosen_end_scores = eval_preds.predictions[0] # chosen scores + rejected_end_scores = eval_preds.predictions[1] # rejected scores + + result = {} + acc = sum(chosen_end_scores > rejected_end_scores) / len(rejected_end_scores) + result["accuracy"] = acc + + return result + + +if __name__ == "__main__": + tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") + tokenizer.pad_token = tokenizer.eos_token + + if not os.path.exists("rm_checkpoint"): + os.mkdir("rm_checkpoint") + + training_args = TrainingArguments( + output_dir="rm_checkpoint/", + num_train_epochs=5, + logging_steps=10, + gradient_accumulation_steps=4, + save_strategy="steps", + evaluation_strategy="steps", + per_device_train_batch_size=1, + per_device_eval_batch_size=1, + eval_accumulation_steps=1, + eval_steps=500, + save_steps=500, + warmup_steps=100, + logging_dir="./logs", + fp16=True, + bf16=False, + learning_rate=1e-5, + deepspeed="ds_config_gpt_j.json", + save_total_limit=1, + ) + + # Initialize the reward model from the (supervised) fine-tuned GPT-J + model = GPTRewardModel("CarperAI/openai_summarize_tldr_sft") + + # Freeze the first 70% of the hidden layers of the reward model backbone + layers = model.transformer.h + num_layers = len(layers) + num_unfrozen = int(0.3 * num_layers) + for layer in layers[:-num_unfrozen]: + layer.requires_grad_(False) + + # Create the comparisons datasets + data_path = "CarperAI/openai_summarize_comparisons" + train_pairs = create_comparison_dataset(data_path, "train") + val_pairs = create_comparison_dataset(data_path, "test") + + # Make pairwise datasets for training + max_length = 550 + train_dataset = PairwiseDataset(train_pairs, tokenizer, max_length=max_length) + val_dataset = PairwiseDataset(val_pairs, tokenizer, max_length=max_length) + + # Create the collator to gather batches of pairwise comparisons + data_collator = DataCollatorReward() + + Trainer( + model=model, + args=training_args, + train_dataset=train_dataset, + compute_metrics=compute_metrics, + eval_dataset=val_dataset, + data_collator=data_collator, + ).train() diff --git a/examples/summarize_rlhf/sft/ds_config_gptj.json b/examples/summarize_rlhf/sft/ds_config_gptj.json new file mode 100644 index 000000000..c06a05dd3 --- /dev/null +++ b/examples/summarize_rlhf/sft/ds_config_gptj.json @@ -0,0 +1,39 @@ +{ + "train_batch_size": 128, + "fp16": { + "enabled": true, + "min_loss_scale": 1, + "opt_level": "O2" + }, + "zero_optimization": { + "stage": 2, + "offload_param": { + "device": "cpu" + }, + "offload_optimizer": { + "device": "cpu" + }, + "allgather_partitions": true, + "allgather_bucket_size": 5e8, + "contiguous_gradients": true + }, + "optimizer": { + "type": "AdamW", + "params": { + "lr": 1e-05, + "betas": [ + 0.9, + 0.95 + ], + "eps": 1e-08 + } + }, + "scheduler": { + "type": "WarmupLR", + "params": { + "warmup_min_lr": 0, + "warmup_max_lr": 1e-05, + "warmup_num_steps": "auto" + } + } +} diff --git a/examples/summarize_rlhf/sft/summarize_dataset.py b/examples/summarize_rlhf/sft/summarize_dataset.py new file mode 100755 index 000000000..c81c9a7c9 --- /dev/null +++ b/examples/summarize_rlhf/sft/summarize_dataset.py @@ -0,0 +1,138 @@ +import json + +import pandas as pd +import torch +from datasets import load_dataset +from torch.utils.data import Dataset + + +def get_dataset_from_jsonl(jsonl_file, return_summary=True): + # if return_summary is True, return a list of posts with summary concatenated + # if return_summary is False, return a list of posts and a list of summaries + with open(jsonl_file, "r") as f: + dataset = [json.loads(line) for line in f] + post_list = [] + summary_list = [] + for d in dataset: + if return_summary: + post = f"SUBREDDIT: r/{d['subreddit']}\nTITLE: {d['title']}\nPOST: {d['post']}\nTL;DR: {d['summary']}" + else: + post = f"SUBREDDIT: r/{d['subreddit']}\nTITLE: {d['title']}\nPOST: {d['post']}\nTL;DR: " + summary_list.append(d["summary"]) + post_list.append(post) + if not return_summary: + return post_list, summary_list + return post_list + + +class TLDRDataset(Dataset): + def __init__(self, train_path, tokenizer, split, max_length=550): + self.post_list = [] + dataset = load_dataset(train_path, split=split) + for sample in dataset: + self.post_list.append(sample["prompt"] + sample["label"]) + if "valid" in train_path: + self.post_list = self.post_list[0:2000] + self.tokenizer = tokenizer + self.max_length = max_length + self.input_ids = [] + self.attn_masks = [] + + def __len__(self): + return len(self.post_list) + + def __getitem__(self, idx): + txt = self.post_list[idx] + encodings_dict = self.tokenizer( + txt, truncation=True, max_length=self.max_length, padding="max_length" + ) + input_ids = torch.tensor(encodings_dict["input_ids"]) + attn_masks = torch.tensor(encodings_dict["attention_mask"]) + + return { + "input_ids": input_ids, + "attention_mask": attn_masks, + "labels": input_ids, + } + + +class ComparisonDataset(Dataset): + def __init__(self, comparison_path, tokenizer, max_length=550): + with open(comparison_path, "r") as f: + dataset = [json.loads(line) for line in f] + + self.tokenizer = tokenizer + self.post_list = [] + self.summaries_0 = [] + self.summaries_1 = [] + self.labels = [] + self.max_length = max_length + + def make_text(post, summarize): + return f"SUBREDDIT: r/{post['subreddit']}\nTITLE: {post['title']}\nPOST: {post['post']}\nTL;DR: {summarize}" + + for sample in dataset: # chosen summary is always the first one + self.post_list.append(sample["info"]["post"]) + # NOTE: The chosen summary is always the first one, i.e. `sample["summaries"][0]` + if sample["choice"] == 0: + self.summaries_0.append( + make_text(sample["info"], sample["summaries"][0]["text"]) + ) + self.summaries_1.append( + make_text(sample["info"], sample["summaries"][1]["text"]) + ) + else: + self.summaries_0.append( + make_text(sample["info"], sample["summaries"][1]["text"]) + ) + self.summaries_1.append( + make_text(sample["info"], sample["summaries"][0]["text"]) + ) + self.labels.append(0) + + def __len__(self): + return len(self.post_list) + + def __getitem__(self, idx): + summ0 = self.summaries_0[idx] + summ1 = self.summaries_1[idx] + encodings_dict = self.tokenizer( + [summ0, summ1], + truncation=True, + max_length=self.max_length, + padding="max_length", + ) + input_ids = torch.tensor(encodings_dict["input_ids"]) + attention_mask = torch.tensor(encodings_dict["attention_mask"]) + return {"input_ids": input_ids, "attention_mask": attention_mask} + + +class AllSummDataset(Dataset): + def __init__(self, train_path, tokenizer, split, max_length=1024): + df = pd.read_parquet(train_path) + if split == "valid": + df = df.sample(n=5000) + self.summarizes = [] + for (i, row) in df.iterrows(): + self.summarizes.append(f"Summarize: {row['text']}. TL;DR: {row['summary']}") + self.tokenizer = tokenizer + self.max_length = max_length + self.input_ids = [] + self.attn_masks = [] + + def __len__(self): + return len(self.summarizes) + + def __getitem__(self, idx): + txt = self.summarizes[idx] + encodings_dict = self.tokenizer( + txt, truncation=True, max_length=self.max_length, padding="max_length" + ) + input_ids = torch.tensor(encodings_dict["input_ids"]) + attn_masks = torch.tensor(encodings_dict["attention_mask"]) + + return { + "input_ids": input_ids, + "attention_mask": attn_masks, + "labels": input_ids, + } diff --git a/examples/summarize_rlhf/sft/train_gptj_summarize.py b/examples/summarize_rlhf/sft/train_gptj_summarize.py new file mode 100755 index 000000000..b0da7d525 --- /dev/null +++ b/examples/summarize_rlhf/sft/train_gptj_summarize.py @@ -0,0 +1,108 @@ +import random + +import evaluate +import numpy as np +import torch +from summarize_dataset import TLDRDataset +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + Trainer, + TrainingArguments, + default_data_collator, +) + + +def set_seed(seed_val=42): + random.seed(seed_val) + np.random.seed(seed_val) + torch.manual_seed(seed_val) + torch.cuda.manual_seed_all(seed_val) + + +if __name__ == "__main__": + output_dir = "gptj-supervised-summarize-checkpoint" + train_batch_size = 16 + gradient_accumulation_steps = 1 + learning_rate = 1e-5 + eval_batch_size = 1 + eval_steps = 500 + max_input_length = 550 + save_steps = 1000 + num_train_epochs = 5 + random.seed(42) + + tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") + model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", use_cache=False) + tokenizer.pad_token = tokenizer.eos_token + model.resize_token_embeddings(len(tokenizer)) + tokenizer.pad_token_id = tokenizer.eos_token_id + model.config.end_token_id = tokenizer.eos_token_id + model.config.pad_token_id = model.config.eos_token_id + + # Set up the datasets + data_path = "CarperAI/openai_summarize_tldr" + train_dataset = TLDRDataset( + data_path, + tokenizer, + "train", + max_length=max_input_length, + ) + dev_dataset = TLDRDataset( + data_path, + tokenizer, + "valid", + max_length=max_input_length, + ) + + # Set up the metric + rouge = evaluate.load("rouge") + + def compute_metrics(eval_preds): + labels_ids = eval_preds.label_ids + pred_ids = eval_preds.predictions + pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True) + label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True) + result = rouge.compute(predictions=pred_str, references=label_str) + return result + + # Create a preprocessing function to extract out the proper logits from the model output + def preprocess_logits_for_metrics(logits, labels): + if isinstance(logits, tuple): + logits = logits[0] + return logits.argmax(dim=-1) + + # Prepare the trainer and start training + training_args = TrainingArguments( + output_dir=output_dir, + evaluation_strategy="steps", + eval_accumulation_steps=1, + learning_rate=learning_rate, + per_device_train_batch_size=train_batch_size, + per_device_eval_batch_size=eval_batch_size, + gradient_checkpointing=True, + half_precision_backend=True, + fp16=True, + adam_beta1=0.9, + adam_beta2=0.95, + gradient_accumulation_steps=gradient_accumulation_steps, + num_train_epochs=num_train_epochs, + warmup_steps=100, + eval_steps=eval_steps, + save_steps=save_steps, + load_best_model_at_end=True, + logging_steps=50, + deepspeed="./ds_config_gptj.json", + ) + + trainer = Trainer( + model=model, + args=training_args, + train_dataset=train_dataset, + eval_dataset=dev_dataset, + compute_metrics=compute_metrics, + data_collator=default_data_collator, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + ) + trainer.train() + trainer.save_model(output_dir) diff --git a/examples/summarize_rlhf/trlx_gptj_text_summarization.py b/examples/summarize_rlhf/trlx_gptj_text_summarization.py new file mode 100755 index 000000000..60d2aa052 --- /dev/null +++ b/examples/summarize_rlhf/trlx_gptj_text_summarization.py @@ -0,0 +1,130 @@ +import os +from typing import List + +import torch +from datasets import load_dataset +from reward_model.reward_model import GPTRewardModel +from tqdm import tqdm +from transformers import AutoTokenizer + +import trlx +from trlx.data.configs import TRLConfig + +REWARD_CHECKPOINT_PATH = "reward_model/rm_checkpoint/pytorch_model.bin" +if not os.path.exists(REWARD_CHECKPOINT_PATH): + os.makedirs("reward_model/rm_checkpoint", exist_ok=True) + os.system( + f"wget -O {REWARD_CHECKPOINT_PATH} \ + https://huggingface.co/CarperAI/openai_summarize_tldr_rm_checkpoint/resolve/main/pytorch_model.bin" + ) +SFT_MODEL_PATH = "CarperAI/openai_summarize_tldr_sft" + + +if __name__ == "__main__": + + # Load the pre-trained reward model + rw_tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") + rw_tokenizer.pad_token = rw_tokenizer.eos_token + rw_model = GPTRewardModel(SFT_MODEL_PATH) + rw_model.load_state_dict(torch.load(REWARD_CHECKPOINT_PATH)) + rw_model.half() + rw_model.eval() + rw_device = torch.device("cuda:{}".format(1)) # set reward model device + rw_model.to(rw_device) + + def get_scores(samples: List[str]): + scores_list = [] + batch_size = 2 + for i in range(0, len(samples), batch_size): + sub_samples = samples[i : i + batch_size] + sub_samples = [ + "<|startoftext|>" + chosen + "<|endoftext|>" for chosen in sub_samples + ] + encodings_dict = rw_tokenizer( + sub_samples, + truncation=True, + max_length=config.train.seq_length, + padding="max_length", + return_tensors="pt", + ) + input_ids = encodings_dict["input_ids"].to(rw_device) + attn_masks = encodings_dict["attention_mask"].to(rw_device) + input_ids = input_ids.repeat(2, 1) + attn_masks = attn_masks.repeat(2, 1) + with torch.no_grad(): + sub_scores = rw_model(input_ids=input_ids, attention_mask=attn_masks) + scores_list.append(sub_scores["chosen_end_scores"]) + scores = torch.cat(scores_list, dim=0) + return scores + + def get_prompt_dataset(prompts, max_length): + """ + Get the prompt after T5 decoding to make sure dictionary + of prompts and summaries is consistent decode prompt from trlX pipeline + """ + formatted_prompts = [] + for i in tqdm(range(len(prompts))): + tmp = tokenizer.decode( + tokenizer( + prompts[i].split("TL;DR:")[0], + truncation=True, + max_length=max_length + - 5, # to make sure "TL;DR" dont get truncated + )["input_ids"], + skip_special_tokens=True, + ).strip() + tmp = tmp + "\nTL;DR:" + tmp = tokenizer.decode( + tokenizer(tmp, truncation=True, max_length=max_length)["input_ids"], + skip_special_tokens=True, + ).strip() + formatted_prompts.append(tmp) + return formatted_prompts + + def reward_fn(samples: List[str]): + original_samples = [text.split("TL;DR:")[0] + "TL;DR: " for text in samples] + original_samples = [ + text + post_summary_dict[text.strip()] for text in original_samples + ] + original_scores = get_scores(original_samples) + scores = get_scores(samples) + norms_scores = scores - original_scores + return norms_scores + + config = TRLConfig.load_yaml("configs/ppo_config_summ_gptj.yml") + + tokenizer = AutoTokenizer.from_pretrained(config.model.tokenizer_path) + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = "left" + max_length_input = ( + config.train.seq_length - config.method.gen_kwargs["max_new_tokens"] + ) + + dataset = load_dataset("CarperAI/openai_summarize_tldr") + + # Store data into prompt and label pairs + train_set = [(sample["prompt"], sample["label"]) for sample in dataset["train"]] + val_set = [(sample["prompt"], sample["label"]) for sample in dataset["valid"]] + + # Split contents into summaries and labels + train_posts, train_summaries = zip(*train_set) + val_posts, val_summaries = zip(*val_set) + + # Get the OpenAI summaries + post_summary_dict = {} + train_prompts = get_prompt_dataset(train_posts, max_length_input) + for i in range(len(train_prompts)): + post_summary_dict[train_prompts[i]] = train_summaries[i] + val_prompts = get_prompt_dataset(val_posts, max_length_input) + for i in range(len(val_prompts)): + post_summary_dict[val_prompts[i]] = val_summaries[i] + + trainer = trlx.train( + config.model.model_path, + reward_fn=reward_fn, + prompts=train_prompts, + eval_prompts=val_prompts[ + 0:1000 + ], # sampling 1000 validation prompts for evaluation speed in training + config=config, + ) diff --git a/examples/summarize_rlhf/trlx_inference_gptj.py b/examples/summarize_rlhf/trlx_inference_gptj.py new file mode 100644 index 000000000..254d89915 --- /dev/null +++ b/examples/summarize_rlhf/trlx_inference_gptj.py @@ -0,0 +1,192 @@ +import evaluate +import pandas as pd +import torch +from datasets import load_dataset +from reward_model.reward_model import GPTRewardModel +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer + + +def load_model(path): + tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") + model = AutoModelForCausalLM.from_pretrained(path) + model.config.pad_token_id = tokenizer.bos_token_id + tokenizer.pad_token = tokenizer.eos_token + tokenizer.pad_token_id = tokenizer.bos_token_id + tokenizer.padding_side = "left" + return model, tokenizer + + +rw_tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") +rw_tokenizer.pad_token = rw_tokenizer.eos_token +rw_model = GPTRewardModel("CarperAI/openai_summarize_tldr_sft") +rw_model.load_state_dict(torch.load("reward_model/rm_checkpoint/pytorch_model.bin")) +rw_model.half() +rw_model.eval() +rw_device = torch.device("cuda:{}".format(1)) +rw_model.to(rw_device) + + +def reward_fn(samples): + scores_list = [] + batch_size = 2 + for i in range(0, len(samples), batch_size): + sub_samples = samples[i : i + batch_size] + sub_samples = [ + "<|startoftext|>" + chosen + "<|endoftext|>" for chosen in sub_samples + ] + encodings_dict = rw_tokenizer( + sub_samples, + truncation=True, + max_length=550, + padding="max_length", + return_tensors="pt", + ) + input_ids = encodings_dict["input_ids"].to(rw_device) + attn_masks = encodings_dict["attention_mask"].to(rw_device) + input_ids = input_ids.repeat(2, 1) + attn_masks = attn_masks.repeat(2, 1) + with torch.no_grad(): + sub_scores = rw_model(input_ids=input_ids, attention_mask=attn_masks) + scores_list.append(sub_scores["chosen_end_scores"]) + scores = torch.cat(scores_list, dim=0) + return scores + + +def inference(model, tokenizer): + model.to("cuda") + model.eval() + + pred_list = [] + summarize_list = [] + post_list = [] + rouge = evaluate.load("rouge") + count = 0 + for post, summarize in tqdm( + zip(test_post_list, test_summ_list), total=len(test_post_list) + ): + encode_dict = tokenizer( + post, return_tensors="pt", padding=False, truncation=True + ) + txt_tokens = encode_dict["input_ids"].cuda() + attention_mask = encode_dict["attention_mask"].cuda() + kwargs = {"max_new_tokens": 50, "eos_token_id": 50256, "pad_token_id": 50256} + summ_tokens = model.generate( + txt_tokens, attention_mask=attention_mask, **kwargs + ) + pred = tokenizer.batch_decode(summ_tokens)[0] + pred = pred.split("TL;DR:")[1].replace("<|endoftext|>", "") + pred_list.append(pred) + summarize_list.append(summarize) + post_list.append(post) + if count % 10 == 0: + result = rouge.compute(predictions=pred_list, references=summarize_list) + print(result) + count += 1 + df = pd.DataFrame.from_dict( + {"pred": pred_list, "truth": summarize_list, "post": post_list} + ) + result = rouge.compute(predictions=pred_list, references=summarize_list) + print(result) + return df + + +def inference_batches(model, tokenizer, test_post_list, test_summ_list, batch_size=16): + model.to("cuda") + model.eval() + + pred_list = [] + summarize_list = [] + post_list = [] + rouge = evaluate.load("rouge") + + # Iterate over the input data in mini-batches + for i in tqdm(range(0, len(test_post_list), batch_size)): + batch_post_list = test_post_list[i : i + batch_size] + batch_summ_list = test_summ_list[i : i + batch_size] + + # Convert input data to tensors + encode_dict = tokenizer( + batch_post_list, + return_tensors="pt", + padding=True, + truncation=True, + max_length=512, + ) + txt_tokens = encode_dict["input_ids"].cuda() + attention_mask = encode_dict["attention_mask"].cuda() + + # Perform inference on the batch + kwargs = {"max_new_tokens": 50, "eos_token_id": 50256, "pad_token_id": 50256} + summ_tokens = model.generate( + txt_tokens, attention_mask=attention_mask, **kwargs + ) + + # Decode output tokens + preds = tokenizer.batch_decode(summ_tokens) + + # Add predictions, truths, and input posts to lists + pred_list += preds + summarize_list += batch_summ_list + post_list += batch_post_list + + # Compute rouge scores every 10 mini-batches + result = rouge.compute(predictions=pred_list, references=summarize_list) + print(result) + + # Compute final rouge scores and create a dataframe + result = rouge.compute(predictions=pred_list, references=summarize_list) + print(result) + df = pd.DataFrame.from_dict( + {"pred": pred_list, "truth": summarize_list, "post": post_list} + ) + return df + + +if __name__ == "__main__": + + model, tokenizer = load_model("CarperAI/openai_summarize_tldr_sft") + + test_post_list = [ + sample["prompt"] + for sample in load_dataset("CarperAI/openai_summarize_tldr", split="test") + ] + test_summ_list = [ + sample["label"] + for sample in load_dataset("CarperAI/openai_summarize_tldr", split="test") + ] + + df_result = inference(model, tokenizer) + sup_pred = df_result["pred"].values + truth = df_result["truth"].values + + scores_pred = [] + scores_truth = [] + preds_list = [] + truth_list = [] + post_list = [] + batch_size = 16 + for i in range(0, len(df_result), batch_size): + predicts = df_result["pred"].values[i : i + batch_size] + labels = df_result["truth"].values[i : i + batch_size] + posts = df_result["post"].values[i : i + batch_size] + data_pred = [posts[i] + predicts[i] for i in range(len(predicts))] + data_truth = [posts[i] + labels[i] for i in range(len(labels))] + preds_list.extend(list(predicts)) + truth_list.extend(list(labels)) + post_list.extend(list(posts)) + scores_pred.extend(list(reward_fn(data_pred).cpu().numpy())) + scores_truth.extend(list(reward_fn(data_truth).cpu().numpy())) + + df = pd.DataFrame.from_dict( + { + "pred": preds_list, + "truth": truth_list, + "post": post_list, + "score_pred": scores_pred, + "score_truth": scores_truth, + } + ) + df.to_csv("ppo_with_reward_scores.csv", index=False) + print("Reward score pred: ", df.score_pred.values.mean()) + print("Reward score truth: ", df.score_truth.values.mean()) diff --git a/trlx/orchestrator/ppo_orchestrator.py b/trlx/orchestrator/ppo_orchestrator.py index 1da373365..5729a89f6 100644 --- a/trlx/orchestrator/ppo_orchestrator.py +++ b/trlx/orchestrator/ppo_orchestrator.py @@ -65,6 +65,9 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq stats = {} clock = Clock() while len(ppo_rl_elements) < num_rollouts: + if self.trainer.accelerator.is_main_process: + print(f"Making experience {len(ppo_rl_elements)} / {num_rollouts}") + # Get next batch in prompt dataset and refresh if exhausted try: batch: PromptBatch = next(self.pipeline_iterator) @@ -182,10 +185,8 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq ref_logits[:, :-1, :], response_tensors[:, 1:] ) else: - logprobs = logprobs_from_logits(logits[:, :-1, :], all_tokens[:, 1:]) - ref_logprobs = logprobs_from_logits( - ref_logits[:, :-1, :], all_tokens[:, 1:] - ) + logprobs = logprobs_from_logits(logits, all_tokens) + ref_logprobs = logprobs_from_logits(ref_logits, all_tokens) n = samples.shape[0] logprobs = logprobs.cpu() @@ -206,21 +207,23 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq for ix in range(n) ] else: + logprobs = logprobs_from_logits(logits[:, :-1, :], all_tokens[:, 1:]) + ref_logprobs = logprobs_from_logits( + ref_logits[:, :-1, :], all_tokens[:, 1:] + ) + n = samples.shape[0] - values = values.cpu() + values = values.cpu()[:, :-1] logprobs = logprobs.cpu() ref_logprobs = ref_logprobs.cpu() query_tensors = query_tensors.cpu() response_tensors = response_tensors.cpu() - start = ( - query_tensors.shape[1] - 1 - ) # left shift by 1 ref: https://github.com/lvwerra/trl/blob/main/trl/trainer/ppo_trainer.py#L425 - ends = start + attention_mask[:, start:].sum(1) - 1 - for ix in range(n): - if ends[ix] == all_tokens.shape[1]: - ends[ix] = ends[ix] - 1 - all_values = [values[ix, start - 1 : ends[ix] - 1] for ix in range(n)] + + start = query_tensors.shape[1] - 1 + ends = start + attention_mask[:, start:].sum(1) + all_values = [values[ix, start : ends[ix]] for ix in range(n)] all_logprobs = [logprobs[ix, start : ends[ix]] for ix in range(n)] + rewards = -self.trainer.kl_ctl.value * (logprobs - ref_logprobs) rewards = [rs[start : ends[ix]] for ix, rs in enumerate(rewards)] diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index d6d858f7d..5e24eec14 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -219,7 +219,7 @@ def evaluate(self): # noqa: C901 stats = {} all_samples = [] prompts_sizes = [] - lst_prompts = [] + prompts_list = [] generate_time = time() for prompts in self.eval_dataloader: if isinstance(prompts, torch.Tensor): @@ -244,7 +244,7 @@ def evaluate(self): # noqa: C901 len(prompts.input_ids) ) prompts_sizes.append(sizes.to(samples.device)) - lst_prompts.extend(prompts.input_ids) + prompts_list.extend(prompts.input_ids) stats["time/generate"] = time() - generate_time @@ -258,7 +258,7 @@ def evaluate(self): # noqa: C901 if self.tokenizer: prompts, responses = [], [] if self.config.model.model_arch_type == "seq2seq": - prompts = lst_prompts + prompts = prompts_list responses = all_samples else: for sample, prompt_size in zip(samples, prompts_sizes): diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index 030e481cc..85da0bd16 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -148,16 +148,17 @@ def loss(self, batch: PPORLBatch): attention_mask = ( tokens.not_equal(self.tokenizer.pad_token_id).long().to(tokens.device) ) - logits, *_, values_pred = self.model( - tokens, - attention_mask=attention_mask, - ) + outputs = self.model(tokens, attention_mask, return_dict=True) + logits = outputs.logits + values_pred = outputs.value + values_pred = values_pred[:, :-1] logprobs = logprobs_from_logits(logits[:, :-1, :], tokens[:, 1:]) + start = query_tensors.shape[1] - 1 end = start + response_length logprobs, values_pred, mask = ( logprobs[:, start:end], - values_pred[:, start - 1 : end - 1], + values_pred[:, start:end], attention_mask[:, start:end], )