Skip to content

Commit

Permalink
Minor refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
alexpalms committed Nov 29, 2023
1 parent 519caec commit 463ea5c
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 13 deletions.
15 changes: 7 additions & 8 deletions diambra/arena/stable_baselines/make_sb_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ def make_sb_env(game_id: str, env_settings: EnvironmentSettings=EnvironmentSetti
"""
Create a wrapped, monitored VecEnv.
:param game_id: (str) the game environment ID
:param env_settings: (dict) parameters for DIAMBRA Arena environment
:param wrappers_settings: (dict) parameters for environment wrapping function
:param episode_recording_settings: (dict) parameters for environment recording wrapping function
:param env_settings: (EnvironmentSettings) parameters for DIAMBRA Arena environment
:param wrappers_settings: (WrappersSettings) parameters for environment wrapping function
:param episode_recording_settings: (RecordingSettings) parameters for environment recording wrapping function
:param start_index: (int) start rank index
:param allow_early_resets: (bool) allows early reset of the environment
:param start_method: (str) method used to start the subprocesses. See SubprocVecEnv doc for more information
Expand All @@ -42,18 +42,17 @@ def make_sb_env(game_id: str, env_settings: EnvironmentSettings=EnvironmentSetti

def _make_sb_env(rank, seed):
# Seed management
if seed is None:
env_settings.seed = int(time.time()) + rank
else:
env_settings.seed = seed + rank
env_settings.seed = int(time.time()) if seed is None else seed
env_settings.seed += rank

def _init():
env = diambra.arena.make(game_id, env_settings, wrappers_settings,
episode_recording_settings, render_mode, rank=rank)

env = Monitor(env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)),
allow_early_resets=allow_early_resets)
return env
set_global_seeds(env_settings.seed + rank)
set_global_seeds(env_settings.seed)
return _init

# If not wanting vectorized envs
Expand Down
9 changes: 4 additions & 5 deletions diambra/arena/stable_baselines3/make_sb3_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,9 @@ def make_sb3_env(game_id: str, env_settings: EnvironmentSettings=EnvironmentSett

def _make_sb3_env(rank, seed):
# Seed management
if seed is None:
env_settings.seed = int(time.time()) + rank
else:
env_settings.seed = seed + rank
env_settings.seed = int(time.time()) if seed is None else seed
env_settings.seed += rank

def _init():
env = diambra.arena.make(game_id, env_settings, wrappers_settings,
episode_recording_settings, render_mode, rank=rank)
Expand All @@ -49,7 +48,7 @@ def _init():
os.makedirs(log_dir, exist_ok=True)
env = Monitor(env, log_dir, allow_early_resets=allow_early_resets)
return env
set_random_seed(env_settings.seed + rank)
set_random_seed(env_settings.seed)
return _init

# If not wanting vectorized envs
Expand Down

0 comments on commit 463ea5c

Please sign in to comment.