diff --git a/examples/ppo_redemption.py b/examples/ppo_redemption.py new file mode 100644 index 000000000..84435b225 --- /dev/null +++ b/examples/ppo_redemption.py @@ -0,0 +1,82 @@ +# Generates positive movie reviews by tuning a pretrained model on IMDB dataset +# with a sentiment reward function +import json +import os +import sys +from typing import List + +import torch +from datasets import load_dataset +from transformers import pipeline + +import trlx +from trlx.data.default_configs import TRLConfig, default_ppo_config + + +def get_positive_score(scores): + "Extract value associated with a positive sentiment from pipeline's output" + return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"] + + +def get_negative_score(scores): + return dict(map(lambda x: tuple(x.values()), scores))["NEGATIVE"] + + +def main(hparams={}): + # Merge sweep config with default config if given + config = TRLConfig.update(default_ppo_config().to_dict(), hparams) + config.method.cliprange_reward = False + config.method.gen_kwargs["max_new_tokens"] = 70 + config.method.gen_kwargs["temperature"] = 0.3 + config.train.total_steps = 20000 + config.train.checkpoint_interval = 10000000 + # config.method.init_kl_coef = 0 + + if torch.cuda.is_available(): + device = int(os.environ.get("LOCAL_RANK", 0)) + else: + device = -1 + + sentiment_fn = pipeline( + "sentiment-analysis", + "lvwerra/distilbert-imdb", + top_k=2, + truncation=True, + batch_size=256, + device=device, + ) + + def dense_reward_fn(samples: List[str], prompts: List[str], outputs: List[str], model_tok, **kwargs) -> List[float]: + # Reward positively for initially negative then positive review + # Reward functions should never receive padded text except for a singel EOS at the end + # Reward function should return token rewards for just the response + first_halves = [".".join(sample.split(".")[: len(sample.split(".")) // 2]) for sample in samples] + negative_first_halves = list(map(get_negative_score, sentiment_fn(first_halves))) + second_halves = [".".join(sample.split(".")[len(sample.split(".")) // 2 :]) for sample in samples] + positive_second_halves = list(map(get_positive_score, sentiment_fn(second_halves))) + text_scores = [[f, s] for f, s in zip(negative_first_halves, positive_second_halves)] + tok_scores = [] + for sample, prompt, response, text_score in zip(samples, prompts, outputs, text_scores): + toks = model_tok(response).input_ids + tok_score = [0] * len(toks) + # Hacky way of assigning intermediate score + tok_score[len(tok_score) // 2] = text_score[0] + tok_score[-1] = text_score[1] + tok_scores.append(tok_score) + return tok_scores + + # Take few words off of movies reviews as prompts + imdb = load_dataset("imdb", split="train+test") + prompts = [" ".join(review.split()[:4]) for review in imdb["text"]] + + trlx.train( + reward_fn=dense_reward_fn, + prompts=prompts, + eval_prompts=["I don't know much about Hungarian underground"] * 256, + config=config, + ) + + +if __name__ == "__main__": + hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1]) + main(hparams) diff --git a/trlx/data/default_configs.py b/trlx/data/default_configs.py index 5277d7010..3acee97ab 100644 --- a/trlx/data/default_configs.py +++ b/trlx/data/default_configs.py @@ -49,11 +49,13 @@ def default_ppo_config(): ref_mean=None, ref_std=None, cliprange_reward=10, + num_train_sequences=1, gen_kwargs=dict( max_new_tokens=40, top_k=0, top_p=1.0, do_sample=True, + num_return_sequences=1, ), ), ) diff --git a/trlx/models/modeling_ppo.py b/trlx/models/modeling_ppo.py index 82d3ec637..7856f6fa0 100644 --- a/trlx/models/modeling_ppo.py +++ b/trlx/models/modeling_ppo.py @@ -112,6 +112,9 @@ class PPOConfig(MethodConfig): :param gen_experience_kwargs: if this is not None, then the experience is generated using this :type gen_experience_kwargs: Dict[str, Any] + + :param num_train_sequences: top_k of n sampled sequences from prompt + :type num_train_sequences: int """ ppo_epochs: int @@ -131,6 +134,8 @@ class PPOConfig(MethodConfig): cliprange_reward: float gen_kwargs: dict gen_experience_kwargs: Optional[dict] = None + num_train_sequences: int = 1 + dist_ref_model: bool = False def get_advantages_and_returns( self, diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index 5c82335c0..801214bf4 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -4,6 +4,7 @@ import sys from abc import abstractmethod from contextlib import contextmanager +from copy import copy from time import time from typing import Dict, List, Optional, Tuple @@ -201,81 +202,124 @@ def decode( prompts: List[torch.LongTensor], samples: List[torch.LongTensor], prompt_sizes: torch.LongTensor = None, - append_eos_token: bool = False, - ) -> Tuple[List[str], List[str], List[str]]: + append_eos_token: bool = True, + ) -> Tuple[List[str], List[str], List[str], List[torch.LongTensor], List[torch.LongTensor], List[torch.LongTensor]]: """ - Decode tensor generations into lists of strings (`samples`: List[str], `prompts`: List[str], `outputs`: List[str]) + Decode tensor generations with stopping criteria into lists + of strings (`samples`: List[str], `prompts`: List[str], `outputs`: List[str]) and + Note prompts maybe sompetimes be right padded, as well as samples """ if prompt_sizes is None: # Assuming prompts were left-padded prompt_sizes = [prompts.shape[1]] * len(prompts) str_samples, str_prompts, str_outputs = [], [], [] + tok_samples, tok_prompts, tok_outputs = [], [], [] for prompt, sample, prompt_size in zip(prompts, samples, prompt_sizes): if self.config.model.model_arch_type == "seq2seq": output_start_ix = 0 else: output_start_ix = prompt_size + # We must decode by skipping padding in the middle with skip_special_tokens str_prompt = self.tokenizer.decode(prompt[:prompt_size], skip_special_tokens=True) + # Return the prompt tensor (with the exact padding) used to generate the sample + tok_prompt = prompt[:prompt_size].cpu() + str_output = self.tokenizer.decode(sample[output_start_ix:], skip_special_tokens=True) # Trim outputs up to `self.stop_sequences` if any are present - trimmed = False if self.stop_sequences: for stop in self.stop_sequences: stop_ix = str_output.find(stop) if stop_ix >= 0: str_output = str_output[:stop_ix].rstrip() - trimmed = True + + # Recover sequence of tokens corresponding to string + # NOTE: Cast to torch.long in the case the input is empty + tok_output = self.tokenizer(str_output, return_tensors="pt").input_ids[0].long() + # Remove bos from tokenized output (if present) + if ( + hasattr(self.tokenizer, "bos_token") + and len(tok_output) > 0 + and tok_output[0].item() == self.tokenizer.bos_token_id + ): + tok_output = tok_output[1:] # Recover the last if it was present in the original sample # or add one if it was trimmed with `self.stop_sequences`. # When a generation ended due to `max_new_tokens` exhaustion, # only then or token would not be present in the original sample at the end - if append_eos_token and ( - trimmed or sample[-1] == self.tokenizer.eos_token_id or sample[-1] == self.tokenizer.pad_token_id - ): + # Note we append to both separately because encoding does not result in tokenizer.eos_token_id + if append_eos_token: str_output += self.tokenizer.eos_token + tok_output = torch.cat( + [tok_output, torch.tensor(self.tokenizer.eos_token_id, dtype=torch.long).view((1,))] + ) str_prompts.append(str_prompt) str_outputs.append(str_output) + tok_prompts.append(tok_prompt) + tok_outputs.append(tok_output) if self.config.model.model_arch_type == "seq2seq": sample = str_prompt + self.tokenizer.sep_token + str_output + tok_sample = torch.cat( + [tok_prompt, torch.tensor(self.tokenizer.sep_token_id, dtype=torch.long).view((1,)), tok_output] + ) else: sample = str_prompt + str_output + tok_sample = torch.cat([tok_prompt, tok_output]) str_samples.append(sample) + tok_samples.append(tok_sample) - return str_samples, str_prompts, str_outputs + return str_samples, str_prompts, str_outputs, tok_samples, tok_prompts, tok_outputs - def generate(self, input_ids, attention_mask=None, **kwargs): + def generate(self, input_ids, attention_mask=None, chunk_size=None, **kwargs): """Wraps hf's `generate` adding some specific method's defaults""" + # Decide into chunk sizes and generate saples input_ids = input_ids.to(self.accelerator.device) if attention_mask is not None: attention_mask = attention_mask.to(self.accelerator.device) - if self.generate_experience_kwargs is not None: - kwargs = dict(self.generate_experience_kwargs, **kwargs) - else: - kwargs = dict(self.generate_kwargs, **kwargs) - with torch.no_grad(): - return self.accelerator.unwrap_model(self.model).generate( - input_ids=input_ids, attention_mask=attention_mask, **kwargs + generate_kwargs = copy(self.generate_kwargs) + generate_kwargs.update(kwargs) + + # Update max_new_tokens to respect max_seq_length + prompt_length = input_ids.shape[1] + if generate_kwargs.get("max_new_tokens") is not None: + generate_kwargs["max_new_tokens"] = min( + max(self.max_length - prompt_length, 0), generate_kwargs["max_new_tokens"] ) + else: + generate_kwargs["max_new_tokens"] = max(self.max_length - prompt_length, 0) - def generate_eval(self, input_ids, attention_mask=None, **kwargs): - """Wraps hf's `generate` adding some specific method's defaults""" - input_ids = input_ids.to(self.accelerator.device) + # Repeat prompts, attention_masks for chunking if returning multiple sequences + if generate_kwargs.get("num_return_sequences") is None: + generate_kwargs["num_return_sequences"] = 1 + + num_return_sequences = generate_kwargs.pop("num_return_sequences") # Pop to hide from model.generate call + input_ids = input_ids.repeat_interleave(num_return_sequences, dim=0) if attention_mask is not None: - attention_mask = attention_mask.to(self.accelerator.device) + attention_mask = attention_mask.repeat_interleave(num_return_sequences, dim=0) - kwargs = dict(self.generate_kwargs, **kwargs) + if chunk_size is None: + chunk_size = input_ids.shape[0] - with torch.no_grad(): - return self.accelerator.unwrap_model(self.model).generate( - input_ids=input_ids, attention_mask=attention_mask, **kwargs - ) + # Chunk input_ids and attention_mask + input_ids = input_ids.split(chunk_size, dim=0) + if attention_mask is not None: + attention_mask = attention_mask.split(chunk_size, dim=0) + samples = [] + for chunk_idx in range(len(input_ids)): + with torch.no_grad(): + sample = self.accelerator.unwrap_model(self.model).generate( + input_ids=input_ids[chunk_idx], attention_mask=attention_mask[chunk_idx], **generate_kwargs + ) + samples.append(sample) + # Concat samples + samples = torch.cat(samples, 0) + return samples def save_pretrained(self, directory: Optional[str] = None, **kwargs): """Save the underlying Hugging Face model, tokenizer, and configuration files to a directory for @@ -377,11 +421,20 @@ def evaluate(self): # noqa: C901 for i_prompt, prompts in enumerate(self.eval_dataloader): metadata = {k: v for k, v in prompts.items() if k != "input_ids" and k != "attention_mask"} if self.generate_sweep_kwarg: - samples = self.generate_eval( + samples = self.generate( prompts["input_ids"], prompts["attention_mask"], **{gen_sweep_arg: gen_sweep_value} ) else: - samples = self.generate_eval(prompts["input_ids"], prompts["attention_mask"]) + chunk_size = self.config.method.chunk_size if hasattr(self.config.method, "chunk_size") else None + samples = self.generate(prompts["input_ids"], prompts["attention_mask"], chunk_size=chunk_size) + + # Repeat prompts, metadata num_return_sequence times + num_return_sequences = 1 + if self.generate_kwargs.get("num_return_sequences") is not None: + num_return_sequences = self.generate_kwargs["num_return_sequences"] + prompts["input_ids"] = prompts["input_ids"].repeat_interleave(num_return_sequences, dim=0) + prompts["attention_mask"] = prompts["attention_mask"].repeat_interleave(num_return_sequences, dim=0) + metadata = {k: self.repeat_interleave(v, num_return_sequences) for k, v in metadata.items()} # TODO(reciprocated): this should be moved into `decode` # but that needs to be synced with indexing in `make_experience` @@ -396,9 +449,9 @@ def evaluate(self): # noqa: C901 pad_index=self.tokenizer.pad_token_id, ) ) - all_samples.extend(samples.tolist()) - all_prompts.extend(prompts.tolist()) - all_prompt_sizes.extend(prompt_sizes.tolist()) + all_samples.extend(samples) + all_prompts.extend(prompts) + all_prompt_sizes.extend(prompt_sizes) metadata = gather_dict(metadata, self.accelerator.gradient_state) all_metadata.append(metadata) @@ -414,7 +467,9 @@ def evaluate(self): # noqa: C901 stats["time/generate"] = time() - generate_time if self.accelerator.is_main_process: - str_samples, str_prompts, str_outputs = self.decode(all_prompts, all_samples, all_prompt_sizes) + str_samples, str_prompts, str_outputs, tok_samples, tok_prompts, tok_outputs = self.decode( + all_prompts, all_samples, all_prompt_sizes + ) columns = ["prompt", "output"] columns_data = [str_prompts, str_outputs] @@ -427,10 +482,25 @@ def evaluate(self): # noqa: C901 # in online setting, compute the reward for validation if self.reward_fn: logger.info("Computing rewards") - rewards = torch.tensor( - self.reward_fn(samples=str_samples, prompts=str_prompts, outputs=str_outputs, **metadata), - dtype=float, + rewards = self.reward_fn( + samples=str_samples, + prompts=str_prompts, + outputs=str_outputs, + tok_samples=tok_samples, + tok_prompts=tok_prompts, + tok_outputs=tok_outputs, + model_tok=self.tokenizer, + **metadata, ) + # Remove kl terms from reward + if hasattr(self, "dist_ref_model") and self.dist_ref_model: + rewards = [[r[0] for r in reward] for reward in rewards] + if type(rewards[0]) is torch.Tensor: + rewards = torch.tensor([reward.sum().item() for reward in rewards], dtype=float) + elif type(rewards[0]) is list: + rewards = torch.tensor([sum(reward) for reward in rewards]) + else: + rewards = torch.tensor(rewards) mean_reward = rewards.mean().item() columns.append("reward") if not isinstance(rewards, list): @@ -442,7 +512,16 @@ def evaluate(self): # noqa: C901 if self.metric_fn: logger.info("Computing metrics") metric_time = time() - metrics = self.metric_fn(samples=str_samples, prompts=str_prompts, outputs=str_outputs, **metadata) + metrics = self.metric_fn( + samples=str_samples, + prompts=str_prompts, + outputs=str_outputs, + tok_samples=tok_samples, + tok_prompts=tok_prompts, + tok_outputs=tok_outputs, + model_tok=self.tokenizer, + **metadata, + ) stats["time/metric"] = time() - metric_time mean_metrics = { @@ -633,6 +712,15 @@ def learn(self): # noqa: C901 self.post_epoch_callback() tbar.close() + @staticmethod + def repeat_interleave(l, n): + if type(l) is torch.Tensor: + l = l.repeat_interleave(n, dim=0) + elif type(l) is list: + l = [[s] * n for s in l] + l = [item for sublist in l for item in sublist] + return l + @abstractmethod def get_arch(self, config: TRLConfig): """Returns a specific wrapper of the decoder architecture""" diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index a3af9aa3f..381e2ecfd 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -4,9 +4,10 @@ from time import time from typing import Callable, List +import numpy as np import torch -import torch.nn.functional as F import transformers +from torch.nn.utils.rnn import pad_sequence from torch.utils.data import DataLoader from transformers import AutoTokenizer @@ -67,8 +68,9 @@ def __init__(self, config: TRLConfig, **kwargs): self.store.clear_history() # Clear the rollout store - # Set up a reference model when hydra heads are not used - if not hasattr(self.model, "frozen_head") and not self.model.peft_type: + # Setup a reference model when hydra heads and distributed ref model are not used + self.dist_ref_model = config.method.dist_ref_model + if not hasattr(self.model, "frozen_head") and not self.model.peft_type and not self.dist_ref_model: self.ref_model = self.get_arch(self.config) self.ref_model.to(self.accelerator.device) self.ref_model.eval() @@ -228,13 +230,15 @@ def prepare_learning(self): self.train_dataloader = self.store.create_loader(self.config.train.batch_size, shuffle=False) + self.make_experience(self.config.method.num_rollouts) + self.n_updates_per_batch = self.config.method.ppo_epochs self.total_steps = self.config.train.epochs * self.n_updates_per_batch * len(self.train_dataloader) self.total_steps = min(self.total_steps, self.config.train.total_steps) def add_prompt_pipeline(self, pipeline: PromptPipeline): """Add a prompt pipeline dataloader to a trainer instance for the `make_experience` stage""" - prompt_dataloader = pipeline.create_loader(self.config.method.chunk_size, shuffle=True) + prompt_dataloader = pipeline.create_loader(self.config.method.chunk_size, shuffle=False) prompt_dataloader = self.accelerator.prepare_data_loader(prompt_dataloader) self.prompt_iterator = infinite_dataloader(prompt_dataloader) @@ -265,6 +269,9 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq ppo_rl_elements = [] accumulated_stats = [] + # Require chunk_size * num_train_sequences divides num_rollouts + assert num_rollouts % (self.config.method.chunk_size * self.config.method.num_train_sequences) == 0 + while len(ppo_rl_elements) < num_rollouts: stats = {} # Get next batch in prompt dataset @@ -273,10 +280,20 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq rollout_generate_time = time() # Generate samples from the language model (similar to using HuggingFace `generate` method) - samples = self.generate(batch["input_ids"], batch["attention_mask"]) + samples = self.generate( + batch["input_ids"], + batch["attention_mask"], + chunk_size=self.config.method.chunk_size, + **self.generate_experience_kwargs, + ) stats["time/rollout_generate"] = time() - rollout_generate_time - prompt_tensors = batch.input_ids + num_return_sequences = ( + self.generate_experience_kwargs["num_return_sequences"] + if self.generate_experience_kwargs.get("num_return_sequences") is not None + else 1 + ) + prompt_tensors = batch.input_ids.repeat_interleave(num_return_sequences, dim=0) device = samples.device prompt_sizes = torch.tensor([prompt_tensors.shape[1]] * len(prompt_tensors), device=device) @@ -289,61 +306,101 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq gathered_samples = self.accelerator.gather(padded_samples) gathered_prompts = self.accelerator.gather(padded_prompts) gathered_prompt_sizes = self.accelerator.gather(prompt_sizes) - metadata = gather_dict({k: v for k, v in batch.items() if k != "input_ids" and k != "attention_mask"}) + metadata = gather_dict( + { + k: self.repeat_interleave(v, num_return_sequences) + for k, v in batch.items() + if k != "input_ids" and k != "attention_mask" + } + ) if self.accelerator.is_main_process: - all_str_samples, all_str_prompts, all_str_outputs = self.decode( - gathered_prompts, gathered_samples, gathered_prompt_sizes, append_eos_token=True - ) + ( + all_str_samples, + all_str_prompts, + all_str_outputs, + all_tok_samples, + all_tok_prompts, + all_tok_outputs, + ) = self.decode(gathered_prompts, gathered_samples, gathered_prompt_sizes, append_eos_token=True) rollout_score_time = time() - all_scores = torch.tensor( - self.reward_fn( - samples=all_str_samples, prompts=all_str_prompts, outputs=all_str_outputs, **metadata - ), - dtype=torch.float, - device=device, + # reward_fn should return list of rewards at each token per sample + # NOTE: all_scores[0][i] is the reward due to token (action) i in prompt + response (b/c of how kl is computed) + # NOTE: reward_fn can optionally also compute the ref_logits. + # In this case size will be [batch_size, response_length, 2] + all_scores = self.reward_fn( + samples=all_str_samples, + prompts=all_str_prompts, + outputs=all_str_outputs, + tok_samples=all_tok_samples, + tok_prompts=all_tok_prompts, + tok_outputs=all_tok_outputs, + model_tok=self.tokenizer, + **metadata, ) + all_scores = [ + torch.tensor(score, dtype=torch.float, device=device).view( + -1, + ) + for score in all_scores + ] + # Pad -np.inf reward on the ends + all_scores = pad_sequence(all_scores, batch_first=True, padding_value=-np.inf) + max_len = torch.tensor( + len(all_scores[0]) / (1 + int(self.dist_ref_model)), dtype=torch.long, device=device + ) + stats["time/rollout_score"] = time() - rollout_score_time - all_scores = list(all_scores.reshape(self.accelerator.num_processes, -1).unbind()) + all_scores = list( + all_scores.reshape(self.accelerator.num_processes, len(samples), max_len, -1).unbind() + ) else: all_scores = None + max_len = torch.tensor(0, dtype=torch.long, device=device) if torch.distributed.is_initialized(): - scores = torch.empty(len(samples), device=device) + torch.distributed.broadcast(max_len, 0) + # Allocate extra space if scores include ref_logits + scores = torch.empty((len(samples), max_len, 1 + int(self.dist_ref_model)), device=device) torch.distributed.scatter(scores, all_scores) else: scores = all_scores[0].clone().detach() - str_samples, str_prompts, str_outputs = self.decode(prompt_tensors, samples, append_eos_token=True) - - # Pad the sample outputs - outputs = self.tokenizer(str_outputs).input_ids - if self.config.model.model_arch_type == "seq2seq": - # add to the start of the output - for i in range(len(outputs)): - outputs[i] = [self.tokenizer.pad_token_id] + outputs[i] - - outputs = list(map(torch.LongTensor, outputs)) - maxsize = max(map(len, outputs)) - outputs = [ - F.pad( - output, - (0, maxsize - len(output)), - value=self.tokenizer.pad_token_id, - ) - for output in outputs - ] - sample_outputs = torch.vstack(outputs).to(device) + # Remove ref_logits from scores if present + if self.dist_ref_model: + all_ref_logprobs = scores[:, :, 1] + scores = scores[:, :, 0] + else: + all_ref_logprobs = None + scores = scores.squeeze(-1) + scores_mask = scores != -np.inf + # Remove infs so mask can be used if self.config.method.cliprange_reward: scores = torch.clip(scores, -self.config.method.cliprange_reward, self.config.method.cliprange_reward) + # Best-of-N Sampling. + train_indices = self.get_topk_indices( + input_tensor=scores_mask * scores, + window_size=num_return_sequences, + k=self.config.method.num_train_sequences, + device=device, + ) + scores = scores[train_indices] + scores_mask = scores_mask[train_indices] + samples = samples[train_indices] + prompt_tensors = prompt_tensors[train_indices] + if all_ref_logprobs is not None: + all_ref_logprobs = all_ref_logprobs[train_indices] + # store statistics of the initial rollout as reference if self.ref_mean is None: - self.ref_mean, self.ref_std = scores.mean(), scores.std() - all_scores_mean, all_scores_std = self.running_moments.update(scores) + self.ref_mean, self.ref_std = (scores * scores_mask).sum(dim=1).mean(), (scores * scores_mask).sum( + dim=1 + ).std() + all_scores_mean, all_scores_std = self.running_moments.update(scores, scores_mask) stats["rollout_scores/mean"] = all_scores_mean.item() stats["rollout_scores/std"] = all_scores_std.item() stats["rollout_scores/running_mean"] = self.running_moments.mean.item() @@ -354,7 +411,30 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq elif self.config.method.scale_reward == "ref": scores /= self.ref_std + # Only use these samples, prompts, outputs to compute ppo stats + _, _, _, tok_samples, tok_prompts, tok_outputs = self.decode(prompt_tensors, samples, append_eos_token=True) + + # Pad the sample outputs + # outputs = self.tokenizer(str_outputs).input_ids + # TODO: Why is this here? Should this be a sep token? + if self.config.model.model_arch_type == "seq2seq": + # add to the start of the output + for i in range(len(tok_outputs)): + tok_outputs[i] = [self.tokenizer.pad_token_id] + outputs[i].tolist() + + tok_prompts = torch.stack(tok_prompts, dim=0) + padded_tok_samples = pad_sequence(tok_samples, batch_first=True, padding_value=self.tokenizer.pad_token_id) + attention_mask = padded_tok_samples.not_equal(self.tokenizer.pad_token_id).long() + + if self.config.model.model_arch_type == "seq2seq": + attention_mask = sample_outputs != self.tokenizer.pad_token_id + start = 0 + else: + # NOTE: -1 because kl[prompt_tensors.shape[1]] is kl of the second token in the response + start = tok_prompts.shape[1] - 1 + # Precompute logprobs, values + # TODO: Come back to seq2seq if self.config.model.model_arch_type == "seq2seq": attention_mask = batch.attention_mask.to(device) prompt_tensors = batch.input_ids.to(device) @@ -386,80 +466,156 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq return_dict=True, ).logits else: + values_chunks = [] + log_probs_chunks = [] + ref_logprobs_chunks = [] all_tokens = torch.cat((prompt_tensors.to(device), sample_outputs), dim=1) attention_mask = all_tokens.not_equal(self.tokenizer.pad_token_id).long().to(device) position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) - with torch.no_grad(): - logits, *_, values = self.model( - all_tokens, attention_mask=attention_mask, position_ids=position_ids - ) - # TODO(dahoas): When hydra model works need to also support generation on hydra head - if hasattr(self.model, "frozen_head") or self.model.peft_type: - ref_logits = self.model.forward_hydra( - all_tokens, - attention_mask=attention_mask, - position_ids=position_ids, - return_dict=True, - ).logits + all_tokens_chunks = torch.chunk(all_tokens, chunks=self.config.method.gen_chunk_size, dim=0) + attention_mask_chunks = torch.chunk(attention_mask, chunks=self.config.method.gen_chunk_size, dim=0) + position_ids_chunks = torch.chunk(position_ids, chunks=self.config.method.gen_chunk_size, dim=0) + for all_tokens_chunk, attention_mask_chunk, position_ids_chunk in zip( + all_tokens_chunks, attention_mask_chunks, position_ids_chunks + ): + all_tokens_chunk = all_tokens_chunk.to(device) + attention_mask_chunk = attention_mask_chunk.to(device) + position_ids_chunk = position_ids_chunk.to(device) + with torch.no_grad(): + logits, *_, values = self.model( + all_tokens_chunk, + attention_mask=attention_mask_chunk, + position_ids=position_ids_chunk, + ) + # If all_ref_logits is not None they have already been generated during call to reward_fn + if all_ref_logprobs is None: + if hasattr(self.model, "frozen_head"): + ref_logits = self.model.forward_hydra( + all_tokens_chunk, + attention_mask=attention_mask_chunk, + position_ids=position_ids_chunk, + return_dict=True, + ).logits + elif hasattr(self, "ref_model"): + ref_logits = self.ref_model( + all_tokens_chunk, + attention_mask=attention_mask_chunk, + position_ids=position_ids_chunk, + return_dict=True, + ).logits + ref_logits = ref_logits.to(device) + # If no ref model is provided then we compute no kl penalty + else: + ref_logits = logits.clone() + + if self.config.model.model_arch_type == "seq2seq": + logprobs = logprobs_of_labels(logits[:, start:-1, :], all_tokens_chunk[:, start + 1 :]) + ref_logprobs = logprobs_of_labels(ref_logits[:, start:-1, :], all_tokens_chunk[:, start + 1 :]) else: - ref_logits = self.ref_model( - all_tokens, - attention_mask=attention_mask, - position_ids=position_ids, - return_dict=True, - ).logits - ref_logits = ref_logits.to(device) - - if self.config.model.model_arch_type == "seq2seq": - logprobs = logprobs_of_labels(logits[:, :-1, :], sample_outputs[:, 1:]) - ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], sample_outputs[:, 1:]) + # NOTE: logprob[i] is (log)prob at which all_token[i+1] was sampled + # So need to index at start = prompt_tensors.shape[1] - 1 which is + # the logprob corresponding to the first sampled token + # Indexing ends at -1 because the last logprob corresponds to an unsampled token + logprobs = logprobs_of_labels(logits[:, start:-1, :], all_tokens_chunk[:, start + 1 :]) + if all_ref_logprobs is None: + ref_logprobs = logprobs_of_labels( + ref_logits[:, start:-1, :], all_tokens_chunk[:, start + 1 :] + ) + + values_chunks.append(values.cpu()) + log_probs_chunks.append(logprobs.cpu()) + if all_ref_logprobs is None: + ref_logprobs_chunks.append(ref_logprobs.cpu()) + + # Remove values before v[start] (this is the value of the state before any tokens are sampled) + # and remove the last value v[-1] (this is a terminal state after all tokens have been generated with value 0) + values = torch.cat(values_chunks, dim=0)[:, start:-1] + logprobs = torch.cat(log_probs_chunks, dim=0) + attention_mask = attention_mask[:, start:].cpu() + + if all_ref_logprobs is None: + ref_logprobs = torch.cat(ref_logprobs_chunks, dim=0) + # all_ref_logprobs returned from reward already has prompt prefix removed else: - logprobs = logprobs_of_labels(logits[:, :-1, :], all_tokens[:, 1:]) - ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], all_tokens[:, 1:]) - - n_samples: int = samples.shape[0] + # Remove (some) padding from distributed communication + # So arithmetic with logprobs can be done + ref_logprobs = all_ref_logprobs[:, : logprobs.shape[1]].cpu() # Estimate the KL divergence between the model and reference model - if self.config.model.model_arch_type == "seq2seq": - attention_mask = sample_outputs != self.tokenizer.pad_token_id - start = 0 - else: - start = prompt_tensors.shape[1] - 1 - - log_ratio = (logprobs - ref_logprobs) * attention_mask[:, :-1] + # NOTE: nan is interfering with kl estimates since 0 * nan = 0 + # Convert inf padding terms in ref_logprobs to number removable with attention mask mult + log_ratio = (logprobs - torch.nan_to_num(ref_logprobs)) * attention_mask[:, :-1] kl = log_ratio.exp() - 1 - log_ratio mean_kl_per_token = kl.mean() mean_kl = kl.sum(1).mean() + kl_penalties = self.kl_ctl.value * -log_ratio.cpu() - logprobs = logprobs.cpu() - ref_logprobs = ref_logprobs.cpu() - prompt_tensors = prompt_tensors.cpu() - sample_outputs = sample_outputs.cpu() - values = values.cpu()[:, :-1] + n_samples = padded_tok_samples.shape[0] + rollout_count = 0 # Get the logprobs and values, for tokens that are not padding, - # from the start of the prompt up to the token, while also including the latter + # from the end of the prompt up to the token, while also including the latter # (these are taken from the student model and not the reference model) - ends = start + attention_mask[:, start:].sum(1) + 1 - all_values = [values[ix, start : ends[ix]] for ix in range(n_samples)] - all_logprobs = [logprobs[ix, start : ends[ix]] for ix in range(n_samples)] - - kl_penalty = self.kl_ctl.value * -log_ratio.cpu() - kl_penalty = [xs[start : ends[ix]] for ix, xs in enumerate(kl_penalty)] - - rollout_count = 0 - + # NOTE: Why are we summing including a token from the prompt? + # In our case it's ok because we then subtract -1 from resulting end index + ends = attention_mask.sum(1) + 1 for sample_idx in range(n_samples): - rewards = kl_penalty[sample_idx] - rewards[-1] += scores[sample_idx].cpu() + value = values[sample_idx, : ends[sample_idx] - 1] + logprob = logprobs[sample_idx, : ends[sample_idx] - 1] + kl_penalty = kl_penalties[sample_idx, : ends[sample_idx] - 1] + query_tensor = tok_prompts[sample_idx] + response_tensor = tok_outputs[sample_idx] + if ( + len(value) != len(logprob) + or len(logprob) != len(kl_penalty) + or len(kl_penalty) != len(response_tensor) + ): + raise ValueError( + f"Length mismatch between value, logprob, kl, and response_tensor:\n\ + Value: {value.shape}, {value}\n\ + Logprob: {logprob.shape}, {logprob}\n\ + KL: {kl_penalty.shape}, {kl_penalty}\n\ + Response: {response_tensor.shape}, {response_tensor}, \ + {self.tokenizer.decode(response_tensor)}\n" + ) + + # Then add in rewards + if scores.shape[1] == 1: + # NOTE: Final reward given at EOS token following HHH practice + score = scores[sample_idx][0].cpu() + kl_penalty[-1] += score + rewards = kl_penalty + else: + score = scores[sample_idx] + score_right_padding = torch.sum(scores_mask[sample_idx]) + score = score[:score_right_padding].cpu() + if len(score) != len(kl_penalty): + raise ValueError( + f"Length mismatch between score and kl penalty:\n\ + Logprob: {logprob.shape}, {logprob}\n\ + kl_penalty: {kl_penalty.shape}, {kl_penalty}\n\ + Score: {score.shape}, {score}" + ) + rewards = kl_penalty + score + + if kl_penalty.isnan().any() or score.isnan().any(): + raise ValueError( + f"nan in tensor:\n\ + KL: {kl_penalty}\n\ + Score: {score}\n\ + logprob: {logprob}\n\ + ref logprob: {ref_logprobs[sample_idx][:ends[sample_idx]-1]}\n\ + mask: {attention_mask[sample_idx]}\n\ + kl ctl: {self.kl_ctl.value}" + ) ppo_rl_elements.append( PPORLElement( - query_tensor=prompt_tensors[sample_idx], - response_tensor=sample_outputs[sample_idx], - logprobs=all_logprobs[sample_idx], - values=all_values[sample_idx], + query_tensor=query_tensor, + response_tensor=response_tensor, + logprobs=logprob, + values=value, rewards=rewards, ) ) @@ -467,7 +623,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq rollout_count += 1 if torch.distributed.is_initialized(): - torch.distributed.all_reduce(mean_kl, torch.distributed.ReduceOp.AVG) + torch.distributed.all_reduce(mean_kl.to(self.accelerator.device), torch.distributed.ReduceOp.AVG) stats["time/rollout_time"] = clock.tick() stats["policy/sqrt_kl"] = torch.sqrt(mean_kl).item() @@ -485,3 +641,17 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq # Push samples and rewards to trainer's rollout storage self.push_to_store(ppo_rl_elements) + + @staticmethod + def get_topk_indices(input_tensor, window_size: int, k: int, device): + # Sum the scores along dim 1 + input_tensor = input_tensor.sum(1).unsqueeze(1) + # Use unfold to create the sliding windows + unfolded = input_tensor.unfold(0, window_size, window_size) + # Find the topk values and indices along the unfolded dimension + _, indices = torch.topk(unfolded, k, dim=2) + # Adjust indices to be relative to original tensor + indices = indices.squeeze(1) + torch.arange(0, input_tensor.size(0) - window_size + 1, window_size).to( + device + ).unsqueeze(1) + return indices.reshape(-1) diff --git a/trlx/utils/modeling.py b/trlx/utils/modeling.py index 47688f553..b0036b3f6 100644 --- a/trlx/utils/modeling.py +++ b/trlx/utils/modeling.py @@ -1,5 +1,5 @@ import functools -from typing import Dict, MutableMapping, Tuple, Union +from typing import Dict, MutableMapping, Optional, Tuple, Union import accelerate import numpy as np @@ -276,8 +276,11 @@ def __init__(self): self.var = 1 self.count = 1e-24 - def update(self, xs: torch.Tensor) -> Tuple[float, float]: + def update(self, xs: torch.Tensor, xs_mask: Optional[torch.Tensor] = None) -> Tuple[float, float]: """Updates running moments from batch's moments computed across ranks""" + if xs_mask is None: + xs_mask = torch.ones_like(xs) + xs = torch.sum(xs * xs_mask, dim=1) if dist.is_initialized(): xs_mean, xs_var, xs_count = get_global_statistics(xs) else: