From f04446d9bcff37a1215279b1394fc8aff93b47d0 Mon Sep 17 00:00:00 2001 From: Dahoas Date: Mon, 5 Jun 2023 05:29:32 -0700 Subject: [PATCH 01/27] Implementing support for dense rewards --- examples/ppo_redemption.py | 83 +++++++++++++++++++++++++ trlx/trainer/accelerate_base_trainer.py | 10 +-- trlx/trainer/accelerate_ppo_trainer.py | 46 ++++++++++---- 3 files changed, 120 insertions(+), 19 deletions(-) create mode 100644 examples/ppo_redemption.py diff --git a/examples/ppo_redemption.py b/examples/ppo_redemption.py new file mode 100644 index 000000000..b930b2dc7 --- /dev/null +++ b/examples/ppo_redemption.py @@ -0,0 +1,83 @@ +# Generates positive movie reviews by tuning a pretrained model on IMDB dataset +# with a sentiment reward function +import json +import os +import sys +from typing import List + +import torch +from datasets import load_dataset +from transformers import pipeline, AutoTokenizer + +import trlx +from trlx.data.default_configs import TRLConfig, default_ppo_config + + +def get_positive_score(scores): + "Extract value associated with a positive sentiment from pipeline's output" + return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"] + +def get_negative_score(scores): + return dict(map(lambda x: tuple(x.values()), scores))["NEGATIVE"] + + +def main(hparams={}): + # Merge sweep config with default config if given + config = TRLConfig.update(default_ppo_config().to_dict(), hparams) + config.method.cliprange_reward = False + config.method.gen_kwargs["max_new_tokens"] = 70 + config.method.gen_kwargs["temperature"] = 0.3 + config.train.total_steps = 20000 + config.train.checkpoint_interval = 10000000 + #config.method.init_kl_coef = 0 + + if torch.cuda.is_available(): + device = int(os.environ.get("LOCAL_RANK", 0)) + else: + device = -1 + + sentiment_fn = pipeline( + "sentiment-analysis", + "lvwerra/distilbert-imdb", + top_k=2, + truncation=True, + batch_size=256, + device=device, + ) + + def dense_reward_fn(samples: List[str], prompts: List[str], outputs: List[str], model_tok, **kwargs) -> List[float]: + # Reward positively for initially negative then positive review + # Reward functions should never receive padded text except for a singel EOS at the end + # Reward function should return token rewards for just the response + # Note: To get trajectory length, the reward fn should not tokenize the samples but should instead separately tokenizer prompts and outputs and then combine them + # Also note outputs has a single EOS at end of each + first_halves = [".".join(sample.split(".")[:len(sample.split(".")) // 2]) for sample in samples] + negative_first_halves = list(map(get_negative_score, sentiment_fn(first_halves))) + second_halves = [".".join(sample.split(".")[len(sample.split(".")) // 2:]) for sample in samples] + positive_second_halves = list(map(get_positive_score, sentiment_fn(second_halves))) + text_scores = [[f, s] for f, s in zip(negative_first_halves, positive_second_halves)] + tok_scores = [] + for sample, prompt, response, text_score in zip(samples, prompts, outputs, text_scores): + toks = model_tok(response).input_ids + tok_score = [0] * len(toks) + # Hacky way of assigning intermediate score + tok_score[len(tok_score) // 2] = text_score[0] + tok_score[-1] = text_score[1] + tok_scores.append(tok_score) + return tok_scores + + # Take few words off of movies reviews as prompts + imdb = load_dataset("imdb", split="train+test") + prompts = [" ".join(review.split()[:4]) for review in imdb["text"]] + + trlx.train( + reward_fn=dense_reward_fn, + prompts=prompts, + eval_prompts=["I don't know much about Hungarian underground"] * 256, + config=config, + ) + + +if __name__ == "__main__": + hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1]) + main(hparams) diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index 5c82335c0..18af1333d 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -232,9 +232,7 @@ def decode( # or add one if it was trimmed with `self.stop_sequences`. # When a generation ended due to `max_new_tokens` exhaustion, # only then or token would not be present in the original sample at the end - if append_eos_token and ( - trimmed or sample[-1] == self.tokenizer.eos_token_id or sample[-1] == self.tokenizer.pad_token_id - ): + if append_eos_token: str_output += self.tokenizer.eos_token str_prompts.append(str_prompt) @@ -427,10 +425,8 @@ def evaluate(self): # noqa: C901 # in online setting, compute the reward for validation if self.reward_fn: logger.info("Computing rewards") - rewards = torch.tensor( - self.reward_fn(samples=str_samples, prompts=str_prompts, outputs=str_outputs, **metadata), - dtype=float, - ) + rewards = self.reward_fn(samples=str_samples, prompts=str_prompts, outputs=str_outputs, model_tok=self.tokenizer, **metadata) + rewards = torch.tensor([sum(r) if type(r) is list else r for r in rewards], dtype=float) mean_reward = rewards.mean().item() columns.append("reward") if not isinstance(rewards, list): diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index a3af9aa3f..985b79c21 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -6,6 +6,7 @@ import torch import torch.nn.functional as F +from torch.nn.utils.rnn import pad_sequence import transformers from torch.utils.data import DataLoader from transformers import AutoTokenizer @@ -297,21 +298,24 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq ) rollout_score_time = time() - all_scores = torch.tensor( - self.reward_fn( - samples=all_str_samples, prompts=all_str_prompts, outputs=all_str_outputs, **metadata - ), - dtype=torch.float, - device=device, - ) + # reward_fn should return list of rewards at each token per sample + # NOTE: all_scores[0][i] is the reward due to token (action) i in prompt + response (b/c of how kl is computed) + all_scores = self.reward_fn(samples=all_str_samples, prompts=all_str_prompts, outputs=all_str_outputs, model_tok=self.tokenizer, **metadata) + all_scores = [torch.tensor(score, dtype=torch.float, device=device).view(-1,) for score in all_scores] + # Pad 0 reward on the ends + all_scores = pad_sequence(all_scores, batch_first=True, padding_value=-1) + max_len = torch.tensor(all_scores.shape[1], dtype=torch.long, device=device) + stats["time/rollout_score"] = time() - rollout_score_time - all_scores = list(all_scores.reshape(self.accelerator.num_processes, -1).unbind()) + all_scores = list(all_scores.reshape(self.accelerator.num_processes, -1, max_len).unbind()) else: all_scores = None + max_len = torch.tensor(0, dtype=torch.long, device=device) if torch.distributed.is_initialized(): - scores = torch.empty(len(samples), device=device) + torch.distributed.broadcast(max_len, 0) + scores = torch.empty((len(samples), max_len), device=device) torch.distributed.scatter(scores, all_scores) else: scores = all_scores[0].clone().detach() @@ -342,7 +346,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq # store statistics of the initial rollout as reference if self.ref_mean is None: - self.ref_mean, self.ref_std = scores.mean(), scores.std() + self.ref_mean, self.ref_std = scores.sum(dim=1).mean(), scores.sum(dim=1).std() all_scores_mean, all_scores_std = self.running_moments.update(scores) stats["rollout_scores/mean"] = all_scores_mean.item() stats["rollout_scores/std"] = all_scores_std.item() @@ -415,6 +419,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq logprobs = logprobs_of_labels(logits[:, :-1, :], sample_outputs[:, 1:]) ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], sample_outputs[:, 1:]) else: + # NOTE: logprob[i] is (log)prob at which all_token[i+1] was sampled logprobs = logprobs_of_labels(logits[:, :-1, :], all_tokens[:, 1:]) ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], all_tokens[:, 1:]) @@ -425,6 +430,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq attention_mask = sample_outputs != self.tokenizer.pad_token_id start = 0 else: + # NOTE: -1 because kl[prompt_tensors.shape[1]] is kl of the second token in the response start = prompt_tensors.shape[1] - 1 log_ratio = (logprobs - ref_logprobs) * attention_mask[:, :-1] @@ -436,12 +442,16 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq ref_logprobs = ref_logprobs.cpu() prompt_tensors = prompt_tensors.cpu() sample_outputs = sample_outputs.cpu() + # TODO(dahoas): Why [:, :-1]? Redudant with clipping via start : ends[ix]? + # Actually I think it's just wrong? values = values.cpu()[:, :-1] # Get the logprobs and values, for tokens that are not padding, - # from the start of the prompt up to the token, while also including the latter + # from the end of the prompt up to the token, while also including the latter # (these are taken from the student model and not the reference model) ends = start + attention_mask[:, start:].sum(1) + 1 + # NOTE: values[i] is the value of the state after response token i + # TODO(dahoas): Does it actually make sense to get the rewards one step early? all_values = [values[ix, start : ends[ix]] for ix in range(n_samples)] all_logprobs = [logprobs[ix, start : ends[ix]] for ix in range(n_samples)] @@ -451,8 +461,20 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq rollout_count = 0 for sample_idx in range(n_samples): + # To compute per token reward first add in kl penalties over trajectory + # NOTE: kl_penalty[i] is kl_diff at token i+1 in the output (w/o EOS) rewards = kl_penalty[sample_idx] - rewards[-1] += scores[sample_idx].cpu() + # Then add in rewards + if scores.shape[1] == 1: + # NOTE: Final reward given at EOS token following HHH practice + rewards[-1] += scores[sample_idx][0].cpu() + else: + score = scores[sample_idx] + score_right_padding = torch.sum(score != -1) + score = score[:score_right_padding].cpu() + p_score = torch.zeros_like(rewards) + p_score[:score.shape[0]] += score + rewards += p_score ppo_rl_elements.append( PPORLElement( From 13a01fc6f986e36c77572d7d6732ceadc213b098 Mon Sep 17 00:00:00 2001 From: Sharath Raparthy Date: Fri, 16 Jun 2023 06:42:41 -0700 Subject: [PATCH 02/27] added "num_return_sequences" param which corresponds to n in Best-of-N sampling --- trlx/data/default_configs.py | 1 + 1 file changed, 1 insertion(+) diff --git a/trlx/data/default_configs.py b/trlx/data/default_configs.py index 5277d7010..9f82a5ba3 100644 --- a/trlx/data/default_configs.py +++ b/trlx/data/default_configs.py @@ -54,6 +54,7 @@ def default_ppo_config(): top_k=0, top_p=1.0, do_sample=True, + num_return_sequences=16, ), ), ) From 5421a73bd680cc328db5b96ce1a3768243da8682 Mon Sep 17 00:00:00 2001 From: Sharath Raparthy Date: Fri, 16 Jun 2023 07:20:21 -0700 Subject: [PATCH 03/27] updates to "num_return_sequences" param --- trlx/data/default_configs.py | 2 +- trlx/models/modeling_ppo.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/trlx/data/default_configs.py b/trlx/data/default_configs.py index 9f82a5ba3..2b9b67b52 100644 --- a/trlx/data/default_configs.py +++ b/trlx/data/default_configs.py @@ -49,12 +49,12 @@ def default_ppo_config(): ref_mean=None, ref_std=None, cliprange_reward=10, + num_return_sequences=10, gen_kwargs=dict( max_new_tokens=40, top_k=0, top_p=1.0, do_sample=True, - num_return_sequences=16, ), ), ) diff --git a/trlx/models/modeling_ppo.py b/trlx/models/modeling_ppo.py index 82d3ec637..eba137802 100644 --- a/trlx/models/modeling_ppo.py +++ b/trlx/models/modeling_ppo.py @@ -130,6 +130,7 @@ class PPOConfig(MethodConfig): ref_std: Optional[float] cliprange_reward: float gen_kwargs: dict + num_return_sequences: int gen_experience_kwargs: Optional[dict] = None def get_advantages_and_returns( From 2f3ac2816e60af5aeb9f2b8eac5e16a8465e9616 Mon Sep 17 00:00:00 2001 From: Sharath Raparthy Date: Fri, 16 Jun 2023 07:30:25 -0700 Subject: [PATCH 04/27] BoN implementation --- trlx/trainer/accelerate_ppo_trainer.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index 985b79c21..32e6a860c 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -274,10 +274,10 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq rollout_generate_time = time() # Generate samples from the language model (similar to using HuggingFace `generate` method) - samples = self.generate(batch["input_ids"], batch["attention_mask"]) + samples = self.generate(batch["input_ids"], batch["attention_mask"], num_return_sequences=self.config.method.num_return_sequences) stats["time/rollout_generate"] = time() - rollout_generate_time - prompt_tensors = batch.input_ids + prompt_tensors = batch.input_ids.repeat_interleave(self.config.method.num_return_sequences, dim=0) # TODO: It is hard-coded to 10 here. Change it to a variable device = samples.device prompt_sizes = torch.tensor([prompt_tensors.shape[1]] * len(prompt_tensors), device=device) @@ -319,6 +319,11 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq torch.distributed.scatter(scores, all_scores) else: scores = all_scores[0].clone().detach() + # Best-of-N Sampling. + max_score_indices = self.get_max_indices(scores, self.config.method.num_return_sequences, device) + scores = scores.index_select(0, max_score_indices) + samples = samples.index_select(0, max_score_indices) + prompt_tensors = prompt_tensors.index_select(0, max_score_indices) str_samples, str_prompts, str_outputs = self.decode(prompt_tensors, samples, append_eos_token=True) @@ -507,3 +512,16 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq # Push samples and rewards to trainer's rollout storage self.push_to_store(ppo_rl_elements) + + @staticmethod + def get_max_indices(input_tensor, window_size, device): + # Use unfold to create the sliding windows + unfolded = input_tensor.unfold(0, window_size, window_size) + + # Find the max values and indices along the unfolded dimension + values, indices = unfolded.max(dim=2) + + # Adjust indices to be relative to original tensor + indices += torch.arange(0, input_tensor.size(0) - window_size + 1, window_size).to(device).unsqueeze(1) + + return indices.squeeze() From 2f1dace62a637ded875ff7955e16d77e64ac0419 Mon Sep 17 00:00:00 2001 From: Sharath Raparthy Date: Mon, 19 Jun 2023 03:13:37 -0700 Subject: [PATCH 05/27] Changed back to default. --- trlx/data/default_configs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trlx/data/default_configs.py b/trlx/data/default_configs.py index 2b9b67b52..57adeea8b 100644 --- a/trlx/data/default_configs.py +++ b/trlx/data/default_configs.py @@ -49,7 +49,7 @@ def default_ppo_config(): ref_mean=None, ref_std=None, cliprange_reward=10, - num_return_sequences=10, + num_return_sequences=1, gen_kwargs=dict( max_new_tokens=40, top_k=0, From f58170dc3022f1c21f7bd53c5c88882984240751 Mon Sep 17 00:00:00 2001 From: Sharath Raparthy Date: Mon, 19 Jun 2023 05:25:32 -0700 Subject: [PATCH 06/27] TopK sampling instead of Top1 --- trlx/data/default_configs.py | 1 + trlx/models/modeling_ppo.py | 1 + trlx/trainer/accelerate_ppo_trainer.py | 21 +++++++++------------ 3 files changed, 11 insertions(+), 12 deletions(-) diff --git a/trlx/data/default_configs.py b/trlx/data/default_configs.py index 57adeea8b..b29202628 100644 --- a/trlx/data/default_configs.py +++ b/trlx/data/default_configs.py @@ -50,6 +50,7 @@ def default_ppo_config(): ref_std=None, cliprange_reward=10, num_return_sequences=1, + num_train_sequences=1, gen_kwargs=dict( max_new_tokens=40, top_k=0, diff --git a/trlx/models/modeling_ppo.py b/trlx/models/modeling_ppo.py index eba137802..45c7780d6 100644 --- a/trlx/models/modeling_ppo.py +++ b/trlx/models/modeling_ppo.py @@ -131,6 +131,7 @@ class PPOConfig(MethodConfig): cliprange_reward: float gen_kwargs: dict num_return_sequences: int + num_train_sequences: int gen_experience_kwargs: Optional[dict] = None def get_advantages_and_returns( diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index 32e6a860c..ec3594174 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -320,10 +320,10 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq else: scores = all_scores[0].clone().detach() # Best-of-N Sampling. - max_score_indices = self.get_max_indices(scores, self.config.method.num_return_sequences, device) - scores = scores.index_select(0, max_score_indices) - samples = samples.index_select(0, max_score_indices) - prompt_tensors = prompt_tensors.index_select(0, max_score_indices) + train_indices = self.get_topk_indices(input_tensor=scores, window_size=self.config.method.num_return_sequences,k=self.config.method.num_train_sequences, device=device) + scores = scores.index_select(0, train_indices) + samples = samples.index_select(0, train_indices) + prompt_tensors = prompt_tensors.index_select(0, train_indices) str_samples, str_prompts, str_outputs = self.decode(prompt_tensors, samples, append_eos_token=True) @@ -514,14 +514,11 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq self.push_to_store(ppo_rl_elements) @staticmethod - def get_max_indices(input_tensor, window_size, device): + def get_topk_indices(input_tensor, window_size: int, k: int, device): # Use unfold to create the sliding windows unfolded = input_tensor.unfold(0, window_size, window_size) - - # Find the max values and indices along the unfolded dimension - values, indices = unfolded.max(dim=2) - + # Find the topk values and indices along the unfolded dimension + _, indices = torch.topk(unfolded, k, dim=2) # Adjust indices to be relative to original tensor - indices += torch.arange(0, input_tensor.size(0) - window_size + 1, window_size).to(device).unsqueeze(1) - - return indices.squeeze() + indices = indices.squeeze(1) + torch.arange(0, input_tensor.size(0) - window_size + 1, window_size).to(device).unsqueeze(1) + return indices.reshape(-1) From be8bc1a27929157dfd50c2f7a053ee91083e0f76 Mon Sep 17 00:00:00 2001 From: Sharath Raparthy Date: Mon, 26 Jun 2023 03:10:16 -0700 Subject: [PATCH 07/27] summed along dim=1 --- trlx/trainer/accelerate_ppo_trainer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index ec3594174..cd7fd90a2 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -515,6 +515,8 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq @staticmethod def get_topk_indices(input_tensor, window_size: int, k: int, device): + # Sum the scores along dim 1 + input_tensor = input_tensor.sum(1) # Use unfold to create the sliding windows unfolded = input_tensor.unfold(0, window_size, window_size) # Find the topk values and indices along the unfolded dimension From 608d812478bb6e38a2e86296de604572bddcb3cc Mon Sep 17 00:00:00 2001 From: Sharath Raparthy Date: Mon, 26 Jun 2023 06:45:56 -0700 Subject: [PATCH 08/27] Generating samples in chunks --- trlx/trainer/accelerate_base_trainer.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index 18af1333d..60e6233f7 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -247,8 +247,9 @@ def decode( return str_samples, str_prompts, str_outputs - def generate(self, input_ids, attention_mask=None, **kwargs): + def generate(self, input_ids, attention_mask=None, chunk_size=4, **kwargs): """Wraps hf's `generate` adding some specific method's defaults""" + # Decide into chunk sizes and generate saples input_ids = input_ids.to(self.accelerator.device) if attention_mask is not None: attention_mask = attention_mask.to(self.accelerator.device) @@ -256,11 +257,23 @@ def generate(self, input_ids, attention_mask=None, **kwargs): kwargs = dict(self.generate_experience_kwargs, **kwargs) else: kwargs = dict(self.generate_kwargs, **kwargs) - - with torch.no_grad(): - return self.accelerator.unwrap_model(self.model).generate( - input_ids=input_ids, attention_mask=attention_mask, **kwargs - ) + # Chunk input_ids and attention_mask + + input_ids = input_ids.chunk(chunk_size, 0) + if attention_mask is not None: + attention_mask = attention_mask.chunk(chunk_size, 0) + + samples = [] + for chunk_idx in range(chunk_size): + with torch.no_grad(): + sample = self.accelerator.unwrap_model(self.model).generate( + input_ids=input_ids[chunk_idx], attention_mask=attention_mask[chunk_idx], **kwargs + ) + samples.append(sample) + # Concat samples + samples = torch.cat(samples, 0) + return samples + def generate_eval(self, input_ids, attention_mask=None, **kwargs): """Wraps hf's `generate` adding some specific method's defaults""" From d8557e73002de89764d813753dacac27f5afee82 Mon Sep 17 00:00:00 2001 From: Sharath Raparthy Date: Mon, 26 Jun 2023 08:09:10 -0700 Subject: [PATCH 09/27] added gen_chunk_size parameter --- trlx/data/default_configs.py | 5 +++-- trlx/models/modeling_ppo.py | 1 + 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/trlx/data/default_configs.py b/trlx/data/default_configs.py index b29202628..e49b46f65 100644 --- a/trlx/data/default_configs.py +++ b/trlx/data/default_configs.py @@ -49,8 +49,9 @@ def default_ppo_config(): ref_mean=None, ref_std=None, cliprange_reward=10, - num_return_sequences=1, - num_train_sequences=1, + num_return_sequences=10, + num_train_sequences=10, + gen_chunk_size=4, gen_kwargs=dict( max_new_tokens=40, top_k=0, diff --git a/trlx/models/modeling_ppo.py b/trlx/models/modeling_ppo.py index 45c7780d6..5bb808b41 100644 --- a/trlx/models/modeling_ppo.py +++ b/trlx/models/modeling_ppo.py @@ -132,6 +132,7 @@ class PPOConfig(MethodConfig): gen_kwargs: dict num_return_sequences: int num_train_sequences: int + gen_chunk_size: int gen_experience_kwargs: Optional[dict] = None def get_advantages_and_returns( From 8ef9c36622cab21bab46dd1e3a60250bf587113c Mon Sep 17 00:00:00 2001 From: Sharath Raparthy Date: Mon, 26 Jun 2023 08:09:30 -0700 Subject: [PATCH 10/27] chunking in forward prop --- trlx/trainer/accelerate_ppo_trainer.py | 86 ++++++++++++++++---------- 1 file changed, 55 insertions(+), 31 deletions(-) diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index cd7fd90a2..2ac901f03 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -274,7 +274,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq rollout_generate_time = time() # Generate samples from the language model (similar to using HuggingFace `generate` method) - samples = self.generate(batch["input_ids"], batch["attention_mask"], num_return_sequences=self.config.method.num_return_sequences) + samples = self.generate(batch["input_ids"], batch["attention_mask"], chunk_size=self.config.method.gen_chunk_size, num_return_sequences=self.config.method.num_return_sequences) stats["time/rollout_generate"] = time() - rollout_generate_time prompt_tensors = batch.input_ids.repeat_interleave(self.config.method.num_return_sequences, dim=0) # TODO: It is hard-coded to 10 here. Change it to a variable @@ -395,39 +395,63 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq return_dict=True, ).logits else: + values_chunks = [] + logits_chunks = [] + ref_logits_chunks = [] + log_probs_chunks = [] + ref_logprobs_chunks = [] all_tokens = torch.cat((prompt_tensors.to(device), sample_outputs), dim=1) attention_mask = all_tokens.not_equal(self.tokenizer.pad_token_id).long().to(device) position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) - with torch.no_grad(): - logits, *_, values = self.model( - all_tokens, attention_mask=attention_mask, position_ids=position_ids - ) - # TODO(dahoas): When hydra model works need to also support generation on hydra head - if hasattr(self.model, "frozen_head") or self.model.peft_type: - ref_logits = self.model.forward_hydra( - all_tokens, - attention_mask=attention_mask, + all_tokens_chunks = torch.chunk(all_tokens, chunks=self.config.method.gen_chunk_size, dim=0) + attention_mask_chunks = torch.chunk(attention_mask, chunks=self.config.method.gen_chunk_size, dim=0) + position_ids_chunks = torch.chunk(position_ids, chunks=self.config.method.gen_chunk_size, dim=0) + for all_tokens_chunk, attention_mask_chunk, position_ids_chunk in zip(all_tokens_chunks, attention_mask_chunks, position_ids_chunks): + with torch.no_grad(): + logits, *_, values = self.model( + all_tokens_chunk, + attention_mask=attention_mask_chunk, position_ids=position_ids, - return_dict=True, - ).logits + ) + # TODO(dahoas): When hydra model works need to also support generation on hydra head + if hasattr(self.model, "frozen_head"): + ref_logits = self.model.forward_hydra( + all_tokens_chunk, + attention_mask=attention_mask_chunk, + position_ids=position_ids, + return_dict=True, + ).logits + elif hasattr(self, "ref_model"): + ref_logits = self.ref_model( + all_tokens_chunk, + attention_mask=attention_mask_chunk, + position_ids=position_ids, + return_dict=True, + ).logits + ref_logits = ref_logits.to(device) + else: + ref_logits = logits.clone().detach() + if self.config.model.model_arch_type == "seq2seq": + logprobs = logprobs_of_labels(logits[:, :-1, :], sample_outputs[:, 1:]) + ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], sample_outputs[:, 1:]) else: - ref_logits = self.ref_model( - all_tokens, - attention_mask=attention_mask, - position_ids=position_ids, - return_dict=True, - ).logits - ref_logits = ref_logits.to(device) - - if self.config.model.model_arch_type == "seq2seq": - logprobs = logprobs_of_labels(logits[:, :-1, :], sample_outputs[:, 1:]) - ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], sample_outputs[:, 1:]) - else: - # NOTE: logprob[i] is (log)prob at which all_token[i+1] was sampled - logprobs = logprobs_of_labels(logits[:, :-1, :], all_tokens[:, 1:]) - ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], all_tokens[:, 1:]) - + # NOTE: logprob[i] is (log)prob at which all_token[i+1] was sampled + logprobs = logprobs_of_labels(logits[:, :-1, :], all_tokens_chunk[:, 1:]) + ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], all_tokens_chunk[:, 1:]) + + values_chunks.append(values.cpu()) + logits_chunks.append(logits.cpu()) + ref_logits_chunks.append(ref_logits.cpu()) + log_probs_chunks.append(logprobs.cpu()) + ref_logprobs_chunks.append(ref_logprobs.cpu()) + + values = torch.cat(values_chunks, dim=0) + logits = torch.cat(logits_chunks, dim=0) + ref_logits = torch.cat(ref_logits_chunks, dim=0) + logprobs = torch.cat(log_probs_chunks, dim=0) + ref_logprobs = torch.cat(ref_logprobs_chunks, dim=0) + n_samples: int = samples.shape[0] # Estimate the KL divergence between the model and reference model @@ -437,7 +461,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq else: # NOTE: -1 because kl[prompt_tensors.shape[1]] is kl of the second token in the response start = prompt_tensors.shape[1] - 1 - + attention_mask = attention_mask.cpu() log_ratio = (logprobs - ref_logprobs) * attention_mask[:, :-1] kl = log_ratio.exp() - 1 - log_ratio mean_kl_per_token = kl.mean() @@ -494,7 +518,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq rollout_count += 1 if torch.distributed.is_initialized(): - torch.distributed.all_reduce(mean_kl, torch.distributed.ReduceOp.AVG) + torch.distributed.all_reduce(mean_kl.to(self.accelerator.device), torch.distributed.ReduceOp.AVG) stats["time/rollout_time"] = clock.tick() stats["policy/sqrt_kl"] = torch.sqrt(mean_kl).item() @@ -516,7 +540,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq @staticmethod def get_topk_indices(input_tensor, window_size: int, k: int, device): # Sum the scores along dim 1 - input_tensor = input_tensor.sum(1) + input_tensor = input_tensor.sum(1).unsqueeze(1) # Use unfold to create the sliding windows unfolded = input_tensor.unfold(0, window_size, window_size) # Find the topk values and indices along the unfolded dimension From 4c1d82df50884a3d731d43053da0c63b15ef4508 Mon Sep 17 00:00:00 2001 From: Sharath Raparthy Date: Mon, 26 Jun 2023 08:13:52 -0700 Subject: [PATCH 11/27] chunking generations in train and eval --- trlx/trainer/accelerate_base_trainer.py | 58 ++++++++++++++++--------- 1 file changed, 38 insertions(+), 20 deletions(-) diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index 60e6233f7..1efdb8195 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -257,25 +257,28 @@ def generate(self, input_ids, attention_mask=None, chunk_size=4, **kwargs): kwargs = dict(self.generate_experience_kwargs, **kwargs) else: kwargs = dict(self.generate_kwargs, **kwargs) - # Chunk input_ids and attention_mask - - input_ids = input_ids.chunk(chunk_size, 0) - if attention_mask is not None: - attention_mask = attention_mask.chunk(chunk_size, 0) - - samples = [] - for chunk_idx in range(chunk_size): + if chunk_size is not None: + # Chunk input_ids and attention_mask + input_ids = input_ids.chunk(chunk_size, 0) + if attention_mask is not None: + attention_mask = attention_mask.chunk(chunk_size, 0) + samples = [] + for chunk_idx in range(chunk_size): + with torch.no_grad(): + sample = self.accelerator.unwrap_model(self.model).generate( + input_ids=input_ids[chunk_idx], attention_mask=attention_mask[chunk_idx], **kwargs + ) + samples.append(sample) + # Concat samples + return torch.cat(samples, 0) + else: with torch.no_grad(): - sample = self.accelerator.unwrap_model(self.model).generate( - input_ids=input_ids[chunk_idx], attention_mask=attention_mask[chunk_idx], **kwargs + return self.accelerator.unwrap_model(self.model).generate( + input_ids=input_ids, attention_mask=attention_mask, **kwargs ) - samples.append(sample) - # Concat samples - samples = torch.cat(samples, 0) - return samples - + - def generate_eval(self, input_ids, attention_mask=None, **kwargs): + def generate_eval(self, input_ids, attention_mask=None, chunk_size=None, **kwargs): """Wraps hf's `generate` adding some specific method's defaults""" input_ids = input_ids.to(self.accelerator.device) if attention_mask is not None: @@ -283,10 +286,25 @@ def generate_eval(self, input_ids, attention_mask=None, **kwargs): kwargs = dict(self.generate_kwargs, **kwargs) - with torch.no_grad(): - return self.accelerator.unwrap_model(self.model).generate( - input_ids=input_ids, attention_mask=attention_mask, **kwargs - ) + if chunk_size is not None: + # Chunk input_ids and attention_mask + input_ids = input_ids.chunk(chunk_size, 0) + if attention_mask is not None: + attention_mask = attention_mask.chunk(chunk_size, 0) + samples = [] + for chunk_idx in range(chunk_size): + with torch.no_grad(): + sample = self.accelerator.unwrap_model(self.model).generate( + input_ids=input_ids[chunk_idx], attention_mask=attention_mask[chunk_idx], **kwargs + ) + samples.append(sample) + # Concat samples + return torch.cat(samples, 0) + else: + with torch.no_grad(): + return self.accelerator.unwrap_model(self.model).generate( + input_ids=input_ids, attention_mask=attention_mask, **kwargs + ) def save_pretrained(self, directory: Optional[str] = None, **kwargs): """Save the underlying Hugging Face model, tokenizer, and configuration files to a directory for From ecd5107e3f119d6f84951cdc1f59ced2d819862b Mon Sep 17 00:00:00 2001 From: Dahoas Date: Mon, 5 Jun 2023 05:29:32 -0700 Subject: [PATCH 12/27] Implementing support for dense rewards --- trlx/trainer/accelerate_base_trainer.py | 7 ++++++- trlx/trainer/accelerate_ppo_trainer.py | 9 ++++++--- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index 1efdb8195..3e11d0c53 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -457,7 +457,12 @@ def evaluate(self): # noqa: C901 if self.reward_fn: logger.info("Computing rewards") rewards = self.reward_fn(samples=str_samples, prompts=str_prompts, outputs=str_outputs, model_tok=self.tokenizer, **metadata) - rewards = torch.tensor([sum(r) if type(r) is list else r for r in rewards], dtype=float) + if type(rewards[0]) is torch.Tensor: + rewards = torch.tensor([reward.sum().item() for reward in rewards], dtype=float) + elif type(rewards[0]) is list: + rewards = torch.tensor([sum(reward) for reward in rewards]) + else: + rewards = torch.tensor(rewards) mean_reward = rewards.mean().item() columns.append("reward") if not isinstance(rewards, list): diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index 2ac901f03..ad2d5167c 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -319,11 +319,14 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq torch.distributed.scatter(scores, all_scores) else: scores = all_scores[0].clone().detach() - # Best-of-N Sampling. + # Best-of-N Sampling. + scores_mask = scores != -1 train_indices = self.get_topk_indices(input_tensor=scores, window_size=self.config.method.num_return_sequences,k=self.config.method.num_train_sequences, device=device) scores = scores.index_select(0, train_indices) samples = samples.index_select(0, train_indices) prompt_tensors = prompt_tensors.index_select(0, train_indices) + scores_mask = scores_mask[train_indices] + str_samples, str_prompts, str_outputs = self.decode(prompt_tensors, samples, append_eos_token=True) @@ -351,7 +354,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq # store statistics of the initial rollout as reference if self.ref_mean is None: - self.ref_mean, self.ref_std = scores.sum(dim=1).mean(), scores.sum(dim=1).std() + self.ref_mean, self.ref_std = (scores * scores_mask).sum(dim=1).mean(), (scores * scores_mask).sum(dim=1).std() all_scores_mean, all_scores_std = self.running_moments.update(scores) stats["rollout_scores/mean"] = all_scores_mean.item() stats["rollout_scores/std"] = all_scores_std.item() @@ -499,7 +502,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq rewards[-1] += scores[sample_idx][0].cpu() else: score = scores[sample_idx] - score_right_padding = torch.sum(score != -1) + score_right_padding = torch.sum(scores_mask[sample_idx]) score = score[:score_right_padding].cpu() p_score = torch.zeros_like(rewards) p_score[:score.shape[0]] += score From 4071604cbff3983d6d447ee9f64ecc947341f502 Mon Sep 17 00:00:00 2001 From: Dahoas Date: Thu, 15 Jun 2023 08:20:11 -0700 Subject: [PATCH 13/27] Fix distributed ref_mean, ref_var bug for dense rewards --- trlx/trainer/accelerate_ppo_trainer.py | 2 +- trlx/utils/modeling.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index ad2d5167c..d772a8653 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -355,7 +355,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq # store statistics of the initial rollout as reference if self.ref_mean is None: self.ref_mean, self.ref_std = (scores * scores_mask).sum(dim=1).mean(), (scores * scores_mask).sum(dim=1).std() - all_scores_mean, all_scores_std = self.running_moments.update(scores) + all_scores_mean, all_scores_std = self.running_moments.update(scores, scores_mask) stats["rollout_scores/mean"] = all_scores_mean.item() stats["rollout_scores/std"] = all_scores_std.item() stats["rollout_scores/running_mean"] = self.running_moments.mean.item() diff --git a/trlx/utils/modeling.py b/trlx/utils/modeling.py index 47688f553..c6f3dd8ee 100644 --- a/trlx/utils/modeling.py +++ b/trlx/utils/modeling.py @@ -1,5 +1,5 @@ import functools -from typing import Dict, MutableMapping, Tuple, Union +from typing import Any, Dict, List, MutableMapping, Tuple, Union, Optional import accelerate import numpy as np @@ -276,8 +276,11 @@ def __init__(self): self.var = 1 self.count = 1e-24 - def update(self, xs: torch.Tensor) -> Tuple[float, float]: + def update(self, xs: torch.Tensor, xs_mask: Optional[torch.Tensor] = None) -> Tuple[float, float]: """Updates running moments from batch's moments computed across ranks""" + if xs_mask is None: + xs_mask = torch.ones_like(xs) + xs = torch.sum(xs * xs_mask, dim=1) if dist.is_initialized(): xs_mean, xs_var, xs_count = get_global_statistics(xs) else: From 5f41413bb3b1d8e788f13a76bfac75fe5355c4f9 Mon Sep 17 00:00:00 2001 From: Dahoas Date: Fri, 23 Jun 2023 07:57:15 -0700 Subject: [PATCH 14/27] Make generation respect max seq length --- trlx/trainer/accelerate_base_trainer.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index 3e11d0c53..780759422 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -253,10 +253,16 @@ def generate(self, input_ids, attention_mask=None, chunk_size=4, **kwargs): input_ids = input_ids.to(self.accelerator.device) if attention_mask is not None: attention_mask = attention_mask.to(self.accelerator.device) + # Update max_new_tokens to respect max_seq_length + prompt_length = input_ids.shape[1] if self.generate_experience_kwargs is not None: kwargs = dict(self.generate_experience_kwargs, **kwargs) else: kwargs = dict(self.generate_kwargs, **kwargs) + if kwargs.get("max_new_tokens") is not None: + kwargs["max_new_tokens"] = min(max(self.max_length - prompt_length, 0), kwargs["max_new_tokens"]) + else: + kwargs["max_new_tokens"] = max(self.max_length - prompt_length, 0) if chunk_size is not None: # Chunk input_ids and attention_mask input_ids = input_ids.chunk(chunk_size, 0) @@ -286,6 +292,11 @@ def generate_eval(self, input_ids, attention_mask=None, chunk_size=None, **kwarg kwargs = dict(self.generate_kwargs, **kwargs) + if kwargs.get("max_new_tokens") is not None: + kwargs["max_new_tokens"] = min(max(self.max_length - prompt_length, 0), kwargs["max_new_tokens"]) + else: + kwargs["max_new_tokens"] = max(self.max_length - prompt_length, 0) + if chunk_size is not None: # Chunk input_ids and attention_mask input_ids = input_ids.chunk(chunk_size, 0) From 22ae83f5e1ffd96a0afb8eddf440bf4f6340d13c Mon Sep 17 00:00:00 2001 From: Dahoas Date: Fri, 23 Jun 2023 08:26:37 -0700 Subject: [PATCH 15/27] Make experience before first round of training --- trlx/trainer/accelerate_ppo_trainer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index d772a8653..2f0cc8dd7 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -229,6 +229,8 @@ def prepare_learning(self): self.train_dataloader = self.store.create_loader(self.config.train.batch_size, shuffle=False) + self.make_experience(self.config.method.num_rollouts) + self.n_updates_per_batch = self.config.method.ppo_epochs self.total_steps = self.config.train.epochs * self.n_updates_per_batch * len(self.train_dataloader) self.total_steps = min(self.total_steps, self.config.train.total_steps) From 7d0a4be143530167f2e9f7061dfbe30d838b2d12 Mon Sep 17 00:00:00 2001 From: Dahoas Date: Tue, 27 Jun 2023 04:33:44 -0700 Subject: [PATCH 16/27] Refactoring .generate/.generate_eval --- trlx/data/default_configs.py | 5 +- trlx/models/modeling_ppo.py | 10 ++- trlx/trainer/accelerate_base_trainer.py | 89 +++++++++---------------- trlx/trainer/accelerate_ppo_trainer.py | 25 ++++--- 4 files changed, 56 insertions(+), 73 deletions(-) diff --git a/trlx/data/default_configs.py b/trlx/data/default_configs.py index e49b46f65..3acee97ab 100644 --- a/trlx/data/default_configs.py +++ b/trlx/data/default_configs.py @@ -49,14 +49,13 @@ def default_ppo_config(): ref_mean=None, ref_std=None, cliprange_reward=10, - num_return_sequences=10, - num_train_sequences=10, - gen_chunk_size=4, + num_train_sequences=1, gen_kwargs=dict( max_new_tokens=40, top_k=0, top_p=1.0, do_sample=True, + num_return_sequences=1, ), ), ) diff --git a/trlx/models/modeling_ppo.py b/trlx/models/modeling_ppo.py index 5bb808b41..bd0f57f88 100644 --- a/trlx/models/modeling_ppo.py +++ b/trlx/models/modeling_ppo.py @@ -112,6 +112,12 @@ class PPOConfig(MethodConfig): :param gen_experience_kwargs: if this is not None, then the experience is generated using this :type gen_experience_kwargs: Dict[str, Any] + + :param num_train_sequences: top_k of n sampled sequences from prompt + :type num_train_sequences: int + + :param mix_sft: if this is True, then SFT gradients will be mixed into PPO traininig + :type mix_sft: bool """ ppo_epochs: int @@ -130,10 +136,8 @@ class PPOConfig(MethodConfig): ref_std: Optional[float] cliprange_reward: float gen_kwargs: dict - num_return_sequences: int - num_train_sequences: int - gen_chunk_size: int gen_experience_kwargs: Optional[dict] = None + num_train_sequences: int = 1 def get_advantages_and_returns( self, diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index 780759422..1b7ad8423 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -6,6 +6,7 @@ from contextlib import contextmanager from time import time from typing import Dict, List, Optional, Tuple +from copy import copy import ray import torch @@ -247,75 +248,49 @@ def decode( return str_samples, str_prompts, str_outputs - def generate(self, input_ids, attention_mask=None, chunk_size=4, **kwargs): + def generate(self, input_ids, attention_mask=None, chunk_size=None, **kwargs): """Wraps hf's `generate` adding some specific method's defaults""" # Decide into chunk sizes and generate saples input_ids = input_ids.to(self.accelerator.device) if attention_mask is not None: attention_mask = attention_mask.to(self.accelerator.device) + + generate_kwargs = copy(self.generate_kwargs) + generate_kwargs.update(kwargs) + # Update max_new_tokens to respect max_seq_length prompt_length = input_ids.shape[1] - if self.generate_experience_kwargs is not None: - kwargs = dict(self.generate_experience_kwargs, **kwargs) - else: - kwargs = dict(self.generate_kwargs, **kwargs) - if kwargs.get("max_new_tokens") is not None: - kwargs["max_new_tokens"] = min(max(self.max_length - prompt_length, 0), kwargs["max_new_tokens"]) - else: - kwargs["max_new_tokens"] = max(self.max_length - prompt_length, 0) - if chunk_size is not None: - # Chunk input_ids and attention_mask - input_ids = input_ids.chunk(chunk_size, 0) - if attention_mask is not None: - attention_mask = attention_mask.chunk(chunk_size, 0) - samples = [] - for chunk_idx in range(chunk_size): - with torch.no_grad(): - sample = self.accelerator.unwrap_model(self.model).generate( - input_ids=input_ids[chunk_idx], attention_mask=attention_mask[chunk_idx], **kwargs - ) - samples.append(sample) - # Concat samples - return torch.cat(samples, 0) + if generate_kwargs.get("max_new_tokens") is not None: + generate_kwargs["max_new_tokens"] = min(max(self.max_length - prompt_length, 0), generate_kwargs["max_new_tokens"]) else: - with torch.no_grad(): - return self.accelerator.unwrap_model(self.model).generate( - input_ids=input_ids, attention_mask=attention_mask, **kwargs - ) - + generate_kwargs["max_new_tokens"] = max(self.max_length - prompt_length, 0) - def generate_eval(self, input_ids, attention_mask=None, chunk_size=None, **kwargs): - """Wraps hf's `generate` adding some specific method's defaults""" - input_ids = input_ids.to(self.accelerator.device) + # Repeat prompts, attention_masks for chunking if returning multiple sequences + if generate_kwargs.get("num_return_sequences") is None: + generate_kwargs["num_return_sequences"] = 1 + + num_return_sequences = generate_kwargs.pop("num_return_sequences") # Pop to hide from model.generate call + input_ids = input_ids.repeat_interleave(num_return_sequences, dim=0) if attention_mask is not None: - attention_mask = attention_mask.to(self.accelerator.device) + attention_mask = attention_mask.repeat_interleave(num_return_sequences, dim=0) - kwargs = dict(self.generate_kwargs, **kwargs) + if chunk_size is None: + chunk_size = input_ids.shape[0] - if kwargs.get("max_new_tokens") is not None: - kwargs["max_new_tokens"] = min(max(self.max_length - prompt_length, 0), kwargs["max_new_tokens"]) - else: - kwargs["max_new_tokens"] = max(self.max_length - prompt_length, 0) - - if chunk_size is not None: - # Chunk input_ids and attention_mask - input_ids = input_ids.chunk(chunk_size, 0) - if attention_mask is not None: - attention_mask = attention_mask.chunk(chunk_size, 0) - samples = [] - for chunk_idx in range(chunk_size): - with torch.no_grad(): - sample = self.accelerator.unwrap_model(self.model).generate( - input_ids=input_ids[chunk_idx], attention_mask=attention_mask[chunk_idx], **kwargs - ) - samples.append(sample) - # Concat samples - return torch.cat(samples, 0) - else: + # Chunk input_ids and attention_mask + input_ids = input_ids.split(chunk_size, dim=0) + if attention_mask is not None: + attention_mask = attention_mask.split(chunk_size, dim=0) + samples = [] + for chunk_idx in range(len(input_ids)): with torch.no_grad(): - return self.accelerator.unwrap_model(self.model).generate( - input_ids=input_ids, attention_mask=attention_mask, **kwargs + sample = self.accelerator.unwrap_model(self.model).generate( + input_ids=input_ids[chunk_idx], attention_mask=attention_mask[chunk_idx], **generate_kwargs ) + samples.append(sample) + # Concat samples + samples = torch.cat(samples, 0) + return samples def save_pretrained(self, directory: Optional[str] = None, **kwargs): """Save the underlying Hugging Face model, tokenizer, and configuration files to a directory for @@ -417,11 +392,11 @@ def evaluate(self): # noqa: C901 for i_prompt, prompts in enumerate(self.eval_dataloader): metadata = {k: v for k, v in prompts.items() if k != "input_ids" and k != "attention_mask"} if self.generate_sweep_kwarg: - samples = self.generate_eval( + samples = self.generate( prompts["input_ids"], prompts["attention_mask"], **{gen_sweep_arg: gen_sweep_value} ) else: - samples = self.generate_eval(prompts["input_ids"], prompts["attention_mask"]) + samples = self.generate(prompts["input_ids"], prompts["attention_mask"]) # TODO(reciprocated): this should be moved into `decode` # but that needs to be synced with indexing in `make_experience` diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index 2f0cc8dd7..cdbe12472 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -3,6 +3,7 @@ import uuid from time import time from typing import Callable, List +from copy import copy import torch import torch.nn.functional as F @@ -268,6 +269,9 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq ppo_rl_elements = [] accumulated_stats = [] + # Require chunk_size * num_train_sequences divides num_rollouts + assert num_rollouts % (self.config.method.chunk_size * self.config.method.num_train_sequences) == 0 + while len(ppo_rl_elements) < num_rollouts: stats = {} # Get next batch in prompt dataset @@ -276,10 +280,11 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq rollout_generate_time = time() # Generate samples from the language model (similar to using HuggingFace `generate` method) - samples = self.generate(batch["input_ids"], batch["attention_mask"], chunk_size=self.config.method.gen_chunk_size, num_return_sequences=self.config.method.num_return_sequences) + samples = self.generate(batch["input_ids"], batch["attention_mask"], chunk_size=self.config.method.chunk_size, **self.generate_experience_kwargs) stats["time/rollout_generate"] = time() - rollout_generate_time - prompt_tensors = batch.input_ids.repeat_interleave(self.config.method.num_return_sequences, dim=0) # TODO: It is hard-coded to 10 here. Change it to a variable + num_return_sequences = self.generate_experience_kwargs["num_return_sequences"] if self.generate_experience_kwargs.get("num_return_sequences") is not None else 1 + prompt_tensors = batch.input_ids.repeat_interleave(num_return_sequences, dim=0) device = samples.device prompt_sizes = torch.tensor([prompt_tensors.shape[1]] * len(prompt_tensors), device=device) @@ -323,12 +328,11 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq scores = all_scores[0].clone().detach() # Best-of-N Sampling. scores_mask = scores != -1 - train_indices = self.get_topk_indices(input_tensor=scores, window_size=self.config.method.num_return_sequences,k=self.config.method.num_train_sequences, device=device) - scores = scores.index_select(0, train_indices) - samples = samples.index_select(0, train_indices) - prompt_tensors = prompt_tensors.index_select(0, train_indices) + train_indices = self.get_topk_indices(input_tensor=scores_mask*scores, window_size=num_return_sequences, k=self.config.method.num_train_sequences, device=device) + scores = scores[train_indices] scores_mask = scores_mask[train_indices] - + samples = samples[train_indices] + prompt_tensors = prompt_tensors[train_indices] str_samples, str_prompts, str_outputs = self.decode(prompt_tensors, samples, append_eos_token=True) @@ -417,21 +421,21 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq logits, *_, values = self.model( all_tokens_chunk, attention_mask=attention_mask_chunk, - position_ids=position_ids, + position_ids=position_ids_chunk, ) # TODO(dahoas): When hydra model works need to also support generation on hydra head if hasattr(self.model, "frozen_head"): ref_logits = self.model.forward_hydra( all_tokens_chunk, attention_mask=attention_mask_chunk, - position_ids=position_ids, + position_ids=position_ids_chunk, return_dict=True, ).logits elif hasattr(self, "ref_model"): ref_logits = self.ref_model( all_tokens_chunk, attention_mask=attention_mask_chunk, - position_ids=position_ids, + position_ids=position_ids_chunk, return_dict=True, ).logits ref_logits = ref_logits.to(device) @@ -466,6 +470,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq else: # NOTE: -1 because kl[prompt_tensors.shape[1]] is kl of the second token in the response start = prompt_tensors.shape[1] - 1 + attention_mask = attention_mask.cpu() log_ratio = (logprobs - ref_logprobs) * attention_mask[:, :-1] kl = log_ratio.exp() - 1 - log_ratio From b79dd19915cb93fadb752d8f7740166bed303091 Mon Sep 17 00:00:00 2001 From: Dahoas Date: Thu, 29 Jun 2023 10:03:10 -0700 Subject: [PATCH 17/27] Fix BoN metric support --- trlx/trainer/accelerate_base_trainer.py | 19 ++++++++++++++++++- trlx/trainer/accelerate_ppo_trainer.py | 2 +- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index 1b7ad8423..9695ebb04 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -398,6 +398,14 @@ def evaluate(self): # noqa: C901 else: samples = self.generate(prompts["input_ids"], prompts["attention_mask"]) + # Repeat prompts, metadata num_return_sequence times + num_return_sequences = 1 + if self.generate_kwargs.get("num_return_sequences") is not None: + num_return_sequences = self.generate_kwargs["num_return_sequences"] + prompts["input_ids"] = prompts["input_ids"].repeat_interleave(num_return_sequences, dim=0) + prompts["attention_mask"] = prompts["attention_mask"].repeat_interleave(num_return_sequences, dim=0) + metadata = {k: self.repeat_interleave(v, num_return_sequences) for k, v in metadata.items()} + # TODO(reciprocated): this should be moved into `decode` # but that needs to be synced with indexing in `make_experience` if self.config.model.model_arch_type == "seq2seq": @@ -460,7 +468,7 @@ def evaluate(self): # noqa: C901 if self.metric_fn: logger.info("Computing metrics") metric_time = time() - metrics = self.metric_fn(samples=str_samples, prompts=str_prompts, outputs=str_outputs, **metadata) + metrics = self.metric_fn(samples=str_samples, prompts=str_prompts, outputs=str_outputs, model_tok=self.tokenizer, **metadata) stats["time/metric"] = time() - metric_time mean_metrics = { @@ -651,6 +659,15 @@ def learn(self): # noqa: C901 self.post_epoch_callback() tbar.close() + @staticmethod + def repeat_interleave(l, n): + if type(l) is torch.Tensor: + l = l.repeat_interleave(n, dim=0) + elif type(l) is list: + l = [[s]*n for s in l] + l = [item for sublist in l for item in sublist] + return l + @abstractmethod def get_arch(self, config: TRLConfig): """Returns a specific wrapper of the decoder architecture""" diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index cdbe12472..a14c0752d 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -297,7 +297,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq gathered_samples = self.accelerator.gather(padded_samples) gathered_prompts = self.accelerator.gather(padded_prompts) gathered_prompt_sizes = self.accelerator.gather(prompt_sizes) - metadata = gather_dict({k: v for k, v in batch.items() if k != "input_ids" and k != "attention_mask"}) + metadata = gather_dict({k: self.repeat_interleave(v, num_return_sequences) for k, v in batch.items() if k != "input_ids" and k != "attention_mask"}) if self.accelerator.is_main_process: all_str_samples, all_str_prompts, all_str_outputs = self.decode( From cb49dc538c592b78651157947d601741e4967247 Mon Sep 17 00:00:00 2001 From: Dahoas Date: Mon, 3 Jul 2023 04:27:55 -0700 Subject: [PATCH 18/27] Enforce chunk_size param for eval generation when present --- trlx/trainer/accelerate_base_trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index 9695ebb04..60c2ae622 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -396,7 +396,8 @@ def evaluate(self): # noqa: C901 prompts["input_ids"], prompts["attention_mask"], **{gen_sweep_arg: gen_sweep_value} ) else: - samples = self.generate(prompts["input_ids"], prompts["attention_mask"]) + chunk_size = self.config.method.chunk_size if hasattr(self.config.method, "chunk_size") else None + samples = self.generate(prompts["input_ids"], prompts["attention_mask"], chunk_size=chunk_size) # Repeat prompts, metadata num_return_sequence times num_return_sequences = 1 From e290412541409206d51ece8f81309c28143af44f Mon Sep 17 00:00:00 2001 From: Dahoas Date: Tue, 4 Jul 2023 07:22:50 -0700 Subject: [PATCH 19/27] Fix: Don't shuffle prompt dataset --- trlx/trainer/accelerate_ppo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index a14c0752d..cb553432f 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -238,7 +238,7 @@ def prepare_learning(self): def add_prompt_pipeline(self, pipeline: PromptPipeline): """Add a prompt pipeline dataloader to a trainer instance for the `make_experience` stage""" - prompt_dataloader = pipeline.create_loader(self.config.method.chunk_size, shuffle=True) + prompt_dataloader = pipeline.create_loader(self.config.method.chunk_size, shuffle=False) prompt_dataloader = self.accelerator.prepare_data_loader(prompt_dataloader) self.prompt_iterator = infinite_dataloader(prompt_dataloader) From 391d04cd51a1ba3d63d4b4421fe4f6295c4be654 Mon Sep 17 00:00:00 2001 From: dahoas Date: Tue, 18 Jul 2023 11:59:36 +0000 Subject: [PATCH 20/27] Move inputs to device --- trlx/trainer/accelerate_ppo_trainer.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index cb553432f..dec81d90c 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -417,6 +417,9 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq attention_mask_chunks = torch.chunk(attention_mask, chunks=self.config.method.gen_chunk_size, dim=0) position_ids_chunks = torch.chunk(position_ids, chunks=self.config.method.gen_chunk_size, dim=0) for all_tokens_chunk, attention_mask_chunk, position_ids_chunk in zip(all_tokens_chunks, attention_mask_chunks, position_ids_chunks): + all_tokens_chunk = all_tokens_chunk.to(device) + attention_mask_chunk = attention_mask_chunk.to(device) + position_ids_chunk = position_ids_chunk.to(device) with torch.no_grad(): logits, *_, values = self.model( all_tokens_chunk, From 8de84e42e572721bb5e1a08b0f61a6c0583b6463 Mon Sep 17 00:00:00 2001 From: dahoas Date: Tue, 18 Jul 2023 12:19:47 +0000 Subject: [PATCH 21/27] Fix style --- examples/ppo_redemption.py | 11 ++-- trlx/trainer/accelerate_base_trainer.py | 30 +++++++---- trlx/trainer/accelerate_ppo_trainer.py | 70 +++++++++++++++++++------ trlx/utils/modeling.py | 2 +- 4 files changed, 80 insertions(+), 33 deletions(-) diff --git a/examples/ppo_redemption.py b/examples/ppo_redemption.py index b930b2dc7..84435b225 100644 --- a/examples/ppo_redemption.py +++ b/examples/ppo_redemption.py @@ -7,7 +7,7 @@ import torch from datasets import load_dataset -from transformers import pipeline, AutoTokenizer +from transformers import pipeline import trlx from trlx.data.default_configs import TRLConfig, default_ppo_config @@ -17,6 +17,7 @@ def get_positive_score(scores): "Extract value associated with a positive sentiment from pipeline's output" return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"] + def get_negative_score(scores): return dict(map(lambda x: tuple(x.values()), scores))["NEGATIVE"] @@ -29,7 +30,7 @@ def main(hparams={}): config.method.gen_kwargs["temperature"] = 0.3 config.train.total_steps = 20000 config.train.checkpoint_interval = 10000000 - #config.method.init_kl_coef = 0 + # config.method.init_kl_coef = 0 if torch.cuda.is_available(): device = int(os.environ.get("LOCAL_RANK", 0)) @@ -49,11 +50,9 @@ def dense_reward_fn(samples: List[str], prompts: List[str], outputs: List[str], # Reward positively for initially negative then positive review # Reward functions should never receive padded text except for a singel EOS at the end # Reward function should return token rewards for just the response - # Note: To get trajectory length, the reward fn should not tokenize the samples but should instead separately tokenizer prompts and outputs and then combine them - # Also note outputs has a single EOS at end of each - first_halves = [".".join(sample.split(".")[:len(sample.split(".")) // 2]) for sample in samples] + first_halves = [".".join(sample.split(".")[: len(sample.split(".")) // 2]) for sample in samples] negative_first_halves = list(map(get_negative_score, sentiment_fn(first_halves))) - second_halves = [".".join(sample.split(".")[len(sample.split(".")) // 2:]) for sample in samples] + second_halves = [".".join(sample.split(".")[len(sample.split(".")) // 2 :]) for sample in samples] positive_second_halves = list(map(get_positive_score, sentiment_fn(second_halves))) text_scores = [[f, s] for f, s in zip(negative_first_halves, positive_second_halves)] tok_scores = [] diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index 60c2ae622..ff19a4288 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -4,9 +4,9 @@ import sys from abc import abstractmethod from contextlib import contextmanager +from copy import copy from time import time from typing import Dict, List, Optional, Tuple -from copy import copy import ray import torch @@ -221,13 +221,11 @@ def decode( str_prompt = self.tokenizer.decode(prompt[:prompt_size], skip_special_tokens=True) str_output = self.tokenizer.decode(sample[output_start_ix:], skip_special_tokens=True) # Trim outputs up to `self.stop_sequences` if any are present - trimmed = False if self.stop_sequences: for stop in self.stop_sequences: stop_ix = str_output.find(stop) if stop_ix >= 0: str_output = str_output[:stop_ix].rstrip() - trimmed = True # Recover the last if it was present in the original sample # or add one if it was trimmed with `self.stop_sequences`. @@ -250,18 +248,20 @@ def decode( def generate(self, input_ids, attention_mask=None, chunk_size=None, **kwargs): """Wraps hf's `generate` adding some specific method's defaults""" - # Decide into chunk sizes and generate saples + # Decide into chunk sizes and generate saples input_ids = input_ids.to(self.accelerator.device) if attention_mask is not None: attention_mask = attention_mask.to(self.accelerator.device) generate_kwargs = copy(self.generate_kwargs) generate_kwargs.update(kwargs) - + # Update max_new_tokens to respect max_seq_length prompt_length = input_ids.shape[1] if generate_kwargs.get("max_new_tokens") is not None: - generate_kwargs["max_new_tokens"] = min(max(self.max_length - prompt_length, 0), generate_kwargs["max_new_tokens"]) + generate_kwargs["max_new_tokens"] = min( + max(self.max_length - prompt_length, 0), generate_kwargs["max_new_tokens"] + ) else: generate_kwargs["max_new_tokens"] = max(self.max_length - prompt_length, 0) @@ -451,7 +451,13 @@ def evaluate(self): # noqa: C901 # in online setting, compute the reward for validation if self.reward_fn: logger.info("Computing rewards") - rewards = self.reward_fn(samples=str_samples, prompts=str_prompts, outputs=str_outputs, model_tok=self.tokenizer, **metadata) + rewards = self.reward_fn( + samples=str_samples, + prompts=str_prompts, + outputs=str_outputs, + model_tok=self.tokenizer, + **metadata, + ) if type(rewards[0]) is torch.Tensor: rewards = torch.tensor([reward.sum().item() for reward in rewards], dtype=float) elif type(rewards[0]) is list: @@ -469,7 +475,13 @@ def evaluate(self): # noqa: C901 if self.metric_fn: logger.info("Computing metrics") metric_time = time() - metrics = self.metric_fn(samples=str_samples, prompts=str_prompts, outputs=str_outputs, model_tok=self.tokenizer, **metadata) + metrics = self.metric_fn( + samples=str_samples, + prompts=str_prompts, + outputs=str_outputs, + model_tok=self.tokenizer, + **metadata, + ) stats["time/metric"] = time() - metric_time mean_metrics = { @@ -665,7 +677,7 @@ def repeat_interleave(l, n): if type(l) is torch.Tensor: l = l.repeat_interleave(n, dim=0) elif type(l) is list: - l = [[s]*n for s in l] + l = [[s] * n for s in l] l = [item for sublist in l for item in sublist] return l diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index dec81d90c..2fbb2068c 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -3,12 +3,11 @@ import uuid from time import time from typing import Callable, List -from copy import copy import torch import torch.nn.functional as F -from torch.nn.utils.rnn import pad_sequence import transformers +from torch.nn.utils.rnn import pad_sequence from torch.utils.data import DataLoader from transformers import AutoTokenizer @@ -280,10 +279,19 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq rollout_generate_time = time() # Generate samples from the language model (similar to using HuggingFace `generate` method) - samples = self.generate(batch["input_ids"], batch["attention_mask"], chunk_size=self.config.method.chunk_size, **self.generate_experience_kwargs) + samples = self.generate( + batch["input_ids"], + batch["attention_mask"], + chunk_size=self.config.method.chunk_size, + **self.generate_experience_kwargs, + ) stats["time/rollout_generate"] = time() - rollout_generate_time - num_return_sequences = self.generate_experience_kwargs["num_return_sequences"] if self.generate_experience_kwargs.get("num_return_sequences") is not None else 1 + num_return_sequences = ( + self.generate_experience_kwargs["num_return_sequences"] + if self.generate_experience_kwargs.get("num_return_sequences") is not None + else 1 + ) prompt_tensors = batch.input_ids.repeat_interleave(num_return_sequences, dim=0) device = samples.device @@ -297,7 +305,13 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq gathered_samples = self.accelerator.gather(padded_samples) gathered_prompts = self.accelerator.gather(padded_prompts) gathered_prompt_sizes = self.accelerator.gather(prompt_sizes) - metadata = gather_dict({k: self.repeat_interleave(v, num_return_sequences) for k, v in batch.items() if k != "input_ids" and k != "attention_mask"}) + metadata = gather_dict( + { + k: self.repeat_interleave(v, num_return_sequences) + for k, v in batch.items() + if k != "input_ids" and k != "attention_mask" + } + ) if self.accelerator.is_main_process: all_str_samples, all_str_prompts, all_str_outputs = self.decode( @@ -307,8 +321,19 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq rollout_score_time = time() # reward_fn should return list of rewards at each token per sample # NOTE: all_scores[0][i] is the reward due to token (action) i in prompt + response (b/c of how kl is computed) - all_scores = self.reward_fn(samples=all_str_samples, prompts=all_str_prompts, outputs=all_str_outputs, model_tok=self.tokenizer, **metadata) - all_scores = [torch.tensor(score, dtype=torch.float, device=device).view(-1,) for score in all_scores] + all_scores = self.reward_fn( + samples=all_str_samples, + prompts=all_str_prompts, + outputs=all_str_outputs, + model_tok=self.tokenizer, + **metadata, + ) + all_scores = [ + torch.tensor(score, dtype=torch.float, device=device).view( + -1, + ) + for score in all_scores + ] # Pad 0 reward on the ends all_scores = pad_sequence(all_scores, batch_first=True, padding_value=-1) max_len = torch.tensor(all_scores.shape[1], dtype=torch.long, device=device) @@ -326,9 +351,14 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq torch.distributed.scatter(scores, all_scores) else: scores = all_scores[0].clone().detach() - # Best-of-N Sampling. + # Best-of-N Sampling. scores_mask = scores != -1 - train_indices = self.get_topk_indices(input_tensor=scores_mask*scores, window_size=num_return_sequences, k=self.config.method.num_train_sequences, device=device) + train_indices = self.get_topk_indices( + input_tensor=scores_mask * scores, + window_size=num_return_sequences, + k=self.config.method.num_train_sequences, + device=device, + ) scores = scores[train_indices] scores_mask = scores_mask[train_indices] samples = samples[train_indices] @@ -360,7 +390,9 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq # store statistics of the initial rollout as reference if self.ref_mean is None: - self.ref_mean, self.ref_std = (scores * scores_mask).sum(dim=1).mean(), (scores * scores_mask).sum(dim=1).std() + self.ref_mean, self.ref_std = (scores * scores_mask).sum(dim=1).mean(), (scores * scores_mask).sum( + dim=1 + ).std() all_scores_mean, all_scores_std = self.running_moments.update(scores, scores_mask) stats["rollout_scores/mean"] = all_scores_mean.item() stats["rollout_scores/std"] = all_scores_std.item() @@ -416,7 +448,9 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq all_tokens_chunks = torch.chunk(all_tokens, chunks=self.config.method.gen_chunk_size, dim=0) attention_mask_chunks = torch.chunk(attention_mask, chunks=self.config.method.gen_chunk_size, dim=0) position_ids_chunks = torch.chunk(position_ids, chunks=self.config.method.gen_chunk_size, dim=0) - for all_tokens_chunk, attention_mask_chunk, position_ids_chunk in zip(all_tokens_chunks, attention_mask_chunks, position_ids_chunks): + for all_tokens_chunk, attention_mask_chunk, position_ids_chunk in zip( + all_tokens_chunks, attention_mask_chunks, position_ids_chunks + ): all_tokens_chunk = all_tokens_chunk.to(device) attention_mask_chunk = attention_mask_chunk.to(device) position_ids_chunk = position_ids_chunk.to(device) @@ -451,19 +485,19 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq # NOTE: logprob[i] is (log)prob at which all_token[i+1] was sampled logprobs = logprobs_of_labels(logits[:, :-1, :], all_tokens_chunk[:, 1:]) ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], all_tokens_chunk[:, 1:]) - + values_chunks.append(values.cpu()) logits_chunks.append(logits.cpu()) ref_logits_chunks.append(ref_logits.cpu()) log_probs_chunks.append(logprobs.cpu()) ref_logprobs_chunks.append(ref_logprobs.cpu()) - + values = torch.cat(values_chunks, dim=0) logits = torch.cat(logits_chunks, dim=0) ref_logits = torch.cat(ref_logits_chunks, dim=0) logprobs = torch.cat(log_probs_chunks, dim=0) ref_logprobs = torch.cat(ref_logprobs_chunks, dim=0) - + n_samples: int = samples.shape[0] # Estimate the KL divergence between the model and reference model @@ -515,7 +549,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq score_right_padding = torch.sum(scores_mask[sample_idx]) score = score[:score_right_padding].cpu() p_score = torch.zeros_like(rewards) - p_score[:score.shape[0]] += score + p_score[: score.shape[0]] += score rewards += p_score ppo_rl_elements.append( @@ -549,7 +583,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq # Push samples and rewards to trainer's rollout storage self.push_to_store(ppo_rl_elements) - + @staticmethod def get_topk_indices(input_tensor, window_size: int, k: int, device): # Sum the scores along dim 1 @@ -559,5 +593,7 @@ def get_topk_indices(input_tensor, window_size: int, k: int, device): # Find the topk values and indices along the unfolded dimension _, indices = torch.topk(unfolded, k, dim=2) # Adjust indices to be relative to original tensor - indices = indices.squeeze(1) + torch.arange(0, input_tensor.size(0) - window_size + 1, window_size).to(device).unsqueeze(1) + indices = indices.squeeze(1) + torch.arange(0, input_tensor.size(0) - window_size + 1, window_size).to( + device + ).unsqueeze(1) return indices.reshape(-1) diff --git a/trlx/utils/modeling.py b/trlx/utils/modeling.py index c6f3dd8ee..b0036b3f6 100644 --- a/trlx/utils/modeling.py +++ b/trlx/utils/modeling.py @@ -1,5 +1,5 @@ import functools -from typing import Any, Dict, List, MutableMapping, Tuple, Union, Optional +from typing import Dict, MutableMapping, Optional, Tuple, Union import accelerate import numpy as np From 404ef1476c963a11ab93546596fe96534101bc2e Mon Sep 17 00:00:00 2001 From: Dahoas Date: Fri, 23 Jun 2023 08:19:13 -0700 Subject: [PATCH 22/27] Fix: Do not shuffle empty experience dataloader --- trlx/trainer/accelerate_ppo_trainer.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index 2fbb2068c..6561dce34 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -229,8 +229,6 @@ def prepare_learning(self): self.train_dataloader = self.store.create_loader(self.config.train.batch_size, shuffle=False) - self.make_experience(self.config.method.num_rollouts) - self.n_updates_per_batch = self.config.method.ppo_epochs self.total_steps = self.config.train.epochs * self.n_updates_per_batch * len(self.train_dataloader) self.total_steps = min(self.total_steps, self.config.train.total_steps) From 67b711a3ec2715cb7744a237ef9737aae03daf54 Mon Sep 17 00:00:00 2001 From: Dahoas Date: Fri, 23 Jun 2023 08:26:37 -0700 Subject: [PATCH 23/27] Make experience before first round of training --- trlx/trainer/accelerate_ppo_trainer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index 6561dce34..2fbb2068c 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -229,6 +229,8 @@ def prepare_learning(self): self.train_dataloader = self.store.create_loader(self.config.train.batch_size, shuffle=False) + self.make_experience(self.config.method.num_rollouts) + self.n_updates_per_batch = self.config.method.ppo_epochs self.total_steps = self.config.train.epochs * self.n_updates_per_batch * len(self.train_dataloader) self.total_steps = min(self.total_steps, self.config.train.total_steps) From 34e185a8eb88bdc9477e190f92d1104be6c41d2d Mon Sep 17 00:00:00 2001 From: Dahoas Date: Tue, 27 Jun 2023 04:33:44 -0700 Subject: [PATCH 24/27] Refactoring .generate/.generate_eval --- trlx/trainer/accelerate_base_trainer.py | 1 + trlx/trainer/accelerate_ppo_trainer.py | 1 + 2 files changed, 2 insertions(+) diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index ff19a4288..a055879a4 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -7,6 +7,7 @@ from copy import copy from time import time from typing import Dict, List, Optional, Tuple +from copy import copy import ray import torch diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index 2fbb2068c..b7a090054 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -3,6 +3,7 @@ import uuid from time import time from typing import Callable, List +from copy import copy import torch import torch.nn.functional as F From 11e1e95df02bd64b8c3350571b29bfb20ee18e4a Mon Sep 17 00:00:00 2001 From: dahoas Date: Fri, 14 Jul 2023 09:49:48 +0000 Subject: [PATCH 25/27] Refactored decode, make_experience and added support for external reference models --- trlx/models/modeling_ppo.py | 4 +- trlx/trainer/accelerate_base_trainer.py | 89 ++++++-- trlx/trainer/accelerate_ppo_trainer.py | 271 +++++++++++++----------- 3 files changed, 216 insertions(+), 148 deletions(-) diff --git a/trlx/models/modeling_ppo.py b/trlx/models/modeling_ppo.py index bd0f57f88..7856f6fa0 100644 --- a/trlx/models/modeling_ppo.py +++ b/trlx/models/modeling_ppo.py @@ -115,9 +115,6 @@ class PPOConfig(MethodConfig): :param num_train_sequences: top_k of n sampled sequences from prompt :type num_train_sequences: int - - :param mix_sft: if this is True, then SFT gradients will be mixed into PPO traininig - :type mix_sft: bool """ ppo_epochs: int @@ -138,6 +135,7 @@ class PPOConfig(MethodConfig): gen_kwargs: dict gen_experience_kwargs: Optional[dict] = None num_train_sequences: int = 1 + dist_ref_model: bool = False def get_advantages_and_returns( self, diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index a055879a4..75c226fb2 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -203,23 +203,29 @@ def decode( prompts: List[torch.LongTensor], samples: List[torch.LongTensor], prompt_sizes: torch.LongTensor = None, - append_eos_token: bool = False, - ) -> Tuple[List[str], List[str], List[str]]: + append_eos_token: bool = True, + ) -> Tuple[List[str], List[str], List[str], List[torch.LongTensor], List[torch.LongTensor], List[torch.LongTensor]]: """ - Decode tensor generations into lists of strings (`samples`: List[str], `prompts`: List[str], `outputs`: List[str]) + Decode tensor generations with stopping criteria into lists of strings (`samples`: List[str], `prompts`: List[str], `outputs`: List[str]) and + Note prompts maybe sompetimes be right padded, as well as samples """ if prompt_sizes is None: # Assuming prompts were left-padded prompt_sizes = [prompts.shape[1]] * len(prompts) str_samples, str_prompts, str_outputs = [], [], [] + tok_samples, tok_prompts, tok_outputs = [], [], [] for prompt, sample, prompt_size in zip(prompts, samples, prompt_sizes): if self.config.model.model_arch_type == "seq2seq": output_start_ix = 0 else: output_start_ix = prompt_size + # We must decode by skipping padding in the middle with skip_special_tokens str_prompt = self.tokenizer.decode(prompt[:prompt_size], skip_special_tokens=True) + # Return the prompt tensor (with the exact padding) used to generate the sample + tok_prompt = prompt[:prompt_size].cpu() + str_output = self.tokenizer.decode(sample[output_start_ix:], skip_special_tokens=True) # Trim outputs up to `self.stop_sequences` if any are present if self.stop_sequences: @@ -227,25 +233,63 @@ def decode( stop_ix = str_output.find(stop) if stop_ix >= 0: str_output = str_output[:stop_ix].rstrip() + + # Recover sequence of tokens corresponding to string + # NOTE: Cast to torch.long in the case the input is empty + tok_output = self.tokenizer(str_output, return_tensors="pt").input_ids[0].long() + # Remove bos from tokenized output (if present) + if hasattr(self.tokenizer, "bos_token") and len(tok_output) > 0 and tok_output[0].item() == self.tokenizer.bos_token_id: + tok_output = tok_output[1:] # Recover the last if it was present in the original sample # or add one if it was trimmed with `self.stop_sequences`. # When a generation ended due to `max_new_tokens` exhaustion, # only then or token would not be present in the original sample at the end + # Note we append to both separately because encoding does not result in tokenizer.eos_token_id if append_eos_token: str_output += self.tokenizer.eos_token + tok_output = torch.cat([tok_output, torch.tensor(self.tokenizer.eos_token_id, dtype=torch.long).view((1,))]) + + '''if len(tok_output) > 91: + print("#####") + print("Sample: ", sample.shape) + print(sample) + print("output: ", tok_output.shape) + print(tok_output) + print(str_output) + print("prompt size: ", output_start_ix) + print("tok_prompt: ", tok_prompt) + print("tok prompt shape: ", tok_prompt.shape) + print("prompt: ", prompt) + print("str prompt: ", str_prompt) + print("prompt: ", prompt.shape) + exit()''' str_prompts.append(str_prompt) str_outputs.append(str_output) + tok_prompts.append(tok_prompt) + tok_outputs.append(tok_output) if self.config.model.model_arch_type == "seq2seq": sample = str_prompt + self.tokenizer.sep_token + str_output + tok_sample = torch.cat([tok_prompt, torch.tensor(self.tokenizer.sep_token_id, dtype=torch.long).view((1,)), tok_output]) else: sample = str_prompt + str_output + tok_sample = torch.cat([tok_prompt, tok_output]) str_samples.append(sample) + tok_samples.append(tok_sample) + + if tok_sample.dtype == torch.float32: + print("tok_sample ", tok_sample.dtype) + print(tok_sample) + print("tok prompt: ", tok_prompt.dtype) + print(tok_prompt) + print("tok output: ", tok_output.dtype) + print(tok_output) + exit() - return str_samples, str_prompts, str_outputs + return str_samples, str_prompts, str_outputs, tok_samples, tok_prompts, tok_outputs def generate(self, input_ids, attention_mask=None, chunk_size=None, **kwargs): """Wraps hf's `generate` adding some specific method's defaults""" @@ -421,9 +465,9 @@ def evaluate(self): # noqa: C901 pad_index=self.tokenizer.pad_token_id, ) ) - all_samples.extend(samples.tolist()) - all_prompts.extend(prompts.tolist()) - all_prompt_sizes.extend(prompt_sizes.tolist()) + all_samples.extend(samples) + all_prompts.extend(prompts) + all_prompt_sizes.extend(prompt_sizes) metadata = gather_dict(metadata, self.accelerator.gradient_state) all_metadata.append(metadata) @@ -439,7 +483,7 @@ def evaluate(self): # noqa: C901 stats["time/generate"] = time() - generate_time if self.accelerator.is_main_process: - str_samples, str_prompts, str_outputs = self.decode(all_prompts, all_samples, all_prompt_sizes) + str_samples, str_prompts, str_outputs, tok_samples, tok_prompts, tok_outputs = self.decode(all_prompts, all_samples, all_prompt_sizes) columns = ["prompt", "output"] columns_data = [str_prompts, str_outputs] @@ -453,12 +497,16 @@ def evaluate(self): # noqa: C901 if self.reward_fn: logger.info("Computing rewards") rewards = self.reward_fn( - samples=str_samples, - prompts=str_prompts, - outputs=str_outputs, - model_tok=self.tokenizer, - **metadata, - ) + samples=str_samples, + prompts=str_prompts, + outputs=str_outputs, + tok_samples=tok_samples, + tok_prompts=tok_prompts, + tok_outputs=tok_outputs, + model_tok=self.tokenizer, **metadata) + # Remove kl terms from reward + if hasattr(self, "dist_ref_model") and self.dist_ref_model: + rewards = [[r[0] for r in reward] for reward in rewards] if type(rewards[0]) is torch.Tensor: rewards = torch.tensor([reward.sum().item() for reward in rewards], dtype=float) elif type(rewards[0]) is list: @@ -477,12 +525,13 @@ def evaluate(self): # noqa: C901 logger.info("Computing metrics") metric_time = time() metrics = self.metric_fn( - samples=str_samples, - prompts=str_prompts, - outputs=str_outputs, - model_tok=self.tokenizer, - **metadata, - ) + samples=str_samples, + prompts=str_prompts, + outputs=str_outputs, + tok_samples=tok_samples, + tok_prompts=tok_prompts, + tok_outputs=tok_outputs, + model_tok=self.tokenizer, **metadata) stats["time/metric"] = time() - metric_time mean_metrics = { diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index b7a090054..d278821b8 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -11,6 +11,7 @@ from torch.nn.utils.rnn import pad_sequence from torch.utils.data import DataLoader from transformers import AutoTokenizer +import numpy as np import trlx.utils.logging as logging from trlx.data.accelerate_base_datatypes import PromptBatch @@ -69,8 +70,9 @@ def __init__(self, config: TRLConfig, **kwargs): self.store.clear_history() # Clear the rollout store - # Set up a reference model when hydra heads are not used - if not hasattr(self.model, "frozen_head") and not self.model.peft_type: + # Setup a reference model when hydra heads and distributed ref model are not used + self.dist_ref_model = config.method.dist_ref_model + if not hasattr(self.model, "frozen_head") and not self.model.peft_type and not self.dist_ref_model: self.ref_model = self.get_arch(self.config) self.ref_model.to(self.accelerator.device) self.ref_model.eval() @@ -315,76 +317,59 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq ) if self.accelerator.is_main_process: - all_str_samples, all_str_prompts, all_str_outputs = self.decode( + all_str_samples, all_str_prompts, all_str_outputs, all_tok_samples, all_tok_prompts, all_tok_outputs = self.decode( gathered_prompts, gathered_samples, gathered_prompt_sizes, append_eos_token=True ) rollout_score_time = time() # reward_fn should return list of rewards at each token per sample # NOTE: all_scores[0][i] is the reward due to token (action) i in prompt + response (b/c of how kl is computed) - all_scores = self.reward_fn( - samples=all_str_samples, - prompts=all_str_prompts, - outputs=all_str_outputs, - model_tok=self.tokenizer, - **metadata, - ) - all_scores = [ - torch.tensor(score, dtype=torch.float, device=device).view( - -1, - ) - for score in all_scores - ] - # Pad 0 reward on the ends - all_scores = pad_sequence(all_scores, batch_first=True, padding_value=-1) - max_len = torch.tensor(all_scores.shape[1], dtype=torch.long, device=device) + # NOTE: reward_fn can optionally also compute the ref_logits. In this case size will be [batch_size, response_length, 2] + all_scores = self.reward_fn(samples=all_str_samples, + prompts=all_str_prompts, + outputs=all_str_outputs, + tok_samples=all_tok_samples, + tok_prompts=all_tok_prompts, + tok_outputs=all_tok_outputs, + model_tok=self.tokenizer, + **metadata) + all_scores = [torch.tensor(score, dtype=torch.float, device=device).view(-1,) for score in all_scores] + # Pad -np.inf reward on the ends + all_scores = pad_sequence(all_scores, batch_first=True, padding_value=-np.inf) + max_len = torch.tensor(len(all_scores[0])/(1+int(self.dist_ref_model)), dtype=torch.long, device=device) stats["time/rollout_score"] = time() - rollout_score_time - all_scores = list(all_scores.reshape(self.accelerator.num_processes, -1, max_len).unbind()) + all_scores = list(all_scores.reshape(self.accelerator.num_processes, len(samples), max_len, -1).unbind()) else: all_scores = None max_len = torch.tensor(0, dtype=torch.long, device=device) if torch.distributed.is_initialized(): torch.distributed.broadcast(max_len, 0) - scores = torch.empty((len(samples), max_len), device=device) + # Allocate extra space if scores include ref_logits + scores = torch.empty((len(samples), max_len, 1+int(self.dist_ref_model)), device=device) torch.distributed.scatter(scores, all_scores) else: scores = all_scores[0].clone().detach() - # Best-of-N Sampling. - scores_mask = scores != -1 - train_indices = self.get_topk_indices( - input_tensor=scores_mask * scores, - window_size=num_return_sequences, - k=self.config.method.num_train_sequences, - device=device, - ) + + # Remove ref_logits from scores if present + if self.dist_ref_model: + all_ref_logprobs = scores[:, :, 1] + scores = scores[:, :, 0] + else: + all_ref_logprobs = None + scores = scores.squeeze(-1) + + # Best-of-N Sampling. + scores_mask = scores != -np.inf + train_indices = self.get_topk_indices(input_tensor=scores_mask*scores, window_size=num_return_sequences, k=self.config.method.num_train_sequences, device=device) scores = scores[train_indices] scores_mask = scores_mask[train_indices] samples = samples[train_indices] prompt_tensors = prompt_tensors[train_indices] - - str_samples, str_prompts, str_outputs = self.decode(prompt_tensors, samples, append_eos_token=True) - - # Pad the sample outputs - outputs = self.tokenizer(str_outputs).input_ids - if self.config.model.model_arch_type == "seq2seq": - # add to the start of the output - for i in range(len(outputs)): - outputs[i] = [self.tokenizer.pad_token_id] + outputs[i] - - outputs = list(map(torch.LongTensor, outputs)) - maxsize = max(map(len, outputs)) - outputs = [ - F.pad( - output, - (0, maxsize - len(output)), - value=self.tokenizer.pad_token_id, - ) - for output in outputs - ] - sample_outputs = torch.vstack(outputs).to(device) + if all_ref_logprobs is not None: + all_ref_logprobs = all_ref_logprobs[train_indices] if self.config.method.cliprange_reward: scores = torch.clip(scores, -self.config.method.cliprange_reward, self.config.method.cliprange_reward) @@ -405,7 +390,31 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq elif self.config.method.scale_reward == "ref": scores /= self.ref_std + # Only use these samples, prompts, outputs to compute ppo stats + _, _, _, tok_samples, tok_prompts, tok_outputs = self.decode(prompt_tensors, samples, append_eos_token=True) + + # Pad the sample outputs + #outputs = self.tokenizer(str_outputs).input_ids + # TODO: Why is this here? Should this be a sep token? + if self.config.model.model_arch_type == "seq2seq": + # add to the start of the output + for i in range(len(tok_outputs)): + tok_outputs[i] = [self.tokenizer.pad_token_id] + outputs[i].tolist() + + padded_tok_outputs = pad_sequence(tok_outputs, batch_first=True, padding_value=self.tokenizer.pad_token_id) + tok_prompts = torch.stack(tok_prompts, dim=0) + padded_tok_samples = pad_sequence(tok_samples, batch_first=True, padding_value=self.tokenizer.pad_token_id) + attention_mask = padded_tok_samples.not_equal(self.tokenizer.pad_token_id).long() + + if self.config.model.model_arch_type == "seq2seq": + attention_mask = sample_outputs != self.tokenizer.pad_token_id + start = 0 + else: + # NOTE: -1 because kl[prompt_tensors.shape[1]] is kl of the second token in the response + start = tok_prompts.shape[1] - 1 + # Precompute logprobs, values + # TODO: Come back to seq2seq if self.config.model.model_arch_type == "seq2seq": attention_mask = batch.attention_mask.to(device) prompt_tensors = batch.input_ids.to(device) @@ -438,8 +447,6 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq ).logits else: values_chunks = [] - logits_chunks = [] - ref_logits_chunks = [] log_probs_chunks = [] ref_logprobs_chunks = [] all_tokens = torch.cat((prompt_tensors.to(device), sample_outputs), dim=1) @@ -449,9 +456,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq all_tokens_chunks = torch.chunk(all_tokens, chunks=self.config.method.gen_chunk_size, dim=0) attention_mask_chunks = torch.chunk(attention_mask, chunks=self.config.method.gen_chunk_size, dim=0) position_ids_chunks = torch.chunk(position_ids, chunks=self.config.method.gen_chunk_size, dim=0) - for all_tokens_chunk, attention_mask_chunk, position_ids_chunk in zip( - all_tokens_chunks, attention_mask_chunks, position_ids_chunks - ): + for all_tokens_chunk, attention_mask_chunk, position_ids_chunk in zip(all_tokens_chunks, attention_mask_chunks, position_ids_chunks): all_tokens_chunk = all_tokens_chunk.to(device) attention_mask_chunk = attention_mask_chunk.to(device) position_ids_chunk = position_ids_chunk.to(device) @@ -461,104 +466,120 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq attention_mask=attention_mask_chunk, position_ids=position_ids_chunk, ) - # TODO(dahoas): When hydra model works need to also support generation on hydra head - if hasattr(self.model, "frozen_head"): - ref_logits = self.model.forward_hydra( - all_tokens_chunk, - attention_mask=attention_mask_chunk, - position_ids=position_ids_chunk, - return_dict=True, - ).logits - elif hasattr(self, "ref_model"): - ref_logits = self.ref_model( - all_tokens_chunk, - attention_mask=attention_mask_chunk, - position_ids=position_ids_chunk, - return_dict=True, - ).logits - ref_logits = ref_logits.to(device) - else: - ref_logits = logits.clone().detach() + # If all_ref_logits is not None they have already been generated during call to reward_fn + if all_ref_logprobs is None: + if hasattr(self.model, "frozen_head"): + ref_logits = self.model.forward_hydra( + all_tokens_chunk, + attention_mask=attention_mask_chunk, + position_ids=position_ids_chunk, + return_dict=True, + ).logits + elif hasattr(self, "ref_model"): + ref_logits = self.ref_model( + all_tokens_chunk, + attention_mask=attention_mask_chunk, + position_ids=position_ids_chunk, + return_dict=True, + ).logits + ref_logits = ref_logits.to(device) + # If no ref model is provided then we compute no kl penalty + else: + ref_logits = logits.clone() + if self.config.model.model_arch_type == "seq2seq": - logprobs = logprobs_of_labels(logits[:, :-1, :], sample_outputs[:, 1:]) - ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], sample_outputs[:, 1:]) + logprobs = logprobs_of_labels(logits[:, start:-1, :], all_tokens_chunk[:, start+1:]) + ref_logprobs = logprobs_of_labels(ref_logits[:, start:-1, :], all_tokens_chunk[:, start+1:]) else: # NOTE: logprob[i] is (log)prob at which all_token[i+1] was sampled - logprobs = logprobs_of_labels(logits[:, :-1, :], all_tokens_chunk[:, 1:]) - ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], all_tokens_chunk[:, 1:]) - + # So need to index at start = prompt_tensors.shape[1] - 1 which is the logprob corresponding to the first sampled token + # Indexing ends at -1 because the last logprob corresponds to an unsampled token + logprobs = logprobs_of_labels(logits[:, start:-1, :], all_tokens_chunk[:, start+1:]) + if all_ref_logprobs is None: + ref_logprobs = logprobs_of_labels(ref_logits[:, start:-1, :], all_tokens_chunk[:, start+1:]) + values_chunks.append(values.cpu()) - logits_chunks.append(logits.cpu()) - ref_logits_chunks.append(ref_logits.cpu()) log_probs_chunks.append(logprobs.cpu()) - ref_logprobs_chunks.append(ref_logprobs.cpu()) - - values = torch.cat(values_chunks, dim=0) - logits = torch.cat(logits_chunks, dim=0) - ref_logits = torch.cat(ref_logits_chunks, dim=0) + if all_ref_logprobs is None: + ref_logprobs_chunks.append(ref_logprobs.cpu()) + + # Remove values before v[start] (this is the value of the state before any tokens are sampled) + # and remove the last value v[-1] (this is a terminal state after all tokens have been generated with value 0) + values = torch.cat(values_chunks, dim=0)[:, start:-1] logprobs = torch.cat(log_probs_chunks, dim=0) - ref_logprobs = torch.cat(ref_logprobs_chunks, dim=0) - - n_samples: int = samples.shape[0] + attention_mask = attention_mask[:, start:].cpu() - # Estimate the KL divergence between the model and reference model - if self.config.model.model_arch_type == "seq2seq": - attention_mask = sample_outputs != self.tokenizer.pad_token_id - start = 0 + if all_ref_logprobs is None: + ref_logprobs = torch.cat(ref_logprobs_chunks, dim=0) + # all_ref_logprobs returned from reward already has prompt prefix removed else: - # NOTE: -1 because kl[prompt_tensors.shape[1]] is kl of the second token in the response - start = prompt_tensors.shape[1] - 1 + # Remove (some) padding from distributed communication + # So arithmetic with logprobs can be done + ref_logprobs = all_ref_logprobs[:, :logprobs.shape[1]].cpu() - attention_mask = attention_mask.cpu() - log_ratio = (logprobs - ref_logprobs) * attention_mask[:, :-1] + # Estimate the KL divergence between the model and reference model + # NOTE: nan is interfering with kl estimates since 0 * nan = 0 + # Convert inf padding terms in ref_logprobs to number removable with attention mask mult + log_ratio = (logprobs - torch.nan_to_num(ref_logprobs)) * attention_mask[:, :-1] kl = log_ratio.exp() - 1 - log_ratio mean_kl_per_token = kl.mean() mean_kl = kl.sum(1).mean() + kl_penalties = self.kl_ctl.value * -log_ratio.cpu() - logprobs = logprobs.cpu() - ref_logprobs = ref_logprobs.cpu() - prompt_tensors = prompt_tensors.cpu() - sample_outputs = sample_outputs.cpu() - # TODO(dahoas): Why [:, :-1]? Redudant with clipping via start : ends[ix]? - # Actually I think it's just wrong? - values = values.cpu()[:, :-1] + n_samples = padded_tok_samples.shape[0] + rollout_count = 0 # Get the logprobs and values, for tokens that are not padding, # from the end of the prompt up to the token, while also including the latter # (these are taken from the student model and not the reference model) - ends = start + attention_mask[:, start:].sum(1) + 1 - # NOTE: values[i] is the value of the state after response token i - # TODO(dahoas): Does it actually make sense to get the rewards one step early? - all_values = [values[ix, start : ends[ix]] for ix in range(n_samples)] - all_logprobs = [logprobs[ix, start : ends[ix]] for ix in range(n_samples)] - - kl_penalty = self.kl_ctl.value * -log_ratio.cpu() - kl_penalty = [xs[start : ends[ix]] for ix, xs in enumerate(kl_penalty)] - - rollout_count = 0 - + # NOTE: Why are we summing including a token from the prompt? + # In our case it's ok because we then subtract -1 from resulting end index + ends = attention_mask.sum(1) + 1 for sample_idx in range(n_samples): - # To compute per token reward first add in kl penalties over trajectory - # NOTE: kl_penalty[i] is kl_diff at token i+1 in the output (w/o EOS) - rewards = kl_penalty[sample_idx] + value = values[sample_idx, :ends[sample_idx]-1] + logprob = logprobs[sample_idx, :ends[sample_idx]-1] + kl_penalty = kl_penalties[sample_idx, :ends[sample_idx]-1] + query_tensor = tok_prompts[sample_idx] + response_tensor = tok_outputs[sample_idx] + if len(value) != len(logprob) or len(logprob) != len(kl_penalty) or len(kl_penalty) != len(response_tensor): + raise ValueError(f"Length mismatch between value, logprob, kl, and response_tensor:\n\ + Value: {value.shape}, {value}\n\ + Logprob: {logprob.shape}, {logprob}\n\ + KL: {kl_penalty.shape}, {kl_penalty}\n\ + Response: {response_tensor.shape}, {response_tensor}, {self.tokenizer.decode(response_tensor)}\n") + # Then add in rewards if scores.shape[1] == 1: # NOTE: Final reward given at EOS token following HHH practice - rewards[-1] += scores[sample_idx][0].cpu() + score = scores[sample_idx][0].cpu() + kl_penalty[-1] += score + rewards = kl_penalty else: score = scores[sample_idx] score_right_padding = torch.sum(scores_mask[sample_idx]) score = score[:score_right_padding].cpu() - p_score = torch.zeros_like(rewards) - p_score[: score.shape[0]] += score - rewards += p_score + if len(score) != len(kl_penalty): + raise ValueError(f"Length mismatch between score and kl penalty:\n\ + Logprob: {logprob.shape}, {logprob}\n\ + kl_penalty: {kl_penalty.shape}, {kl_penalty}\n\ + Score: {score.shape}, {score}") + rewards = kl_penalty + score + + if kl_penalty.isnan().any() or score.isnan().any(): + raise ValueError(f"nan in tensor:\n\ + KL: {kl_penalty}\n\ + Score: {score}\n\ + logprob: {logprob}\n\ + ref logprob: {ref_logprobs[sample_idx][:ends[sample_idx]-1]}\n\ + mask: {attention_mask[sample_idx]}\n\ + kl ctl: {self.kl_ctl.value}") ppo_rl_elements.append( PPORLElement( - query_tensor=prompt_tensors[sample_idx], - response_tensor=sample_outputs[sample_idx], - logprobs=all_logprobs[sample_idx], - values=all_values[sample_idx], + query_tensor=query_tensor, + response_tensor=response_tensor, + logprobs=logprob, + values=value, rewards=rewards, ) ) From 527ba2353bfa3bc1dbb4e70ab35ce3dc2ff6728e Mon Sep 17 00:00:00 2001 From: dahoas Date: Mon, 17 Jul 2023 16:07:05 +0000 Subject: [PATCH 26/27] Fix BoN sampling after big refactor --- trlx/trainer/accelerate_ppo_trainer.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index d278821b8..6ef62ddb1 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -360,9 +360,13 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq else: all_ref_logprobs = None scores = scores.squeeze(-1) + scores_mask = scores != -np.inf + + # Remove infs so mask can be used + if self.config.method.cliprange_reward: + scores = torch.clip(scores, -self.config.method.cliprange_reward, self.config.method.cliprange_reward) # Best-of-N Sampling. - scores_mask = scores != -np.inf train_indices = self.get_topk_indices(input_tensor=scores_mask*scores, window_size=num_return_sequences, k=self.config.method.num_train_sequences, device=device) scores = scores[train_indices] scores_mask = scores_mask[train_indices] @@ -371,9 +375,6 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq if all_ref_logprobs is not None: all_ref_logprobs = all_ref_logprobs[train_indices] - if self.config.method.cliprange_reward: - scores = torch.clip(scores, -self.config.method.cliprange_reward, self.config.method.cliprange_reward) - # store statistics of the initial rollout as reference if self.ref_mean is None: self.ref_mean, self.ref_std = (scores * scores_mask).sum(dim=1).mean(), (scores * scores_mask).sum( From 676a1cd32ea7bdfeead1e43b4b0d3427b8d137f7 Mon Sep 17 00:00:00 2001 From: dahoas Date: Tue, 18 Jul 2023 13:24:35 +0000 Subject: [PATCH 27/27] Fixing style --- trlx/trainer/accelerate_base_trainer.py | 74 ++++++------- trlx/trainer/accelerate_ppo_trainer.py | 135 +++++++++++++++--------- 2 files changed, 117 insertions(+), 92 deletions(-) diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index 75c226fb2..801214bf4 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -7,7 +7,6 @@ from copy import copy from time import time from typing import Dict, List, Optional, Tuple -from copy import copy import ray import torch @@ -206,7 +205,8 @@ def decode( append_eos_token: bool = True, ) -> Tuple[List[str], List[str], List[str], List[torch.LongTensor], List[torch.LongTensor], List[torch.LongTensor]]: """ - Decode tensor generations with stopping criteria into lists of strings (`samples`: List[str], `prompts`: List[str], `outputs`: List[str]) and + Decode tensor generations with stopping criteria into lists + of strings (`samples`: List[str], `prompts`: List[str], `outputs`: List[str]) and Note prompts maybe sompetimes be right padded, as well as samples """ if prompt_sizes is None: @@ -233,12 +233,16 @@ def decode( stop_ix = str_output.find(stop) if stop_ix >= 0: str_output = str_output[:stop_ix].rstrip() - + # Recover sequence of tokens corresponding to string # NOTE: Cast to torch.long in the case the input is empty tok_output = self.tokenizer(str_output, return_tensors="pt").input_ids[0].long() # Remove bos from tokenized output (if present) - if hasattr(self.tokenizer, "bos_token") and len(tok_output) > 0 and tok_output[0].item() == self.tokenizer.bos_token_id: + if ( + hasattr(self.tokenizer, "bos_token") + and len(tok_output) > 0 + and tok_output[0].item() == self.tokenizer.bos_token_id + ): tok_output = tok_output[1:] # Recover the last if it was present in the original sample @@ -248,22 +252,9 @@ def decode( # Note we append to both separately because encoding does not result in tokenizer.eos_token_id if append_eos_token: str_output += self.tokenizer.eos_token - tok_output = torch.cat([tok_output, torch.tensor(self.tokenizer.eos_token_id, dtype=torch.long).view((1,))]) - - '''if len(tok_output) > 91: - print("#####") - print("Sample: ", sample.shape) - print(sample) - print("output: ", tok_output.shape) - print(tok_output) - print(str_output) - print("prompt size: ", output_start_ix) - print("tok_prompt: ", tok_prompt) - print("tok prompt shape: ", tok_prompt.shape) - print("prompt: ", prompt) - print("str prompt: ", str_prompt) - print("prompt: ", prompt.shape) - exit()''' + tok_output = torch.cat( + [tok_output, torch.tensor(self.tokenizer.eos_token_id, dtype=torch.long).view((1,))] + ) str_prompts.append(str_prompt) str_outputs.append(str_output) @@ -272,7 +263,9 @@ def decode( if self.config.model.model_arch_type == "seq2seq": sample = str_prompt + self.tokenizer.sep_token + str_output - tok_sample = torch.cat([tok_prompt, torch.tensor(self.tokenizer.sep_token_id, dtype=torch.long).view((1,)), tok_output]) + tok_sample = torch.cat( + [tok_prompt, torch.tensor(self.tokenizer.sep_token_id, dtype=torch.long).view((1,)), tok_output] + ) else: sample = str_prompt + str_output tok_sample = torch.cat([tok_prompt, tok_output]) @@ -280,15 +273,6 @@ def decode( str_samples.append(sample) tok_samples.append(tok_sample) - if tok_sample.dtype == torch.float32: - print("tok_sample ", tok_sample.dtype) - print(tok_sample) - print("tok prompt: ", tok_prompt.dtype) - print(tok_prompt) - print("tok output: ", tok_output.dtype) - print(tok_output) - exit() - return str_samples, str_prompts, str_outputs, tok_samples, tok_prompts, tok_outputs def generate(self, input_ids, attention_mask=None, chunk_size=None, **kwargs): @@ -483,7 +467,9 @@ def evaluate(self): # noqa: C901 stats["time/generate"] = time() - generate_time if self.accelerator.is_main_process: - str_samples, str_prompts, str_outputs, tok_samples, tok_prompts, tok_outputs = self.decode(all_prompts, all_samples, all_prompt_sizes) + str_samples, str_prompts, str_outputs, tok_samples, tok_prompts, tok_outputs = self.decode( + all_prompts, all_samples, all_prompt_sizes + ) columns = ["prompt", "output"] columns_data = [str_prompts, str_outputs] @@ -497,13 +483,15 @@ def evaluate(self): # noqa: C901 if self.reward_fn: logger.info("Computing rewards") rewards = self.reward_fn( - samples=str_samples, - prompts=str_prompts, - outputs=str_outputs, - tok_samples=tok_samples, - tok_prompts=tok_prompts, - tok_outputs=tok_outputs, - model_tok=self.tokenizer, **metadata) + samples=str_samples, + prompts=str_prompts, + outputs=str_outputs, + tok_samples=tok_samples, + tok_prompts=tok_prompts, + tok_outputs=tok_outputs, + model_tok=self.tokenizer, + **metadata, + ) # Remove kl terms from reward if hasattr(self, "dist_ref_model") and self.dist_ref_model: rewards = [[r[0] for r in reward] for reward in rewards] @@ -525,13 +513,15 @@ def evaluate(self): # noqa: C901 logger.info("Computing metrics") metric_time = time() metrics = self.metric_fn( - samples=str_samples, - prompts=str_prompts, - outputs=str_outputs, + samples=str_samples, + prompts=str_prompts, + outputs=str_outputs, tok_samples=tok_samples, tok_prompts=tok_prompts, tok_outputs=tok_outputs, - model_tok=self.tokenizer, **metadata) + model_tok=self.tokenizer, + **metadata, + ) stats["time/metric"] = time() - metric_time mean_metrics = { diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index 6ef62ddb1..381e2ecfd 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -3,15 +3,13 @@ import uuid from time import time from typing import Callable, List -from copy import copy +import numpy as np import torch -import torch.nn.functional as F import transformers from torch.nn.utils.rnn import pad_sequence from torch.utils.data import DataLoader from transformers import AutoTokenizer -import numpy as np import trlx.utils.logging as logging from trlx.data.accelerate_base_datatypes import PromptBatch @@ -317,30 +315,47 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq ) if self.accelerator.is_main_process: - all_str_samples, all_str_prompts, all_str_outputs, all_tok_samples, all_tok_prompts, all_tok_outputs = self.decode( - gathered_prompts, gathered_samples, gathered_prompt_sizes, append_eos_token=True - ) + ( + all_str_samples, + all_str_prompts, + all_str_outputs, + all_tok_samples, + all_tok_prompts, + all_tok_outputs, + ) = self.decode(gathered_prompts, gathered_samples, gathered_prompt_sizes, append_eos_token=True) rollout_score_time = time() # reward_fn should return list of rewards at each token per sample # NOTE: all_scores[0][i] is the reward due to token (action) i in prompt + response (b/c of how kl is computed) - # NOTE: reward_fn can optionally also compute the ref_logits. In this case size will be [batch_size, response_length, 2] - all_scores = self.reward_fn(samples=all_str_samples, - prompts=all_str_prompts, - outputs=all_str_outputs, - tok_samples=all_tok_samples, - tok_prompts=all_tok_prompts, - tok_outputs=all_tok_outputs, - model_tok=self.tokenizer, - **metadata) - all_scores = [torch.tensor(score, dtype=torch.float, device=device).view(-1,) for score in all_scores] + # NOTE: reward_fn can optionally also compute the ref_logits. + # In this case size will be [batch_size, response_length, 2] + all_scores = self.reward_fn( + samples=all_str_samples, + prompts=all_str_prompts, + outputs=all_str_outputs, + tok_samples=all_tok_samples, + tok_prompts=all_tok_prompts, + tok_outputs=all_tok_outputs, + model_tok=self.tokenizer, + **metadata, + ) + all_scores = [ + torch.tensor(score, dtype=torch.float, device=device).view( + -1, + ) + for score in all_scores + ] # Pad -np.inf reward on the ends all_scores = pad_sequence(all_scores, batch_first=True, padding_value=-np.inf) - max_len = torch.tensor(len(all_scores[0])/(1+int(self.dist_ref_model)), dtype=torch.long, device=device) + max_len = torch.tensor( + len(all_scores[0]) / (1 + int(self.dist_ref_model)), dtype=torch.long, device=device + ) stats["time/rollout_score"] = time() - rollout_score_time - all_scores = list(all_scores.reshape(self.accelerator.num_processes, len(samples), max_len, -1).unbind()) + all_scores = list( + all_scores.reshape(self.accelerator.num_processes, len(samples), max_len, -1).unbind() + ) else: all_scores = None max_len = torch.tensor(0, dtype=torch.long, device=device) @@ -348,13 +363,13 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq if torch.distributed.is_initialized(): torch.distributed.broadcast(max_len, 0) # Allocate extra space if scores include ref_logits - scores = torch.empty((len(samples), max_len, 1+int(self.dist_ref_model)), device=device) + scores = torch.empty((len(samples), max_len, 1 + int(self.dist_ref_model)), device=device) torch.distributed.scatter(scores, all_scores) else: scores = all_scores[0].clone().detach() - # Remove ref_logits from scores if present - if self.dist_ref_model: + # Remove ref_logits from scores if present + if self.dist_ref_model: all_ref_logprobs = scores[:, :, 1] scores = scores[:, :, 0] else: @@ -362,12 +377,17 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq scores = scores.squeeze(-1) scores_mask = scores != -np.inf - # Remove infs so mask can be used + # Remove infs so mask can be used if self.config.method.cliprange_reward: scores = torch.clip(scores, -self.config.method.cliprange_reward, self.config.method.cliprange_reward) - # Best-of-N Sampling. - train_indices = self.get_topk_indices(input_tensor=scores_mask*scores, window_size=num_return_sequences, k=self.config.method.num_train_sequences, device=device) + # Best-of-N Sampling. + train_indices = self.get_topk_indices( + input_tensor=scores_mask * scores, + window_size=num_return_sequences, + k=self.config.method.num_train_sequences, + device=device, + ) scores = scores[train_indices] scores_mask = scores_mask[train_indices] samples = samples[train_indices] @@ -391,18 +411,17 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq elif self.config.method.scale_reward == "ref": scores /= self.ref_std - # Only use these samples, prompts, outputs to compute ppo stats + # Only use these samples, prompts, outputs to compute ppo stats _, _, _, tok_samples, tok_prompts, tok_outputs = self.decode(prompt_tensors, samples, append_eos_token=True) # Pad the sample outputs - #outputs = self.tokenizer(str_outputs).input_ids + # outputs = self.tokenizer(str_outputs).input_ids # TODO: Why is this here? Should this be a sep token? if self.config.model.model_arch_type == "seq2seq": # add to the start of the output for i in range(len(tok_outputs)): tok_outputs[i] = [self.tokenizer.pad_token_id] + outputs[i].tolist() - padded_tok_outputs = pad_sequence(tok_outputs, batch_first=True, padding_value=self.tokenizer.pad_token_id) tok_prompts = torch.stack(tok_prompts, dim=0) padded_tok_samples = pad_sequence(tok_samples, batch_first=True, padding_value=self.tokenizer.pad_token_id) attention_mask = padded_tok_samples.not_equal(self.tokenizer.pad_token_id).long() @@ -457,7 +476,9 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq all_tokens_chunks = torch.chunk(all_tokens, chunks=self.config.method.gen_chunk_size, dim=0) attention_mask_chunks = torch.chunk(attention_mask, chunks=self.config.method.gen_chunk_size, dim=0) position_ids_chunks = torch.chunk(position_ids, chunks=self.config.method.gen_chunk_size, dim=0) - for all_tokens_chunk, attention_mask_chunk, position_ids_chunk in zip(all_tokens_chunks, attention_mask_chunks, position_ids_chunks): + for all_tokens_chunk, attention_mask_chunk, position_ids_chunk in zip( + all_tokens_chunks, attention_mask_chunks, position_ids_chunks + ): all_tokens_chunk = all_tokens_chunk.to(device) attention_mask_chunk = attention_mask_chunk.to(device) position_ids_chunk = position_ids_chunk.to(device) @@ -467,7 +488,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq attention_mask=attention_mask_chunk, position_ids=position_ids_chunk, ) - # If all_ref_logits is not None they have already been generated during call to reward_fn + # If all_ref_logits is not None they have already been generated during call to reward_fn if all_ref_logprobs is None: if hasattr(self.model, "frozen_head"): ref_logits = self.model.forward_hydra( @@ -489,23 +510,26 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq ref_logits = logits.clone() if self.config.model.model_arch_type == "seq2seq": - logprobs = logprobs_of_labels(logits[:, start:-1, :], all_tokens_chunk[:, start+1:]) - ref_logprobs = logprobs_of_labels(ref_logits[:, start:-1, :], all_tokens_chunk[:, start+1:]) + logprobs = logprobs_of_labels(logits[:, start:-1, :], all_tokens_chunk[:, start + 1 :]) + ref_logprobs = logprobs_of_labels(ref_logits[:, start:-1, :], all_tokens_chunk[:, start + 1 :]) else: # NOTE: logprob[i] is (log)prob at which all_token[i+1] was sampled - # So need to index at start = prompt_tensors.shape[1] - 1 which is the logprob corresponding to the first sampled token + # So need to index at start = prompt_tensors.shape[1] - 1 which is + # the logprob corresponding to the first sampled token # Indexing ends at -1 because the last logprob corresponds to an unsampled token - logprobs = logprobs_of_labels(logits[:, start:-1, :], all_tokens_chunk[:, start+1:]) + logprobs = logprobs_of_labels(logits[:, start:-1, :], all_tokens_chunk[:, start + 1 :]) if all_ref_logprobs is None: - ref_logprobs = logprobs_of_labels(ref_logits[:, start:-1, :], all_tokens_chunk[:, start+1:]) - + ref_logprobs = logprobs_of_labels( + ref_logits[:, start:-1, :], all_tokens_chunk[:, start + 1 :] + ) + values_chunks.append(values.cpu()) log_probs_chunks.append(logprobs.cpu()) if all_ref_logprobs is None: ref_logprobs_chunks.append(ref_logprobs.cpu()) - + # Remove values before v[start] (this is the value of the state before any tokens are sampled) - # and remove the last value v[-1] (this is a terminal state after all tokens have been generated with value 0) + # and remove the last value v[-1] (this is a terminal state after all tokens have been generated with value 0) values = torch.cat(values_chunks, dim=0)[:, start:-1] logprobs = torch.cat(log_probs_chunks, dim=0) attention_mask = attention_mask[:, start:].cpu() @@ -515,8 +539,8 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq # all_ref_logprobs returned from reward already has prompt prefix removed else: # Remove (some) padding from distributed communication - # So arithmetic with logprobs can be done - ref_logprobs = all_ref_logprobs[:, :logprobs.shape[1]].cpu() + # So arithmetic with logprobs can be done + ref_logprobs = all_ref_logprobs[:, : logprobs.shape[1]].cpu() # Estimate the KL divergence between the model and reference model # NOTE: nan is interfering with kl estimates since 0 * nan = 0 @@ -537,18 +561,25 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq # In our case it's ok because we then subtract -1 from resulting end index ends = attention_mask.sum(1) + 1 for sample_idx in range(n_samples): - value = values[sample_idx, :ends[sample_idx]-1] - logprob = logprobs[sample_idx, :ends[sample_idx]-1] - kl_penalty = kl_penalties[sample_idx, :ends[sample_idx]-1] + value = values[sample_idx, : ends[sample_idx] - 1] + logprob = logprobs[sample_idx, : ends[sample_idx] - 1] + kl_penalty = kl_penalties[sample_idx, : ends[sample_idx] - 1] query_tensor = tok_prompts[sample_idx] response_tensor = tok_outputs[sample_idx] - if len(value) != len(logprob) or len(logprob) != len(kl_penalty) or len(kl_penalty) != len(response_tensor): - raise ValueError(f"Length mismatch between value, logprob, kl, and response_tensor:\n\ + if ( + len(value) != len(logprob) + or len(logprob) != len(kl_penalty) + or len(kl_penalty) != len(response_tensor) + ): + raise ValueError( + f"Length mismatch between value, logprob, kl, and response_tensor:\n\ Value: {value.shape}, {value}\n\ Logprob: {logprob.shape}, {logprob}\n\ KL: {kl_penalty.shape}, {kl_penalty}\n\ - Response: {response_tensor.shape}, {response_tensor}, {self.tokenizer.decode(response_tensor)}\n") - + Response: {response_tensor.shape}, {response_tensor}, \ + {self.tokenizer.decode(response_tensor)}\n" + ) + # Then add in rewards if scores.shape[1] == 1: # NOTE: Final reward given at EOS token following HHH practice @@ -560,20 +591,24 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq score_right_padding = torch.sum(scores_mask[sample_idx]) score = score[:score_right_padding].cpu() if len(score) != len(kl_penalty): - raise ValueError(f"Length mismatch between score and kl penalty:\n\ + raise ValueError( + f"Length mismatch between score and kl penalty:\n\ Logprob: {logprob.shape}, {logprob}\n\ kl_penalty: {kl_penalty.shape}, {kl_penalty}\n\ - Score: {score.shape}, {score}") + Score: {score.shape}, {score}" + ) rewards = kl_penalty + score if kl_penalty.isnan().any() or score.isnan().any(): - raise ValueError(f"nan in tensor:\n\ + raise ValueError( + f"nan in tensor:\n\ KL: {kl_penalty}\n\ Score: {score}\n\ logprob: {logprob}\n\ ref logprob: {ref_logprobs[sample_idx][:ends[sample_idx]-1]}\n\ mask: {attention_mask[sample_idx]}\n\ - kl ctl: {self.kl_ctl.value}") + kl ctl: {self.kl_ctl.value}" + ) ppo_rl_elements.append( PPORLElement(