Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates on Reanalyse / Sample Efficiency (Re-executing MCTS, Parallelization, Stabilization with a target model, etc.) #142

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion games/atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,9 @@ def __init__(self):
# Reanalyze (See paper appendix Reanalyse)
self.use_last_model_value = True # Use the last model to provide a fresher, stable n-step value (See paper appendix Reanalyze)
self.reanalyse_on_gpu = False

self.num_reanalyse_workers = 1
self.value_target_update_freq = 1 # Update frequency of the target model used to provide fresher value (and possibly policy) estimates
self.use_updated_mcts_value_targets = False # If True, root values targets are updated according to the re-execution of the MCTS (in this case, lagging parameters are used to run the MCTS to stabilize bootstrapping). Otherwise, a lagging value of the network (representation & value) is used to obtain the updated value targets.


### Adjust the self play / training ratio to avoid over/underfitting
Expand Down
4 changes: 3 additions & 1 deletion games/breakout.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,9 @@ def __init__(self):
# Reanalyze (See paper appendix Reanalyse)
self.use_last_model_value = False # Use the last model to provide a fresher, stable n-step value (See paper appendix Reanalyze)
self.reanalyse_on_gpu = False

self.num_reanalyse_workers = 1
self.value_target_update_freq = 1 # Update frequency of the target model used to provide fresher value (and possibly policy) estimates
self.use_updated_mcts_value_targets = False # If True, root values targets are updated according to the re-execution of the MCTS (in this case, lagging parameters are used to run the MCTS to stabilize bootstrapping). Otherwise, a lagging value of the network (representation & value) is used to obtain the updated value targets.


### Adjust the self play / training ratio to avoid over/underfitting
Expand Down
4 changes: 3 additions & 1 deletion games/cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,9 @@ def __init__(self):
# Reanalyze (See paper appendix Reanalyse)
self.use_last_model_value = True # Use the last model to provide a fresher, stable n-step value (See paper appendix Reanalyze)
self.reanalyse_on_gpu = False

self.num_reanalyse_workers = 1
self.value_target_update_freq = 1 # Update frequency of the target model used to provide fresher value (and possibly policy) estimates
self.use_updated_mcts_value_targets = False # If True, root values targets are updated according to the re-execution of the MCTS (in this case, lagging parameters are used to run the MCTS to stabilize bootstrapping). Otherwise, a lagging value of the network (representation & value) is used to obtain the updated value targets.


### Adjust the self play / training ratio to avoid over/underfitting
Expand Down
204 changes: 204 additions & 0 deletions games/cartpole_sample_efficient.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
import datetime
import os

import gym
import numpy
import torch

from .abstract_game import AbstractGame


class MuZeroConfig:
def __init__(self):
# More information is available here: https://github.com/werner-duvaud/muzero-general/wiki/Hyperparameter-Optimization

self.seed = 0 # Seed for numpy, torch and the game
self.max_num_gpus = None # Fix the maximum number of GPUs to use. It's usually faster to use a single GPU (set it to 1) if it has enough memory. None will use every GPUs available



### Game
self.observation_shape = (1, 1, 4) # Dimensions of the game observation, must be 3D (channel, height, width). For a 1D array, please reshape it to (1, 1, length of array)
self.action_space = list(range(2)) # Fixed list of all possible actions. You should only edit the length
self.players = list(range(1)) # List of players. You should only edit the length
self.stacked_observations = 0 # Number of previous observations and previous actions to add to the current observation

# Evaluate
self.muzero_player = 0 # Turn Muzero begins to play (0: MuZero plays first, 1: MuZero plays second)
self.opponent = None # Hard coded agent that MuZero faces to assess his progress in multiplayer games. It doesn't influence training. None, "random" or "expert" if implemented in the Game class



### Self-Play
self.num_workers = 1 # Number of simultaneous threads/workers self-playing to feed the replay buffer
self.selfplay_on_gpu = False
self.max_moves = 500 # Maximum number of moves if game is not finished before
self.num_simulations = 50 # Number of future moves self-simulated
self.discount = 0.997 # Chronological discount of the reward
self.temperature_threshold = None # Number of moves before dropping the temperature given by visit_softmax_temperature_fn to 0 (ie selecting the best action). If None, visit_softmax_temperature_fn is used every time

