Skip to content

Commit

Permalink
Add render_mode attribute to gym class, add optional splash screen, f…
Browse files Browse the repository at this point in the history
…ix make env for custom env_address
  • Loading branch information
alexpalms committed Oct 8, 2023
1 parent d17f791 commit b643cf1
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 18 deletions.
14 changes: 11 additions & 3 deletions diambra/arena/arena_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,24 @@ class DiambraGymBase(gym.Env):
_frame = None
reward_normalization_value = 1.0
render_gui_started = False
render_mode = None

def __init__(self, env_settings: Union[EnvironmentSettings, EnvironmentSettingsMultiAgent]):
self.logger = logging.getLogger(__name__)
super(DiambraGymBase, self).__init__()

self.env_settings = env_settings
assert env_settings.render_mode is None or env_settings.render_mode in self.metadata["render_modes"]
self.render_mode = env_settings.render_mode

# Launch DIAMBRA Engine
self.arena_engine = DiambraEngine(env_settings.env_address, env_settings.grpc_timeout)

# Splash Screen
if 'DISPLAY' in os.environ and env_settings.splash_screen is True:
from .utils.splash_screen import SplashScreen
SplashScreen()

# Send environment settings, retrieve environment info
self.env_info = self.arena_engine.env_init(self.env_settings.get_pb_request(init=True))
self.env_settings.finalize_init(self.env_info)
Expand Down Expand Up @@ -120,7 +128,7 @@ def reset(self, seed: int = None, options: Dict[str, Any] = None):

# Rendering the environment
def render(self, wait_key=1):
if self.env_settings.render_mode == "human" and (sys.platform.startswith('linux') is False or 'DISPLAY' in os.environ):
if self.render_mode == "human" and (sys.platform.startswith('linux') is False or 'DISPLAY' in os.environ):
try:
if (self.render_gui_started is False):
self.window_name = "[{}] DIAMBRA Arena - {} - ({})".format(
Expand All @@ -134,7 +142,7 @@ def render(self, wait_key=1):
return True
except:
return False
elif self.env_settings.render_mode == "rgb_array":
elif self.render_mode == "rgb_array":
return self._frame

# Print observation details to the console
Expand Down Expand Up @@ -281,4 +289,4 @@ def _map_action_spaces_to_agents(self, values_dict):
for idx, action_space in enumerate(self.env_settings.action_space):
out_dict["agent_{}".format(idx)] = values_dict[action_space]

return out_dict
return out_dict
5 changes: 0 additions & 5 deletions diambra/arena/engine/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,6 @@ def __init__(self, env_address, grpc_timeout=60):

self.logger.info("... done.")

# Splash Screen
if 'DISPLAY' in os.environ:
from ..utils.splash_screen import SplashScreen
SplashScreen()

# Send env settings, retrieve env info and int variables list [pb low level]
def env_init(self, env_settings_pb):
try:
Expand Down
2 changes: 2 additions & 0 deletions diambra/arena/env_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class EnvironmentSettingsBase:
disable_keyboard: bool = True
disable_joystick: bool = True
render_mode: Union[None, str] = None
splash_screen: bool = True
rank: int = 0
env_address: str = None
grpc_timeout: int = 600
Expand Down Expand Up @@ -184,6 +185,7 @@ def _sanity_check(self):
check_val_in_list("frame_shape[2]", self.frame_shape[2], [0, 1])
if self.render_mode is not None:
check_val_in_list("render_mode", self.render_mode, ["human", "rgb_array"])
check_type("splash_screen", self.splash_screen, bool, admit_none=False)
check_num_in_range("rank", self.rank, [0, MAX_VAL])
check_type("env_address", self.env_address, str)
check_num_in_range("grpc_timeout", self.grpc_timeout, [0, 3600])
Expand Down
18 changes: 8 additions & 10 deletions diambra/arena/make_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

def make(game_id, env_settings: Union[EnvironmentSettings, EnvironmentSettingsMultiAgent]=EnvironmentSettings(),
wrappers_settings: WrappersSettings=WrappersSettings(), episode_recording_settings: RecordingSettings=RecordingSettings(),
render_mode: str=None, rank: int=0, log_level=logging.INFO):
render_mode: str=None, rank: int=0, env_addresses=["localhost:50051"], log_level=logging.INFO):
"""
Create a wrapped environment.
:param seed: (int) the initial seed for RNG
Expand All @@ -24,20 +24,18 @@ def make(game_id, env_settings: Union[EnvironmentSettings, EnvironmentSettingsMu
env_settings.render_mode = render_mode

# Check if DIAMBRA_ENVS var present
env_addresses = os.getenv("DIAMBRA_ENVS", "").split()
if len(env_addresses) >= 1: # If present
env_addresses_cli = os.getenv("DIAMBRA_ENVS", "").split()
if len(env_addresses_cli) >= 1: # If present
# Check if there are at least n env_addresses as the prescribed rank
if len(env_addresses) < rank + 1:
if len(env_addresses_cli) < rank + 1:
raise Exception("Rank of env client is higher than the available env_addresses servers:",
"# of env servers: {}".format(len(env_addresses)),
"# of env servers: {}".format(len(env_addresses_cli)),
"# rank of client: {} (0-based index)".format(rank))
env_addresses_list = env_addresses_cli
else: # If not present, set default value
if env_settings.env_address is None:
env_addresses = ["localhost:50051"]
else:
env_addresses = [env_settings.env_address]
env_addresses_list = env_addresses

env_settings.env_address = env_addresses[rank]
env_settings.env_address = env_addresses_list[rank]
env_settings.rank = rank

# Make environment
Expand Down

0 comments on commit b643cf1

Please sign in to comment.