diff --git a/diambra/arena/arena_gym.py b/diambra/arena/arena_gym.py index 0710b6e4..56db53c5 100644 --- a/diambra/arena/arena_gym.py +++ b/diambra/arena/arena_gym.py @@ -16,16 +16,24 @@ class DiambraGymBase(gym.Env): _frame = None reward_normalization_value = 1.0 render_gui_started = False + render_mode = None def __init__(self, env_settings: Union[EnvironmentSettings, EnvironmentSettingsMultiAgent]): self.logger = logging.getLogger(__name__) super(DiambraGymBase, self).__init__() self.env_settings = env_settings + assert env_settings.render_mode is None or env_settings.render_mode in self.metadata["render_modes"] + self.render_mode = env_settings.render_mode # Launch DIAMBRA Engine self.arena_engine = DiambraEngine(env_settings.env_address, env_settings.grpc_timeout) + # Splash Screen + if 'DISPLAY' in os.environ and env_settings.splash_screen is True: + from .utils.splash_screen import SplashScreen + SplashScreen() + # Send environment settings, retrieve environment info self.env_info = self.arena_engine.env_init(self.env_settings.get_pb_request(init=True)) self.env_settings.finalize_init(self.env_info) @@ -120,7 +128,7 @@ def reset(self, seed: int = None, options: Dict[str, Any] = None): # Rendering the environment def render(self, wait_key=1): - if self.env_settings.render_mode == "human" and (sys.platform.startswith('linux') is False or 'DISPLAY' in os.environ): + if self.render_mode == "human" and (sys.platform.startswith('linux') is False or 'DISPLAY' in os.environ): try: if (self.render_gui_started is False): self.window_name = "[{}] DIAMBRA Arena - {} - ({})".format( @@ -134,7 +142,7 @@ def render(self, wait_key=1): return True except: return False - elif self.env_settings.render_mode == "rgb_array": + elif self.render_mode == "rgb_array": return self._frame # Print observation details to the console @@ -281,4 +289,4 @@ def _map_action_spaces_to_agents(self, values_dict): for idx, action_space in enumerate(self.env_settings.action_space): out_dict["agent_{}".format(idx)] = values_dict[action_space] - return out_dict \ No newline at end of file + return out_dict diff --git a/diambra/arena/engine/interface.py b/diambra/arena/engine/interface.py index f5045503..0ac5f28e 100644 --- a/diambra/arena/engine/interface.py +++ b/diambra/arena/engine/interface.py @@ -27,11 +27,6 @@ def __init__(self, env_address, grpc_timeout=60): self.logger.info("... done.") - # Splash Screen - if 'DISPLAY' in os.environ: - from ..utils.splash_screen import SplashScreen - SplashScreen() - # Send env settings, retrieve env info and int variables list [pb low level] def env_init(self, env_settings_pb): try: diff --git a/diambra/arena/env_settings.py b/diambra/arena/env_settings.py index 73903242..b824e48f 100644 --- a/diambra/arena/env_settings.py +++ b/diambra/arena/env_settings.py @@ -54,6 +54,7 @@ class EnvironmentSettingsBase: disable_keyboard: bool = True disable_joystick: bool = True render_mode: Union[None, str] = None + splash_screen: bool = True rank: int = 0 env_address: str = None grpc_timeout: int = 600 @@ -184,6 +185,7 @@ def _sanity_check(self): check_val_in_list("frame_shape[2]", self.frame_shape[2], [0, 1]) if self.render_mode is not None: check_val_in_list("render_mode", self.render_mode, ["human", "rgb_array"]) + check_type("splash_screen", self.splash_screen, bool, admit_none=False) check_num_in_range("rank", self.rank, [0, MAX_VAL]) check_type("env_address", self.env_address, str) check_num_in_range("grpc_timeout", self.grpc_timeout, [0, 3600]) diff --git a/diambra/arena/make_env.py b/diambra/arena/make_env.py index d0b0778e..bcaae96d 100644 --- a/diambra/arena/make_env.py +++ b/diambra/arena/make_env.py @@ -8,7 +8,7 @@ def make(game_id, env_settings: Union[EnvironmentSettings, EnvironmentSettingsMultiAgent]=EnvironmentSettings(), wrappers_settings: WrappersSettings=WrappersSettings(), episode_recording_settings: RecordingSettings=RecordingSettings(), - render_mode: str=None, rank: int=0, log_level=logging.INFO): + render_mode: str=None, rank: int=0, env_addresses=["localhost:50051"], log_level=logging.INFO): """ Create a wrapped environment. :param seed: (int) the initial seed for RNG @@ -24,20 +24,18 @@ def make(game_id, env_settings: Union[EnvironmentSettings, EnvironmentSettingsMu env_settings.render_mode = render_mode # Check if DIAMBRA_ENVS var present - env_addresses = os.getenv("DIAMBRA_ENVS", "").split() - if len(env_addresses) >= 1: # If present + env_addresses_cli = os.getenv("DIAMBRA_ENVS", "").split() + if len(env_addresses_cli) >= 1: # If present # Check if there are at least n env_addresses as the prescribed rank - if len(env_addresses) < rank + 1: + if len(env_addresses_cli) < rank + 1: raise Exception("Rank of env client is higher than the available env_addresses servers:", - "# of env servers: {}".format(len(env_addresses)), + "# of env servers: {}".format(len(env_addresses_cli)), "# rank of client: {} (0-based index)".format(rank)) + env_addresses_list = env_addresses_cli else: # If not present, set default value - if env_settings.env_address is None: - env_addresses = ["localhost:50051"] - else: - env_addresses = [env_settings.env_address] + env_addresses_list = env_addresses - env_settings.env_address = env_addresses[rank] + env_settings.env_address = env_addresses_list[rank] env_settings.rank = rank # Make environment