-
Notifications
You must be signed in to change notification settings - Fork 471
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement BoN for training and eval #528
Open
Dahoas
wants to merge
40
commits into
main
Choose a base branch
from
BoN
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 23 commits
Commits
Show all changes
40 commits
Select commit
Hold shift + click to select a range
f04446d
Implementing support for dense rewards
Dahoas 13a01fc
added "num_return_sequences" param which corresponds to n in Best-of-…
SharathRaparthy 5421a73
updates to "num_return_sequences" param
SharathRaparthy 2f3ac28
BoN implementation
SharathRaparthy 2f1dace
Changed back to default.
SharathRaparthy f58170d
TopK sampling instead of Top1
SharathRaparthy be8bc1a
summed along dim=1
SharathRaparthy 608d812
Generating samples in chunks
SharathRaparthy d8557e7
added gen_chunk_size parameter
SharathRaparthy 8ef9c36
chunking in forward prop
SharathRaparthy 4c1d82d
chunking generations in train and eval
SharathRaparthy ecd5107
Implementing support for dense rewards
Dahoas 4071604
Fix distributed ref_mean, ref_var bug for dense rewards
Dahoas 5f41413
Make generation respect max seq length
Dahoas 22ae83f
Make experience before first round of training
Dahoas 7d0a4be
Refactoring .generate/.generate_eval
Dahoas b79dd19
Fix BoN metric support
Dahoas cb49dc5
Enforce chunk_size param for eval generation when present
Dahoas e290412
Fix: Don't shuffle prompt dataset
Dahoas 391d04c
Move inputs to device
Dahoas 8de84e4
Fix style
Dahoas 3d7e0d5
Fix chunked generation
Dahoas 1fda0ce
fix(accelerate_base_trainer): order of keyword arguments
maxreciprocate 4ac1707
Merging main
Dahoas de3d854
Merge branch 'BoN' of https://github.com/CarperAI/trlx into BoN
Dahoas 3ce3c2b
Removing old example
Dahoas 2635de5
Fix: remove extraneous method args
Dahoas 1be2c3c
Fix: Always set generate_experience_kwargs
Dahoas 3cba0db
Fix: Remove mask from RunningMoments update call
Dahoas 0cb91c4
Fix: style
Dahoas cc92911
Fix: rename 'gen_chunk_size' to 'chunk_size'
Dahoas 4297f98
Fix: generated samples padding
Dahoas 36f06af
Remove prints
Dahoas a2980dd
Rename 'num_train_sequences' to 'num_topk_samples'
Dahoas 3d5a639
Address nits
Dahoas 87837b6
Fix: style
Dahoas ed93be8
Set 'num_return_sequences' to 1 by default
Dahoas 24925c8
Fix: typo
Dahoas a022d3f
Merge branch 'main' into BoN
maxreciprocate 9680c9f
Merge branch 'main' into bon-x
maxreciprocate File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
# Generates positive movie reviews by tuning a pretrained model on IMDB dataset | ||
# with a sentiment reward function | ||
import json | ||
import os | ||
import sys | ||
from typing import List | ||
|
||
import torch | ||
from datasets import load_dataset | ||
from transformers import pipeline | ||
|
||
import trlx | ||
from trlx.data.default_configs import TRLConfig, default_ppo_config | ||
|
||
|
||
def get_positive_score(scores): | ||
"Extract value associated with a positive sentiment from pipeline's output" | ||
return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"] | ||
|
||
|
||
def get_negative_score(scores): | ||
return dict(map(lambda x: tuple(x.values()), scores))["NEGATIVE"] | ||
|
||
|
||
def main(hparams={}): | ||
# Merge sweep config with default config if given | ||
config = TRLConfig.update(default_ppo_config().to_dict(), hparams) | ||
config.method.cliprange_reward = False | ||
config.method.gen_kwargs["max_new_tokens"] = 70 | ||
config.method.gen_kwargs["temperature"] = 0.3 | ||
config.train.total_steps = 20000 | ||
config.train.checkpoint_interval = 10000000 | ||
# config.method.init_kl_coef = 0 | ||
|
||
if torch.cuda.is_available(): | ||
device = int(os.environ.get("LOCAL_RANK", 0)) | ||
else: | ||
device = -1 | ||
|
||
sentiment_fn = pipeline( | ||
"sentiment-analysis", | ||
"lvwerra/distilbert-imdb", | ||
top_k=2, | ||
truncation=True, | ||
batch_size=256, | ||
device=device, | ||
) | ||
|
||
def dense_reward_fn(samples: List[str], prompts: List[str], outputs: List[str], model_tok, **kwargs) -> List[float]: | ||
# Reward positively for initially negative then positive review | ||
# Reward functions should never receive padded text except for a singel EOS at the end | ||
# Reward function should return token rewards for just the response | ||
first_halves = [".".join(sample.split(".")[: len(sample.split(".")) // 2]) for sample in samples] | ||
negative_first_halves = list(map(get_negative_score, sentiment_fn(first_halves))) | ||
second_halves = [".".join(sample.split(".")[len(sample.split(".")) // 2 :]) for sample in samples] | ||
positive_second_halves = list(map(get_positive_score, sentiment_fn(second_halves))) | ||
text_scores = [[f, s] for f, s in zip(negative_first_halves, positive_second_halves)] | ||
tok_scores = [] | ||
for sample, prompt, response, text_score in zip(samples, prompts, outputs, text_scores): | ||
toks = model_tok(response).input_ids | ||
tok_score = [0] * len(toks) | ||
# Hacky way of assigning intermediate score | ||
tok_score[len(tok_score) // 2] = text_score[0] | ||
tok_score[-1] = text_score[1] | ||
tok_scores.append(tok_score) | ||
return tok_scores | ||
|
||
# Take few words off of movies reviews as prompts | ||
imdb = load_dataset("imdb", split="train+test") | ||
prompts = [" ".join(review.split()[:4]) for review in imdb["text"]] | ||
|
||
trlx.train( | ||
reward_fn=dense_reward_fn, | ||
prompts=prompts, | ||
eval_prompts=["I don't know much about Hungarian underground"] * 256, | ||
config=config, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1]) | ||
main(hparams) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this example got here by inertia from the previous PR