From 519caec6803e2943665e9a1f5fc188fdcecd3e8c Mon Sep 17 00:00:00 2001 From: Alessandro Palmas Date: Tue, 28 Nov 2023 01:15:29 -0500 Subject: [PATCH 1/2] Fix seed management for SB and SB3 --- diambra/arena/arena_gym.py | 2 +- diambra/arena/stable_baselines/make_sb_env.py | 20 +++++++++--------- .../arena/stable_baselines3/make_sb3_env.py | 21 +++++++++---------- 3 files changed, 21 insertions(+), 22 deletions(-) diff --git a/diambra/arena/arena_gym.py b/diambra/arena/arena_gym.py index 56db53c5..58473e66 100644 --- a/diambra/arena/arena_gym.py +++ b/diambra/arena/arena_gym.py @@ -121,7 +121,7 @@ def get_cumulative_reward_bounds(self): def reset(self, seed: int = None, options: Dict[str, Any] = None): if options is None: options = {} - options["seed"] = seed + options["seed"] = seed if seed is None else seed + self.env_settings.rank request = self.env_settings.update_episode_settings(options) response = self.arena_engine.reset(request.episode_settings) return self._get_obs(response), self._get_info(response) diff --git a/diambra/arena/stable_baselines/make_sb_env.py b/diambra/arena/stable_baselines/make_sb_env.py index 9b26b847..294a95aa 100644 --- a/diambra/arena/stable_baselines/make_sb_env.py +++ b/diambra/arena/stable_baselines/make_sb_env.py @@ -36,16 +36,16 @@ def make_sb_env(game_id: str, env_settings: EnvironmentSettings=EnvironmentSetti num_envs = len(env_addresses) - # Seed management - if seed is None: - seed = int(time.time()) - env_settings.seed = seed - # Add the conversion from gymnasium to gym old_gym_wrapper = [OldGymWrapper, {}] wrappers_settings.wrappers.insert(0, old_gym_wrapper) - def _make_sb_env(rank): + def _make_sb_env(rank, seed): + # Seed management + if seed is None: + env_settings.seed = int(time.time()) + rank + else: + env_settings.seed = seed + rank def _init(): env = diambra.arena.make(game_id, env_settings, wrappers_settings, episode_recording_settings, render_mode, rank=rank) @@ -53,18 +53,18 @@ def _init(): env = Monitor(env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)), allow_early_resets=allow_early_resets) return env - set_global_seeds(seed) + set_global_seeds(env_settings.seed + rank) return _init # If not wanting vectorized envs if no_vec and num_envs == 1: - env = _make_sb_env(0)() + env = _make_sb_env(0, seed)() else: # When using one environment, no need to start subprocesses if num_envs == 1 or not use_subprocess: - env = DummyVecEnv([_make_sb_env(i + start_index) for i in range(num_envs)]) + env = DummyVecEnv([_make_sb_env(i + start_index, seed) for i in range(num_envs)]) else: - env = SubprocVecEnv([_make_sb_env(i + start_index) for i in range(num_envs)], + env = SubprocVecEnv([_make_sb_env(i + start_index, seed) for i in range(num_envs)], start_method=start_method) return env, num_envs diff --git a/diambra/arena/stable_baselines3/make_sb3_env.py b/diambra/arena/stable_baselines3/make_sb3_env.py index b554cd40..5f010ac7 100644 --- a/diambra/arena/stable_baselines3/make_sb3_env.py +++ b/diambra/arena/stable_baselines3/make_sb3_env.py @@ -34,34 +34,33 @@ def make_sb3_env(game_id: str, env_settings: EnvironmentSettings=EnvironmentSett num_envs = len(env_addresses) - # Seed management - if seed is None: - seed = int(time.time()) - env_settings.seed = seed - - def _make_sb3_env(rank): + def _make_sb3_env(rank, seed): + # Seed management + if seed is None: + env_settings.seed = int(time.time()) + rank + else: + env_settings.seed = seed + rank def _init(): env = diambra.arena.make(game_id, env_settings, wrappers_settings, episode_recording_settings, render_mode, rank=rank) - env.reset(seed=seed + rank) # Create log dir log_dir = os.path.join(log_dir_base, str(rank)) os.makedirs(log_dir, exist_ok=True) env = Monitor(env, log_dir, allow_early_resets=allow_early_resets) return env - set_random_seed(seed) + set_random_seed(env_settings.seed + rank) return _init # If not wanting vectorized envs if no_vec and num_envs == 1: - env = _make_sb3_env(0)() + env = _make_sb3_env(0, seed)() else: # When using one environment, no need to start subprocesses if num_envs == 1 or not use_subprocess: - env = DummyVecEnv([_make_sb3_env(i + start_index) for i in range(num_envs)]) + env = DummyVecEnv([_make_sb3_env(i + start_index, seed) for i in range(num_envs)]) else: - env = SubprocVecEnv([_make_sb3_env(i + start_index) for i in range(num_envs)], + env = SubprocVecEnv([_make_sb3_env(i + start_index, seed) for i in range(num_envs)], start_method=start_method) return env, num_envs From 463ea5c51831f4b1de94343e8eaaed76bfd92ef3 Mon Sep 17 00:00:00 2001 From: Alessandro Palmas Date: Wed, 29 Nov 2023 11:20:10 -0500 Subject: [PATCH 2/2] Minor refactor --- diambra/arena/stable_baselines/make_sb_env.py | 15 +++++++-------- diambra/arena/stable_baselines3/make_sb3_env.py | 9 ++++----- 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/diambra/arena/stable_baselines/make_sb_env.py b/diambra/arena/stable_baselines/make_sb_env.py index 294a95aa..d66b1ea3 100644 --- a/diambra/arena/stable_baselines/make_sb_env.py +++ b/diambra/arena/stable_baselines/make_sb_env.py @@ -19,9 +19,9 @@ def make_sb_env(game_id: str, env_settings: EnvironmentSettings=EnvironmentSetti """ Create a wrapped, monitored VecEnv. :param game_id: (str) the game environment ID - :param env_settings: (dict) parameters for DIAMBRA Arena environment - :param wrappers_settings: (dict) parameters for environment wrapping function - :param episode_recording_settings: (dict) parameters for environment recording wrapping function + :param env_settings: (EnvironmentSettings) parameters for DIAMBRA Arena environment + :param wrappers_settings: (WrappersSettings) parameters for environment wrapping function + :param episode_recording_settings: (RecordingSettings) parameters for environment recording wrapping function :param start_index: (int) start rank index :param allow_early_resets: (bool) allows early reset of the environment :param start_method: (str) method used to start the subprocesses. See SubprocVecEnv doc for more information @@ -42,10 +42,9 @@ def make_sb_env(game_id: str, env_settings: EnvironmentSettings=EnvironmentSetti def _make_sb_env(rank, seed): # Seed management - if seed is None: - env_settings.seed = int(time.time()) + rank - else: - env_settings.seed = seed + rank + env_settings.seed = int(time.time()) if seed is None else seed + env_settings.seed += rank + def _init(): env = diambra.arena.make(game_id, env_settings, wrappers_settings, episode_recording_settings, render_mode, rank=rank) @@ -53,7 +52,7 @@ def _init(): env = Monitor(env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)), allow_early_resets=allow_early_resets) return env - set_global_seeds(env_settings.seed + rank) + set_global_seeds(env_settings.seed) return _init # If not wanting vectorized envs diff --git a/diambra/arena/stable_baselines3/make_sb3_env.py b/diambra/arena/stable_baselines3/make_sb3_env.py index 5f010ac7..3c94a89b 100644 --- a/diambra/arena/stable_baselines3/make_sb3_env.py +++ b/diambra/arena/stable_baselines3/make_sb3_env.py @@ -36,10 +36,9 @@ def make_sb3_env(game_id: str, env_settings: EnvironmentSettings=EnvironmentSett def _make_sb3_env(rank, seed): # Seed management - if seed is None: - env_settings.seed = int(time.time()) + rank - else: - env_settings.seed = seed + rank + env_settings.seed = int(time.time()) if seed is None else seed + env_settings.seed += rank + def _init(): env = diambra.arena.make(game_id, env_settings, wrappers_settings, episode_recording_settings, render_mode, rank=rank) @@ -49,7 +48,7 @@ def _init(): os.makedirs(log_dir, exist_ok=True) env = Monitor(env, log_dir, allow_early_resets=allow_early_resets) return env - set_random_seed(env_settings.seed + rank) + set_random_seed(env_settings.seed) return _init # If not wanting vectorized envs