# Root prior exploration noise
self.root_dirichlet_alpha = 0.25
self.root_exploration_fraction = 0.25

# UCB formula
self.pb_c_base = 19652
self.pb_c_init = 1.25



### Network
self.network = "fullyconnected" # "resnet" / "fullyconnected"
self.support_size = 10 # Value and reward are scaled (with almost sqrt) and encoded on a vector with a range of -support_size to support_size. Choose it so that support_size <= sqrt(max(abs(discounted reward)))

# Residual Network
self.downsample = False # Downsample observations before representation network, False / "CNN" (lighter) / "resnet" (See paper appendix Network Architecture)
self.blocks = 1 # Number of blocks in the ResNet
self.channels = 2 # Number of channels in the ResNet
self.reduced_channels_reward = 2 # Number of channels in reward head
self.reduced_channels_value = 2 # Number of channels in value head
self.reduced_channels_policy = 2 # Number of channels in policy head
self.resnet_fc_reward_layers = [] # Define the hidden layers in the reward head of the dynamic network
self.resnet_fc_value_layers = [] # Define the hidden layers in the value head of the prediction network
self.resnet_fc_policy_layers = [] # Define the hidden layers in the policy head of the prediction network

# Fully Connected Network
self.encoding_size = 8
self.fc_representation_layers = [] # Define the hidden layers in the representation network
self.fc_dynamics_layers = [16] # Define the hidden layers in the dynamics network
self.fc_reward_layers = [16] # Define the hidden layers in the reward network
self.fc_value_layers = [16] # Define the hidden layers in the value network
self.fc_policy_layers = [16] # Define the hidden layers in the policy network



### Training
self.results_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../results", os.path.basename(__file__)[:-3], datetime.datetime.now().strftime("%Y-%m-%d--%H-%M-%S")) # Path to store the model weights and TensorBoard logs
self.save_model = True # Save the checkpoint in results_path as model.checkpoint
self.training_steps = 10000 # Total number of training steps (ie weights update according to a batch)
self.batch_size = 128 # Number of parts of games to train on at each training step
self.checkpoint_interval = 10 # Number of training steps before using the model for self-playing
self.value_loss_weight = 1 # Scale the value loss to avoid overfitting of the value function, paper recommends 0.25 (See paper appendix Reanalyze)
self.train_on_gpu = torch.cuda.is_available() # Train on GPU if available

self.optimizer = "Adam" # "Adam" or "SGD". Paper uses SGD
self.weight_decay = 1e-4 # L2 weights regularization
self.momentum = 0.9 # Used only if optimizer is SGD

# Exponential learning rate schedule
self.lr_init = 0.02 # Initial learning rate
self.lr_decay_rate = 1 # Set it to 1 to use a constant learning rate
self.lr_decay_steps = 1000



### Replay Buffer
self.replay_buffer_size = 500 # Number of self-play games to keep in the replay buffer
self.num_unroll_steps = 5 # Number of game moves to keep for every batch element
self.td_steps = 5 # Number of steps in the future to take into account for calculating the target value
self.PER = True # Prioritized Replay (See paper appendix Training), select in priority the elements in the replay buffer which are unexpected for the network
self.PER_alpha = 1.0 # How much prioritization is used, 0 corresponding to the uniform case, paper suggests 1

# Reanalyze (See paper appendix Reanalyse)
self.use_last_model_value = True # Use the last model to provide a fresher, stable n-step value (See paper appendix Reanalyze)
self.reanalyse_on_gpu = False
self.num_reanalyse_workers = 1
self.value_target_update_freq = 10 # Update frequency of the target model used to provide fresher value (and possibly policy) estimates
self.use_updated_mcts_value_targets = True # If True, root values targets are updated according to the re-execution of the MCTS (in this case, lagging parameters are used to run the MCTS to stabilize bootstrapping). Otherwise, a lagging value of the network (representation & value) is used to obtain the updated value targets.


