diff --git a/games/connect4.py b/games/connect4.py index 090c9261..a86d519c 100644 --- a/games/connect4.py +++ b/games/connect4.py @@ -26,7 +26,9 @@ def __init__(self): self.muzero_player = 0 # Turn Muzero begins to play (0: MuZero plays first, 1: MuZero plays second) self.opponent = "expert" # Hard coded agent that MuZero faces to assess his progress in multiplayer games. It doesn't influence training. None, "random" or "expert" if implemented in the Game class - + #Vanilla MCTS, used if self.opponent = "MCTS" + self.num_simulations_vanilla = 200 #Number of simulations + self.n_rollout = 50 #Number of rollouts to estimate the value of a position ### Self-Play self.num_workers = 1 # Number of simultaneous threads/workers self-playing to feed the replay buffer @@ -345,4 +347,4 @@ def expert_action(self): return action def render(self): - print(self.board[::-1]) + print(self.board[::-1]) \ No newline at end of file diff --git a/games/tictactoe.py b/games/tictactoe.py index 35f54388..5c375357 100644 --- a/games/tictactoe.py +++ b/games/tictactoe.py @@ -350,4 +350,4 @@ def expert_action(self): return action def render(self): - print(self.board[::-1]) + print(self.board[::-1]) \ No newline at end of file diff --git a/muzero.py b/muzero.py index f9047d3d..46401a60 100644 --- a/muzero.py +++ b/muzero.py @@ -463,6 +463,7 @@ def hyperparameter_search( "Render some self play games", "Play against MuZero", "Test the game manually", + "Test agaisnt MCTS", "Hyperparameter search", "Exit", ] @@ -498,6 +499,8 @@ def hyperparameter_search( elif choice == 4: muzero.test(render=True, opponent="human", muzero_player=0) elif choice == 5: + muzero.test(render=True, opponent="MCTS", muzero_player=0) + elif choice == 6: env = muzero.Game() env.reset() env.render() @@ -508,7 +511,7 @@ def hyperparameter_search( observation, reward, done = env.step(action) print(f"\nAction: {env.action_to_string(action)}\nReward: {reward}") env.render() - elif choice == 6: + elif choice == 7: # Define here the parameters to tune # Parametrization documentation: https://facebookresearch.github.io/nevergrad/parametrization.html budget = 50 diff --git a/self_play.py b/self_play.py index 4286f649..ecf6d76f 100644 --- a/self_play.py +++ b/self_play.py @@ -152,6 +152,7 @@ def play_game( self.game.to_play(), True, ) + action = self.select_action( root, temperature @@ -218,6 +219,10 @@ def select_opponent_action(self, opponent, stacked_observations): ), "Legal actions should be a subset of the action space." return numpy.random.choice(self.game.legal_actions()), None + elif opponent == "MCTS": + root, mcts_info = MCTS(self.config).run_vanilla_mcts(self.game) + + return self.select_action(node, 1), root else: raise NotImplementedError( 'Wrong argument: "opponent" argument should be "self", "human", "expert" or "random"' @@ -358,8 +363,91 @@ def run( "max_tree_depth": max_tree_depth, "root_predicted_value": root_predicted_value, } + return root, extra_info + + def run_vanilla_mcts(self, game): + """ + Runs MCTS without using any network and uses a rollouts to estimate the value of a position. + This can only be used with games where we have access to a simulator + """ + + def uniform_policy(legal_actions): + p = numpy.ones(len(self.config.action_space)) + p[legal_actions] = 1 + + return [p/len(legal_actions)] + + root = Node(0) + root.expand( + game.legal_actions(), + game.to_play(), + 0, + uniform_policy(game.legal_actions()), + game + ) + min_max_stats = MinMaxStats() + max_tree_depth = 0 + + for _ in range(self.config.num_simulations_vanilla): + virtual_to_play = game.to_play() + node = root + search_path = [node] + current_tree_depth = 0 + + expanded_game = copy.deepcopy(game) + while node.expanded(): + current_tree_depth += 1 + action, node = self.select_child(node, min_max_stats) + search_path.append(node) + + if virtual_to_play + 1 < len(self.config.players): + virtual_to_play = self.config.players[virtual_to_play + 1] + else: + virtual_to_play = self.config.players[0] + + _, _, done = expanded_game.step(action) + + node.expand( + expanded_game.legal_actions(), + virtual_to_play, + 0, + uniform_policy(expanded_game.legal_actions()), + expanded_game + ) + + value = self.rollout(node, virtual_to_play) + self.backpropagate(search_path, value, virtual_to_play, min_max_stats) + + max_tree_depth = max(max_tree_depth, current_tree_depth) + + extra_info = { + "max_tree_depth": max_tree_depth, + } + + return root, extra_info + + def rollout(self, node, to_play): + """ + Estimates the value of a position through rollouts. + + """ + + for _ in range(self.config.n_rollouts): + game = copy.deepcopy(node.hidden_state) + done = len(game.legal_actions()) == 0 + v = 0 + reward = 0 + + while not done: + action = numpy.random.choice(game.legal_actions(), 1) + obs, reward, done = game.step(action) + + v += -reward if game.to_play() == to_play else reward + + return v/self.config.n_rollouts + def select_child(self, node, min_max_stats): """ Select the child with the highest UCB score.