-
Notifications
You must be signed in to change notification settings - Fork 40
/
model.py
80 lines (68 loc) · 2.61 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.multiprocessing as mp
class Model(nn.Module):
def __init__(self, num_inputs, num_outputs):
super(Model, self).__init__()
h_size_1 = 100
h_size_2 = 100
self.p_fc1 = nn.Linear(num_inputs, h_size_1)
self.p_fc2 = nn.Linear(h_size_1, h_size_2)
self.v_fc1 = nn.Linear(num_inputs, h_size_1*5)
self.v_fc2 = nn.Linear(h_size_1*5, h_size_2)
self.mu = nn.Linear(h_size_2, num_outputs)
self.log_std = nn.Parameter(torch.zeros(1, num_outputs))
self.v = nn.Linear(h_size_2,1)
for name, p in self.named_parameters():
# init parameters
if 'bias' in name:
p.data.fill_(0)
'''
if 'mu.weight' in name:
p.data.normal_()
p.data /= torch.sum(p.data**2,0).expand_as(p.data)'''
# mode
self.train()
def forward(self, inputs):
# actor
x = F.tanh(self.p_fc1(inputs))
x = F.tanh(self.p_fc2(x))
mu = self.mu(x)
sigma_sq = torch.exp(self.log_std)
# critic
x = F.tanh(self.v_fc1(inputs))
x = F.tanh(self.v_fc2(x))
v = self.v(x)
return mu, sigma_sq, v
class Shared_grad_buffers():
def __init__(self, model):
self.grads = {}
for name, p in model.named_parameters():
self.grads[name+'_grad'] = torch.ones(p.size()).share_memory_()
def add_gradient(self, model):
for name, p in model.named_parameters():
self.grads[name+'_grad'] += p.grad.data
def reset(self):
for name,grad in self.grads.items():
self.grads[name].fill_(0)
class Shared_obs_stats():
def __init__(self, num_inputs):
self.n = torch.zeros(num_inputs).share_memory_()
self.mean = torch.zeros(num_inputs).share_memory_()
self.mean_diff = torch.zeros(num_inputs).share_memory_()
self.var = torch.zeros(num_inputs).share_memory_()
def observes(self, obs):
# observation mean var updates
x = obs.data.squeeze()
self.n += 1.
last_mean = self.mean.clone()
self.mean += (x-self.mean)/self.n
self.mean_diff += (x-last_mean)*(x-self.mean)
self.var = torch.clamp(self.mean_diff/self.n, min=1e-2)
def normalize(self, inputs):
obs_mean = Variable(self.mean.unsqueeze(0).expand_as(inputs))
obs_std = Variable(torch.sqrt(self.var).unsqueeze(0).expand_as(inputs))
return torch.clamp((inputs-obs_mean)/obs_std, -5., 5.)