Skip to content

Commit

Permalink
Add configurable results dir and save metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
Kaixhin committed Jun 17, 2019
1 parent bdf8e39 commit 8a34488
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 21 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ python main.py --target-update 2000 \
--memory-capacity 100000 \
--replay-frequency 1 \
--multi-step 20 \
--architecture canonical \
--architecture data-efficient \
--hidden-size 256 \
--learning-rate 0.0001
--learning-rate 0.0001 \
--evaluation-interval 10000
```

Requirements
Expand Down
10 changes: 8 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import argparse
from math import inf
import os
from datetime import datetime
import atari_py
import numpy as np
Expand All @@ -13,6 +15,7 @@

# Note that hyperparameters may originally be reported in ATARI game frames instead of agent steps
parser = argparse.ArgumentParser(description='Rainbow')
parser.add_argument('--id', type=str, default='default', help='Experiment ID')
parser.add_argument('--seed', type=int, default=123, help='Random seed')
parser.add_argument('--disable-cuda', action='store_true', help='Disable CUDA')
parser.add_argument('--game', type=str, default='space_invaders', choices=atari_py.list_games(), help='ATARI game')
Expand Down Expand Up @@ -51,6 +54,9 @@
print(' ' * 26 + 'Options')
for k, v in vars(args).items():
print(' ' * 26 + k + ': ' + str(v))
results_dir = os.path.join('results', args.id)
os.makedirs(results_dir, exist_ok=True)
metrics = {'steps': [], 'rewards': [], 'Qs': [], 'best_avg_reward': -inf}
np.random.seed(args.seed)
torch.manual_seed(np.random.randint(1, 10000))
if torch.cuda.is_available() and not args.disable_cuda:
Expand Down Expand Up @@ -92,7 +98,7 @@ def log(s):

if args.evaluate:
dqn.eval() # Set DQN (online network) to evaluation mode
avg_reward, avg_Q = test(args, 0, dqn, val_mem, evaluate=True) # Test
avg_reward, avg_Q = test(args, 0, dqn, val_mem, metrics, results_dir, evaluate=True) # Test
print('Avg. reward: ' + str(avg_reward) + ' | Avg. Q: ' + str(avg_Q))
else:
# Training loop
Expand Down Expand Up @@ -120,7 +126,7 @@ def log(s):

if T % args.evaluation_interval == 0:
dqn.eval() # Set DQN (online network) to evaluation mode
avg_reward, avg_Q = test(args, T, dqn, val_mem) # Test
avg_reward, avg_Q = test(args, T, dqn, val_mem, metrics, results_dir) # Test
log('T = ' + str(T) + ' / ' + str(args.T_max) + ' | Avg. reward: ' + str(avg_reward) + ' | Avg. Q: ' + str(avg_Q))
dqn.train() # Set DQN (online network) back to training mode

Expand Down
30 changes: 13 additions & 17 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,11 @@
from env import Env


# Globals
Ts, rewards, Qs, best_avg_reward = [], [], [], -1e10


# Test DQN
def test(args, T, dqn, val_mem, evaluate=False):
global Ts, rewards, Qs, best_avg_reward
def test(args, T, dqn, val_mem, metrics, results_dir, evaluate=False):
env = Env(args)
env.eval()
Ts.append(T)
metrics['steps'].append(T)
T_rewards, T_Qs = [], []

# Test performance over several episodes
Expand All @@ -43,18 +38,19 @@ def test(args, T, dqn, val_mem, evaluate=False):

avg_reward, avg_Q = sum(T_rewards) / len(T_rewards), sum(T_Qs) / len(T_Qs)
if not evaluate:
# Append to results
rewards.append(T_rewards)
Qs.append(T_Qs)
# Save model parameters if improved
if avg_reward > metrics['best_avg_reward']:
metrics['best_avg_reward'] = avg_reward
dqn.save(results_dir)

# Plot
_plot_line(Ts, rewards, 'Reward', path='results')
_plot_line(Ts, Qs, 'Q', path='results')
# Append to results and save metrics
metrics['rewards'].append(T_rewards)
metrics['Qs'].append(T_Qs)
torch.save(metrics, os.path.join(results_dir, 'metrics.pth'))

# Save model parameters if improved
if avg_reward > best_avg_reward:
best_avg_reward = avg_reward
dqn.save('results')
# Plot
_plot_line(metrics['steps'], metrics['rewards'], 'Reward', path=results_dir)
_plot_line(metrics['steps'], metrics['Qs'], 'Q', path=results_dir)

# Return average reward and Q-value
return avg_reward, avg_Q
Expand Down

0 comments on commit 8a34488

Please sign in to comment.