From c6ea13631908294e3ebaee8f28ca1f6ee8097265 Mon Sep 17 00:00:00 2001 From: Joseph Denman Date: Sun, 18 Dec 2022 13:14:31 -0600 Subject: [PATCH] FIXUP - Simplifications --- diagnose_model.py | 26 ++--- game_history.py | 18 +-- games/breakout.py | 8 +- games/cartpole.py | 18 +-- games/connect4.py | 10 +- mcts.py | 29 ++--- models/muzero_network.py | 2 +- models/muzero_sampled.py | 232 +++++++++++++-------------------------- node.py | 15 +-- replay_buffer.py | 24 ++-- self_play.py | 12 +- trainer.py | 61 +++++----- 12 files changed, 186 insertions(+), 269 deletions(-) diff --git a/diagnose_model.py b/diagnose_model.py index 0dfbc9a6..0b535d7d 100644 --- a/diagnose_model.py +++ b/diagnose_model.py @@ -54,18 +54,18 @@ def get_virtual_trajectory_from_obs( virtual_to_play = self.config.players[0] # Generate new root - value, reward, policy_logits, hidden_state = self.model.recurrent_inference( + value, reward, policy_parameters, hidden_state = self.model.recurrent_inference( root.hidden_state, torch.tensor([[action]]).to(root.hidden_state.device), ) value = support_to_scalar(value, self.config.support_size).item() reward = support_to_scalar(reward, self.config.support_size).item() root = Node(0) + sampled_actions = self.model.sample_actions(policy_parameters) root.expand( - self.config.action_space, + sampled_actions, virtual_to_play, reward, - policy_logits, hidden_state, ) @@ -208,10 +208,10 @@ def __init__(self, title, config): self.policies_after_planning = [] # Not implemented, need to store them in every nodes of the mcts self.prior_values = [] - self.values_after_planning = [[numpy.NaN] * len(self.config.action_space)] + self.values_after_planning = [[numpy.NaN] * sum(self.config.action_shape)] self.prior_root_value = [] self.root_value_after_planning = [] - self.prior_rewards = [[numpy.NaN] * len(self.config.action_space)] + self.prior_rewards = [[numpy.NaN] * sum(self.config.action_shape)] self.mcts_depth = [] def store_info(self, root, mcts_info, action, reward, new_prior_root_value=None): @@ -222,25 +222,19 @@ def store_info(self, root, mcts_info, action, reward, new_prior_root_value=None) self.prior_policies.append( [ root.children[action].prior - if action in root.children.keys() - else numpy.NaN - for action in self.config.action_space + for action in root.children.keys() ] ) self.policies_after_planning.append( [ root.children[action].visit_count / self.config.num_simulations - if action in root.children.keys() - else numpy.NaN - for action in self.config.action_space + for action in root.children.keys() ] ) self.values_after_planning.append( [ root.children[action].value() - if action in root.children.keys() - else numpy.NaN - for action in self.config.action_space + for action in root.children.keys() ] ) self.prior_root_value.append( @@ -252,9 +246,7 @@ def store_info(self, root, mcts_info, action, reward, new_prior_root_value=None) self.prior_rewards.append( [ root.children[action].reward - if action in root.children.keys() - else numpy.NaN - for action in self.config.action_space + for action in root.children.keys() ] ) self.mcts_depth.append(mcts_info["max_tree_depth"]) diff --git a/game_history.py b/game_history.py index 52b02f7b..c06be7c4 100644 --- a/game_history.py +++ b/game_history.py @@ -12,30 +12,24 @@ def __init__(self): self.to_play_history = [] self.child_visits = [] self.root_values = [] + self.sampled_actions_history = [] self.reanalysed_predicted_root_values = None # For PER self.priorities = None self.game_priority = None - def store_search_statistics(self, root, action_space): + def store_search_statistics(self, root): # Turn visit count from root into a policy if root is not None: sum_visits = sum(child.visit_count for child in root.children.values()) - self.child_visits.append( - [ - root.children[a].visit_count / sum_visits - if a in root.children - else 0 - for a in action_space - ] - ) - + self.child_visits.append([root.children[a].visit_count / sum_visits for a in root.children.keys()]) + self.sampled_actions_history.append(root.sampled_actions) self.root_values.append(root.value()) else: self.root_values.append(None) def get_stacked_observations( - self, index, num_stacked_observations, action_space_size + self, index, num_stacked_observations ): """ Generate a new observation with the observation at the index position @@ -55,7 +49,7 @@ def get_stacked_observations( [ numpy.ones_like(stacked_observations[0]) * self.action_history[past_observation_index + 1] - / action_space_size + / len(self.sampled_actions_history[past_observation_index + 1]) ], ) ) diff --git a/games/breakout.py b/games/breakout.py index 8a078d90..e5ce8b93 100644 --- a/games/breakout.py +++ b/games/breakout.py @@ -4,6 +4,7 @@ import gym import numpy import torch +from torch.distributions import Categorical from .abstract_game import AbstractGame @@ -54,7 +55,7 @@ def __init__(self): ### Network - self.network = "resnet" # "resnet" / "fullyconnected" + self.network = "sampled" # "resnet" / "fullyconnected" / "sampled" self.support_size = 10 # Value and reward are scaled (with almost sqrt) and encoded on a vector with a range of -support_size to support_size. Choose it so that support_size <= sqrt(max(abs(discounted reward))) # Residual Network @@ -76,7 +77,10 @@ def __init__(self): self.fc_value_layers = [] # Define the hidden layers in the value network self.fc_policy_layers = [] # Define the hidden layers in the policy network - + # Sampled + self.sample_size = 4 + self.action_shape = [4] + self.policy_distribution = Categorical ### Training self.results_path = pathlib.Path(__file__).resolve().parents[1] / "results" / pathlib.Path(__file__).stem / datetime.datetime.now().strftime("%Y-%m-%d--%H-%M-%S") # Path to store the model weights and TensorBoard logs diff --git a/games/cartpole.py b/games/cartpole.py index fa1e8bbf..21c47569 100644 --- a/games/cartpole.py +++ b/games/cartpole.py @@ -49,9 +49,9 @@ def __init__(self): ### Network - self.network = "fullyconnected" # "resnet" / "fullyconnected" + self.network = "sampled" # "resnet" / "fullyconnected" self.support_size = 10 # Value and reward are scaled (with almost sqrt) and encoded on a vector with a range of -support_size to support_size. Choose it so that support_size <= sqrt(max(abs(discounted reward))) - + # Residual Network self.downsample = False # Downsample observations before representation network, False / "CNN" (lighter) / "resnet" (See paper appendix Network Architecture) self.blocks = 1 # Number of blocks in the ResNet @@ -66,18 +66,22 @@ def __init__(self): # Fully Connected Network self.encoding_size = 8 self.fc_representation_layers = [] # Define the hidden layers in the representation network - self.fc_dynamics_layers = [16] # Define the hidden layers in the dynamics network - self.fc_reward_layers = [16] # Define the hidden layers in the reward network - self.fc_value_layers = [16] # Define the hidden layers in the value network - self.fc_policy_layers = [16] # Define the hidden layers in the policy network + self.fc_dynamics_layers = [32] # Define the hidden layers in the dynamics network + self.fc_reward_layers = [32] # Define the hidden layers in the reward network + self.fc_value_layers = [32] # Define the hidden layers in the value network + self.fc_policy_layers = [128, 128] # Define the hidden layers in the policy network + # Sampled + self.sample_size = 50 + self.action_shape = [2] + self.policy_distribution = torch.distributions.Categorical ### Training self.results_path = pathlib.Path(__file__).resolve().parents[1] / "results" / pathlib.Path(__file__).stem / datetime.datetime.now().strftime("%Y-%m-%d--%H-%M-%S") # Path to store the model weights and TensorBoard logs self.save_model = True # Save the checkpoint in results_path as model.checkpoint self.training_steps = 10000 # Total number of training steps (ie weights update according to a batch) - self.batch_size = 128 # Number of parts of games to train on at each training step + self.batch_size = 256 # Number of parts of games to train on at each training step self.checkpoint_interval = 10 # Number of training steps before using the model for self-playing self.value_loss_weight = 1 # Scale the value loss to avoid overfitting of the value function, paper recommends 0.25 (See paper appendix Reanalyze) self.train_on_gpu = torch.cuda.is_available() # Train on GPU if available diff --git a/games/connect4.py b/games/connect4.py index a01e6551..e8fdfb31 100644 --- a/games/connect4.py +++ b/games/connect4.py @@ -3,6 +3,7 @@ import numpy import torch +from torch.distributions import Categorical from .abstract_game import AbstractGame @@ -48,7 +49,7 @@ def __init__(self): ### Network - self.network = "resnet" # "resnet" / "fullyconnected" + self.network = "sampled" # "resnet" / "fullyconnected" self.support_size = 10 # Value and reward are scaled (with almost sqrt) and encoded on a vector with a range of -support_size to support_size. Choose it so that support_size <= sqrt(max(abs(discounted reward))) # Residual Network @@ -70,14 +71,17 @@ def __init__(self): self.fc_value_layers = [] # Define the hidden layers in the value network self.fc_policy_layers = [] # Define the hidden layers in the policy network - + # Sampled + self.sample_size = 7 + self.action_shape = [7] + self.policy_distribution = Categorical ### Training self.results_path = pathlib.Path(__file__).resolve().parents[1] / "results" / pathlib.Path(__file__).stem / datetime.datetime.now().strftime("%Y-%m-%d--%H-%M-%S") # Path to store the model weights and TensorBoard logs self.save_model = True # Save the checkpoint in results_path as model.checkpoint self.training_steps = 100000 # Total number of training steps (ie weights update according to a batch) self.batch_size = 64 # Number of parts of games to train on at each training step - self.checkpoint_interval = 10 # Number of training steps before using the model for self-playing + self.checkpoint_interval = 200 # Number of training steps before using the model for self-playing self.value_loss_weight = 0.25 # Scale the value loss to avoid overfitting of the value function, paper recommends 0.25 (See paper appendix Reanalyze) self.train_on_gpu = torch.cuda.is_available() # Train on GPU if available diff --git a/mcts.py b/mcts.py index d44ab573..83a3bf7e 100644 --- a/mcts.py +++ b/mcts.py @@ -45,12 +45,7 @@ def run( .unsqueeze(0) .to(next(model.parameters()).device) ) - ( - root_predicted_value, - reward, - policy_logits, - hidden_state, - ) = model.initial_inference(observation) + root_predicted_value, reward, policy_parameters, hidden_state = model.initial_inference(observation) root_predicted_value = support_to_scalar( root_predicted_value, self.config.support_size ).item() @@ -58,14 +53,11 @@ def run( assert ( legal_actions ), f"Legal actions should not be an empty array. Got {legal_actions}." - assert set(legal_actions).issubset( - set(self.config.action_space) - ), "Legal actions should be a subset of the action space." + sampled_actions = model.sample_actions(policy_parameters) root.expand( - legal_actions, + sampled_actions, to_play, reward, - policy_logits, hidden_state, ) @@ -98,17 +90,17 @@ def run( # Inside the search tree we use the dynamics function to obtain the next hidden # state given an action and the previous hidden state parent = search_path[-2] - value, reward, policy_logits, hidden_state = model.recurrent_inference( + value, reward, policy_parameters, hidden_state = model.recurrent_inference( parent.hidden_state, torch.tensor([[action]]).to(parent.hidden_state.device), ) + sampled_actions = model.sample_actions(policy_parameters) value = support_to_scalar(value, self.config.support_size).item() reward = support_to_scalar(reward, self.config.support_size).item() node.expand( - self.config.action_space, + sampled_actions, virtual_to_play, reward, - policy_logits, hidden_state, ) @@ -130,13 +122,8 @@ def select_child(self, node, min_max_stats): self.ucb_score(node, child, min_max_stats) for action, child in node.children.items() ) - action = numpy.random.choice( - [ - action - for action, child in node.children.items() - if self.ucb_score(node, child, min_max_stats) == max_ucb - ] - ) + actions = [action for action, child in node.children.items() if self.ucb_score(node, child, min_max_stats) == max_ucb] + action = actions[numpy.random.choice(range(len(actions)))] return action, node.children[action] def ucb_score(self, parent, child, min_max_stats): diff --git a/models/muzero_network.py b/models/muzero_network.py index cc71afbb..452efde5 100644 --- a/models/muzero_network.py +++ b/models/muzero_network.py @@ -42,7 +42,7 @@ def __new__(cls, config): config.action_shape, config.encoding_size, config.sample_size, - config.blocks, + config.policy_distribution, config.fc_reward_layers, config.fc_value_layers, config.fc_policy_layers, diff --git a/models/muzero_sampled.py b/models/muzero_sampled.py index bd58bf66..12ce1618 100644 --- a/models/muzero_sampled.py +++ b/models/muzero_sampled.py @@ -1,6 +1,6 @@ -from stable_baselines3.common.distributions import MultiCategoricalDistribution -from torch import tanh, relu, log, zeros, tensor, softmax -from torch.nn import Module, Linear, LayerNorm, DataParallel, Tanh + +import torch +from torch.nn.functional import linear from models.abstract_network import AbstractNetwork from models.utils import mlp @@ -13,7 +13,7 @@ def __init__(self, action_shape, encoding_size, sample_size, - blocks, + policy_distribution, fc_reward_layers, fc_value_layers, fc_policy_layers, @@ -24,49 +24,26 @@ def __init__(self, self.action_shape = action_shape self.sample_size = sample_size self.full_support_size = 2 * support_size + 1 - self.policy_distribution = MultiCategoricalDistribution(torch.softmax(torch.rand(1, sum(action_shape)), dim=-1)) - self.representation_network = DataParallel( - RepresentationNetwork( - observation_shape[0] * observation_shape[1] * observation_shape[2] * \ - (stacked_observations + 1) + stacked_observations * observation_shape[1] * \ - observation_shape[2], - encoding_size, - fc_representation_layers, - blocks - ) - ) - self.dynamics_network = DataParallel( - DynamicsNetwork( - encoding_size, - len(action_shape), - fc_reward_layers, - fc_dynamics_layers, - blocks, - self.full_support_size - ) - ) - self.prediction_network = DataParallel( - PredictionNetwork( - encoding_size, - action_shape, - fc_value_layers, - fc_policy_layers, - self.full_support_size - ) - ) - - def initial_inference(self, observation, legal_actions=None, sampled_actions=None): + self.policy_distribution = policy_distribution + self.representation_network = torch.nn.DataParallel( + RepresentationNetwork(encoding_size, observation_shape, stacked_observations, fc_representation_layers)) + self.dynamics_network = torch.nn.DataParallel( + DynamicsNetwork(encoding_size, action_shape, fc_dynamics_layers, fc_reward_layers, self.full_support_size)) + self.prediction_network = torch.nn.DataParallel( + PredictionNetwork(encoding_size, action_shape, fc_policy_layers, fc_value_layers, self.full_support_size)) + + def initial_inference(self, observation, legal_actions=None): encoded_state = self.representation(observation) - reward = log((zeros(1, self.full_support_size) - .scatter(1, tensor([[self.full_support_size // 2]]).long(), 1.0) + reward = torch.log((torch.zeros(1, self.full_support_size) + .scatter(1, torch.tensor([[self.full_support_size // 2]]).long(), 1.0) .repeat(len(observation), 1).to(observation.device))) - sampled_actions, policy_probabilities, value = self.prediction(encoded_state, sampled_actions) - return sampled_actions, value, reward, policy_probabilities, encoded_state + policy_parameters, value = self.prediction(encoded_state) + return value, reward, policy_parameters, encoded_state - def recurrent_inference(self, encoded_state, action, sampled_actions=None): + def recurrent_inference(self, encoded_state, action): next_encoded_state, reward = self.dynamics(encoded_state, action) - sampled_actions, policy_probabilities, value = self.prediction(next_encoded_state, sampled_actions) - return sampled_actions, value, reward, policy_probabilities, next_encoded_state + policy_parameters, value = self.prediction(next_encoded_state) + return value, reward, policy_parameters, next_encoded_state def representation(self, observation): encoded_state = self.representation_network(observation.view(observation.shape[0], -1)) @@ -79,7 +56,7 @@ def representation(self, observation): return encoded_state_normalized def dynamics(self, encoded_state, action): - next_encoded_state, reward = self.dynamics_network(encoded_state, action) + reward, next_encoded_state = self.dynamics_network(encoded_state, action.float()) # Scale encoded state between [0, 1] (See paper appendix Training) min_next_encoded_state = next_encoded_state.min(1, keepdim=True)[0] max_next_encoded_state = next_encoded_state.max(1, keepdim=True)[0] @@ -88,122 +65,67 @@ def dynamics(self, encoded_state, action): next_encoded_state_normalized = (next_encoded_state - min_next_encoded_state) / scale_next_encoded_state return next_encoded_state_normalized, reward - def prediction(self, encoded_state, sampled_actions): - if sampled_actions is None: - if encoded_state.shape[0] != 1: - raise Exception('State batch prediction requires sample batch') - else: - if encoded_state.shape[0] == 1: - raise Exception('State prediction requires no sample batch') - if encoded_state.shape[0] != len(sampled_actions): - raise Exception('State batch and sample batch must be the same length') - policy_params, value = self.prediction_network(encoded_state) - if sampled_actions is None: - distribution = self.policy_distribution.proba_distribution(policy_params) - sampled_actions = self.sample_actions(distribution) - probabilities = self.policy_probabilities(distribution, sampled_actions) - else: - distributions = [self.policy_distribution.proba_distribution(params) for params in - policy_params.unsqueeze(1)] - probabilities = [self.policy_probabilities(distribution, actions) - for distribution, actions in zip(distributions, sampled_actions)] - return sampled_actions, probabilities, value - - def sample_actions(self, distribution): - return torch.stack([distribution.sample().float() for _ in range(self.sample_size)]) - - @staticmethod - def policy_probabilities(distribution, sampled_actions): - return torch.stack( - [distribution.log_prob(tensor(sampled_action)).squeeze(-1) for sampled_action in sampled_actions]) - - -class RepresentationNetwork(Module): - def __init__(self, - observation_size, - encoding_size, - fc_representation_layers, - blocks): - super().__init__() - self.ll = Linear(observation_size, encoding_size) - self.ln = LayerNorm(encoding_size) - self.rbs = [PreActivationResidualBlock(encoding_size, - fc_representation_layers) for _ in range(blocks)] + def prediction(self, encoded_state): + policy_parameters, value = self.prediction_network(encoded_state) + return policy_parameters, value - def forward(self, observation): - state = self.ll(observation) - state = self.ln(state) - state = tanh(state) - for residual_block in self.rbs: - state = residual_block(state) - return state + def sample_actions(self, policy_parameters): + distribution = self.policy_distribution(policy_parameters) + return torch.stack([distribution.sample() for _ in range(self.sample_size)]).squeeze(-1) + def policy_probabilities(self, policy_parameters_batch, sampled_actions_batch): + policies = [] + for policy_parameters, sampled_actions in zip(policy_parameters_batch, sampled_actions_batch): + distribution = self.policy_distribution(policy_parameters) + policies.append(distribution.log_prob(sampled_actions)) + return policies -class DynamicsNetwork(Module): - def __init__(self, - encoding_size, - action_size, - fc_reward_layers, - fc_dynamics_layers, - blocks, - full_support_size): - super().__init__() - self.ll1 = Linear(action_size, encoding_size) - self.ln1 = LayerNorm(encoding_size) - self.ll2 = Linear(encoding_size, encoding_size) - self.ln2 = LayerNorm(encoding_size) - self.rbs = [PreActivationResidualBlock(encoding_size, fc_dynamics_layers) for _ in range(blocks)] - self.mlp = mlp(encoding_size, fc_reward_layers, full_support_size) - - def forward(self, state, action): - x = self.ll1(action) - x = self.ln1(x) - x = relu(x) - x = self.ll2(x + state) - x = self.ln2(x) - x = tanh(x) - for residual_block in self.rbs: - x = residual_block(x) - state = x - reward = self.mlp(x) - return state, reward - - -class PredictionNetwork(Module): - def __init__(self, - encoding_size, - action_shape, - fc_value_layers, - fc_policy_layers, - full_support_size): + +class RepresentationNetwork(torch.nn.Module): + def __init__(self, encoding_size, observation_shape, stacked_observations, fc_representation_layers): super().__init__() - self.value_mlp = mlp(encoding_size, fc_value_layers, full_support_size) - self.policy_mlp = mlp(encoding_size, fc_policy_layers, sum(action_shape), activation=Tanh) + self.representation_network = mlp( + observation_shape[0] + * observation_shape[1] + * observation_shape[2] + * (stacked_observations + 1) + + stacked_observations * observation_shape[1] * observation_shape[2], + fc_representation_layers, + encoding_size, + ) - def forward(self, state): - """ - :state: tensor of shape [[...]] or [[[...]]...] - """ - policy = self.policy_mlp(state) - policy = softmax(policy, dim=1) - value = self.value_mlp(state) - return policy, value + def forward(self, observation): + encoded_state = self.representation_network(observation) + return encoded_state -class PreActivationResidualBlock(Module): - def __init__(self, input_size, layers): +class DynamicsNetwork(torch.nn.Module): + def __init__(self, encoding_size, action_shape, fc_dynamics_layers, fc_reward_layers, full_support_size): + super().__init__() + self.action_block_linear = torch.nn.Linear(len(action_shape), len(action_shape)) + self.action_block_norm = torch.nn.LayerNorm(len(action_shape)) + self.action_block_relu = torch.nn.ReLU() + self.dynamics_encoded_state_network = mlp(encoding_size + len(action_shape), fc_dynamics_layers, encoding_size) + self.dynamics_reward_network = mlp(encoding_size, fc_reward_layers, full_support_size) + + def forward(self, encoded_state, action): + action = self.action_block_linear(action) + action = self.action_block_norm(action) + action = self.action_block_relu(action) + state_action = torch.cat((encoded_state, action), dim=1) + state = self.dynamics_encoded_state_network(state_action) + reward = self.dynamics_reward_network(state) + return reward, state + + +class PredictionNetwork(torch.nn.Module): + def __init__(self, encoding_size, action_shape, fc_policy_layers, fc_value_layers, full_support_size): super().__init__() - self.ln1 = LayerNorm(input_size) - self.mlp1 = mlp(input_size, layers, input_size) - self.ln2 = LayerNorm(input_size) - self.mlp2 = mlp(input_size, layers, input_size) - - def forward(self, x): - out = self.ln1(x) - out = relu(out) - out = self.mlp1(out) - out = self.ln2(out) - out = relu(out) - out = self.mlp2(out) - out += x - return out + self.prediction_policy_network = mlp(encoding_size, fc_policy_layers, sum(action_shape), activation=torch.nn.Tanh) + self.prediction_value_network = mlp(encoding_size, fc_value_layers, full_support_size) + + def forward(self, encoded_state): + policy_parameters = self.prediction_policy_network(encoded_state) + policy_parameters = torch.softmax(policy_parameters, dim=-1) + value = self.prediction_value_network(encoded_state) + return policy_parameters, value \ No newline at end of file diff --git a/node.py b/node.py index b3fc0c08..5365458e 100644 --- a/node.py +++ b/node.py @@ -8,6 +8,7 @@ def __init__(self, prior): self.prior = prior self.value_sum = 0 self.children = {} + self.sampled_actions = None self.hidden_state = None self.reward = 0 @@ -19,7 +20,7 @@ def value(self): return 0 return self.value_sum / self.visit_count - def expand(self, actions, to_play, reward, policy_logits, hidden_state): + def expand(self, sampled_actions, to_play, reward, hidden_state): """ We expand a node using the value, reward and policy prediction obtained from the neural network. @@ -28,12 +29,12 @@ def expand(self, actions, to_play, reward, policy_logits, hidden_state): self.reward = reward self.hidden_state = hidden_state - policy_values = torch.softmax( - torch.tensor([policy_logits[0][a] for a in actions]), dim=0 - ).tolist() - policy = {a: policy_values[i] for i, a in enumerate(actions)} - for action, p in policy.items(): - self.children[action] = Node(p) + uniques, counts = torch.unique(sampled_actions, dim=0, return_counts=True, sorted=True) + self.sampled_actions = uniques + empirical_probabilities = counts / counts.sum() + for action, p in zip(uniques, empirical_probabilities): + self.children[action.item()] = Node(p.item()) + def add_exploration_noise(self, dirichlet_alpha, exploration_fraction): """ diff --git a/replay_buffer.py b/replay_buffer.py index bfa1c1cb..1b042bb9 100644 --- a/replay_buffer.py +++ b/replay_buffer.py @@ -76,8 +76,9 @@ def get_batch(self): reward_batch, value_batch, policy_batch, + sampled_actions_batch, gradient_scale_batch, - ) = ([], [], [], [], [], [], []) + ) = ([], [], [], [], [], [], [], []) weight_batch = [] if self.config.PER else None for game_id, game_history, game_prob in self.sample_n_games( @@ -85,7 +86,7 @@ def get_batch(self): ): game_pos, pos_prob = self.sample_position(game_history) - values, rewards, policies, actions = self.make_target( + values, rewards, policies, sampled_actions, actions = self.make_target( game_history, game_pos ) @@ -93,14 +94,14 @@ def get_batch(self): observation_batch.append( game_history.get_stacked_observations( game_pos, - self.config.stacked_observations, - len(self.config.action_space), + self.config.stacked_observations ) ) action_batch.append(actions) value_batch.append(values) reward_batch.append(rewards) policy_batch.append(policies) + sampled_actions_batch.append(sampled_actions) gradient_scale_batch.append( [ min( @@ -133,6 +134,7 @@ def get_batch(self): value_batch, reward_batch, policy_batch, + sampled_actions_batch, weight_batch, gradient_scale_batch, ), @@ -266,7 +268,7 @@ def make_target(self, game_history, state_index): """ Generate targets for every unroll steps. """ - target_values, target_rewards, target_policies, actions = [], [], [], [] + target_values, target_rewards, target_policies, sampled_actions, actions = [], [], [], [], [] for current_index in range( state_index, state_index + self.config.num_unroll_steps + 1 ): @@ -276,6 +278,7 @@ def make_target(self, game_history, state_index): target_values.append(value) target_rewards.append(game_history.reward_history[current_index]) target_policies.append(game_history.child_visits[current_index]) + sampled_actions.append(game_history.sampled_actions_history[current_index]) actions.append(game_history.action_history[current_index]) elif current_index == len(game_history.root_values): target_values.append(0) @@ -287,6 +290,7 @@ def make_target(self, game_history, state_index): for _ in range(len(game_history.child_visits[0])) ] ) + sampled_actions.append(game_history.sampled_actions_history[0]) actions.append(game_history.action_history[current_index]) else: # States past the end of games are treated as absorbing states @@ -299,9 +303,10 @@ def make_target(self, game_history, state_index): for _ in range(len(game_history.child_visits[0])) ] ) - actions.append(numpy.random.choice(self.config.action_space)) - - return target_values, target_rewards, target_policies, actions + sampled_actions.append(game_history.sampled_actions_history[0]) + actions.append(game_history.sampled_actions_history[0][ + numpy.random.choice(len(game_history.sampled_actions_history[0]))]) + return target_values, target_rewards, target_policies, sampled_actions, actions @ray.remote @@ -347,8 +352,7 @@ def reanalyse(self, replay_buffer, shared_storage): [ game_history.get_stacked_observations( i, - self.config.stacked_observations, - len(self.config.action_space), + self.config.stacked_observations ) for i in range(len(game_history.root_values)) ] diff --git a/self_play.py b/self_play.py index 0120ae57..ba7add54 100644 --- a/self_play.py +++ b/self_play.py @@ -137,7 +137,7 @@ def play_game( numpy.array(observation).shape == self.config.observation_shape ), f"Observation should match the observation_shape defined in MuZeroConfig. Expected {self.config.observation_shape} but got {numpy.array(observation).shape}." stacked_observations = game_history.get_stacked_observations( - -1, self.config.stacked_observations, len(self.config.action_space) + -1, self.config.stacked_observations ) # Choose the action @@ -173,7 +173,7 @@ def play_game( print(f"Played action: {self.game.action_to_string(action)}") self.game.render() - game_history.store_search_statistics(root, self.config.action_space) + game_history.store_search_statistics(root) # Next batch game_history.action_history.append(action) @@ -210,10 +210,6 @@ def select_opponent_action(self, opponent, stacked_observations): assert ( self.game.legal_actions() ), f"Legal actions should not be an empty array. Got {self.game.legal_actions()}." - assert set(self.game.legal_actions()).issubset( - set(self.config.action_space) - ), "Legal actions should be a subset of the action space." - return numpy.random.choice(self.game.legal_actions()), None else: raise NotImplementedError( @@ -234,13 +230,13 @@ def select_action(node, temperature): if temperature == 0: action = actions[numpy.argmax(visit_counts)] elif temperature == float("inf"): - action = numpy.random.choice(actions) + action = actions[numpy.random.choice(range(len(actions)))] else: # See paper appendix Data Generation visit_count_distribution = visit_counts ** (1 / temperature) visit_count_distribution = visit_count_distribution / sum( visit_count_distribution ) - action = numpy.random.choice(actions, p=visit_count_distribution) + action = actions[numpy.random.choice(range(len(actions)),p=visit_count_distribution)] return action \ No newline at end of file diff --git a/trainer.py b/trainer.py index 1534b28f..97991b98 100644 --- a/trainer.py +++ b/trainer.py @@ -1,4 +1,5 @@ import copy +import math import time import numpy @@ -122,6 +123,13 @@ def continuous_update_weights(self, replay_buffer, shared_storage): ): time.sleep(0.5) + @staticmethod + def get_section(batch, i): + section = [] + for game in batch: + section.append(game[i]) + return section + def update_weights(self, batch): """ Perform one training step. @@ -133,6 +141,7 @@ def update_weights(self, batch): target_value, target_reward, target_policy, + sampled_action_batch, weight_batch, gradient_scale_batch, ) = batch @@ -150,13 +159,12 @@ def update_weights(self, batch): action_batch = torch.tensor(action_batch).long().to(device).unsqueeze(-1) target_value = torch.tensor(target_value).float().to(device) target_reward = torch.tensor(target_reward).float().to(device) - target_policy = torch.tensor(target_policy).float().to(device) gradient_scale_batch = torch.tensor(gradient_scale_batch).float().to(device) # observation_batch: batch, channels, height, width # action_batch: batch, num_unroll_steps+1, 1 (unsqueeze) # target_value: batch, num_unroll_steps+1 # target_reward: batch, num_unroll_steps+1 - # target_policy: batch, num_unroll_steps+1, len(action_space) + # target_policy: batch, num_unroll_steps+1, k <= self.sample_size # gradient_scale_batch: batch, num_unroll_steps+1 target_value = scalar_to_support(target_value, self.config.support_size) @@ -167,30 +175,32 @@ def update_weights(self, batch): # target_reward: batch, num_unroll_steps+1, 2*support_size+1 ## Generate predictions - value, reward, policy_logits, hidden_state = self.model.initial_inference( + value, reward, policy_parameters, hidden_state = self.model.initial_inference( observation_batch ) - predictions = [(value, reward, policy_logits)] + policy = self.model.policy_probabilities(policy_parameters, self.get_section(sampled_action_batch, 0)) + predictions = [(value, reward, policy)] for i in range(1, action_batch.shape[1]): - value, reward, policy_logits, hidden_state = self.model.recurrent_inference( + value, reward, policy_parameters, hidden_state = self.model.recurrent_inference( hidden_state, action_batch[:, i] ) + policy = self.model.policy_probabilities(policy_parameters, self.get_section(sampled_action_batch, i)) # Scale the gradient at the start of the dynamics function (See paper appendix Training) hidden_state.register_hook(lambda grad: grad * 0.5) - predictions.append((value, reward, policy_logits)) + predictions.append((value, reward, policy)) # predictions: num_unroll_steps+1, 3, batch, 2*support_size+1 | 2*support_size+1 | 9 (according to the 2nd dim) ## Compute losses value_loss, reward_loss, policy_loss = (0, 0, 0) - value, reward, policy_logits = predictions[0] + value, reward, policy = predictions[0] # Ignore reward loss for the first batch step current_value_loss, _, current_policy_loss = self.loss_function( value.squeeze(-1), reward.squeeze(-1), - policy_logits, + policy, target_value[:, 0], target_reward[:, 0], - target_policy[:, 0], + self.get_section(target_policy, 0), ) value_loss += current_value_loss policy_loss += current_policy_loss @@ -208,7 +218,7 @@ def update_weights(self, batch): ) for i in range(1, len(predictions)): - value, reward, policy_logits = predictions[i] + value, reward, policy = predictions[i] ( current_value_loss, current_reward_loss, @@ -216,10 +226,10 @@ def update_weights(self, batch): ) = self.loss_function( value.squeeze(-1), reward.squeeze(-1), - policy_logits, + policy, target_value[:, i], target_reward[:, i], - target_policy[:, i], + self.get_section(target_policy, i), ) # Scale gradient by the number of unroll steps (See paper appendix Training) @@ -277,25 +287,24 @@ def update_lr(self): """ Update learning rate """ - lr = self.config.lr_init * self.config.lr_decay_rate ** ( - self.training_step / self.config.lr_decay_steps - ) + lr = self.config.lr_init * 0.5 * (1 + math.cos(math.pi * (self.training_step / self.config.lr_decay_steps))) for param_group in self.optimizer.param_groups: param_group["lr"] = lr @staticmethod def loss_function( - value, - reward, - policy_logits, - target_value, - target_reward, - target_policy, + values, + rewards, + policies, + target_values, + target_rewards, + target_policies, ): # Cross-entropy seems to have a better convergence than MSE - value_loss = (-target_value * torch.nn.LogSoftmax(dim=1)(value)).sum(1) - reward_loss = (-target_reward * torch.nn.LogSoftmax(dim=1)(reward)).sum(1) - policy_loss = (-target_policy * torch.nn.LogSoftmax(dim=1)(policy_logits)).sum( - 1 - ) + value_loss = (-target_values * torch.nn.LogSoftmax(dim=1)(values)).sum(1) + reward_loss = (-target_rewards * torch.nn.LogSoftmax(dim=1)(rewards)).sum(1) + policy_loss = [] + for target_policy, policy in zip(target_policies, policies): + policy_loss.append((-torch.tensor(target_policy).float().to(policy.device) * policy).sum(0)) + policy_loss = torch.stack(policy_loss) return value_loss, reward_loss, policy_loss