From 987e6629e18d79835a3fa728d2a5ea7806509f7e Mon Sep 17 00:00:00 2001 From: Alessandro Palmas Date: Sun, 8 Oct 2023 00:55:55 -0400 Subject: [PATCH 1/2] Add integration test speed and adapt CICD --- .github/workflows/test.yaml | 2 +- tests/test_speed.py | 40 +++++++++++++++++++++---------------- 2 files changed, 24 insertions(+), 18 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 51f0845..a3a3e7a 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -16,5 +16,5 @@ jobs: - run: pytest tests/test_wrappers_settings.py - run: pytest tests/test_recording_settings.py - run: pytest tests/test_examples.py - - run: pytest tests/test_speed.py + - run: pytest -k "test_speed_gym_mock or test_speed_wrappers_mock" tests/test_speed.py # Run only mocked tests - run: pytest -k "test_random_gym_mock or test_random_wrappers_mock" tests/test_random.py # Run only mocked tests diff --git a/tests/test_speed.py b/tests/test_speed.py index 0ac53db..60245fb 100644 --- a/tests/test_speed.py +++ b/tests/test_speed.py @@ -13,8 +13,9 @@ def reject_outliers(data): filtered = [e for e in data if (u - 2 * s < e < u + 2 * s)] return filtered -def func(n_players, wrappers_settings, target_speed, mocker): - load_mocker(mocker) +def func(game_id, n_players, wrappers_settings, use_mocker, mocker): + if use_mocker is True: + load_mocker(mocker) try: # Settings if (n_players == 1): @@ -22,7 +23,10 @@ def func(n_players, wrappers_settings, target_speed, mocker): else: settings = EnvironmentSettingsMultiAgent() - env = diambra.arena.make("doapp", settings, wrappers_settings) + settings.step_ratio = 1 + settings.frame_shape = (128, 128, 1) + + env = diambra.arena.make(game_id, settings, wrappers_settings) observation, info = env.reset() n_step = 0 @@ -45,15 +49,7 @@ def func(n_players, wrappers_settings, target_speed, mocker): fps_val2 = reject_outliers(fps_val) avg_fps = np.mean(fps_val2) - print("Average speed = {} FPS, STD {} FPS".format(avg_fps, np.std(fps_val2))) - - if abs(avg_fps - target_speed) > target_speed * 0.025: - # TODO: restore when using a stable platform for testing with consistent measurement - #if avg_fps < target_speed: - # raise RuntimeError("Fps lower than expected: {} VS {}".format(avg_fps, target_speed)) - #else: - # warnings.warn(UserWarning("Fps higher than expected: {} VS {}".format(avg_fps, target_speed))) - warnings.warn(UserWarning("Fps different than expected: {} VS {}".format(avg_fps, target_speed))) + warnings.warn(UserWarning("Average speed = {} FPS, STD {} FPS".format(avg_fps, np.std(fps_val2)))) return 0 except Exception as e: @@ -61,14 +57,18 @@ def func(n_players, wrappers_settings, target_speed, mocker): return 1 n_players = [1, 2] -target_speeds = [400, 300] +game_ids = ["doapp", "sfiii3n", "tektagt", "umk3", "samsh5sp", "kof98umh"] @pytest.mark.parametrize("n_players", n_players) -def test_speed_gym(n_players, mocker): - assert func(n_players, WrappersSettings(), target_speeds[0], mocker) == 0 +def test_speed_gym_mock(n_players, mocker): + use_mocker = True + game_id = "doapp" + assert func(game_id, n_players, WrappersSettings(), use_mocker, mocker) == 0 @pytest.mark.parametrize("n_players", n_players) -def test_speed_wrappers(n_players, mocker): +def test_speed_wrappers_mock(n_players, mocker): + use_mocker = True + game_id = "doapp" # Env wrappers settings wrappers_settings = WrappersSettings() @@ -90,4 +90,10 @@ def test_speed_wrappers(n_players, mocker): wrappers_settings.filter_keys = ["stage", "timer", suffix + "own_side", suffix + "opp_side", suffix + "opp_character", suffix + "action"] - assert func(n_players, wrappers_settings, target_speeds[1], mocker) == 0 + assert func(game_id, n_players, wrappers_settings, use_mocker, mocker) == 0 + +@pytest.mark.parametrize("game_id", game_ids) +@pytest.mark.parametrize("n_players", n_players) +def test_speed_gym_integration(game_id, n_players, mocker): + use_mocker = False + assert func(game_id, n_players, WrappersSettings(), use_mocker, mocker) == 0 \ No newline at end of file From 8247a1a39a0ea3bc286285ab6a91507d0dee6e83 Mon Sep 17 00:00:00 2001 From: Alessandro Palmas Date: Sun, 8 Oct 2023 00:56:35 -0400 Subject: [PATCH 2/2] Add render_mode attribute to gym class, add optional splash screen, fix make env for custom env_address --- diambra/arena/arena_gym.py | 14 +++++++++++--- diambra/arena/engine/interface.py | 5 ----- diambra/arena/env_settings.py | 2 ++ diambra/arena/make_env.py | 18 ++++++++---------- 4 files changed, 21 insertions(+), 18 deletions(-) diff --git a/diambra/arena/arena_gym.py b/diambra/arena/arena_gym.py index 0710b6e..56db53c 100644 --- a/diambra/arena/arena_gym.py +++ b/diambra/arena/arena_gym.py @@ -16,16 +16,24 @@ class DiambraGymBase(gym.Env): _frame = None reward_normalization_value = 1.0 render_gui_started = False + render_mode = None def __init__(self, env_settings: Union[EnvironmentSettings, EnvironmentSettingsMultiAgent]): self.logger = logging.getLogger(__name__) super(DiambraGymBase, self).__init__() self.env_settings = env_settings + assert env_settings.render_mode is None or env_settings.render_mode in self.metadata["render_modes"] + self.render_mode = env_settings.render_mode # Launch DIAMBRA Engine self.arena_engine = DiambraEngine(env_settings.env_address, env_settings.grpc_timeout) + # Splash Screen + if 'DISPLAY' in os.environ and env_settings.splash_screen is True: + from .utils.splash_screen import SplashScreen + SplashScreen() + # Send environment settings, retrieve environment info self.env_info = self.arena_engine.env_init(self.env_settings.get_pb_request(init=True)) self.env_settings.finalize_init(self.env_info) @@ -120,7 +128,7 @@ def reset(self, seed: int = None, options: Dict[str, Any] = None): # Rendering the environment def render(self, wait_key=1): - if self.env_settings.render_mode == "human" and (sys.platform.startswith('linux') is False or 'DISPLAY' in os.environ): + if self.render_mode == "human" and (sys.platform.startswith('linux') is False or 'DISPLAY' in os.environ): try: if (self.render_gui_started is False): self.window_name = "[{}] DIAMBRA Arena - {} - ({})".format( @@ -134,7 +142,7 @@ def render(self, wait_key=1): return True except: return False - elif self.env_settings.render_mode == "rgb_array": + elif self.render_mode == "rgb_array": return self._frame # Print observation details to the console @@ -281,4 +289,4 @@ def _map_action_spaces_to_agents(self, values_dict): for idx, action_space in enumerate(self.env_settings.action_space): out_dict["agent_{}".format(idx)] = values_dict[action_space] - return out_dict \ No newline at end of file + return out_dict diff --git a/diambra/arena/engine/interface.py b/diambra/arena/engine/interface.py index f504550..0ac5f28 100644 --- a/diambra/arena/engine/interface.py +++ b/diambra/arena/engine/interface.py @@ -27,11 +27,6 @@ def __init__(self, env_address, grpc_timeout=60): self.logger.info("... done.") - # Splash Screen - if 'DISPLAY' in os.environ: - from ..utils.splash_screen import SplashScreen - SplashScreen() - # Send env settings, retrieve env info and int variables list [pb low level] def env_init(self, env_settings_pb): try: diff --git a/diambra/arena/env_settings.py b/diambra/arena/env_settings.py index 7390324..b824e48 100644 --- a/diambra/arena/env_settings.py +++ b/diambra/arena/env_settings.py @@ -54,6 +54,7 @@ class EnvironmentSettingsBase: disable_keyboard: bool = True disable_joystick: bool = True render_mode: Union[None, str] = None + splash_screen: bool = True rank: int = 0 env_address: str = None grpc_timeout: int = 600 @@ -184,6 +185,7 @@ def _sanity_check(self): check_val_in_list("frame_shape[2]", self.frame_shape[2], [0, 1]) if self.render_mode is not None: check_val_in_list("render_mode", self.render_mode, ["human", "rgb_array"]) + check_type("splash_screen", self.splash_screen, bool, admit_none=False) check_num_in_range("rank", self.rank, [0, MAX_VAL]) check_type("env_address", self.env_address, str) check_num_in_range("grpc_timeout", self.grpc_timeout, [0, 3600]) diff --git a/diambra/arena/make_env.py b/diambra/arena/make_env.py index d0b0778..bcaae96 100644 --- a/diambra/arena/make_env.py +++ b/diambra/arena/make_env.py @@ -8,7 +8,7 @@ def make(game_id, env_settings: Union[EnvironmentSettings, EnvironmentSettingsMultiAgent]=EnvironmentSettings(), wrappers_settings: WrappersSettings=WrappersSettings(), episode_recording_settings: RecordingSettings=RecordingSettings(), - render_mode: str=None, rank: int=0, log_level=logging.INFO): + render_mode: str=None, rank: int=0, env_addresses=["localhost:50051"], log_level=logging.INFO): """ Create a wrapped environment. :param seed: (int) the initial seed for RNG @@ -24,20 +24,18 @@ def make(game_id, env_settings: Union[EnvironmentSettings, EnvironmentSettingsMu env_settings.render_mode = render_mode # Check if DIAMBRA_ENVS var present - env_addresses = os.getenv("DIAMBRA_ENVS", "").split() - if len(env_addresses) >= 1: # If present + env_addresses_cli = os.getenv("DIAMBRA_ENVS", "").split() + if len(env_addresses_cli) >= 1: # If present # Check if there are at least n env_addresses as the prescribed rank - if len(env_addresses) < rank + 1: + if len(env_addresses_cli) < rank + 1: raise Exception("Rank of env client is higher than the available env_addresses servers:", - "# of env servers: {}".format(len(env_addresses)), + "# of env servers: {}".format(len(env_addresses_cli)), "# rank of client: {} (0-based index)".format(rank)) + env_addresses_list = env_addresses_cli else: # If not present, set default value - if env_settings.env_address is None: - env_addresses = ["localhost:50051"] - else: - env_addresses = [env_settings.env_address] + env_addresses_list = env_addresses - env_settings.env_address = env_addresses[rank] + env_settings.env_address = env_addresses_list[rank] env_settings.rank = rank # Make environment