-
Notifications
You must be signed in to change notification settings - Fork 7
/
runner_for_test.py
70 lines (58 loc) · 2.38 KB
/
runner_for_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import torch
import ray
import numpy as np
from config import config
from model import Model
from env import Env
cfg = config()
class TestRunner(object):
def __init__(self, metaAgentID, cfg, decode_type='sampling', plot=False):
self.ID = metaAgentID
self.decode_type = decode_type
self.model = Model(cfg,self.decode_type, training=False)
self.model.to(cfg.device)
self.local_decoder_gradient = []
self.local_agent_encoder_gradient = []
self.local_target_encoder_gradient = []
self.agent_amount = cfg.agent_amount
self.plot = plot
def run(self, env):
return self.model(env,self.agent_amount)
def set_weights(self, global_weights):
self.model.load_state_dict(global_weights)
def sample(self, env):
with torch.no_grad():
route_set, _, _, max_length, _ = self.run(env)
if self.plot:
return max_length,route_set # use this code for plot
else:
return max_length
@ray.remote(num_gpus= cfg.number_of_gpu / cfg.meta_agent_amount, num_cpus=1) # use this for training and sampling test
class RayTestRunner(TestRunner):
pass
if __name__ == '__main__':
cfg = config()
env = Env(cfg)
env1 = Env(cfg)
# device = 'cuda:0'
# agent_encoder = AgentEncoder(cfg)
# agent_encoder.to(device)
# target_encoder = TargetEncoder(cfg)
# target_encoder.to(device)
# decoder = Decoder(cfg)
# decoder.to(device)
# workerList = [Worker(agentID=i, local_agent_encoder=agent_encoder, local_target_encoder=target_encoder,
# local_decoder=decoder, target_inputs=env.target_inputs) for i in range(cfg.agent_amount)]
# agent_inputs = env.get_agent_inputs(workerList)
runner = TestRunner(1, cfg)
# cost_set, route_set, log_p_set, reward_set,reward = runner.single_thread_job(cfg=cfg, env=env)
# baseline = torch.Tensor([0]).cuda()
# advantage = runner.get_advantage(reward.expand_as(cost_set), baseline)
# loss = runner.get_loss(advantage, log_p_set)
# loss.backward()
reward1 = runner.sample(env)
reward2 = runner.sample(env1)
baseline = torch.stack([(reward1 - 1).unsqueeze(0).unsqueeze(0).repeat(5, 1),
(reward2 - 1).unsqueeze(0).unsqueeze(0).repeat(5, 1)])
# baseline size should be [buffer_size,agent_size,1]
g = runner.return_gradient(baseline)