-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
94 lines (83 loc) · 2.67 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import argparse
import logging
import multiprocessing
import tensorflow as tf
tf.enable_eager_execution()
from rl import constants, play
from rl.agent.rnd import RandomAgent
from rl.agent.a3c import A3CAgent
from rl.agent.dqn import DQNAgent
from rl.agent.ppo import PPOAgent
def run():
parser = argparse.ArgumentParser(
description="Run a RL agent on an AI gym environment"
)
parser.add_argument(
"--env-name",
default=constants.DEFAULT_ENV_NAME,
help="The gym environment to run",
)
parser.add_argument(
"--algorithm",
default=constants.DEFAULT_ALGORITHM,
type=str,
choices=["random", "dqn", "a3c", "ppo"],
help="The algorihtm to use for the RL agent.",
)
parser.add_argument(
"--train", dest="train", action="store_true", help="Train our model."
)
parser.add_argument(
"--lr",
default=constants.DEFAULT_LEARNING_RATE,
help="Learning rate for the shared optimizer.",
)
parser.add_argument(
"--update-freq",
default=constants.DEFAULT_UPDATE_FREQUENCY,
type=int,
help="How often to update the global model.",
)
parser.add_argument(
"--max-episodes",
default=constants.DEFAULT_MAX_EPISODES,
type=int,
help="Global maximum number of episodes to run.",
)
parser.add_argument(
"--gamma", default=constants.DEFAULT_GAMMA, help="Discount factor of rewards."
)
parser.add_argument(
"--save-dir", help="Directory in which you desire to save the model."
)
parser.add_argument("--log-level", default="DEBUG")
parser.add_argument("--thread-count", type=int, default=multiprocessing.cpu_count())
args = parser.parse_args()
logging.basicConfig(level=args.log_level)
if args.algorithm == "random":
agent = RandomAgent(args.env_name, args.max_episodes)
elif args.algorithm == "dqn":
agent = DQNAgent(
args.env_name, max_episodes=args.max_episodes, save_dir=args.save_dir
)
elif args.algorithm == "a3c":
agent = A3CAgent(
env_name=args.env_name,
learning_rate=args.lr,
max_episodes=args.max_episodes,
save_dir=args.save_dir,
)
elif args.algorithm == "ppo":
agent = PPOAgent(
env_name=args.env_name,
save_dir=args.save_dir,
max_episodes=args.max_episodes,
)
if args.train:
agent.train()
else:
if hasattr(agent, "model"):
agent.model = agent.load_model(agent.model)
play.play(agent, args.env_name, args.max_episodes)
if __name__ == "__main__":
run()