Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dist ref kl #529

Open
wants to merge 27 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
f04446d
Implementing support for dense rewards
Dahoas Jun 5, 2023
13a01fc
added "num_return_sequences" param which corresponds to n in Best-of-…
SharathRaparthy Jun 16, 2023
5421a73
updates to "num_return_sequences" param
SharathRaparthy Jun 16, 2023
2f3ac28
BoN implementation
SharathRaparthy Jun 16, 2023
2f1dace
Changed back to default.
SharathRaparthy Jun 19, 2023
f58170d
TopK sampling instead of Top1
SharathRaparthy Jun 19, 2023
be8bc1a
summed along dim=1
SharathRaparthy Jun 26, 2023
608d812
Generating samples in chunks
SharathRaparthy Jun 26, 2023
d8557e7
added gen_chunk_size parameter
SharathRaparthy Jun 26, 2023
8ef9c36
chunking in forward prop
SharathRaparthy Jun 26, 2023
4c1d82d
chunking generations in train and eval
SharathRaparthy Jun 26, 2023
ecd5107
Implementing support for dense rewards
Dahoas Jun 5, 2023
4071604
Fix distributed ref_mean, ref_var bug for dense rewards
Dahoas Jun 15, 2023
5f41413
Make generation respect max seq length
Dahoas Jun 23, 2023
22ae83f
Make experience before first round of training
Dahoas Jun 23, 2023
7d0a4be
Refactoring .generate/.generate_eval
Dahoas Jun 27, 2023
b79dd19
Fix BoN metric support
Dahoas Jun 29, 2023
cb49dc5
Enforce chunk_size param for eval generation when present
Dahoas Jul 3, 2023
e290412
Fix: Don't shuffle prompt dataset
Dahoas Jul 4, 2023
391d04c
Move inputs to device
Dahoas Jul 18, 2023
8de84e4
Fix style
Dahoas Jul 18, 2023
404ef14
Fix: Do not shuffle empty experience dataloader
Dahoas Jun 23, 2023
67b711a
Make experience before first round of training
Dahoas Jun 23, 2023
34e185a
Refactoring .generate/.generate_eval
Dahoas Jun 27, 2023
11e1e95
Refactored decode, make_experience and added support for external ref…
Dahoas Jul 14, 2023
527ba23
Fix BoN sampling after big refactor
Dahoas Jul 17, 2023
676a1cd
Fixing style
Dahoas Jul 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions examples/ppo_redemption.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Generates positive movie reviews by tuning a pretrained model on IMDB dataset
# with a sentiment reward function
import json
import os
import sys
from typing import List

import torch
from datasets import load_dataset
from transformers import pipeline

import trlx
from trlx.data.default_configs import TRLConfig, default_ppo_config


def get_positive_score(scores):
"Extract value associated with a positive sentiment from pipeline's output"
return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"]


def get_negative_score(scores):
return dict(map(lambda x: tuple(x.values()), scores))["NEGATIVE"]


def main(hparams={}):
# Merge sweep config with default config if given
config = TRLConfig.update(default_ppo_config().to_dict(), hparams)
config.method.cliprange_reward = False
config.method.gen_kwargs["max_new_tokens"] = 70
config.method.gen_kwargs["temperature"] = 0.3
config.train.total_steps = 20000
config.train.checkpoint_interval = 10000000
# config.method.init_kl_coef = 0

if torch.cuda.is_available():
device = int(os.environ.get("LOCAL_RANK", 0))
else:
device = -1

sentiment_fn = pipeline(
"sentiment-analysis",
"lvwerra/distilbert-imdb",
top_k=2,
truncation=True,
batch_size=256,
device=device,
)

