Skip to content

Commit

Permalink
Restructure sweeps for reuse (#102)
Browse files Browse the repository at this point in the history
* chore(readme): update instructions

* refactor(sweep): reuse existing examples and configs

* fix(sweep): enable checkpointing for hyperband

* feat(sweep): add accelerate support

* fix(sweep): report with new params space

* feat(sweep): replace generic names

* chore(ppo_config): update better values

* chore(sweep): set max_concurrent_trials to default

* chore(examples): update the rest of examples to a new main signature

* chore(readme): update sweep instruction

* chore(sweep): add warning/confirmation check before importing

* chore(sweep): update sweep instruction

* update(config): to more stable values
  • Loading branch information
maxreciprocate authored Nov 21, 2022
1 parent 3db86ca commit ff0d077
Show file tree
Hide file tree
Showing 17 changed files with 292 additions and 225 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ accelerate launch examples/simulacra.py

#### Use Ray Tune to launch hyperparameter sweep
```bash
python train_sweep.py --config configs/ray_tune_configs/ppo_config.yml --example-name ppo_sentiments
python -m trlx.sweep --config configs/sweeps/ppo_sweep.yml examples/ppo_sentiments.py
```

For more usage see [examples](./examples)
Expand Down
14 changes: 7 additions & 7 deletions configs/ilql_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@ train:
seq_length: 64
batch_size: 128
epochs: 100
total_steps: 10000
total_steps: 1000

lr_init: 1.0e-4
lr_target: 1.0e-4
lr_init: 5.0e-5
lr_target: 5.0e-5
opt_betas: [0.9, 0.95]
opt_eps: 1.0e-8
weight_decay: 1.0e-6

checkpoint_interval: 1000
eval_interval: 128
eval_interval: 100

pipeline: "PromptPipeline"
orchestrator: "OfflineOrchestrator"
Expand All @@ -29,7 +29,7 @@ method:
gamma: 0.99
cql_scale: 0.1
awac_scale: 1
alpha: 0.005
steps_for_target_q_sync: 1
betas: [16]
alpha: 0.001
steps_for_target_q_sync: 5
betas: [4]
two_qs: true
12 changes: 6 additions & 6 deletions configs/ppo_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@ train:
total_steps: 10000 # Train for max(epochs, total_steps)
batch_size: 128 # batch size

lr_init: 1.412e-4 # init learning rate
lr_target: 1.412e-4 # target final learning rate
lr_init: 1.0e-4 # init learning rate
lr_target: 1.0e-4 # target final learning rate
opt_betas: [0.9, 0.95] # adam betas
opt_eps: 1.0e-8 # adam eps
weight_decay: 1.0e-6 # weight decay param

checkpoint_interval: 10000 # checkpoint interval
eval_interval: 16 # eval interval
eval_interval: 100 # eval interval

pipeline: "PromptPipeline" # prompt pipeline to load
orchestrator: "PPOOrchestrator" # orchestrator to load
Expand All @@ -28,15 +28,15 @@ method:
num_rollouts: 128 # Number of rollouts to collect per epoch
chunk_size: 128 # Number of rollouts to collect in one loop of orchestrator
ppo_epochs: 4 # Number of ppo epochs
init_kl_coef: 0.2 # init kl coefficient
init_kl_coef: 0.05 # init kl coefficient
target: 6 # target kl coefficient, set None for fixed kl coef
horizon: 10000 # PPO horizon
gamma: 1 # PPO discount
lam: 0.95 # PPO lambda
cliprange: 0.2 # clip range
cliprange_value: 0.2 # clip range
vf_coef: 2.3 # value term weight
scale_reward: "running" # False | "ref" | "running" estimate against which to scale rewards
vf_coef: 1 # value term weight
scale_reward: False # False | "ref" | "running" estimate against which to scale rewards
ref_mean: null
ref_std: null # rescale rewards with this deviation
cliprange_reward: 10
Expand Down
68 changes: 0 additions & 68 deletions configs/ray_tune_configs/ppo_config.yml

This file was deleted.

19 changes: 19 additions & 0 deletions configs/sweeps/ilql_sweep.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
tune_config:
mode: "max"
metric: "metrics/sentiments"
search_alg: "random"
scheduler: "fifo"
num_samples: 32

lr_init:
strategy: "loguniform"
values: [0.00001, 0.01]
tau:
strategy: "uniform"
values: [0.6, 0.9]
steps_for_target_q_sync:
strategy: "choice"
values: [1, 5, 10]
alpha:
strategy: "loguniform"
values: [0.001, 1.0]
17 changes: 17 additions & 0 deletions configs/sweeps/ppo_sweep.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
tune_config:
mode: "max"
metric: "mean_reward"
search_alg: "random"
scheduler: "fifo"
num_samples: 32

# https://docs.ray.io/en/latest/tune/api_docs/search_space.html#tune-sample-docs
lr_init:
strategy: "loguniform"
values: [0.00001, 0.01]
init_kl_coef:
strategy: "uniform"
values: [0, 0.2]
vf_coef:
strategy: "uniform"
values: [0.5, 2]
47 changes: 29 additions & 18 deletions examples/architext.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,33 @@ def reward_fn(samples):
return [-sample.count(":") for sample in samples]


prompts = [
"[prompt] the bedroom is adjacent to the living room [layout]",
"[prompt] a bedroom is adjacent to the living room [layout]",
"[prompt] the bedroom is adjacent to the kitchen [layout]",
"[prompt] a bedroom is adjacent to the kitchen [layout]",
"[prompt] the bedroom is adjacent to the kitchen [layout]",
"[prompt] the kitchen is adjacent to the bathroom [layout]",
"[prompt] a bathroom is adjacent to the living room [layout]",
"[prompt] the bathroom is adjacent to the living room [layout]",
"[prompt] the bedroom is not adjacent to the living room [layout]",
"[prompt] a bedroom is not adjacent to the living room [layout]",
"[prompt] the bedroom is not adjacent to the kitchen [layout]",
"[prompt] a bedroom is not adjacent to the kitchen [layout]",
"[prompt] the bedroom is not adjacent to the kitchen [layout]",
"[prompt] the kitchen is not adjacent to the bathroom [layout]",
]

default_config = yaml.safe_load(open("configs/ppo_config.yml"))


def main(hparams={}):
config = TRLConfig.update(default_config, hparams)

model = trlx.train(
"architext/gptj-162M", reward_fn=reward_fn, prompts=prompts, config=config
)


if __name__ == "__main__":
prompts = [
"[prompt] the bedroom is adjacent to the living room [layout]",
"[prompt] a bedroom is adjacent to the living room [layout]",
"[prompt] the bedroom is adjacent to the kitchen [layout]",
"[prompt] a bedroom is adjacent to the kitchen [layout]",
"[prompt] the bedroom is adjacent to the kitchen [layout]",
"[prompt] the kitchen is adjacent to the bathroom [layout]",
"[prompt] a bathroom is adjacent to the living room [layout]",
"[prompt] the bathroom is adjacent to the living room [layout]",
"[prompt] the bedroom is not adjacent to the living room [layout]",
"[prompt] a bedroom is not adjacent to the living room [layout]",
"[prompt] the bedroom is not adjacent to the kitchen [layout]",
"[prompt] a bedroom is not adjacent to the kitchen [layout]",
"[prompt] the bedroom is not adjacent to the kitchen [layout]",
"[prompt] the kitchen is not adjacent to the bathroom [layout]",
]

model = trlx.train("architext/gptj-162M", reward_fn=reward_fn, prompts=prompts)
main()
34 changes: 21 additions & 13 deletions examples/experiments/grounded_program_synthesis/train_trlx.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
# Toy example of optimizing textual interior designs to output the least number of rooms
# Also see https://architext.design/
import trlx
from trlx.data.configs import TRLConfig
from lang import Interpreter
import json
import logging
import yaml


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -50,6 +49,25 @@ def reward_fn(samples):
return reward_list


default_config = yaml.safe_load(open("config/trlx_ppo_config.yml"))


def main(hparams={}):
config = TRLConfig.update(default_config, hparams)

# Dataset
dataset = DSLDataset()
train_prompts = list(dataset.load_datapoints(split="train"))[:1000]

model = trlx.train(
"reshinthadith/codegen_350M_list_manip_5_len",
reward_fn=reward_fn,
prompts=train_prompts,
config=config,
)
model.save_pretrained("dataset/trained_model")


if __name__ == "__main__":
# TEST REWARD FUNTION
assert (
Expand All @@ -67,15 +85,5 @@ def reward_fn(samples):
["Input: 1 Output: [-4,-5,-2] Function: div_n(reverse([-2, -5, -3]),1)"]
)
) == [-0.5]
# Datset
dataset = DSLDataset()
train_prompts = list(dataset.load_datapoints(split="train"))[:1000]
trl_config = TRLConfig.load_yaml("config/trlx_ppo_config.yml")

