-
Notifications
You must be signed in to change notification settings - Fork 2
/
play_game.py
72 lines (59 loc) · 1.77 KB
/
play_game.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# -*- coding: utf-8 -*-
"""
Created on Thu Dec 31 10:36:52 2020
@author: sgyhit
"""
from gamestate import *
from policies import *
import networkx as nx
import sys
import copy
from utilities import *
import matplotlib.pyplot as plt
import random
import pickle
def mcts_process(game, budget=400):
'''
game is gameState object
budget is computational budget for Monte Carlo Tree Search
'''
MCTS = MCTSPolicy(game)
for i in range(budget):
# select the node to expand in search tree
node_exp = MCTS.selection(0)
# expand the node and return the frontier node
node_fron = MCTS.expansion(node_exp)
# rollout from frontier node
reward = MCTS.simulation(node_fron)
# backpropagation
MCTS.backpropagation(node_fron, reward)
# make a copy of the current tree
# s = visualize_MCTS(MCTS)
# s.view()
# you need to rewrite this action_selection function
nextNode = action_selection(MCTS)
return (nextNode, MCTS)
def terminal_condition():
return True
def play_game(starts, budget):
# start a game
horizon = 10
game = GameState(starts, horizon)
# computational budget
path = [starts]
while terminal_condition():
next_node, mcts = mcts_process(game, budget)
path.append(next_node)
game = GameState(next_node, horizon)
trajs = {}
for robot in path[0]:
traj = [node[robot] for node in path]
print('robot {} trajectory:'.format(robot))
print(traj)
trajs[robot] = traj
print('reward is {}'.format(game.collected_reward()))
return game.collected_reward()
if __name__ == "__main__":
# budgets = [100, 400, 800, 1600]
starts = {0: (0, 0), 1: (0, 0), 2: (0, 0), 3: (0, 0)}
play_game(starts, 800)