### Adjust the self play / training ratio to avoid over/underfitting
self.self_play_delay = 0 # Number of seconds to wait after each played game
self.training_delay = 0 # Number of seconds to wait after each training step
self.ratio = 20 # Desired training steps per self played step ratio. Equivalent to a synchronous version, training can take much longer. Set it to None to disable it


def visit_softmax_temperature_fn(self, trained_steps):
"""
Parameter to alter the visit count distribution to ensure that the action selection becomes greedier as training progresses.
The smaller it is, the more likely the best action (ie with the highest visit count) is chosen.

Returns:
Positive float.
"""
if trained_steps < 0.5 * self.training_steps:
return 1.0
elif trained_steps < 0.75 * self.training_steps:
return 0.5
else:
return 0.25


class Game(AbstractGame):
"""
Game wrapper.
"""

def __init__(self, seed=None):
self.env = gym.make("CartPole-v1")
if seed is not None:
self.env.seed(seed)

def step(self, action):
"""
Apply action to the game.

Args:
action : action of the action_space to take.

Returns:
The new observation, the reward and a boolean if the game has ended.
"""
observation, reward, done, _ = self.env.step(action)
return numpy.array([[observation]]), reward, done

def legal_actions(self):
"""
Should return the legal actions at each turn, if it is not available, it can return
the whole action space. At each turn, the game have to be able to handle one of returned actions.

For complex game where calculating legal moves is too long, the idea is to define the legal actions
equal to the action space but to return a negative reward if the action is illegal.

Returns:
An array of integers, subset of the action space.
"""
return list(range(2))

def reset(self):
"""
Reset the game for a new game.

Returns:
Initial observation of the game.
"""
return numpy.array([[self.env.reset()]])

def close(self):
"""
Properly close the game.
"""
self.env.close()

def render(self):
"""
Display the game observation.
"""
self.env.render()
input("Press enter to take a step ")

def action_to_string(self, action_number):
"""
Convert an action number to a string representing the action.

Args:
action_number: an integer from the action space.

Returns:
String representing the action.
"""
actions = {
0: "Push cart to the left",
1: "Push cart to the right",
}
return f"{action_number}. {actions[action_number]}"
5 changes: 3 additions & 2 deletions games/connect4.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,9 @@ def __init__(self):
# Reanalyze (See paper appendix Reanalyse)
self.use_last_model_value = True # Use the last model to provide a fresher, stable n-step value (See paper appendix Reanalyze)
self.reanalyse_on_gpu = False


self.num_reanalyse_workers = 1
self.value_target_update_freq = 1 # Update frequency of the target model used to provide fresher value (and possibly policy) estimates
self.use_updated_mcts_value_targets = False # If True, root values targets are updated according to the re-execution of the MCTS (in this case, lagging parameters are used to run the MCTS to stabilize bootstrapping). Otherwise, a lagging value of the network (representation & value) is used to obtain the updated value targets.

### Adjust the self play / training ratio to avoid over/underfitting
self.self_play_delay = 0 # Number of seconds to wait after each played game
Expand Down
5 changes: 3 additions & 2 deletions games/gomoku.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,9 @@ def __init__(self):
# Reanalyze (See paper appendix Reanalyse)
self.use_last_model_value = False # Use the last model to provide a fresher, stable n-step value (See paper appendix Reanalyze)
self.reanalyse_on_gpu = False


self.num_reanalyse_workers = 1
self.value_target_update_freq = 1 # Update frequency of the target model used to provide fresher value (and possibly policy) estimates
self.use_updated_mcts_value_targets = False # If True, root values targets are updated according to the re-execution of the MCTS (in this case, lagging parameters are used to run the MCTS to stabilize bootstrapping). Otherwise, a lagging value of the network (representation & value) is used to obtain the updated value targets.

### Adjust the self play / training ratio to avoid over/underfitting
self.self_play_delay = 0 # Number of seconds to wait after each played game
Expand Down
4 changes: 3 additions & 1 deletion games/gridworld.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,9 @@ def __init__(self):
# Reanalyze (See paper appendix Reanalyse)
self.use_last_model_value = False # Use the last model to provide a fresher, stable n-step value (See paper appendix Reanalyze)
self.reanalyse_on_gpu = False