model = trlx.train(
"reshinthadith/codegen_350M_list_manip_5_len",
reward_fn=reward_fn,
prompts=train_prompts,
config=trl_config,
)
model.save_pretrained("dataset/trained_model")
main()
10 changes: 9 additions & 1 deletion examples/ilql_sentiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,23 @@
from transformers import pipeline

import trlx
import yaml
from typing import List, Dict
import os
from trlx.data.configs import TRLConfig


def get_positive_score(scores):
"Extract value associated with a positive sentiment from pipeline's output"
return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"]


def main():
default_config = yaml.safe_load(open("configs/ilql_config.yml"))


def main(hparams={}):
config = TRLConfig.update(default_config, hparams)

sentiment_fn = pipeline(
"sentiment-analysis",
"lvwerra/distilbert-imdb",
Expand All @@ -32,6 +39,7 @@ def metric_fn(samples: List[str]) -> Dict[str, List[float]]:
dataset=(imdb["text"], imdb["label"]),
eval_prompts=["I don't know much about Hungarian underground"] * 64,
metric_fn=metric_fn,
config=config,
)


Expand Down
10 changes: 9 additions & 1 deletion examples/ppo_sentiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,25 @@
from datasets import load_dataset
from transformers import pipeline
import os
import yaml

import trlx
import torch
from typing import List
from trlx.data.configs import TRLConfig


def get_positive_score(scores):
"Extract value associated with a positive sentiment from pipeline's output"
return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"]


def main():
default_config = yaml.safe_load(open("configs/ppo_config.yml"))


def main(hparams={}):
config = TRLConfig.update(default_config, hparams)

if torch.cuda.is_available():
device = int(os.environ.get("LOCAL_RANK", 0))
else:
Expand Down Expand Up @@ -43,6 +50,7 @@ def reward_fn(samples: List[str]) -> List[float]:
reward_fn=reward_fn,
prompts=prompts,
eval_prompts=["I don't know much about Hungarian underground"] * 64,
config=config,
)


Expand Down
Loading

0 comments on commit ff0d077

Please sign in to comment.