Skip to content

Commit

Permalink
FIXUP - Simplifications
Browse files Browse the repository at this point in the history
  • Loading branch information
JosephDenman committed Dec 18, 2022
1 parent 9966f72 commit c6ea136
Show file tree
Hide file tree
Showing 12 changed files with 186 additions and 269 deletions.
26 changes: 9 additions & 17 deletions diagnose_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,18 +54,18 @@ def get_virtual_trajectory_from_obs(
virtual_to_play = self.config.players[0]

# Generate new root
value, reward, policy_logits, hidden_state = self.model.recurrent_inference(
value, reward, policy_parameters, hidden_state = self.model.recurrent_inference(
root.hidden_state,
torch.tensor([[action]]).to(root.hidden_state.device),
)
value = support_to_scalar(value, self.config.support_size).item()
reward = support_to_scalar(reward, self.config.support_size).item()
root = Node(0)
sampled_actions = self.model.sample_actions(policy_parameters)
root.expand(
self.config.action_space,
sampled_actions,
virtual_to_play,
reward,
policy_logits,
hidden_state,
)

Expand Down Expand Up @@ -208,10 +208,10 @@ def __init__(self, title, config):
self.policies_after_planning = []
# Not implemented, need to store them in every nodes of the mcts
self.prior_values = []
self.values_after_planning = [[numpy.NaN] * len(self.config.action_space)]
self.values_after_planning = [[numpy.NaN] * sum(self.config.action_shape)]
self.prior_root_value = []
self.root_value_after_planning = []
self.prior_rewards = [[numpy.NaN] * len(self.config.action_space)]
self.prior_rewards = [[numpy.NaN] * sum(self.config.action_shape)]
self.mcts_depth = []

def store_info(self, root, mcts_info, action, reward, new_prior_root_value=None):
Expand All @@ -222,25 +222,19 @@ def store_info(self, root, mcts_info, action, reward, new_prior_root_value=None)
self.prior_policies.append(
[
root.children[action].prior
if action in root.children.keys()
else numpy.NaN
for action in self.config.action_space
for action in root.children.keys()
]
)
self.policies_after_planning.append(
[
root.children[action].visit_count / self.config.num_simulations
if action in root.children.keys()
else numpy.NaN
for action in self.config.action_space
for action in root.children.keys()
]
)
self.values_after_planning.append(
[
root.children[action].value()
if action in root.children.keys()
else numpy.NaN
for action in self.config.action_space
for action in root.children.keys()
]
)
self.prior_root_value.append(
Expand All @@ -252,9 +246,7 @@ def store_info(self, root, mcts_info, action, reward, new_prior_root_value=None)
self.prior_rewards.append(
[
root.children[action].reward
if action in root.children.keys()
else numpy.NaN
for action in self.config.action_space
for action in root.children.keys()
]
)
self.mcts_depth.append(mcts_info["max_tree_depth"])
Expand Down
18 changes: 6 additions & 12 deletions game_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,30 +12,24 @@ def __init__(self):
self.to_play_history = []
self.child_visits = []
self.root_values = []
self.sampled_actions_history = []
self.reanalysed_predicted_root_values = None
# For PER
self.priorities = None
self.game_priority = None

def store_search_statistics(self, root, action_space):
def store_search_statistics(self, root):
# Turn visit count from root into a policy
if root is not None:
sum_visits = sum(child.visit_count for child in root.children.values())
self.child_visits.append(
[
root.children[a].visit_count / sum_visits
if a in root.children
else 0
for a in action_space
]
)

self.child_visits.append([root.children[a].visit_count / sum_visits for a in root.children.keys()])
self.sampled_actions_history.append(root.sampled_actions)
self.root_values.append(root.value())
else:
self.root_values.append(None)

def get_stacked_observations(
self, index, num_stacked_observations, action_space_size
self, index, num_stacked_observations
):
"""
Generate a new observation with the observation at the index position
Expand All @@ -55,7 +49,7 @@ def get_stacked_observations(
[
numpy.ones_like(stacked_observations[0])
* self.action_history[past_observation_index + 1]
/ action_space_size
/ len(self.sampled_actions_history[past_observation_index + 1])
],
)
)
Expand Down
8 changes: 6 additions & 2 deletions games/breakout.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import gym
import numpy
import torch
from torch.distributions import Categorical

from .abstract_game import AbstractGame

Expand Down Expand Up @@ -54,7 +55,7 @@ def __init__(self):


### Network
self.network = "resnet" # "resnet" / "fullyconnected"
self.network = "sampled" # "resnet" / "fullyconnected" / "sampled"
self.support_size = 10 # Value and reward are scaled (with almost sqrt) and encoded on a vector with a range of -support_size to support_size. Choose it so that support_size <= sqrt(max(abs(discounted reward)))

# Residual Network
Expand All @@ -76,7 +77,10 @@ def __init__(self):
self.fc_value_layers = [] # Define the hidden layers in the value network
self.fc_policy_layers = [] # Define the hidden layers in the policy network


# Sampled
self.sample_size = 4
self.action_shape = [4]
self.policy_distribution = Categorical

### Training
self.results_path = pathlib.Path(__file__).resolve().parents[1] / "results" / pathlib.Path(__file__).stem / datetime.datetime.now().strftime("%Y-%m-%d--%H-%M-%S") # Path to store the model weights and TensorBoard logs
Expand Down
18 changes: 11 additions & 7 deletions games/cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ def __init__(self):


### Network
self.network = "fullyconnected" # "resnet" / "fullyconnected"
self.network = "sampled" # "resnet" / "fullyconnected"
self.support_size = 10 # Value and reward are scaled (with almost sqrt) and encoded on a vector with a range of -support_size to support_size. Choose it so that support_size <= sqrt(max(abs(discounted reward)))

# Residual Network
self.downsample = False # Downsample observations before representation network, False / "CNN" (lighter) / "resnet" (See paper appendix Network Architecture)
self.blocks = 1 # Number of blocks in the ResNet
Expand All @@ -66,18 +66,22 @@ def __init__(self):
# Fully Connected Network
self.encoding_size = 8
self.fc_representation_layers = [] # Define the hidden layers in the representation network
self.fc_dynamics_layers = [16] # Define the hidden layers in the dynamics network
self.fc_reward_layers = [16] # Define the hidden layers in the reward network
self.fc_value_layers = [16] # Define the hidden layers in the value network
self.fc_policy_layers = [16] # Define the hidden layers in the policy network
self.fc_dynamics_layers = [32] # Define the hidden layers in the dynamics network
self.fc_reward_layers = [32] # Define the hidden layers in the reward network
self.fc_value_layers = [32] # Define the hidden layers in the value network
self.fc_policy_layers = [128, 128] # Define the hidden layers in the policy network


# Sampled
self.sample_size = 50
self.action_shape = [2]
self.policy_distribution = torch.distributions.Categorical

### Training
self.results_path = pathlib.Path(__file__).resolve().parents[1] / "results" / pathlib.Path(__file__).stem / datetime.datetime.now().strftime("%Y-%m-%d--%H-%M-%S") # Path to store the model weights and TensorBoard logs
self.save_model = True # Save the checkpoint in results_path as model.checkpoint
self.training_steps = 10000 # Total number of training steps (ie weights update according to a batch)
self.batch_size = 128 # Number of parts of games to train on at each training step
self.batch_size = 256 # Number of parts of games to train on at each training step
self.checkpoint_interval = 10 # Number of training steps before using the model for self-playing
self.value_loss_weight = 1 # Scale the value loss to avoid overfitting of the value function, paper recommends 0.25 (See paper appendix Reanalyze)
self.train_on_gpu = torch.cuda.is_available() # Train on GPU if available
Expand Down
10 changes: 7 additions & 3 deletions games/connect4.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy
import torch
from torch.distributions import Categorical

from .abstract_game import AbstractGame

Expand Down Expand Up @@ -48,7 +49,7 @@ def __init__(self):


### Network
self.network = "resnet" # "resnet" / "fullyconnected"
self.network = "sampled" # "resnet" / "fullyconnected"
self.support_size = 10 # Value and reward are scaled (with almost sqrt) and encoded on a vector with a range of -support_size to support_size. Choose it so that support_size <= sqrt(max(abs(discounted reward)))

# Residual Network
Expand All @@ -70,14 +71,17 @@ def __init__(self):
self.fc_value_layers = [] # Define the hidden layers in the value network
self.fc_policy_layers = [] # Define the hidden layers in the policy network


# Sampled
self.sample_size = 7
self.action_shape = [7]
self.policy_distribution = Categorical

### Training
self.results_path = pathlib.Path(__file__).resolve().parents[1] / "results" / pathlib.Path(__file__).stem / datetime.datetime.now().strftime("%Y-%m-%d--%H-%M-%S") # Path to store the model weights and TensorBoard logs
self.save_model = True # Save the checkpoint in results_path as model.checkpoint
self.training_steps = 100000 # Total number of training steps (ie weights update according to a batch)
self.batch_size = 64 # Number of parts of games to train on at each training step
self.checkpoint_interval = 10 # Number of training steps before using the model for self-playing
self.checkpoint_interval = 200 # Number of training steps before using the model for self-playing
self.value_loss_weight = 0.25 # Scale the value loss to avoid overfitting of the value function, paper recommends 0.25 (See paper appendix Reanalyze)
self.train_on_gpu = torch.cuda.is_available() # Train on GPU if available

Expand Down
29 changes: 8 additions & 21 deletions mcts.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,27 +45,19 @@ def run(
.unsqueeze(0)
.to(next(model.parameters()).device)
)
(
root_predicted_value,
reward,
policy_logits,
hidden_state,
) = model.initial_inference(observation)
root_predicted_value, reward, policy_parameters, hidden_state = model.initial_inference(observation)
root_predicted_value = support_to_scalar(
root_predicted_value, self.config.support_size
).item()
reward = support_to_scalar(reward, self.config.support_size).item()
assert (
legal_actions
), f"Legal actions should not be an empty array. Got {legal_actions}."
assert set(legal_actions).issubset(
set(self.config.action_space)
), "Legal actions should be a subset of the action space."
sampled_actions = model.sample_actions(policy_parameters)
root.expand(
legal_actions,
sampled_actions,
to_play,
reward,
policy_logits,
hidden_state,
)

Expand Down Expand Up @@ -98,17 +90,17 @@ def run(
# Inside the search tree we use the dynamics function to obtain the next hidden
# state given an action and the previous hidden state
parent = search_path[-2]
value, reward, policy_logits, hidden_state = model.recurrent_inference(
value, reward, policy_parameters, hidden_state = model.recurrent_inference(
parent.hidden_state,
torch.tensor([[action]]).to(parent.hidden_state.device),
)
sampled_actions = model.sample_actions(policy_parameters)
value = support_to_scalar(value, self.config.support_size).item()
reward = support_to_scalar(reward, self.config.support_size).item()
node.expand(
self.config.action_space,
sampled_actions,
virtual_to_play,
reward,
policy_logits,
hidden_state,
)

Expand All @@ -130,13 +122,8 @@ def select_child(self, node, min_max_stats):
self.ucb_score(node, child, min_max_stats)
for action, child in node.children.items()
)
action = numpy.random.choice(
[
action
for action, child in node.children.items()
if self.ucb_score(node, child, min_max_stats) == max_ucb
]
)
actions = [action for action, child in node.children.items() if self.ucb_score(node, child, min_max_stats) == max_ucb]
action = actions[numpy.random.choice(range(len(actions)))]
return action, node.children[action]

def ucb_score(self, parent, child, min_max_stats):
Expand Down
2 changes: 1 addition & 1 deletion models/muzero_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __new__(cls, config):
config.action_shape,
config.encoding_size,
config.sample_size,
config.blocks,
config.policy_distribution,
config.fc_reward_layers,
config.fc_value_layers,
config.fc_policy_layers,
Expand Down
Loading

0 comments on commit c6ea136

Please sign in to comment.