Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optional splash #92

Merged
merged 2 commits into from
Oct 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@ jobs:
- run: pytest tests/test_wrappers_settings.py
- run: pytest tests/test_recording_settings.py
- run: pytest tests/test_examples.py
- run: pytest tests/test_speed.py
- run: pytest -k "test_speed_gym_mock or test_speed_wrappers_mock" tests/test_speed.py # Run only mocked tests
- run: pytest -k "test_random_gym_mock or test_random_wrappers_mock" tests/test_random.py # Run only mocked tests
14 changes: 11 additions & 3 deletions diambra/arena/arena_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,24 @@ class DiambraGymBase(gym.Env):
_frame = None
reward_normalization_value = 1.0
render_gui_started = False
render_mode = None

def __init__(self, env_settings: Union[EnvironmentSettings, EnvironmentSettingsMultiAgent]):
self.logger = logging.getLogger(__name__)
super(DiambraGymBase, self).__init__()

self.env_settings = env_settings
assert env_settings.render_mode is None or env_settings.render_mode in self.metadata["render_modes"]
self.render_mode = env_settings.render_mode

# Launch DIAMBRA Engine
self.arena_engine = DiambraEngine(env_settings.env_address, env_settings.grpc_timeout)

# Splash Screen
if 'DISPLAY' in os.environ and env_settings.splash_screen is True:
from .utils.splash_screen import SplashScreen
SplashScreen()

# Send environment settings, retrieve environment info
self.env_info = self.arena_engine.env_init(self.env_settings.get_pb_request(init=True))
self.env_settings.finalize_init(self.env_info)
Expand Down Expand Up @@ -120,7 +128,7 @@ def reset(self, seed: int = None, options: Dict[str, Any] = None):