self.num_reanalyse_workers = 1
self.value_target_update_freq = 1 # Update frequency of the target model used to provide fresher value (and possibly policy) estimates
self.use_updated_mcts_value_targets = False # If True, root values targets are updated according to the re-execution of the MCTS (in this case, lagging parameters are used to run the MCTS to stabilize bootstrapping). Otherwise, a lagging value of the network (representation & value) is used to obtain the updated value targets.


### Adjust the self play / training ratio to avoid over/underfitting
Expand Down
4 changes: 3 additions & 1 deletion games/lunarlander.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,9 @@ def __init__(self):
# Reanalyze (See paper appendix Reanalyse)
self.use_last_model_value = True # Use the last model to provide a fresher, stable n-step value (See paper appendix Reanalyze)
self.reanalyse_on_gpu = False

self.num_reanalyse_workers = 1
self.value_target_update_freq = 1 # Update frequency of the target model used to provide fresher value (and possibly policy) estimates
self.use_updated_mcts_value_targets = False # If True, root values targets are updated according to the re-execution of the MCTS (in this case, lagging parameters are used to run the MCTS to stabilize bootstrapping). Otherwise, a lagging value of the network (representation & value) is used to obtain the updated value targets.


# Best known ratio for deterministic version: 0.8 --> 0.4 in 250 self played game (self_play_delay = 25 on GTX 1050Ti Max-Q).
Expand Down
4 changes: 3 additions & 1 deletion games/simple_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,9 @@ def __init__(self):
# Reanalyze (See paper appendix Reanalyse)
self.use_last_model_value = True # Use the last model to provide a fresher, stable n-step value (See paper appendix Reanalyze)
self.reanalyse_on_gpu = False

self.num_reanalyse_workers = 1
self.value_target_update_freq = 1 # Update frequency of the target model used to provide fresher value (and possibly policy) estimates
self.use_updated_mcts_value_targets = False # If True, root values targets are updated according to the re-execution of the MCTS (in this case, lagging parameters are used to run the MCTS to stabilize bootstrapping). Otherwise, a lagging value of the network (representation & value) is used to obtain the updated value targets.


### Adjust the self play / training ratio to avoid over/underfitting
Expand Down
4 changes: 3 additions & 1 deletion games/tictactoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,9 @@ def __init__(self):
# Reanalyze (See paper appendix Reanalyse)
self.use_last_model_value = True # Use the last model to provide a fresher, stable n-step value (See paper appendix Reanalyze)
self.reanalyse_on_gpu = False

self.num_reanalyse_workers = 1
self.value_target_update_freq = 1 # Update frequency of the target model used to provide fresher value (and possibly policy) estimates
self.use_updated_mcts_value_targets = False # If True, root values targets are updated according to the re-execution of the MCTS (in this case, lagging parameters are used to run the MCTS to stabilize bootstrapping). Otherwise, a lagging value of the network (representation & value) is used to obtain the updated value targets.


### Adjust the self play / training ratio to avoid over/underfitting
Expand Down
5 changes: 3 additions & 2 deletions games/twentyone.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,9 @@ def __init__(self):
# Reanalyze (See paper appendix Reanalyse)
self.use_last_model_value = True # Use the last model to provide a fresher, stable n-step value (See paper appendix Reanalyze)
self.reanalyse_on_gpu = False


self.num_reanalyse_workers = 1
self.value_target_update_freq = 1 # Update frequency of the target model used to provide fresher value (and possibly policy) estimates
self.use_updated_mcts_value_targets = False # If True, root values targets are updated according to the re-execution of the MCTS (in this case, lagging parameters are used to run the MCTS to stabilize bootstrapping). Otherwise, a lagging value of the network (representation & value) is used to obtain the updated value targets.

### Adjust the self play / training ratio to avoid over/underfitting
self.self_play_delay = 0 # Number of seconds to wait after each played game
Expand Down
Loading