-
Notifications
You must be signed in to change notification settings - Fork 0
/
eval.py
74 lines (61 loc) · 2.29 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch
import torch.nn.functional as F
from torch.distributions import Categorical
import glob
import cv2
import numpy as np
import os
from collect_data import DoomTakeCover
from model import VAE, RNNModel, Controller
from es_train import load_init_z, sample_init_z, encode_action
from config import cfg
from common import Logger
def slave(comm):
vae = VAE()
vae.load_state_dict(torch.load(cfg.vae_save_ckpt, map_location=lambda storage, loc: storage)['model'])
model = RNNModel()
model.load_state_dict(torch.load(cfg.rnn_save_ckpt, map_location=lambda storage, loc: storage)['model'])
controller = Controller()
controller.load_state_dict(torch.load(cfg.ctrl_save_ckpt, map_location=lambda storage, loc: storage)['model'])
env = DoomTakeCover(False)
rewards = []
for epi in range(cfg.trials_per_pop * 4):
obs = env.reset()
model.reset()
for step in range(cfg.max_steps):
obs = torch.from_numpy(obs.transpose(2, 0, 1)).unsqueeze(0).float() / 255.0
mu, logvar, _, z = vae(obs)
inp = torch.cat((model.hx.detach(), model.cx.detach(), z), dim=1)
y = controller(inp)
y = y.item()
action = encode_action(y)
model.step(z.unsqueeze(0), action.unsqueeze(0))
obs_next, reward, done, _ = env.step(action.item())
obs = obs_next
if done:
break
rewards.append(step)
print('Workder {} got reward {} at epi {}'.format(comm.rank, step, epi))
rewards = np.array(rewards)
comm.send(rewards, dest=0, tag=1)
print('Worker {} sent rewards to master'.format(comm.rank))
if __name__ == '__main__':
from mpi4py import MPI
comm = MPI.COMM_WORLD
rank = comm.rank
size = comm.size
if rank == 0:
f = open('result.txt', 'a')
rewards = []
for idx in range(1, size):
reward = comm.recv(source=idx, tag=1)
print('Master received rewards from slave {}'.format(idx))
rewards.append(reward)
rewards = np.array(rewards)
info = 'Mean {}\t Max {}\t Min {}\t Std {}'.format(rewards.mean(), rewards.max(), rewards.min(), rewards.std())
f.write(info)
f.flush()
f.close()
print(info)
else:
slave(comm)