# Rendering the environment
def render(self, wait_key=1):
if self.env_settings.render_mode == "human" and (sys.platform.startswith('linux') is False or 'DISPLAY' in os.environ):
if self.render_mode == "human" and (sys.platform.startswith('linux') is False or 'DISPLAY' in os.environ):
try:
if (self.render_gui_started is False):
self.window_name = "[{}] DIAMBRA Arena - {} - ({})".format(
Expand All @@ -134,7 +142,7 @@ def render(self, wait_key=1):
return True
except:
return False
elif self.env_settings.render_mode == "rgb_array":
elif self.render_mode == "rgb_array":
return self._frame

# Print observation details to the console
Expand Down Expand Up @@ -281,4 +289,4 @@ def _map_action_spaces_to_agents(self, values_dict):
for idx, action_space in enumerate(self.env_settings.action_space):
out_dict["agent_{}".format(idx)] = values_dict[action_space]

return out_dict
return out_dict
5 changes: 0 additions & 5 deletions diambra/arena/engine/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,6 @@ def __init__(self, env_address, grpc_timeout=60):

self.logger.info("... done.")

# Splash Screen
if 'DISPLAY' in os.environ:
from ..utils.splash_screen import SplashScreen
SplashScreen()

# Send env settings, retrieve env info and int variables list [pb low level]
def env_init(self, env_settings_pb):
try:
Expand Down
2 changes: 2 additions & 0 deletions diambra/arena/env_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class EnvironmentSettingsBase:
disable_keyboard: bool = True
disable_joystick: bool = True
render_mode: Union[None, str] = None
splash_screen: bool = True
rank: int = 0
env_address: str = None
grpc_timeout: int = 600
Expand Down Expand Up @@ -184,6 +185,7 @@ def _sanity_check(self):
check_val_in_list("frame_shape[2]", self.frame_shape[2], [0, 1])
if self.render_mode is not None:
check_val_in_list("render_mode", self.render_mode, ["human", "rgb_array"])
check_type("splash_screen", self.splash_screen, bool, admit_none=False)
check_num_in_range("rank", self.rank, [0, MAX_VAL])
check_type("env_address", self.env_address, str)
check_num_in_range("grpc_timeout", self.grpc_timeout, [0, 3600])
Expand Down
18 changes: 8 additions & 10 deletions diambra/arena/make_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

def make(game_id, env_settings: Union[EnvironmentSettings, EnvironmentSettingsMultiAgent]=EnvironmentSettings(),
wrappers_settings: WrappersSettings=WrappersSettings(), episode_recording_settings: RecordingSettings=RecordingSettings(),
render_mode: str=None, rank: int=0, log_level=logging.INFO):
render_mode: str=None, rank: int=0, env_addresses=["localhost:50051"], log_level=logging.INFO):
"""
Create a wrapped environment.
:param seed: (int) the initial seed for RNG
Expand All @@ -24,20 +24,18 @@ def make(game_id, env_settings: Union[EnvironmentSettings, EnvironmentSettingsMu
env_settings.render_mode = render_mode

# Check if DIAMBRA_ENVS var present
env_addresses = os.getenv("DIAMBRA_ENVS", "").split()
if len(env_addresses) >= 1: # If present
env_addresses_cli = os.getenv("DIAMBRA_ENVS", "").split()
if len(env_addresses_cli) >= 1: # If present
# Check if there are at least n env_addresses as the prescribed rank
if len(env_addresses) < rank + 1:
if len(env_addresses_cli) < rank + 1:
raise Exception("Rank of env client is higher than the available env_addresses servers:",
"# of env servers: {}".format(len(env_addresses)),
"# of env servers: {}".format(len(env_addresses_cli)),
"# rank of client: {} (0-based index)".format(rank))
env_addresses_list = env_addresses_cli
else: # If not present, set default value
if env_settings.env_address is None:
env_addresses = ["localhost:50051"]
else:
env_addresses = [env_settings.env_address]
env_addresses_list = env_addresses

env_settings.env_address = env_addresses[rank]
env_settings.env_address = env_addresses_list[rank]
env_settings.rank = rank

# Make environment
Expand Down
40 changes: 23 additions & 17 deletions tests/test_speed.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,20 @@ def reject_outliers(data):
filtered = [e for e in data if (u - 2 * s < e < u + 2 * s)]
return filtered

def func(n_players, wrappers_settings, target_speed, mocker):
load_mocker(mocker)
def func(game_id, n_players, wrappers_settings, use_mocker, mocker):
if use_mocker is True:
load_mocker(mocker)
try:
# Settings
if (n_players == 1):
settings = EnvironmentSettings()
else:
settings = EnvironmentSettingsMultiAgent()

env = diambra.arena.make("doapp", settings, wrappers_settings)
settings.step_ratio = 1
settings.frame_shape = (128, 128, 1)

env = diambra.arena.make(game_id, settings, wrappers_settings)
observation, info = env.reset()

n_step = 0
Expand All @@ -45,30 +49,26 @@ def func(n_players, wrappers_settings, target_speed, mocker):

fps_val2 = reject_outliers(fps_val)
avg_fps = np.mean(fps_val2)
print("Average speed = {} FPS, STD {} FPS".format(avg_fps, np.std(fps_val2)))

if abs(avg_fps - target_speed) > target_speed * 0.025:
# TODO: restore when using a stable platform for testing with consistent measurement
#if avg_fps < target_speed:
# raise RuntimeError("Fps lower than expected: {} VS {}".format(avg_fps, target_speed))
#else:
# warnings.warn(UserWarning("Fps higher than expected: {} VS {}".format(avg_fps, target_speed)))
warnings.warn(UserWarning("Fps different than expected: {} VS {}".format(avg_fps, target_speed)))
warnings.warn(UserWarning("Average speed = {} FPS, STD {} FPS".format(avg_fps, np.std(fps_val2))))

return 0
except Exception as e:
print(e)
return 1

n_players = [1, 2]
target_speeds = [400, 300]
game_ids = ["doapp", "sfiii3n", "tektagt", "umk3", "samsh5sp", "kof98umh"]

@pytest.mark.parametrize("n_players", n_players)
def test_speed_gym(n_players, mocker):
assert func(n_players, WrappersSettings(), target_speeds[0], mocker) == 0
def test_speed_gym_mock(n_players, mocker):
use_mocker = True
game_id = "doapp"
assert func(game_id, n_players, WrappersSettings(), use_mocker, mocker) == 0

@pytest.mark.parametrize("n_players", n_players)
def test_speed_wrappers(n_players, mocker):
def test_speed_wrappers_mock(n_players, mocker):
use_mocker = True
game_id = "doapp"

# Env wrappers settings
wrappers_settings = WrappersSettings()
Expand All @@ -90,4 +90,10 @@ def test_speed_wrappers(n_players, mocker):
wrappers_settings.filter_keys = ["stage", "timer", suffix + "own_side", suffix + "opp_side",
suffix + "opp_character", suffix + "action"]

assert func(n_players, wrappers_settings, target_speeds[1], mocker) == 0
assert func(game_id, n_players, wrappers_settings, use_mocker, mocker) == 0

@pytest.mark.parametrize("game_id", game_ids)
@pytest.mark.parametrize("n_players", n_players)
def test_speed_gym_integration(game_id, n_players, mocker):
use_mocker = False
assert func(game_id, n_players, WrappersSettings(), use_mocker, mocker) == 0
Loading