def dense_reward_fn(samples: List[str], prompts: List[str], outputs: List[str], model_tok, **kwargs) -> List[float]:
# Reward positively for initially negative then positive review
# Reward functions should never receive padded text except for a singel EOS at the end
# Reward function should return token rewards for just the response
first_halves = [".".join(sample.split(".")[: len(sample.split(".")) // 2]) for sample in samples]
negative_first_halves = list(map(get_negative_score, sentiment_fn(first_halves)))
second_halves = [".".join(sample.split(".")[len(sample.split(".")) // 2 :]) for sample in samples]
positive_second_halves = list(map(get_positive_score, sentiment_fn(second_halves)))
text_scores = [[f, s] for f, s in zip(negative_first_halves, positive_second_halves)]
tok_scores = []
for sample, prompt, response, text_score in zip(samples, prompts, outputs, text_scores):
toks = model_tok(response).input_ids
tok_score = [0] * len(toks)
# Hacky way of assigning intermediate score
tok_score[len(tok_score) // 2] = text_score[0]
tok_score[-1] = text_score[1]
tok_scores.append(tok_score)
return tok_scores

# Take few words off of movies reviews as prompts
imdb = load_dataset("imdb", split="train+test")
prompts = [" ".join(review.split()[:4]) for review in imdb["text"]]

trlx.train(
reward_fn=dense_reward_fn,
prompts=prompts,
eval_prompts=["I don't know much about Hungarian underground"] * 256,
config=config,
)


if __name__ == "__main__":
hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1])
main(hparams)
2 changes: 2 additions & 0 deletions trlx/data/default_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,13 @@ def default_ppo_config():
ref_mean=None,
ref_std=None,
cliprange_reward=10,
num_train_sequences=1,
gen_kwargs=dict(
max_new_tokens=40,
top_k=0,
top_p=1.0,
do_sample=True,
num_return_sequences=1,
),
),
)
Expand Down
5 changes: 5 additions & 0 deletions trlx/models/modeling_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ class PPOConfig(MethodConfig):

:param gen_experience_kwargs: if this is not None, then the experience is generated using this
:type gen_experience_kwargs: Dict[str, Any]

:param num_train_sequences: top_k of n sampled sequences from prompt
:type num_train_sequences: int
"""

ppo_epochs: int
Expand All @@ -131,6 +134,8 @@ class PPOConfig(MethodConfig):
cliprange_reward: float
gen_kwargs: dict
gen_experience_kwargs: Optional[dict] = None
num_train_sequences: int = 1
dist_ref_model: bool = False

def get_advantages_and_returns(
self,
Expand Down
160 changes: 124 additions & 36 deletions trlx/trainer/accelerate_base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import sys
from abc import abstractmethod
from contextlib import contextmanager
from copy import copy
from time import time
from typing import Dict, List, Optional, Tuple

Expand Down Expand Up @@ -201,81 +202,124 @@ def decode(
prompts: List[torch.LongTensor],
samples: List[torch.LongTensor],
prompt_sizes: torch.LongTensor = None,
append_eos_token: bool = False,
) -> Tuple[List[str], List[str], List[str]]:
append_eos_token: bool = True,
) -> Tuple[List[str], List[str], List[str], List[torch.LongTensor], List[torch.LongTensor], List[torch.LongTensor]]:
"""
Decode tensor generations into lists of strings (`samples`: List[str], `prompts`: List[str], `outputs`: List[str])
Decode tensor generations with stopping criteria into lists
of strings (`samples`: List[str], `prompts`: List[str], `outputs`: List[str]) and
Note prompts maybe sompetimes be right padded, as well as samples
"""
if prompt_sizes is None:
# Assuming prompts were left-padded
prompt_sizes = [prompts.shape[1]] * len(prompts)

str_samples, str_prompts, str_outputs = [], [], []
tok_samples, tok_prompts, tok_outputs = [], [], []
for prompt, sample, prompt_size in zip(prompts, samples, prompt_sizes):
if self.config.model.model_arch_type == "seq2seq":
output_start_ix = 0
else:
output_start_ix = prompt_size

# We must decode by skipping padding in the middle with skip_special_tokens
str_prompt = self.tokenizer.decode(prompt[:prompt_size], skip_special_tokens=True)
# Return the prompt tensor (with the exact padding) used to generate the sample
tok_prompt = prompt[:prompt_size].cpu()

str_output = self.tokenizer.decode(sample[output_start_ix:], skip_special_tokens=True)
# Trim outputs up to `self.stop_sequences` if any are present
trimmed = False
if self.stop_sequences:
for stop in self.stop_sequences:
stop_ix = str_output.find(stop)
if stop_ix >= 0:
str_output = str_output[:stop_ix].rstrip()
trimmed = True

# Recover sequence of tokens corresponding to string
# NOTE: Cast to torch.long in the case the input is empty
tok_output = self.tokenizer(str_output, return_tensors="pt").input_ids[0].long()
# Remove bos from tokenized output (if present)
if (
hasattr(self.tokenizer, "bos_token")
and len(tok_output) > 0
and tok_output[0].item() == self.tokenizer.bos_token_id
):
tok_output = tok_output[1:]

# Recover the last <eos> if it was present in the original sample
# or add one if it was trimmed with `self.stop_sequences`.
# When a generation ended due to `max_new_tokens` exhaustion,
# only then <pad> or <eos> token would not be present in the original sample at the end
if append_eos_token and (
trimmed or sample[-1] == self.tokenizer.eos_token_id or sample[-1] == self.tokenizer.pad_token_id
):
# Note we append to both separately because encoding </s> does not result in tokenizer.eos_token_id
if append_eos_token:
str_output += self.tokenizer.eos_token
tok_output = torch.cat(
[tok_output, torch.tensor(self.tokenizer.eos_token_id, dtype=torch.long).view((1,))]
)

str_prompts.append(str_prompt)
str_outputs.append(str_output)
tok_prompts.append(tok_prompt)
tok_outputs.append(tok_output)

if self.config.model.model_arch_type == "seq2seq":
sample = str_prompt + self.tokenizer.sep_token + str_output
tok_sample = torch.cat(
[tok_prompt, torch.tensor(self.tokenizer.sep_token_id, dtype=torch.long).view((1,)), tok_output]
)
else:
sample = str_prompt + str_output
tok_sample = torch.cat([tok_prompt, tok_output])

str_samples.append(sample)
tok_samples.append(tok_sample)

return str_samples, str_prompts, str_outputs
return str_samples, str_prompts, str_outputs, tok_samples, tok_prompts, tok_outputs

def generate(self, input_ids, attention_mask=None, **kwargs):
def generate(self, input_ids, attention_mask=None, chunk_size=None, **kwargs):
"""Wraps hf's `generate` adding some specific method's defaults"""
# Decide into chunk sizes and generate saples
input_ids = input_ids.to(self.accelerator.device)
if attention_mask is not None:
attention_mask = attention_mask.to(self.accelerator.device)
if self.generate_experience_kwargs is not None:
kwargs = dict(self.generate_experience_kwargs, **kwargs)
else:
kwargs = dict(self.generate_kwargs, **kwargs)

with torch.no_grad():
return self.accelerator.unwrap_model(self.model).generate(
input_ids=input_ids, attention_mask=attention_mask, **kwargs
generate_kwargs = copy(self.generate_kwargs)
generate_kwargs.update(kwargs)

# Update max_new_tokens to respect max_seq_length
prompt_length = input_ids.shape[1]
if generate_kwargs.get("max_new_tokens") is not None:
generate_kwargs["max_new_tokens"] = min(
max(self.max_length - prompt_length, 0), generate_kwargs["max_new_tokens"]
)
else:
generate_kwargs["max_new_tokens"] = max(self.max_length - prompt_length, 0)

def generate_eval(self, input_ids, attention_mask=None, **kwargs):
"""Wraps hf's `generate` adding some specific method's defaults"""
input_ids = input_ids.to(self.accelerator.device)
# Repeat prompts, attention_masks for chunking if returning multiple sequences
if generate_kwargs.get("num_return_sequences") is None:
generate_kwargs["num_return_sequences"] = 1

num_return_sequences = generate_kwargs.pop("num_return_sequences") # Pop to hide from model.generate call
input_ids = input_ids.repeat_interleave(num_return_sequences, dim=0)
if attention_mask is not None:
attention_mask = attention_mask.to(self.accelerator.device)
attention_mask = attention_mask.repeat_interleave(num_return_sequences, dim=0)

kwargs = dict(self.generate_kwargs, **kwargs)
if chunk_size is None:
chunk_size = input_ids.shape[0]

with torch.no_grad():
return self.accelerator.unwrap_model(self.model).generate(
input_ids=input_ids, attention_mask=attention_mask, **kwargs
)
# Chunk input_ids and attention_mask
input_ids = input_ids.split(chunk_size, dim=0)
if attention_mask is not None:
attention_mask = attention_mask.split(chunk_size, dim=0)
samples = []
for chunk_idx in range(len(input_ids)):
with torch.no_grad():
sample = self.accelerator.unwrap_model(self.model).generate(
input_ids=input_ids[chunk_idx], attention_mask=attention_mask[chunk_idx], **generate_kwargs
)
samples.append(sample)
# Concat samples
samples = torch.cat(samples, 0)
return samples

def save_pretrained(self, directory: Optional[str] = None, **kwargs):
"""Save the underlying Hugging Face model, tokenizer, and configuration files to a directory for
Expand Down Expand Up @@ -377,11 +421,20 @@ def evaluate(self): # noqa: C901
for i_prompt, prompts in enumerate(self.eval_dataloader):
metadata = {k: v for k, v in prompts.items() if k != "input_ids" and k != "attention_mask"}
if self.generate_sweep_kwarg:
samples = self.generate_eval(
samples = self.generate(
prompts["input_ids"], prompts["attention_mask"], **{gen_sweep_arg: gen_sweep_value}
)
else:
samples = self.generate_eval(prompts["input_ids"], prompts["attention_mask"])
chunk_size = self.config.method.chunk_size if hasattr(self.config.method, "chunk_size") else None
samples = self.generate(prompts["input_ids"], prompts["attention_mask"], chunk_size=chunk_size)

# Repeat prompts, metadata num_return_sequence times
num_return_sequences = 1
if self.generate_kwargs.get("num_return_sequences") is not None:
num_return_sequences = self.generate_kwargs["num_return_sequences"]
prompts["input_ids"] = prompts["input_ids"].repeat_interleave(num_return_sequences, dim=0)
prompts["attention_mask"] = prompts["attention_mask"].repeat_interleave(num_return_sequences, dim=0)
metadata = {k: self.repeat_interleave(v, num_return_sequences) for k, v in metadata.items()}

# TODO(reciprocated): this should be moved into `decode`
# but that needs to be synced with indexing in `make_experience`
Expand All @@ -396,9 +449,9 @@ def evaluate(self): # noqa: C901
pad_index=self.tokenizer.pad_token_id,
)
)
all_samples.extend(samples.tolist())
all_prompts.extend(prompts.tolist())
all_prompt_sizes.extend(prompt_sizes.tolist())
all_samples.extend(samples)
all_prompts.extend(prompts)
all_prompt_sizes.extend(prompt_sizes)

metadata = gather_dict(metadata, self.accelerator.gradient_state)
all_metadata.append(metadata)
Expand All @@ -414,7 +467,9 @@ def evaluate(self): # noqa: C901
stats["time/generate"] = time() - generate_time

if self.accelerator.is_main_process:
str_samples, str_prompts, str_outputs = self.decode(all_prompts, all_samples, all_prompt_sizes)
str_samples, str_prompts, str_outputs, tok_samples, tok_prompts, tok_outputs = self.decode(
all_prompts, all_samples, all_prompt_sizes
)

columns = ["prompt", "output"]
columns_data = [str_prompts, str_outputs]
Expand All @@ -427,10 +482,25 @@ def evaluate(self): # noqa: C901
# in online setting, compute the reward for validation
if self.reward_fn:
logger.info("Computing rewards")
rewards = torch.tensor(
self.reward_fn(samples=str_samples, prompts=str_prompts, outputs=str_outputs, **metadata),
dtype=float,
rewards = self.reward_fn(
samples=str_samples,
prompts=str_prompts,
outputs=str_outputs,
tok_samples=tok_samples,
tok_prompts=tok_prompts,
tok_outputs=tok_outputs,
model_tok=self.tokenizer,
**metadata,
)
# Remove kl terms from reward
if hasattr(self, "dist_ref_model") and self.dist_ref_model:
rewards = [[r[0] for r in reward] for reward in rewards]
if type(rewards[0]) is torch.Tensor:
rewards = torch.tensor([reward.sum().item() for reward in rewards], dtype=float)
elif type(rewards[0]) is list:
rewards = torch.tensor([sum(reward) for reward in rewards])
else:
rewards = torch.tensor(rewards)
mean_reward = rewards.mean().item()
columns.append("reward")
if not isinstance(rewards, list):
Expand All @@ -442,7 +512,16 @@ def evaluate(self): # noqa: C901
if self.metric_fn:
logger.info("Computing metrics")
metric_time = time()
metrics = self.metric_fn(samples=str_samples, prompts=str_prompts, outputs=str_outputs, **metadata)
metrics = self.metric_fn(
samples=str_samples,
prompts=str_prompts,
outputs=str_outputs,
tok_samples=tok_samples,
tok_prompts=tok_prompts,
tok_outputs=tok_outputs,
model_tok=self.tokenizer,
**metadata,
)
stats["time/metric"] = time() - metric_time

mean_metrics = {
Expand Down Expand Up @@ -633,6 +712,15 @@ def learn(self): # noqa: C901
self.post_epoch_callback()
tbar.close()

@staticmethod
def repeat_interleave(l, n):
if type(l) is torch.Tensor:
l = l.repeat_interleave(n, dim=0)
elif type(l) is list:
l = [[s] * n for s in l]
l = [item for sublist in l for item in sublist]
return l

@abstractmethod
def get_arch(self, config: TRLConfig):
"""Returns a specific wrapper of the decoder architecture"""
Expand Down
Loading