Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat zoo #19

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions .github/workflows/perf_breaker.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# This workflow will install Python dependencies, run tests and lint with a single version of Python
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions

name: Python application

on:
pull_request:
branches: [ labeled ]

jobs:
build:
if: ${{ github.event.label.name == 'breaker' }}
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v2
- name: Set up Python 3.9
uses: actions/setup-python@v2
with:
python-version: 3.9
- name: Install dependencies
run: |
python -m pip install --upgrade pip
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Test performance of all algorithms to ensure nothing broke
run: |
python ~/train.py --load_config ~/configs/*
13 changes: 13 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,19 @@
# squiRL
An RL library in PyTorch embedded within the PyTorch Lightning framework. Aiming to provide a comprehensive platform for the development and testing of RL algorithms.

## Performance checker
DRL research is painful. Writing DRL code is even more so. Throughout development, this repos is bound to go through many changes and some of those changes may break the performance of older code.

To ensure major pull requests don't have undesirable conequences, and to build a comprehensive zoo of algorithms and envs, we introduce the `performance checker` feature. This is a Github workflow automatically triggered on a pull request if labelled `check_performance`.

The workflow runs all experiments specified in the `configs` folder (5 random seeds each). It then compares the average `mean_episode_reward` of the 5 seeds against the respective `env` thresholds specified in `performance_thresh.json`.

For example `configs/cartpole_ppo.json` has the experiment configurations to run `PPO` on Gym's `CartPole-v0`. The workflow runs 5 random seeds. Getting a mean reward larger than `150` means the env is solved. This value, `150` is saved in `performance_thresh.json` under the env name `CartPole-v0`. So the workflow knows that if the mean reward of the 5 seeds doesn't exceed `150`, something is wrong and an error is returned including the specific runs that failed to meet the threshold.

All runs can be found [here](https://wandb.ai/squirl/squirl). They are grouped under their respective git commits.

We ask that any new algorithm implemented be provided with a respective config file as a benchmark. Also any pull request benchmarking on any new env is more than welcome.

## Branch names
Branches should be using one of these groups to start with:
wip - Works in progress; stuff I know won't be finished soon (like a release)
Expand Down
12 changes: 12 additions & 0 deletions configs/openai_gym/cartpole/a2c.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{
"project": "squirl",
"algorithm": "A2C",
"policy": "MLP",
"env": "CartPole-v0",
"lr_actor": 0.0005,
"lr_critic": 0.0005,
"gamma": 0.99,
"episodes_per_batch": 5,
"num_envs": 5,
"max_epochs": 500
}
14 changes: 14 additions & 0 deletions configs/openai_gym/cartpole/ppo.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{
"project": "squirl",
"algorithm": "PPO",
"policy": "MLP",
"env": "CartPole-v0",
"actor_updates_per_iter": 20,
"clip_rt": 0.1,
"lr_actor": 0.0005,
"lr_critic": 0.0005,
"gamma": 0.99,
"episodes_per_batch": 1,
"num_envs": 1,
"max_epochs": 500
}
11 changes: 11 additions & 0 deletions configs/openai_gym/cartpole/vpg.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
{
"project": "squirl",
"algorithm": "VPG",
"policy": "MLP",
"env": "CartPole-v0",
"lr": 0.0005,
"gamma": 0.99,
"episodes_per_batch": 5,
"num_envs": 5,
"max_epochs": 500
}
51 changes: 51 additions & 0 deletions performance_checker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import wandb
import json
import os
import numpy as np

with open("performance_thresh.json", 'rt') as f:
thresh = json.load(f)
print("Thresholds on file:")
print(thresh)

api = wandb.Api()
failures = {}
alg_means = {}
data = {}
for model in os.listdir("models"):
config_file = "models/" + model + "/" + model + "_init.json"
data[model] = {}
with open(config_file, 'rt') as f:
data[model] = json.load(f)
algorithm = data[model]['algorithm']
if not data[model]['algorithm'] in alg_means:
alg_means[algorithm] = {}
failures[algorithm] = {}
failures[algorithm][model] = {}
run = api.run("squirl/squirl/" + model)
wandb_mean_rewards = run.history(keys=['mean_episode_reward'],
pandas=False)
mean_reward = np.mean(
[i['mean_episode_reward'] for i in wandb_mean_rewards][-100:])
print(model, mean_reward)
alg_means[data[model]['algorithm']][model] = mean_reward
if mean_reward < thresh[data[model]['env']]:
failures[algorithm][model]["env"] = data[model]['env']
failures[algorithm][model]["threshold"] = thresh[data[model]
['env']]
failures[algorithm][model]["mean_last_100_steps"] = mean_reward

alg_failures = {}
for k, v in alg_means.items():
means = []
for nv in v.values():
means.append(nv)
alg_mean = np.mean(means)
print(alg_mean)
if alg_mean < thresh[data[model]['env']]:
alg_failures[k] = alg_mean

assert not bool(
alg_failures) == True, "The following algorithms have failed:\n" + str(
alg_failures.keys(
)) + "\nHere are all failed runs of each algorithm:\n" + str(failures)
3 changes: 3 additions & 0 deletions performance_thresh.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"CartPole-v0": 150
}
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,5 @@ torchtext==0.4.0
torchvision==0.7.0
typing==3.7.4.1
typing-extensions==3.7.4.1
wandb==0.10.22
python-git-info==0.6.1
4 changes: 3 additions & 1 deletion squiRL/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ def __init__(self, hparams: argparse.Namespace) -> None:

self.actor = reg_policies[self.hparams.policy](obs_size, n_actions)
self.critic = reg_policies[self.hparams.policy](obs_size, 1)
if hparams.logger:
hparams.logger.watch(self.actor)
hparams.logger.watch(self.critic)
self.replay_buffer = RolloutCollector(self.hparams.episodes_per_batch)

self.agent = Agent(self.env, self.replay_buffer)
Expand Down Expand Up @@ -120,7 +123,6 @@ def a2c_loss(
actions]

discounted_rewards = reward_to_go(rewards, self.gamma)
discounted_rewards = torch.tensor(discounted_rewards).float()
advantage = discounted_rewards - values
advantage = advantage.type_as(log_probs)

Expand Down
12 changes: 6 additions & 6 deletions squiRL/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,17 @@ def __init__(self, hparams: argparse.Namespace) -> None:
env_kwargs={"id": self.hparams.env})
self.gamma = self.hparams.gamma
self.eps = self.hparams.eps
self.actor_updates_per_iter = self.hparams.actor_updates_per_iter
obs_size = self.env.ob_space.size
n_actions = self.env.ac_space.eltype.n

self.actor = reg_policies[self.hparams.policy](obs_size, n_actions)
self.new_actor = reg_policies[self.hparams.policy](obs_size, n_actions)
self.critic = reg_policies[self.hparams.policy](obs_size, 1)
if hparams.logger:
hparams.logger.watch(self.actor)
hparams.logger.watch(self.new_actor)
hparams.logger.watch(self.critic)
self.replay_buffer = RolloutCollector(self.hparams.episodes_per_batch)

self.agent = Agent(self.env, self.replay_buffer)
Expand All @@ -70,10 +75,6 @@ def add_model_specific_args(
type=str,
default='MLP',
help="NN policy used by agent")
parser.add_argument("--custom_optimizers",
type=bool,
default=True,
help="this value must not be changed")
parser.add_argument("--actor_updates_per_iter",
type=int,
default=10,
Expand Down Expand Up @@ -131,13 +132,12 @@ def ppo_loss(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
dim=-1).squeeze(0)[range(len(actions)),
actions]
discounted_rewards = reward_to_go(rewards, self.gamma)
discounted_rewards = torch.tensor(discounted_rewards).float()
advantage = discounted_rewards - values
advantage = advantage.type_as(log_probs)
criterion = torch.nn.MSELoss()
critic_loss = criterion(discounted_rewards, values.view(-1).float())

for _ in range(self.hparams.actor_updates_per_iter):
for _ in range(self.actor_updates_per_iter):
actor_optimizer.zero_grad()
new_action_logits = self.new_actor(states.float())
new_log_probs = F.log_softmax(
Expand Down
64 changes: 0 additions & 64 deletions squiRL/vpg/config_file.json

This file was deleted.

3 changes: 2 additions & 1 deletion squiRL/vpg/vpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def __init__(self, hparams: argparse.Namespace) -> None:
n_actions = self.env.ac_space.eltype.n

self.net = reg_policies[self.hparams.policy](obs_size, n_actions)
if hparams.logger:
hparams.logger.watch(self.net)
self.replay_buffer = RolloutCollector(self.hparams.episodes_per_batch)

self.agent = Agent(self.env, self.replay_buffer)
Expand Down Expand Up @@ -114,7 +116,6 @@ def vpg_loss(
actions]

discounted_rewards = reward_to_go(rewards, self.gamma)
discounted_rewards = torch.tensor(discounted_rewards)
advantage = (discounted_rewards - discounted_rewards.mean()) / (
discounted_rewards.std() + self.eps)
advantage = advantage.type_as(log_probs)
Expand Down
22 changes: 19 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
"""
import os
import json
import random
import argparse
import gitinfo
from shutil import copyfile
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.utilities.seed import seed_everything
import pytorch_lightning as pl
Expand All @@ -32,19 +35,25 @@ def train(hparams) -> None:
profiler = None
cwd = os.getcwd()
path = os.path.join(cwd, 'models')
if hparams.git_commit is None:
args.git_commit = gitinfo.get_git_info()['commit']
if not os.path.exists(path):
os.mkdir(path)
path = os.path.join(path, hparams.logger.version)
if not os.path.exists(path):
os.mkdir(path)
path = os.path.join(path, hparams.logger.version)
if hparams.save_config:
with open(path + '.json', 'wt') as f:
with open(path + '_init.json', 'wt') as f:
config = vars(hparams).copy()
config.pop("logger")
config.pop("gpus")
config.pop("tpu_cores")
json.dump(config, f, indent=4)
copyfile(path + '_init.json',
hparams.logger.save_dir + "/config_all.json")
copyfile(hparams.load_config,
hparams.logger.save_dir + "/config_init.json")

seed_everything(hparams.seed)
algorithm = squiRL.reg_algorithms[hparams.algorithm](hparams)
Expand All @@ -53,12 +62,19 @@ def train(hparams) -> None:


if __name__ == '__main__':
__spec__ = "ModuleSpec(name='builtins', loader=<class '_frozen_importlib.BuiltinImporter'>)"
# enables pdb debugging
__spec__ = '''ModuleSpec(name='builtins', loader=<class '_frozen_importlib.
BuiltinImporter'>)'''

parser = argparse.ArgumentParser(add_help=False)
group_prog = parser.add_argument_group("program_args")
group_env = parser.add_argument_group("environment_args")

# add PROGRAM level args
parser.add_argument('--git_commit',
type=str,
default=None,
help='current git commit')
parser.add_argument(
'--save_config',
type=bool,
Expand All @@ -69,7 +85,7 @@ def train(hparams) -> None:
help='Load from json file. Command line override.')
group_prog.add_argument('--seed',
type=int,
default=42,
default=random.randint(0, 1000),
help="experiment seed")
group_prog.add_argument(
'--debug',
Expand Down