Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vanilla MCTS (without networks) for testing #67

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions games/connect4.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ def __init__(self):
self.muzero_player = 0 # Turn Muzero begins to play (0: MuZero plays first, 1: MuZero plays second)
self.opponent = "expert" # Hard coded agent that MuZero faces to assess his progress in multiplayer games. It doesn't influence training. None, "random" or "expert" if implemented in the Game class


#Vanilla MCTS, used if self.opponent = "MCTS"
self.num_simulations_vanilla = 200 #Number of simulations
self.n_rollout = 50 #Number of rollouts to estimate the value of a position

### Self-Play
self.num_workers = 1 # Number of simultaneous threads/workers self-playing to feed the replay buffer
Expand Down Expand Up @@ -345,4 +347,4 @@ def expert_action(self):
return action

def render(self):
print(self.board[::-1])
print(self.board[::-1])
2 changes: 1 addition & 1 deletion games/tictactoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,4 +350,4 @@ def expert_action(self):
return action

def render(self):
print(self.board[::-1])
print(self.board[::-1])
5 changes: 4 additions & 1 deletion muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,7 @@ def hyperparameter_search(
"Render some self play games",
"Play against MuZero",
"Test the game manually",
"Test agaisnt MCTS",
"Hyperparameter search",
"Exit",
]
Expand Down Expand Up @@ -498,6 +499,8 @@ def hyperparameter_search(
elif choice == 4:
muzero.test(render=True, opponent="human", muzero_player=0)
elif choice == 5:
muzero.test(render=True, opponent="MCTS", muzero_player=0)
elif choice == 6:
env = muzero.Game()
env.reset()
env.render()
Expand All @@ -508,7 +511,7 @@ def hyperparameter_search(
observation, reward, done = env.step(action)
print(f"\nAction: {env.action_to_string(action)}\nReward: {reward}")
env.render()
elif choice == 6:
elif choice == 7:
# Define here the parameters to tune
# Parametrization documentation: https://facebookresearch.github.io/nevergrad/parametrization.html
budget = 50
Expand Down
88 changes: 88 additions & 0 deletions self_play.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def play_game(
self.game.to_play(),
True,
)

action = self.select_action(
root,
temperature
Expand Down Expand Up @@ -218,6 +219,10 @@ def select_opponent_action(self, opponent, stacked_observations):
), "Legal actions should be a subset of the action space."

return numpy.random.choice(self.game.legal_actions()), None
elif opponent == "MCTS":
root, mcts_info = MCTS(self.config).run_vanilla_mcts(self.game)

return self.select_action(node, 1), root
else:
raise NotImplementedError(
'Wrong argument: "opponent" argument should be "self", "human", "expert" or "random"'
Expand Down Expand Up @@ -358,8 +363,91 @@ def run(
"max_tree_depth": max_tree_depth,
"root_predicted_value": root_predicted_value,
}

return root, extra_info


def run_vanilla_mcts(self, game):
"""
Runs MCTS without using any network and uses a rollouts to estimate the value of a position.
This can only be used with games where we have access to a simulator
"""

def uniform_policy(legal_actions):
p = numpy.ones(len(self.config.action_space))
p[legal_actions] = 1

return [p/len(legal_actions)]

root = Node(0)
root.expand(
game.legal_actions(),
game.to_play(),
0,
uniform_policy(game.legal_actions()),
game
)
min_max_stats = MinMaxStats()
max_tree_depth = 0

for _ in range(self.config.num_simulations_vanilla):
virtual_to_play = game.to_play()
node = root
search_path = [node]
current_tree_depth = 0

expanded_game = copy.deepcopy(game)
while node.expanded():
current_tree_depth += 1
action, node = self.select_child(node, min_max_stats)
search_path.append(node)

if virtual_to_play + 1 < len(self.config.players):
virtual_to_play = self.config.players[virtual_to_play + 1]
else:
virtual_to_play = self.config.players[0]

_, _, done = expanded_game.step(action)

node.expand(
expanded_game.legal_actions(),
virtual_to_play,
0,
uniform_policy(expanded_game.legal_actions()),
expanded_game
)

value = self.rollout(node, virtual_to_play)
self.backpropagate(search_path, value, virtual_to_play, min_max_stats)

max_tree_depth = max(max_tree_depth, current_tree_depth)

extra_info = {
"max_tree_depth": max_tree_depth,
}

return root, extra_info

def rollout(self, node, to_play):
"""
Estimates the value of a position through rollouts.

"""

for _ in range(self.config.n_rollouts):
game = copy.deepcopy(node.hidden_state)
done = len(game.legal_actions()) == 0
v = 0
reward = 0

while not done:
action = numpy.random.choice(game.legal_actions(), 1)
obs, reward, done = game.step(action)

v += -reward if game.to_play() == to_play else reward

return v/self.config.n_rollouts

def select_child(self, node, min_max_stats):
"""
Select the child with the highest UCB score.
Expand Down