-
Notifications
You must be signed in to change notification settings - Fork 471
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* fix(ilql): sampling on variable sized prompts & stage simplified api * Save strategy (#23) * Had to add py_modules=trlx to setup. * Added a save strategy. * Cleaned up a few things. * Added save_steps to ilql_config.yaml and save steps strategy to accelerate_ilql_model.py for consistency. The save_steps parameter must be set now because of how TrainConfig.from_dict operates. If not save_steps parameter is given in the configs it throws an error. * Adding mininal changes to enable step based save strategy in configs/ppo_config.yml, trlx/data/configs.py, and trlx/model_accelerate_ppo_model.py * Some problems crept in despite merge check. This fixes them. * Realized I am merging into stage-api not main so fixed an issue with ilql_config.yml * fix(ilql): eval on a set of betas & add simple timers * fix: saving checkpoints * refactor(ilql): subsume under base_model * fix(ilql): mask prompts * merge hydra * fix(ppo): generalize and stage for api * feat: add architext examples * fix(ppo,ilql): ddp + accelerate * refactor: clean pipelines * feat: add simulacra example * fix(ppo): single token prompting * refactor: fully merge models * refactor(configs): lower batch_sizes & remove dead entries * refactor(examples): update for new api * fix(tests,style): one way to pass tests is to change them * fix(ppo): warnings of the most recent version of transformers 4.23.1 complains if .generate() starts with single bos token, when bos=eos=pad token * refactor(readme): add api * chore: add doc strings * fix: remove dropout * chore: keep gpt2 small in examples * chore: revert to previous default configs * chore(docs): rename classes, remove unused, add examples * chore(readme): add contributing.md & deepspeed note * style(readme): US spelling * chore(examples): add explanations for each task
- Loading branch information
1 parent
4ff712b
commit 06cd30f
Showing
30 changed files
with
1,088 additions
and
1,253 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,152 +1,45 @@ | ||
[docs-image]: https://readthedocs.org/projects/trlX/badge/?version=latest | ||
[docs-url]: https://trlX.readthedocs.io/en/latest/?badge=latest | ||
|
||
# Welcome to Transformer Reinforcement Learning X (`trlX`) | ||
> A repo for distributed training of language models with Reinforcement Learning via Human Feedback (RLHF) | ||
# Transformer Reinforcement Learning X | ||
|
||
[![Docs Status][docs-image]][docs-url] | ||
`trlx` allows you to fine-tune 🤗 Hugging Face supported language models (`gpt2`, `gpt-j`, `gpt-neo` and `gpt-neox` based) up to 20B parameters using reinforcement learning via either a provided reward function or reward-labeled dataset. Proximal Policy Optimization ([PPO](https://arxiv.org/pdf/1909.08593.pdf)) and Implicit Language Q-Learning ([ILQL](https://sea-snell.github.io/ILQL_site/)) are implemented. | ||
|
||
**[Documentation](https://trlX.readthedocs.io)** | ||
## Train | ||
|
||
## Overview | ||
Inspired by the popular `trl` library, the `trlX` repo allows you to fine-tune Huggingface supported language models up to 20B parameters via either reinforcement learning using a provided scoring function or reward-labeled dataset. We aim to support a range of both online and offline RL algorithms including Proximal Policy Optimization (PPO), Natural Language Policy Optimization (NLPO), Actor Critic (A2C), and Implicit Q Learning (ILQL). | ||
|
||
The library supports `gpt2` and `gptj` with plans to include `GPT-NeoX`, `T5` and more. PPO and ILQL algorithms are implemented. Disibtributed training has been implemented via HF Accelerate and tested up to two nodes, each with 8 gpus. | ||
|
||
## Structure | ||
|
||
The training pipeline is broken into four pieces: | ||
```python | ||
import trlx | ||
|
||
- Prompt pipeline: Handles loading of prompts/text used to prompt model for exploration in online methods | ||
- Rollout pipeline: Handles loading and storage of reward labeled data used | ||
- Orchestrator: Handles exploration/rollout collection of online methods. Pushes collected rollouts to the rollout pipeline. | ||
- Model: Wraps the supplied base model (ex: `gpt2`) and implements the desired training method loss (ex: PPO). | ||
# optimize some reward function | ||
model = trlx.train('gpt2', reward_fn=lambda samples: [sample.count('cats') for sample in samples]) | ||
|
||
Adding a task for RLHF training depends on the desired training method and pre-existing data. If we are online and have no reward labeled data this is as simple as writing a new prompt pipeline, which supplies prompts for exploration, and a new reward function to be passed into the `PPOOrchestrator` class. | ||
# or steer a model with a collection of rated samples | ||
model = trlx.train('EleutherAI/gpt-j-6B', dataset=[('dolphins', 'geese'), (1.0, 100.0)]) | ||
|
||
## Installation | ||
```bash | ||
git clone https://github.com/CarperAI/trlx.git | ||
cd trlx | ||
pip install -e ".[dev]" | ||
pre-commit install # see .pre-commit-config.yaml | ||
# model is a wrapper with some logit preprocessing | ||
model.generate(**tokenizer('Q: Who rules the world? A:', return_tensors='pt'), do_sample=True) | ||
``` | ||
|
||
## Example: How to add a task | ||
|
||
In the below we implement a sentiment learning task. | ||
|
||
### Configure `accelerate` | ||
Launch distributed training with 🤗 Accelerate (only DeepSpeed integration is tested) | ||
|
||
```bash | ||
accelerate config | ||
accelerate launch examples/simulacra.py | ||
``` | ||
|
||
### Implement a prompt pipeline | ||
|
||
```python | ||
@register_datapipeline | ||
class PPOPipeline(BasePipeline): | ||
def __init__(self, tokenizer, config, prompt_dataset_path=None): | ||
super().__init__() | ||
|
||
ds = load_dataset("imdb", split="test") | ||
ds = ds.rename_columns({"text": "review", "label": "sentiment"}) | ||
ds = ds.filter(lambda x: len(x["review"]) < 500, batched=False) | ||
|
||
self.tokens = [ | ||
tokenizer( | ||
text, | ||
truncation=True, | ||
padding="max_length", | ||
max_length=config.train.input_size, | ||
return_tensors="pt", | ||
)["input_ids"] | ||
.long() | ||
.flatten() | ||
for text in ds["review"] | ||
] | ||
self.text = [tokenizer.decode(tokens.tolist()) for tokens in self.tokens] | ||
|
||
def __getitem__(self, index: int) -> PromptElement: | ||
return PromptElement(self.text[index], self.tokens[index]) | ||
|
||
def __len__(self) -> int: | ||
return len(self.text) | ||
|
||
def create_loader( | ||
self, | ||
batch_size: int, | ||
shuffle: bool, | ||
prep_fn: Callable = None, | ||
num_workers: int = 0, | ||
) -> DataLoader: | ||
# TODO(dahoas): Decide how to support varying sizes of prompts without having to tokenize on fly | ||
def collate_fn(elems: Iterable[PromptElement]) -> PromptElement: | ||
return PromptBatch( | ||
[elem.text for elem in elems], | ||
torch.stack( | ||
[elem.tokens for elem in elems] | ||
), # Assumes token tensors all same size | ||
) | ||
For more usage see [examples](./examples) | ||
|
||
return DataLoader( | ||
self, batch_size, shuffle, collate_fn=collate_fn, num_workers=num_workers | ||
) | ||
``` | ||
|
||
### Launch training | ||
|
||
```python | ||
from typing import List | ||
|
||
import torch | ||
from transformers import pipeline | ||
|
||
import wandb | ||
from trlx.data.configs import TRLConfig | ||
from trlx.model.accelerate_ppo_model import AcceleratePPOModel | ||
from trlx.orchestrator.ppo_orchestrator import PPOOrchestrator | ||
from trlx.pipeline.ppo_pipeline import PPOPipeline | ||
from trlx.utils.loading import get_model, get_orchestrator, get_pipeline | ||
|
||
if __name__ == "__main__": | ||
cfg = TRLConfig.load_yaml("configs/ppo_config.yml") | ||
|
||
sentiment_pipe = pipeline( | ||
"sentiment-analysis", "lvwerra/distilbert-imdb", device=-1 | ||
) | ||
|
||
def reward_fn(samples: List[str]): | ||
sent_kwargs = { | ||
"return_all_scores": True, | ||
"function_to_apply": None, | ||
"batch_size": cfg.method.chunk_size, | ||
} | ||
pipe_outputs = sentiment_pipe(samples, **sent_kwargs) | ||
scores = torch.tensor([output[1]["score"] for output in pipe_outputs]) | ||
return scores | ||
|
||
model: AcceleratePPOModel = get_model(cfg.model.model_type)(cfg) | ||
if model.accelerator.is_main_process: | ||
wandb.watch(model.model) | ||
|
||
pipeline: PPOPipeline = get_pipeline(cfg.train.pipeline)(model.tokenizer, cfg) | ||
orch: PPOOrchestrator = get_orchestrator(cfg.train.orchestrator)( | ||
model, pipeline, reward_fn=reward_fn, chunk_size=cfg.method.chunk_size | ||
) | ||
orch.make_experience(cfg.method.num_rollouts) | ||
model.learn() | ||
|
||
print("DONE!") | ||
## Install | ||
```bash | ||
git clone https://github.com/CarperAI/trlx.git | ||
cd trlx | ||
pip install torch --extra-index-url https://download.pytorch.org/whl/cu113 # for cuda | ||
pip install -e . | ||
``` | ||
|
||
And run `accelerate launch my_script.py` | ||
|
||
## References | ||
For development check out these [guidelines](./CONTRIBUTING.md) | ||
and also read our [docs](https://trlX.readthedocs.io) | ||
|
||
### Proximal Policy Optimisation | ||
The PPO implementation largely follows the structure introduced in the paper **"Fine-Tuning Language Models from Human Preferences"** by D. Ziegler et al. \[[paper](https://arxiv.org/pdf/1909.08593.pdf), [code](https://github.com/openai/lm-human-preferences)]. | ||
## Acknowledgements | ||
|
||
### Language models | ||
The language models utilize the `transformers` library by 🤗 Hugging Face. | ||
Thanks Leandro for starting the original [trl](https://github.com/lvwerra/trl/) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,42 +1,36 @@ | ||
model: | ||
model_path : "gpt2" | ||
model_type : "ILQLModel" | ||
device : "cuda" | ||
tokenizer_path: "gpt2" | ||
model_type : "ILQLModel" | ||
num_layers_unfrozen: -1 | ||
|
||
train: | ||
n_ctx : 512 | ||
epochs : 1 | ||
total_steps : 80000 | ||
batch_size : 80 | ||
grad_clip : 1.0 | ||
seq_length: 64 | ||
batch_size: 128 | ||
epochs: 10 | ||
total_steps: 10000 | ||
|
||
lr_ramp_steps : 100 | ||
lr_decay_steps : 3366 | ||
weight_decay : 1.0e-6 | ||
learning_rate_init : 1.0e-3 | ||
learning_rate_target : 1.0e-3 | ||
lr_ramp_steps: 100 | ||
lr_decay_steps: 3366 | ||
weight_decay: 1e-6 | ||
learning_rate_init: 1e-4 | ||
learning_rate_target: 1e-4 | ||
opt_betas: [0.9, 0.95] | ||
|
||
log_interval : 25 | ||
checkpoint_interval : 100 | ||
eval_interval : 50 | ||
|
||
input_size: 1 | ||
gen_size: 32 | ||
checkpoint_interval: 1000 | ||
eval_interval: 16 | ||
|
||
pipeline : "OfflinePipeline" | ||
orchestrator : "OfflineOrchestrator" | ||
|
||
accelerate : true | ||
seed: 1000 | ||
|
||
method: | ||
name: "ilqlconfig" | ||
tau: 0.7 | ||
gamma: 0.99 | ||
cql_scale: 0.1 | ||
awac_scale: 1 | ||
alpha: 1 | ||
steps_for_target_q_sync: 10 | ||
beta: 4 | ||
alpha: 0.005 | ||
steps_for_target_q_sync: 1 | ||
betas: [16] | ||
two_qs: true |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,52 +1,44 @@ | ||
model: | ||
model_path : "lvwerra/gpt2-imdb" # Name of hf model to load | ||
tokenizer_path : "gpt2" # Name of hf tokenizer to load | ||
model_type : "AcceleratePPOModel" # Name of accelerate model type to load | ||
device : "cuda" # Train device | ||
num_layers_unfrozen : 2 # Number of bottom layers to freeze during training | ||
model_path: "lvwerra/gpt2-imdb" # Name of hf model to load | ||
tokenizer_path: "gpt2" # Name of hf tokenizer to load | ||
model_type: "AcceleratePPOModel" # Name of accelerate model type to load | ||
num_layers_unfrozen: 2 # Number of bottom layers to freeze during training | ||
|
||
train: | ||
n_ctx : 512 # Size of LM context | ||
epochs : 10 # Train for max(epochs, total_steps) | ||
total_steps : 80000 # Train for max(epochs, total_steps) | ||
batch_size : 128 # batch size | ||
grad_clip : 1.0 # gradient clipping threshold | ||
seq_length: 48 # Size of LM context | ||
epochs: 1000 # Train for max(epochs, total_steps) | ||
total_steps: 10000 # Train for max(epochs, total_steps) | ||
batch_size: 128 # batch size | ||
|
||
lr_ramp_steps : 100 # learning rate warm up | ||
lr_decay_steps : 79000 # learning rate decay | ||
weight_decay : 1.0e-6 # weight decay param | ||
learning_rate_init : 1.412e-4 # init learning rate | ||
learning_rate_target : 1.412e-4 # target final learning rate | ||
lr_ramp_steps: 100 # learning rate warm up | ||
lr_decay_steps: 79000 # learning rate decay | ||
weight_decay: 1.0e-6 # weight decay param | ||
learning_rate_init: 1.412e-4 # init learning rate | ||
learning_rate_target: 1.412e-4 # target final learning rate | ||
opt_betas: [0.9, 0.95] # adam betas | ||
|
||
log_interval : 25 # log interval | ||
checkpoint_interval : 1000000 # checkpoint interval | ||
eval_interval : 16 # eval interval | ||
checkpoint_interval: 10000 # checkpoint interval | ||
eval_interval: 16 # eval interval | ||
|
||
pipeline : "PPOPipeline" # prompt pipeline to load | ||
orchestrator : "PPOOrchestrator" # orchestrator to load | ||
|
||
input_size : 4 # max input size | ||
gen_size : 48 # max gen size | ||
|
||
accelerate : True # Use accelerate | ||
accelerate_config_path : "" # Path to accelerate config(for logging purposes) | ||
pipeline: "PPOPipeline" # prompt pipeline to load | ||
orchestrator: "PPOOrchestrator" # orchestrator to load | ||
|
||
method: | ||
name : 'ppoconfig' # Name of RL method config | ||
num_rollouts : 128 # Number of rollouts to collect per epoch | ||
chunk_size : 128 # Number of rollouts to collect in one loop of orchestrator | ||
ppo_epochs : 4 # Number of ppo epochs | ||
init_kl_coef : 0.2 # init kl coefficient | ||
target : 6 # target kl coefficient, set None for fixed kl coef | ||
horizon : 10000 # PPO horizon | ||
gamma : 1 # PPO discount | ||
lam : 0.95 # PPO lambda | ||
cliprange : 0.2 # clip range | ||
cliprange_value : 0.2 # clip range | ||
vf_coef : 2.3 # value term weight | ||
gen_kwargs : | ||
max_length : 48 # LM max sample gen length | ||
min_length : 48 # LM min sample gen length | ||
top_k : 0.0 # top k | ||
top_p : 1.0 # top p | ||
do_sample : True # sample | ||
name: 'ppoconfig' # Name of RL method config | ||
num_rollouts: 128 # Number of rollouts to collect per epoch | ||
chunk_size: 128 # Number of rollouts to collect in one loop of orchestrator | ||
ppo_epochs: 4 # Number of ppo epochs | ||
init_kl_coef: 0.2 # init kl coefficient | ||
target: 6 # target kl coefficient, set None for fixed kl coef | ||
horizon: 10000 # PPO horizon | ||
gamma: 1 # PPO discount | ||
lam: 0.95 # PPO lambda | ||
cliprange: 0.2 # clip range | ||
cliprange_value: 0.2 # clip range | ||
vf_coef: 2.3 # value term weight | ||
gen_kwargs: | ||
max_length: 48 # LM max sample gen length | ||
min_length: 48 # LM min sample gen length | ||
top_k: 0.0 # top k | ||
top_p: 1.0 # top p | ||
do_sample: True # sample |
Oops, something went wrong.