Skip to content

Commit

Permalink
Fix seed management for SB and SB3
Browse files Browse the repository at this point in the history
  • Loading branch information
alexpalms committed Nov 28, 2023
1 parent 9d041df commit 519caec
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 22 deletions.
2 changes: 1 addition & 1 deletion diambra/arena/arena_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def get_cumulative_reward_bounds(self):
def reset(self, seed: int = None, options: Dict[str, Any] = None):
if options is None:
options = {}
options["seed"] = seed
options["seed"] = seed if seed is None else seed + self.env_settings.rank
request = self.env_settings.update_episode_settings(options)
response = self.arena_engine.reset(request.episode_settings)
return self._get_obs(response), self._get_info(response)
Expand Down
20 changes: 10 additions & 10 deletions diambra/arena/stable_baselines/make_sb_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,35 +36,35 @@ def make_sb_env(game_id: str, env_settings: EnvironmentSettings=EnvironmentSetti

num_envs = len(env_addresses)

# Seed management
if seed is None:
seed = int(time.time())
env_settings.seed = seed

# Add the conversion from gymnasium to gym
old_gym_wrapper = [OldGymWrapper, {}]
wrappers_settings.wrappers.insert(0, old_gym_wrapper)

def _make_sb_env(rank):
def _make_sb_env(rank, seed):
# Seed management
if seed is None:
env_settings.seed = int(time.time()) + rank
else:
env_settings.seed = seed + rank
def _init():
env = diambra.arena.make(game_id, env_settings, wrappers_settings,
episode_recording_settings, render_mode, rank=rank)

env = Monitor(env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)),
allow_early_resets=allow_early_resets)
return env
set_global_seeds(seed)
set_global_seeds(env_settings.seed + rank)
return _init

# If not wanting vectorized envs
if no_vec and num_envs == 1:
env = _make_sb_env(0)()
env = _make_sb_env(0, seed)()
else:
# When using one environment, no need to start subprocesses
if num_envs == 1 or not use_subprocess:
env = DummyVecEnv([_make_sb_env(i + start_index) for i in range(num_envs)])
env = DummyVecEnv([_make_sb_env(i + start_index, seed) for i in range(num_envs)])
else:
env = SubprocVecEnv([_make_sb_env(i + start_index) for i in range(num_envs)],
env = SubprocVecEnv([_make_sb_env(i + start_index, seed) for i in range(num_envs)],
start_method=start_method)

return env, num_envs
Expand Down
21 changes: 10 additions & 11 deletions diambra/arena/stable_baselines3/make_sb3_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,34 +34,33 @@ def make_sb3_env(game_id: str, env_settings: EnvironmentSettings=EnvironmentSett

num_envs = len(env_addresses)

# Seed management
if seed is None:
seed = int(time.time())
env_settings.seed = seed

def _make_sb3_env(rank):
def _make_sb3_env(rank, seed):
# Seed management
if seed is None:
env_settings.seed = int(time.time()) + rank
else:
env_settings.seed = seed + rank
def _init():
env = diambra.arena.make(game_id, env_settings, wrappers_settings,
episode_recording_settings, render_mode, rank=rank)
env.reset(seed=seed + rank)

# Create log dir
log_dir = os.path.join(log_dir_base, str(rank))
os.makedirs(log_dir, exist_ok=True)
env = Monitor(env, log_dir, allow_early_resets=allow_early_resets)
return env
set_random_seed(seed)
set_random_seed(env_settings.seed + rank)
return _init

# If not wanting vectorized envs
if no_vec and num_envs == 1:
env = _make_sb3_env(0)()
env = _make_sb3_env(0, seed)()
else:
# When using one environment, no need to start subprocesses
if num_envs == 1 or not use_subprocess:
env = DummyVecEnv([_make_sb3_env(i + start_index) for i in range(num_envs)])
env = DummyVecEnv([_make_sb3_env(i + start_index, seed) for i in range(num_envs)])
else:
env = SubprocVecEnv([_make_sb3_env(i + start_index) for i in range(num_envs)],
env = SubprocVecEnv([_make_sb3_env(i + start_index, seed) for i in range(num_envs)],
start_method=start_method)

return env, num_envs

0 comments on commit 519caec

Please sign in to comment.