Skip to content

Commit

Permalink
Update sb3 make function
Browse files Browse the repository at this point in the history
  • Loading branch information
alexpalms committed Sep 22, 2023
1 parent b1e24d2 commit 292e5b0
Showing 1 changed file with 11 additions and 8 deletions.
19 changes: 11 additions & 8 deletions diambra/arena/stable_baselines3/make_sb3_env.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
import os
import time
import diambra.arena
from diambra.arena import EnvironmentSettings, WrappersSettings, RecordingSettings

from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
from stable_baselines3.common.utils import set_random_seed

# Make Stable Baselines3 Env function
def make_sb3_env(game_id: str, env_settings: dict={}, wrappers_settings: dict={},
episode_recording_settings: dict={}, render_mode: str="rgb_array", seed: int=None,
start_index: int=0, allow_early_resets: bool=True, start_method: str=None,
no_vec: bool=False, use_subprocess: bool=True, log_dir_base: str="/tmp/DIAMBRALog/"):
def make_sb3_env(game_id: str, env_settings: EnvironmentSettings=EnvironmentSettings(),
wrappers_settings: WrappersSettings=WrappersSettings(),
episode_recording_settings: RecordingSettings=RecordingSettings(),
render_mode: str="rgb_array", seed: int=None, start_index: int=0,
allow_early_resets: bool=True, start_method: str=None, no_vec: bool=False,
use_subprocess: bool=True, log_dir_base: str="/tmp/DIAMBRALog/"):
"""
Create a wrapped, monitored VecEnv.
:param game_id: (str) the game environment ID
:param env_settings: (dict) parameters for DIAMBRA Arena environment
:param wrappers_settings: (dict) parameters for environment wrapping function
:param episode_recording_settings: (dict) parameters for environment recording wrapping function
:param env_settings: (EnvironmentSettings) parameters for DIAMBRA Arena environment
:param wrappers_settings: (WrappersSettings) parameters for environment wrapping function
:param episode_recording_settings: (RecordingSettings) parameters for environment recording wrapping function
:param start_index: (int) start rank index
:param allow_early_resets: (bool) allows early reset of the environment
:param start_method: (str) method used to start the subprocesses. See SubprocVecEnv doc for more information
Expand All @@ -34,7 +37,7 @@ def make_sb3_env(game_id: str, env_settings: dict={}, wrappers_settings: dict={}
# Seed management
if seed is None:
seed = int(time.time())
env_settings["seed"] = seed
env_settings.seed = seed

def _make_sb3_env(rank):
def _init():
Expand Down

0 comments on commit 292e5b0

Please sign in to comment.