forked from MorvanZhou/pytorch-A3C
-
Notifications
You must be signed in to change notification settings - Fork 0
/
continuous_A3C.py
133 lines (113 loc) · 4.43 KB
/
continuous_A3C.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
"""
Reinforcement Learning (A3C) using Pytroch + multiprocessing.
The most simple implementation for continuous action.
View more on my Chinese tutorial page [莫烦Python](https://morvanzhou.github.io/).
"""
import torch
import torch.nn as nn
from utils import v_wrap, set_init, push_and_pull, record
import torch.nn.functional as F
import torch.multiprocessing as mp
from shared_adam import SharedAdam
import gym
import math, os
os.environ["OMP_NUM_THREADS"] = "1"
UPDATE_GLOBAL_ITER = 5
GAMMA = 0.9
MAX_EP = 3000
MAX_EP_STEP = 200
env = gym.make('Pendulum-v0')
N_S = env.observation_space.shape[0]
N_A = env.action_space.shape[0]
class Net(nn.Module):
def __init__(self, s_dim, a_dim):
super(Net, self).__init__()
self.s_dim = s_dim
self.a_dim = a_dim
self.a1 = nn.Linear(s_dim, 100)
self.mu = nn.Linear(100, a_dim)
self.sigma = nn.Linear(100, a_dim)
self.c1 = nn.Linear(s_dim, 100)
self.v = nn.Linear(100, 1)
set_init([self.a1, self.mu, self.sigma, self.c1, self.v])
self.distribution = torch.distributions.Normal
def forward(self, x):
a1 = F.relu(self.a1(x))
mu = 2 * F.tanh(self.mu(a1))
sigma = F.softplus(self.sigma(a1)) + 0.001 # avoid 0
c1 = F.relu(self.c1(x))
values = self.v(c1)
return mu, sigma, values
def choose_action(self, s):
self.training = False
mu, sigma, _ = self.forward(s)
m = self.distribution(mu.view(1, ).data, sigma.view(1, ).data)
return m.sample().numpy()
def loss_func(self, s, a, v_t):
self.train()
mu, sigma, values = self.forward(s)
td = v_t - values
c_loss = td.pow(2)
m = self.distribution(mu, sigma)
log_prob = m.log_prob(a)
entropy = 0.5 + 0.5 * math.log(2 * math.pi) + torch.log(m.scale) # exploration
exp_v = log_prob * td.detach() + 0.005 * entropy
a_loss = -exp_v
total_loss = (a_loss + c_loss).mean()
return total_loss
class Worker(mp.Process):
def __init__(self, gnet, opt, global_ep, global_ep_r, res_queue, name):
super(Worker, self).__init__()
self.name = 'w%i' % name
self.g_ep, self.g_ep_r, self.res_queue = global_ep, global_ep_r, res_queue
self.gnet, self.opt = gnet, opt
self.lnet = Net(N_S, N_A) # local network
self.env = gym.make('Pendulum-v0').unwrapped
def run(self):
total_step = 1
while self.g_ep.value < MAX_EP:
s = self.env.reset()
buffer_s, buffer_a, buffer_r = [], [], []
ep_r = 0.
for t in range(MAX_EP_STEP):
if self.name == 'w0':
self.env.render()
a = self.lnet.choose_action(v_wrap(s[None, :]))
s_, r, done, _ = self.env.step(a.clip(-2, 2))
if t == MAX_EP_STEP - 1:
done = True
ep_r += r
buffer_a.append(a)
buffer_s.append(s)
buffer_r.append((r+8.1)/8.1) # normalize
if total_step % UPDATE_GLOBAL_ITER == 0 or done: # update global and assign to local net
# sync
push_and_pull(self.opt, self.lnet, self.gnet, done, s_, buffer_s, buffer_a, buffer_r, GAMMA)
buffer_s, buffer_a, buffer_r = [], [], []
if done: # done and print information
record(self.g_ep, self.g_ep_r, ep_r, self.res_queue, self.name)
break
s = s_
total_step += 1
self.res_queue.put(None)
if __name__ == "__main__":
gnet = Net(N_S, N_A) # global network
gnet.share_memory() # share the global parameters in multiprocessing
opt = SharedAdam(gnet.parameters(), lr=0.0002) # global optimizer
global_ep, global_ep_r, res_queue = mp.Value('i', 0), mp.Value('d', 0.), mp.Queue()
# parallel training
workers = [Worker(gnet, opt, global_ep, global_ep_r, res_queue, i) for i in range(mp.cpu_count())]
[w.start() for w in workers]
res = [] # record episode reward to plot
while True:
r = res_queue.get()
if r is not None:
res.append(r)
else:
break
[w.join() for w in workers]
import matplotlib.pyplot as plt
plt.plot(res)
plt.ylabel('Moving average ep reward')
plt.xlabel('Step')
plt.show()