Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix seed management for SB and SB3 #99

Merged
merged 2 commits into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion diambra/arena/arena_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def get_cumulative_reward_bounds(self):
def reset(self, seed: int = None, options: Dict[str, Any] = None):
if options is None:
options = {}
options["seed"] = seed
options["seed"] = seed if seed is None else seed + self.env_settings.rank
request = self.env_settings.update_episode_settings(options)
response = self.arena_engine.reset(request.episode_settings)
return self._get_obs(response), self._get_info(response)
Expand Down
25 changes: 12 additions & 13 deletions diambra/arena/stable_baselines/make_sb_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ def make_sb_env(game_id: str, env_settings: EnvironmentSettings=EnvironmentSetti
"""
Create a wrapped, monitored VecEnv.
:param game_id: (str) the game environment ID
:param env_settings: (dict) parameters for DIAMBRA Arena environment
:param wrappers_settings: (dict) parameters for environment wrapping function
:param episode_recording_settings: (dict) parameters for environment recording wrapping function
:param env_settings: (EnvironmentSettings) parameters for DIAMBRA Arena environment
:param wrappers_settings: (WrappersSettings) parameters for environment wrapping function
:param episode_recording_settings: (RecordingSettings) parameters for environment recording wrapping function
:param start_index: (int) start rank index
:param allow_early_resets: (bool) allows early reset of the environment
:param start_method: (str) method used to start the subprocesses. See SubprocVecEnv doc for more information
Expand All @@ -36,35 +36,34 @@ def make_sb_env(game_id: str, env_settings: EnvironmentSettings=EnvironmentSetti

num_envs = len(env_addresses)

# Seed management
if seed is None:
seed = int(time.time())
env_settings.seed = seed

# Add the conversion from gymnasium to gym
old_gym_wrapper = [OldGymWrapper, {}]
wrappers_settings.wrappers.insert(0, old_gym_wrapper)

def _make_sb_env(rank):
def _make_sb_env(rank, seed):
# Seed management
env_settings.seed = int(time.time()) if seed is None else seed
env_settings.seed += rank

def _init():
env = diambra.arena.make(game_id, env_settings, wrappers_settings,
episode_recording_settings, render_mode, rank=rank)

env = Monitor(env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)),
allow_early_resets=allow_early_resets)
return env
set_global_seeds(seed)
set_global_seeds(env_settings.seed)
return _init

# If not wanting vectorized envs
if no_vec and num_envs == 1:
env = _make_sb_env(0)()
env = _make_sb_env(0, seed)()
else:
# When using one environment, no need to start subprocesses
if num_envs == 1 or not use_subprocess:
env = DummyVecEnv([_make_sb_env(i + start_index) for i in range(num_envs)])
env = DummyVecEnv([_make_sb_env(i + start_index, seed) for i in range(num_envs)])
else:
env = SubprocVecEnv([_make_sb_env(i + start_index) for i in range(num_envs)],
env = SubprocVecEnv([_make_sb_env(i + start_index, seed) for i in range(num_envs)],
start_method=start_method)

return env, num_envs
Expand Down
18 changes: 8 additions & 10 deletions diambra/arena/stable_baselines3/make_sb3_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,34 +34,32 @@ def make_sb3_env(game_id: str, env_settings: EnvironmentSettings=EnvironmentSett

num_envs = len(env_addresses)

# Seed management
if seed is None:
seed = int(time.time())
env_settings.seed = seed
def _make_sb3_env(rank, seed):
# Seed management
env_settings.seed = int(time.time()) if seed is None else seed
env_settings.seed += rank

def _make_sb3_env(rank):
def _init():
env = diambra.arena.make(game_id, env_settings, wrappers_settings,
episode_recording_settings, render_mode, rank=rank)
env.reset(seed=seed + rank)

# Create log dir
log_dir = os.path.join(log_dir_base, str(rank))
os.makedirs(log_dir, exist_ok=True)
env = Monitor(env, log_dir, allow_early_resets=allow_early_resets)
return env
set_random_seed(seed)
set_random_seed(env_settings.seed)
return _init

# If not wanting vectorized envs
if no_vec and num_envs == 1:
env = _make_sb3_env(0)()
env = _make_sb3_env(0, seed)()
else:
# When using one environment, no need to start subprocesses
if num_envs == 1 or not use_subprocess:
env = DummyVecEnv([_make_sb3_env(i + start_index) for i in range(num_envs)])
env = DummyVecEnv([_make_sb3_env(i + start_index, seed) for i in range(num_envs)])
else:
env = SubprocVecEnv([_make_sb3_env(i + start_index) for i in range(num_envs)],
env = SubprocVecEnv([_make_sb3_env(i + start_index, seed) for i in range(num_envs)],
start_method=start_method)

return env, num_envs
Loading