diff --git a/.github/workflows/docs-rebuild.yaml b/.github/workflows/docs-rebuild.yaml index 826d7265..21c9a27d 100644 --- a/.github/workflows/docs-rebuild.yaml +++ b/.github/workflows/docs-rebuild.yaml @@ -4,6 +4,7 @@ on: push: branches: - main + - release-2.1 jobs: docs-rebuild-deploy: diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index a71e673f..590c3776 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -9,10 +9,11 @@ jobs: - uses: actions/setup-python@v4 with: python-version: "3.x" - - run: python3 -m pip install wheel==0.38.4 + - run: python3 -m pip install wheel - run: python3 -m pip install .["tests"] - run: pytest tests/test_gym_settings.py - run: pytest tests/test_wrappers_settings.py - run: pytest tests/test_recording_settings.py - - run: pytest tests/test_random.py + - run: pytest tests/test_examples.py - run: pytest tests/test_speed.py + - run: pytest -k "test_random_gym_mock or test_random_wrappers_mock" tests/test_random.py # Run only mocked tests diff --git a/.github/workflows/test_agents.yaml b/.github/workflows/test_agents.yaml index e3abcf93..0fb84e1d 100644 --- a/.github/workflows/test_agents.yaml +++ b/.github/workflows/test_agents.yaml @@ -3,7 +3,7 @@ on: jobs: test: - uses: diambra/agents/.github/workflows/reusable-test.yaml@main + uses: diambra/agents/.github/workflows/reusable_unit_tests.yaml@main with: arena_requirement_specifier: 'git+https://github.com/diambra/arena.git@${{ github.ref }}#egg=' agents_ref: 'main' diff --git a/.vscode/launch.json b/.vscode/launch.json index f53bbeaf..0659802d 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -10,17 +10,183 @@ "diambra_arena_gist.py", "single_player_env.py", "multi_player_env.py", - "wrappers_options.py" + "wrappers_options.py", + "episode_recording.py", + "episode_data_loader.py", + ], + }, + { + "id": "pytest_script", + "type": "pickString", + "description": "Example script:", + "default": "diambra_arena_gist.py", + "options": [ + "test_gym_settings.py", + "test_imitation_learning.py", + "test_integration.py", + "test_random.py", + "test_recording_settings.py", + "test_speed.py", + "test_wrappers_settings.py", + ], + }, + { + "id": "game_id", + "type": "pickString", + "description": "Game ID:", + "default": "diambra_arena_gist.py", + "options": [ + "doapp", + "sfiii3n", + "tektagt", + "umk3", + "samsh5sp", + "kof98umh", + ], + }, + { + "id": "n_players", + "type": "pickString", + "description": "Number of Players:", + "default": "1", + "options": [ + "1", + "2", + ], + }, + { + "id": "role0", + "type": "pickString", + "description": "Role 0:", + "default": "Random", + "options": [ + "P1", + "P2", + "Random", + ], + }, + { + "id": "role1", + "type": "pickString", + "description": "Role 1:", + "default": "Random", + "options": [ + "P1", + "P2", + "Random", + ], + }, + { + "id": "continue_game", + "type": "pickString", + "description": "Continue Game:", + "default": "0.0", + "options": [ + "0.0", + "-1.0", + "0.5", + ], + }, + { + "id": "no_action", + "type": "pickString", + "description": "No Action:", + "default": "0", + "options": [ + { "label": "True", "value": "1" }, + { "label": "False", "value": "0" }, + + ], + }, + { + "id": "interactive", + "type": "pickString", + "description": "Interactive:", + "default": "0", + "options": [ + { "label": "True", "value": "1" }, + { "label": "False", "value": "0" }, + + ], + }, + { + "id": "wrappers", + "type": "pickString", + "description": "Wrappers Active:", + "default": "0", + "options": [ + { "label": "True", "value": "1" }, + { "label": "False", "value": "0" }, ], }, ], "configurations": [ { - "name": "[Arena] (conda diambra-arena) Script", + "name": "[Arena] (conda diambra-arena) Examples", "type": "python", "request": "launch", "program": "${workspaceFolder}/examples/${input:example_script}", "console": "integratedTerminal", + }, + { + "name": "[Arena] (conda diambra-arena) Tests", + "type": "python", + "request": "launch", + "program": "${workspaceFolder}/tests/${input:pytest_script}", + "console": "integratedTerminal", + }, + { + "name": "[Arena] (conda diambra-arena) Run Engine Mock", + "type": "python", + "request": "launch", + "program": "${workspaceFolder}/tests/run_engine_mock.py", + "args": [ + "--gameId", "${input:game_id}", + "--nPlayers", "${input:n_players}", + "--role0", "${input:role0}", + "--role1", "${input:role1}", + "--character0", "Random", + "--character1", "Random", + "--character0_2", "Random", + "--character1_2", "Random", + "--character0_3", "Random", + "--character1_3", "Random", + "--difficulty", "0", + "--stepRatio", "3", + //"--continueGame", "${input:continue_game}", + "--noAction", "${input:no_action}", + //"--interactive", "${input:interactive}", + "--render", "1", + ], + "console": "integratedTerminal", + }, + { + "name": "[Arena] (conda diambra-arena) Manual Random Test", + "type": "python", + "request": "launch", + "program": "${workspaceFolder}/tests/man_test_random.py", + "args": [ + "--gameId", "${input:game_id}", + "--nPlayers", "${input:n_players}", + "--role0", "${input:role0}", + "--role1", "${input:role1}", + "--character0", "Random", + "--character1", "Random", + "--character0_2", "Random", + "--character1_2", "Random", + "--character0_3", "Random", + "--character1_3", "Random", + "--difficulty", "0", + "--stepRatio", "3", + "--nEpisodes", "1", + "--continueGame", "${input:continue_game}", + "--actionSpace", "discrete", + "--noAction", "${input:no_action}", + "--interactive", "${input:interactive}", + "--render", "1", + "--wrappers", "${input:wrappers}", + ], + "console": "integratedTerminal", } ] } diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 00000000..6b8dbff2 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,32 @@ +{ + "python.testing.cwd": "${workspaceFolder}/tests/", + "python.testing.pytestArgs": [ + "--rootdir", + "${workspaceFolder}/tests", + "-s" + ], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true, + "editor.tokenColorCustomizations": { + "textMateRules": [ + { + "scope": "googletest.failed", + "settings": { + "foreground": "#f00" + } + }, + { + "scope": "googletest.passed", + "settings": { + "foreground": "#0f0" + } + }, + { + "scope": "googletest.run", + "settings": { + "foreground": "#0f0" + } + } + ] + } +} diff --git a/diambra/arena/__init__.py b/diambra/arena/__init__.py index 637c7e6d..f05349d4 100644 --- a/diambra/arena/__init__.py +++ b/diambra/arena/__init__.py @@ -1,3 +1,5 @@ +from diambra.engine import SpaceTypes, Roles +from diambra.engine import model +from .env_settings import EnvironmentSettings, EnvironmentSettingsMultiAgent, WrappersSettings, RecordingSettings, load_settings_flat_dict from .make_env import make -from .arena_imitation_learning_gym import ImitationLearning, ImitationLearningHardcore -from .utils.gym_utils import available_games, game_sha_256, check_game_sha_256, get_num_envs +from .utils.gym_utils import available_games, game_sha_256, check_game_sha_256, get_num_envs \ No newline at end of file diff --git a/diambra/arena/arena_gym.py b/diambra/arena/arena_gym.py index f029080d..0710b6e4 100644 --- a/diambra/arena/arena_gym.py +++ b/diambra/arena/arena_gym.py @@ -2,95 +2,97 @@ import os import sys import cv2 -import gym +import gymnasium as gym import logging -from gym import spaces from diambra.arena.utils.gym_utils import discrete_to_multi_discrete_action from diambra.arena.engine.interface import DiambraEngine -from diambra.arena.env_settings import EnvironmentSettings1P, EnvironmentSettings2P -from typing import Union - -# DIAMBRA Env Gym -class DiambraGymHardcoreBase(gym.Env): - """Diambra Environment gym interface""" - metadata = {'render.modes': ['human']} - - def __init__(self, env_settings: Union[EnvironmentSettings1P, EnvironmentSettings2P]): +from diambra.arena.env_settings import EnvironmentSettings, EnvironmentSettingsMultiAgent +from typing import Union, Any, Dict, List +from diambra.engine import model, SpaceTypes + +class DiambraGymBase(gym.Env): + """Diambra Environment gymnasium base interface""" + metadata = {"render_modes": ["human", "rgb_array"]} + _frame = None + reward_normalization_value = 1.0 + render_gui_started = False + + def __init__(self, env_settings: Union[EnvironmentSettings, EnvironmentSettingsMultiAgent]): self.logger = logging.getLogger(__name__) - super(DiambraGymHardcoreBase, self).__init__() - - self.reward_normalization_value = 1.0 - self.attack_but_combination = env_settings.attack_but_combination + super(DiambraGymBase, self).__init__() self.env_settings = env_settings - self.render_gui_started = False - # Launch DIAMBRA Arena + # Launch DIAMBRA Engine self.arena_engine = DiambraEngine(env_settings.env_address, env_settings.grpc_timeout) # Send environment settings, retrieve environment info - env_info_dict = self.arena_engine.env_init(self.env_settings) - self.env_info_process(env_info_dict) - self.player_side = self.env_settings.player - self.difficulty = self.env_settings.difficulty + self.env_info = self.arena_engine.env_init(self.env_settings.get_pb_request(init=True)) + self.env_settings.finalize_init(self.env_info) # Settings log self.logger.info(self.env_settings) - # Image as input: - self.observation_space = spaces.Box(low=0, high=255, - shape=(self.hwc_dim[0], - self.hwc_dim[1], - self.hwc_dim[2]), - dtype=np.uint8) - - # Processing Environment info - def env_info_process(self, env_info_dict): # N actions - self.n_actions_but_comb = env_info_dict["n_actions"][0] - self.n_actions_no_but_comb = env_info_dict["n_actions"][1] - # N actions - if self.env_settings.player == "P1P2": - self.n_actions = [self.n_actions_but_comb, self.n_actions_but_comb] - for idx in range(2): - if self.attack_but_combination[idx] is False: - self.n_actions[idx] = self.n_actions_no_but_comb - else: - self.n_actions = self.n_actions_but_comb - if self.attack_but_combination is False: - self.n_actions = self.n_actions_no_but_comb - - # Frame height, width and channel dimensions - self.hwc_dim = env_info_dict["frame_shape"] + self.n_actions = [self.env_info.available_actions.n_moves, self.env_info.available_actions.n_attacks] - # Maximum difference in players health - self.max_delta_health = env_info_dict["delta_health"] + # Actions tuples and dict + move_tuple = () + move_dict = {} + attack_tuple= () + attack_dict = {} - # Maximum number of stages (1P game vs COM) - self.max_stage = env_info_dict["max_stage"] + for idx in range(len(self.env_info.available_actions.moves)): + move_tuple += (self.env_info.available_actions.moves[idx].key,) + move_dict[idx] = self.env_info.available_actions.moves[idx].label - # Min-Max reward - self.cumulative_reward_bounds = env_info_dict["cumulative_reward_bounds"] + for idx in range(len(self.env_info.available_actions.attacks)): + attack_tuple += (self.env_info.available_actions.attacks[idx].key,) + attack_dict[idx] = self.env_info.available_actions.attacks[idx].label - # Characters names list - self.char_names = env_info_dict["char_list"] + self.actions_tuples = (move_tuple, attack_tuple) + self.print_actions_dict = [move_dict, attack_dict] - # Action list - self.action_list = (tuple(env_info_dict["actions_list"][0]), tuple(env_info_dict["actions_list"][1])) + # Maximum difference in players health + for k in sorted(self.env_info.ram_states_categories[model.RamStatesCategories.P1].ram_states.keys()): + key_enum_name = model.RamStates.Name(k) + if "health" in key_enum_name: + self.max_delta_health = self.env_info.ram_states_categories[model.RamStatesCategories.P1].ram_states[k].max - \ + self.env_info.ram_states_categories[model.RamStatesCategories.P1].ram_states[k].min + break + + # Observation space + # Dictionary + observation_space_dict = {} + observation_space_dict['frame'] = gym.spaces.Box(low=0, high=255, shape=(self.env_info.frame_shape.h, + self.env_info.frame_shape.w, + self.env_info.frame_shape.c), + dtype=np.uint8) + + # Adding RAM states observations + for k, v in self.env_info.ram_states_categories.items(): + if k == model.RamStatesCategories.common: + target_dict = observation_space_dict + else: + observation_space_dict[model.RamStatesCategories.Name(k)] = {} + target_dict = observation_space_dict[model.RamStatesCategories.Name(k)] + + for k2, v2 in v.ram_states.items(): + if v2.type == SpaceTypes.BINARY or v2.type == SpaceTypes.DISCRETE: + target_dict[model.RamStates.Name(k2)] = gym.spaces.Discrete(v2.max + 1) + elif v2.type == SpaceTypes.BOX: + target_dict[model.RamStates.Name(k2)] = gym.spaces.Box(low=v2.min, high=v2.max, shape=(1,), dtype=np.int16) + else: + raise RuntimeError("Only Discrete (Binary/Categorical) | Box Spaces allowed") - # Action dict - self.print_actions_dict = env_info_dict["actions_dict"] + for space_key in [model.RamStatesCategories.P1, model.RamStatesCategories.P2]: + observation_space_dict[model.RamStatesCategories.Name(space_key)] = gym.spaces.Dict(observation_space_dict[model.RamStatesCategories.Name(space_key)]) - # Ram states map - self.ram_states = {} - for k in sorted(env_info_dict["ram_states"].keys()): - self.ram_states[k] = [env_info_dict["ram_states"][k].type, - env_info_dict["ram_states"][k].min, - env_info_dict["ram_states"][k].max] + self.observation_space = gym.spaces.Dict(observation_space_dict) # Return env action list - def action_list(self): - return self.action_list + def get_actions_tuples(self): + return self.actions_tuples # Print Actions def print_actions(self): @@ -104,24 +106,21 @@ def print_actions(self): # Return cumulative reward bounds for the environment def get_cumulative_reward_bounds(self): - return [self.cumulative_reward_bounds[0] / (self.reward_normalization_value), - self.cumulative_reward_bounds[1] / (self.reward_normalization_value)] - - # Step method to be implemented in derived classes - def step(self, action): - raise NotImplementedError() + return [self.env_info.cumulative_reward_bounds.min / (self.reward_normalization_value), + self.env_info.cumulative_reward_bounds.max / (self.reward_normalization_value)] - # Resetting the environment - def reset(self): - cv2.destroyAllWindows() - self.render_gui_started = False - self.frame, data, self.player_side = self.arena_engine.reset() - return self.frame + # Reset the environment + def reset(self, seed: int = None, options: Dict[str, Any] = None): + if options is None: + options = {} + options["seed"] = seed + request = self.env_settings.update_episode_settings(options) + response = self.arena_engine.reset(request.episode_settings) + return self._get_obs(response), self._get_info(response) # Rendering the environment - def render(self, mode='human', wait_key=1): - - if mode == "human" and (sys.platform.startswith('linux') is False or 'DISPLAY' in os.environ): + def render(self, wait_key=1): + if self.env_settings.render_mode == "human" and (sys.platform.startswith('linux') is False or 'DISPLAY' in os.environ): try: if (self.render_gui_started is False): self.window_name = "[{}] DIAMBRA Arena - {} - ({})".format( @@ -130,59 +129,39 @@ def render(self, mode='human', wait_key=1): self.render_gui_started = True wait_key = 100 - cv2.imshow(self.window_name, self.frame[:, :, ::-1]) + cv2.imshow(self.window_name, self._frame[:, :, ::-1]) cv2.waitKey(wait_key) return True except: return False - elif mode == "rgb_array": - return self.frame + elif self.env_settings.render_mode == "rgb_array": + return self._frame # Print observation details to the console - def show_obs(self, observation, wait_key=1, viz=True): - + def show_obs(self, observation, wait_key=1, viz=True, string="observation", key=None, outermost=True): if type(observation) == dict: - for k, v in observation.items(): - if k != "frame": - if type(v) == dict: - for k2, v2 in v.items(): - if k2 == "actions": - - for k3, v3 in v2.items(): - out_value = v3 - additional_string = ": " - if type(v3) != int: - if self.env_settings.player == "P1P2": - n_actions = self.n_actions[0] if k == "P1" else self.n_actions[1] - else: - n_actions = self.n_actions - n_actions_stack = int(self.observation_space[k][k2][k3].n / (n_actions[0] if k3 == "move" else n_actions[1])) - out_value = np.reshape(v3, [n_actions_stack, -1]) - additional_string = " (reshaped for visualization):\n" - print("observation[\"{}\"][\"{}\"][\"{}\"]{}{}".format(k, k2, k3, additional_string, out_value)) - elif "ownChar" in k2 or "oppChar" in k2: - char_idx = v2 if type(v2) == int else np.where(v2 == 1)[0][0] - print("observation[\"{}\"][\"{}\"]: {} / {}".format(k, k2, v2, self.char_names[char_idx])) - else: - print("observation[\"{}\"][\"{}\"]: {}".format(k, k2, v2)) - else: - print("observation[\"{}\"]: {}".format(k, v)) + for k, v in sorted(observation.items()): + self.show_obs(v, wait_key=wait_key, viz=viz, string=string + "[\"{}\"]".format(k), key=k, outermost=False) + else: + if key != "frame": + if key.startswith("character"): + char_idx = observation if type(observation) == int else np.where(observation == 1)[0][0] + print(string + ": {} / {}".format(observation, self.env_info.characters_info.char_list[char_idx])) else: - frame = observation["frame"] - print("observation[\"frame\"]: shape {} - min {} - max {}".format(frame.shape, np.amin(frame), np.amax(frame))) + print(string + ": {}".format(observation)) + else: + print(string + ": shape {} - min {} - max {}".format(observation.shape, np.amin(observation), np.amax(observation))) - if viz: - frame = observation["frame"] - else: - if viz: - frame = observation + if viz is True and (sys.platform.startswith('linux') is False or 'DISPLAY' in os.environ): + try: + norm_factor = 255 if np.amax(observation) > 1.0 else 1.0 + for idx in range(observation.shape[2]): + cv2.imshow("[{}] Frame channel {}".format(os.getpid(), idx), observation[:, :, idx] / norm_factor) + except: + pass - if viz is True and (sys.platform.startswith('linux') is False or 'DISPLAY' in os.environ): + if outermost is True and viz is True and (sys.platform.startswith('linux') is False or 'DISPLAY' in os.environ): try: - norm_factor = 255 if np.amax(frame) > 1.0 else 1.0 - for idx in range(frame.shape[2]): - cv2.imshow("[{}] Frame channel {}".format(os.getpid(), idx), frame[:, :, idx] / norm_factor) - cv2.waitKey(wait_key) except: pass @@ -193,348 +172,113 @@ def close(self): cv2.destroyAllWindows() self.arena_engine.close() -# DIAMBRA Gym base class for single player mode -class DiambraGymHardcore1P(DiambraGymHardcoreBase): - def __init__(self, env_settings): - super().__init__(env_settings) - - # Define action and observation space - # They must be gym.spaces objects - - if env_settings.action_space == "multi_discrete": - # MultiDiscrete actions: - # - Arrows -> One discrete set - # - Buttons -> One discrete set - # NB: use the convention NOOP = 0, and buttons combinations - # can be prescripted: - # e.g. NOOP = [0], ButA = [1], ButB = [2], ButA+ButB = [3] - # or ignored: - # e.g. NOOP = [0], ButA = [1], ButB = [2] - self.action_space = spaces.MultiDiscrete(self.n_actions) - self.logger.debug("Using MultiDiscrete action space") - elif env_settings.action_space == "discrete": - # Discrete actions: - # - Arrows U Buttons -> One discrete set - # NB: use the convention NOOP = 0, and buttons combinations - # can be prescripted: - # e.g. NOOP = [0], ButA = [1], ButB = [2], ButA+ButB = [3] - # or ignored: - # e.g. NOOP = [0], ButA = [1], ButB = [2] - self.action_space = spaces.Discrete(self.n_actions[0] + self.n_actions[1] - 1) - self.logger.debug("Using Discrete action space") - else: - raise Exception("Not recognized action space: {}".format(env_settings.action_space)) - - # Step the environment - def step_complete(self, action): - # Actions initialization - mov_act = 0 - att_act = 0 - - # Defining move and attack actions P1/P2 as a function of action_space - - if isinstance(self.action_space, gym.spaces.MultiDiscrete): - mov_act = action[0] - att_act = action[1] - else: - # Discrete to multidiscrete conversion - mov_act, att_act = discrete_to_multi_discrete_action( - action, self.n_actions[0]) - - self.frame, reward, data = self.arena_engine.step_1p(mov_act, att_act) - done = data["ep_done"] - - return self.frame, reward, done, data - - # Step the environment - def step(self, action): - - self.frame, reward, done, data = self.step_complete(action) - - return self.frame, reward, done,\ - {"round_done": data["round_done"], "stage_done": data["stage_done"], - "game_done": data["game_done"], "ep_done": data["ep_done"], "env_done": data["env_done"]} - -# DIAMBRA Gym base class for two players mode -class DiambraGymHardcore2P(DiambraGymHardcoreBase): - def __init__(self, env_settings): - super().__init__(env_settings) - - # Define action spaces, they must be gym.spaces objects - action_space_dict = {} - for idx in range(2): - if env_settings.action_space[idx] == "multi_discrete": - action_space_dict["P{}".format(idx + 1)] =\ - spaces.MultiDiscrete(self.n_actions[idx]) - self.logger.debug("Using MultiDiscrete action space for P{}".format(idx + 1)) - elif env_settings.action_space[idx] == "discrete": - action_space_dict["P{}".format(idx + 1)] =\ - spaces.Discrete( - self.n_actions[idx][0] + self.n_actions[idx][1] - 1) - self.logger.debug("Using Discrete action space for P{}".format(idx + 1)) - else: - raise Exception("Not recognized action space: {}".format(env_settings.action_space[idx])) - - self.action_space = spaces.Dict(action_space_dict) - - # Step the environment - def step_complete(self, action): - # Actions initialization - mov_act_p1 = 0 - att_act_p1 = 0 - mov_act_p2 = 0 - att_act_p2 = 0 - - # Defining move and attack actions P1/P2 as a function of action_space - if isinstance(self.action_space["P1"], gym.spaces.MultiDiscrete): - # P1 - mov_act_p1 = action[0] - att_act_p1 = action[1] - # P2 - # P2 MultiDiscrete Action Space - if isinstance(self.action_space["P2"], gym.spaces.MultiDiscrete): - mov_act_p2 = action[2] - att_act_p2 = action[3] - else: # P2 Discrete Action Space - mov_act_p2, att_act_p2 = discrete_to_multi_discrete_action(action[2], self.n_actions[1][0]) - - else: # P1 Discrete Action Space - # P2 - # P2 MultiDiscrete Action Space - if isinstance(self.action_space["P2"], gym.spaces.MultiDiscrete): - # P1 - # Discrete to multidiscrete conversion - mov_act_p1, att_act_p1 = discrete_to_multi_discrete_action(action[0], self.n_actions[0][0]) - mov_act_p2 = action[1] - att_act_p2 = action[2] - else: # P2 Discrete Action Space - # P1 - # Discrete to multidiscrete conversion - mov_act_p1, att_act_p1 = discrete_to_multi_discrete_action(action[0], self.n_actions[0][0]) - mov_act_p2, att_act_p2 = discrete_to_multi_discrete_action(action[1], self.n_actions[1][0]) - - self.frame, reward, data = self.arena_engine.step_2p(mov_act_p1, att_act_p1, mov_act_p2, att_act_p2) - done = data["game_done"] - # data["ep_done"] = done - - return self.frame, reward, done, data - - # Step the environment - def step(self, action): - - self.frame, reward, done, data = self.step_complete(action) - - return self.frame, reward, done,\ - {"round_done": data["round_done"], "stage_done": data["stage_done"], - "game_done": data["game_done"], "ep_done": data["ep_done"], "env_done": data["env_done"]} - -# DIAMBRA Gym base class providing frame and additional info as observations -class DiambraGym1P(DiambraGymHardcore1P): - def __init__(self, env_settings): - super().__init__(env_settings) - - # Dictionary observation space - observation_space_dict = {} - observation_space_dict['frame'] = spaces.Box(low=0, high=255, - shape=(self.hwc_dim[0], - self.hwc_dim[1], - self.hwc_dim[2]), - dtype=np.uint8) - player_spec_dict = {} - - # Adding env additional observations (side-specific) - for k, v in self.ram_states.items(): + # Get frame + def _get_frame(self, response): + self._frame = np.frombuffer(response.observation.frame, dtype='uint8').reshape(self.env_info.frame_shape.h, \ + self.env_info.frame_shape.w, \ + self.env_info.frame_shape.c) + return self._frame - if k == "stage": - continue - - if k[-2:] == "P1": - knew = "own" + k[:-2] - else: - knew = "opp" + k[:-2] - - # Discrete spaces (binary / categorical) - if v[0] == 0 or v[0] == 2: - player_spec_dict[knew] = spaces.Discrete(v[2] + 1) - elif v[0] == 1: # Box spaces - player_spec_dict[knew] = spaces.Box(low=v[1], high=v[2], - shape=(1,), dtype=np.int32) - else: - raise RuntimeError( - "Only Discrete (Binary/Categorical) | Box Spaces allowed") - - actions_dict = { - "move": spaces.Discrete(self.n_actions[0]), - "attack": spaces.Discrete(self.n_actions[1]) - } - - player_spec_dict["actions"] = spaces.Dict(actions_dict) - observation_space_dict["P1"] = spaces.Dict(player_spec_dict) - observation_space_dict["stage"] = spaces.Box(low=self.ram_states["stage"][1], - high=self.ram_states["stage"][2], - shape=(1,), dtype=np.int8) - - self.observation_space = spaces.Dict(observation_space_dict) - - def ram_states_integration(self, frame, data): + # Get info + def _get_info(self, response): + info = {model.GameStates.Name(k): v for k, v in response.info.game_states.items()} + info["settings"] = self.env_settings.pb_model + return info + def _get_obs(self, response): observation = {} - observation["frame"] = frame - observation["stage"] = np.array([data["stage"]], dtype=np.int8) + observation["frame"] = self._get_frame(response) - player_spec_dict = {} - - # Adding env additional observations (side-specific) - for k, v in self.ram_states.items(): - - if k == "stage": - continue - - if k[-2:] == self.player_side: - knew = "own" + k[:-2] + # Adding RAM states observations + for k, v in self.env_info.ram_states_categories.items(): + if k == model.RamStatesCategories.common: + target_dict = observation else: - knew = "opp" + k[:-2] - - # Box spaces - if v[0] == 1: - player_spec_dict[knew] = np.array([data[k]], dtype=np.int32) - else: # Discrete spaces (binary / categorical) - player_spec_dict[knew] = data[k] + observation[model.RamStatesCategories.Name(k)] = {} + target_dict = observation[model.RamStatesCategories.Name(k)] - actions_dict = { - "move": data["moveAction{}".format(self.player_side)], - "attack": data["attackAction{}".format(self.player_side)], - } - - player_spec_dict["actions"] = actions_dict - observation["P1"] = player_spec_dict - - return observation + category_ram_states = response.observation.ram_states_categories[k] - def step(self, action): - - self.frame, reward, done, data = self.step_complete(action) - - observation = self.ram_states_integration(self.frame, data) - - return observation, reward, done,\ - {"round_done": data["round_done"], "stage_done": data["stage_done"], - "game_done": data["game_done"], "ep_done": data["ep_done"], "env_done": data["env_done"]} + for k2, v2 in v.ram_states.items(): + # Box spaces + if v2.type == SpaceTypes.BOX: + target_dict[model.RamStates.Name(k2)] = np.array([category_ram_states.ram_states[k2]]) + else: # Discrete spaces (binary / categorical) + target_dict[model.RamStates.Name(k2)] = category_ram_states.ram_states[k2] - # Reset the environment - def reset(self): - self.frame, data, self.player_side = self.arena_engine.reset() - observation = self.ram_states_integration(self.frame, data) return observation -# DIAMBRA Gym base class providing frame and additional info as observations -class DiambraGym2P(DiambraGymHardcore2P): - def __init__(self, env_settings): +class DiambraGym1P(DiambraGymBase): + """Diambra Environment gymnasium single agent interface""" + def __init__(self, env_settings: EnvironmentSettings): super().__init__(env_settings) - # Dictionary observation space - observation_space_dict = {} - observation_space_dict['frame'] = spaces.Box(low=0, high=255, - shape=(self.hwc_dim[0], - self.hwc_dim[1], - self.hwc_dim[2]), - dtype=np.uint8) - player_spec_dict = {} - - # Adding env additional observations (side-specific) - for k, v in self.ram_states.items(): - - if k == "stage": - continue - - if k[-2:] == "P1": - knew = "own" + k[:-2] - else: - knew = "opp" + k[:-2] - - if v[0] == 0 or v[0] == 2: # Discrete spaces - player_spec_dict[knew] = spaces.Discrete(v[2] + 1) - elif v[0] == 1: # Box spaces - player_spec_dict[knew] = spaces.Box(low=v[1], high=v[2], - shape=(1,), dtype=np.int32) - - else: - raise RuntimeError("Only Discrete and Box Spaces allowed") - - actions_dict = { - "move": spaces.Discrete(self.n_actions[0][0]), - "attack": spaces.Discrete(self.n_actions[0][1]) - } - - player_spec_dict["actions"] = spaces.Dict(actions_dict) - player_dict_p1 = player_spec_dict.copy() - observation_space_dict["P1"] = spaces.Dict(player_dict_p1) - - actions_dict = { - "move": spaces.Discrete(self.n_actions[1][0]), - "attack": spaces.Discrete(self.n_actions[1][1]) - } - - player_spec_dict["actions"] = spaces.Dict(actions_dict) - player_dict_p2 = player_spec_dict.copy() - observation_space_dict["P2"] = spaces.Dict(player_dict_p2) - - observation_space_dict["stage"] = spaces.Box(low=self.ram_states["stage"][1], - high=self.ram_states["stage"][2], - shape=(1,), dtype=np.int8) - - self.observation_space = spaces.Dict(observation_space_dict) - - def ram_states_integration(self, frame, data): - - observation = {} - observation["frame"] = frame - observation["stage"] = np.array([data["stage"]], dtype=np.int8) - - for ielem, elem in enumerate(["P1", "P2"]): - - player_spec_dict = {} - - # Adding env additional observations (side-specific) - for k, v in self.ram_states.items(): - - if k == "stage": - continue + # Action space + # MultiDiscrete actions: + # - Arrows -> One discrete set + # - Buttons -> One discrete set + # Discrete actions: + # - Arrows U Buttons -> One discrete set + # NB: use the convention NOOP = 0 + if env_settings.action_space == SpaceTypes.MULTI_DISCRETE: + self.action_space = gym.spaces.MultiDiscrete(self.n_actions) + elif env_settings.action_space == SpaceTypes.DISCRETE: + self.action_space = gym.spaces.Discrete(self.n_actions[0] + self.n_actions[1] - 1) + self.logger.debug("Using {} action space".format(SpaceTypes.Name(env_settings.action_space))) + + # Return the no-op action + def get_no_op_action(self): + if isinstance(self.action_space, gym.spaces.MultiDiscrete): + return [0, 0] + else: + return 0 - if k[-2:] == elem: - knew = "own" + k[:-2] - else: - knew = "opp" + k[:-2] + # Step the environment + def step(self, action: Union[int, List[int]]): + # Defining move and attack actions P1/P2 as a function of action_space + if isinstance(self.action_space, gym.spaces.Discrete): + action = list(discrete_to_multi_discrete_action(action, self.n_actions[0])) + response = self.arena_engine.step([action]) - # Box spaces - if v[0] == 1: - player_spec_dict[knew] = np.array([data[k]], dtype=np.int32) - else: # Discrete spaces (binary / categorical) - player_spec_dict[knew] = data[k] + return self._get_obs(response), response.reward, response.info.game_states[model.GameStates.episode_done], False, self._get_info(response) - actions_dict = { - "move": data["moveAction{}".format(elem)], - "attack": data["attackAction{}".format(elem)], - } +class DiambraGym2P(DiambraGymBase): + """Diambra Environment gymnasium multi-agent interface""" + def __init__(self, env_settings: EnvironmentSettingsMultiAgent): + super().__init__(env_settings) - player_spec_dict["actions"] = actions_dict - observation[elem] = player_spec_dict + # Action space + # Dictionary + action_spaces_values = {SpaceTypes.MULTI_DISCRETE: gym.spaces.MultiDiscrete(self.n_actions), + SpaceTypes.DISCRETE: gym.spaces.Discrete(self.n_actions[0] + self.n_actions[1] - 1)} + action_space_dict = self._map_action_spaces_to_agents(action_spaces_values) + self.logger.debug("Using the following action spaces: {}".format(action_space_dict)) + self.action_space = gym.spaces.Dict(action_space_dict) - return observation + # Return the no-op action + def get_no_op_action(self): + no_op_values = {SpaceTypes.MULTI_DISCRETE: [0, 0], + SpaceTypes.DISCRETE: 0} + return self._map_action_spaces_to_agents(no_op_values) # Step the environment - def step(self, action): - - self.frame, reward, done, data = self.step_complete(action) + def step(self, actions: Dict[str, Union[int, List[int]]]): + # NOTE: the assumption in current interface is that we have actions sorted as agent's order + actions = sorted(actions.items()) + action_list = [[],[]] + for idx, action in enumerate(actions): + # Defining move and attack actions P1/P2 as a function of action_space + if isinstance(self.action_space[action[0]], gym.spaces.MultiDiscrete): + action_list[idx] = action[1] + else: + action_list[idx] = list(discrete_to_multi_discrete_action(action[1], self.n_actions[0])) + response = self.arena_engine.step(action_list) - observation = self.ram_states_integration(self.frame, data) + return self._get_obs(response), response.reward, response.info.game_states[model.GameStates.game_done], False, self._get_info(response) - return observation, reward, done,\ - {"round_done": data["round_done"], "stage_done": data["stage_done"], - "game_done": data["game_done"], "ep_done": data["ep_done"], "env_done": data["env_done"]} + def _map_action_spaces_to_agents(self, values_dict): + out_dict = {} + for idx, action_space in enumerate(self.env_settings.action_space): + out_dict["agent_{}".format(idx)] = values_dict[action_space] - # Reset the environment - def reset(self): - self.frame, data, self.player_side = self.arena_engine.reset() - observation = self.ram_states_integration(self.frame, data) - return observation + return out_dict \ No newline at end of file diff --git a/diambra/arena/arena_imitation_learning_gym.py b/diambra/arena/arena_imitation_learning_gym.py deleted file mode 100644 index 8c389d3a..00000000 --- a/diambra/arena/arena_imitation_learning_gym.py +++ /dev/null @@ -1,422 +0,0 @@ -import numpy as np -import gym -from gym import spaces -import pickle -import bz2 -import copy -import cv2 -import sys -import os -import logging -from .utils.splash_screen import SplashScreen -from .utils.gym_utils import standard_dict_to_gym_obs_dict,\ - discrete_to_multi_discrete_action -from typing import List - -# Diambra imitation learning environment - - -class ImitationLearningBase(gym.Env): - """Diambra Environment that follows gym interface""" - metadata = {'render.modes': ['human']} - - def __init__(self, traj_files_list: List[str], rank: int=0, total_cpus: int=1): - self.logger = logging.getLogger(__name__) - super(ImitationLearningBase, self).__init__() - - # Check for number of files - if total_cpus > len(traj_files_list): - raise Exception( - "Number of requested CPUs > number of " - "recorded experience available files") - - # Splash Screen - SplashScreen() - - # List of RL trajectories files - self.traj_files_list = traj_files_list - - # CPU rank for this env instance - self.rank = rank - self.total_cpus = total_cpus - - # Idx of trajectory file to read - self.traj_idx = self.rank - self.rl_traj_dict = None - - # Open the first file to retrieve env info: --- - tmp_rl_traj_file = self.traj_files_list[self.traj_idx] - - # Read compressed RL Traj file - infile = bz2.BZ2File(tmp_rl_traj_file, 'r') - self.tmp_rl_traj_dict = pickle.load(infile) - infile.close() - - # Observation and action space - self.frame_h = self.tmp_rl_traj_dict["frame_shp"][0] - self.frame_w = self.tmp_rl_traj_dict["frame_shp"][1] - self.frame_n_channels = self.tmp_rl_traj_dict["frame_shp"][2] - self.n_actions = self.tmp_rl_traj_dict["n_actions"] - # --- - - # Define action and observation space - # They must be gym.spaces objects - if self.tmp_rl_traj_dict["action_space"] == "multi_discrete": - # MultiDiscrete actions: - # - Arrows -> One discrete set - # - Buttons -> One discrete set - # NB: use the convention NOOP = 0, and buttons combinations - # can be prescripted: - # e.g. NOOP = [0], ButA = [1], ButB = [2], ButA+ButB = [3] - # or ignored: - # e.g. NOOP = [0], ButA = [1], ButB = [2] - self.action_space = spaces.MultiDiscrete(self.n_actions) - self.logger.debug("Using MultiDiscrete action space") - elif self.tmp_rl_traj_dict["action_space"] == "discrete": - # Discrete actions: - # - Arrows U Buttons -> One discrete set - # NB: use the convention NOOP = 0, and buttons combinations - # can be prescripted: - # e.g. NOOP = [0], ButA = [1], ButB = [2], ButA+ButB = [3] - # or ignored: - # e.g. NOOP = [0], ButA = [1], ButB = [2] - self.action_space = spaces.Discrete( - self.n_actions[0] + self.n_actions[1] - 1) - self.logger.debug("Using Discrete action space") - else: - raise Exception( - "Not recognized action space: {}".format(self.tmp_rl_traj_dict["action_space"])) - - # If run out of examples - self.exhausted = False - - # Reset flag - self.n_reset = 0 - - # Observations shift counter (for new round/stage/game) - self.shift_counter = 1 - - # Print Episode summary - def traj_summary(self): - - self.logger.info(self.rl_traj_dict.keys()) - - self.logger.info("Ep. length = {}".format(self.rl_traj_dict["ep_len"])) - - for key, value in self.rl_traj_dict.items(): - if type(value) == list and len(value) > 2: - self.logger.info("len({}): {}".format(key, len(value))) - else: - self.logger.info("{} : {}".format(key, value)) - - # Step the environment - def step(self, dummy_action): - - # Done retrieval - done = False - if self.step_idx == self.rl_traj_dict["ep_len"] - 1: - done = True - - # Done flags retrieval - done_flags = self.rl_traj_dict["done_flags"][self.step_idx] - - if (done_flags[0] or done_flags[1] or done_flags[2]) and not done: - self.shift_counter += self.frame_n_channels - 1 - - # Observation retrieval - observation = self.obs_retrieval() - - # Reward retrieval - reward = self.rl_traj_dict["rewards"][self.step_idx] - - # Action retrieval - action = self.rl_traj_dict["actions"][self.step_idx] - if isinstance(self.action_space, gym.spaces.Discrete): - action_new = discrete_to_multi_discrete_action(action, self.n_actions[0]) - else: - action_new = action - - action = [action_new[0], action_new[1]] - info = {} - info["action"] = action - info["round_done"] = done_flags[0] - info["stage_done"] = done_flags[1] - info["game_done"] = done_flags[2] - info["episode_done"] = done_flags[3] - - if np.any(done): - self.logger.info("(Rank {}) Episode done".format(self.rank)) - - # Update step idx - self.step_idx += 1 - - return observation, reward, done, info - - # Resetting the environment - def reset(self): - - # Reset run step - self.step_idx = 0 - - # Observations shift counter (for new round/stage/game) - self.shift_counter = 1 - - # Manage ignoreP2 flag for recorded P1P2 trajectory (e.g. when HUMvsAI) - if self.n_reset != 0 and self.rl_traj_dict["ignore_p2"] == 1: - - self.logger.debug("Skipping P2 trajectory for 2P games (e.g. HUMvsAI)") - # Resetting n_reset - self.n_reset = 0 - # Move traj idx to the next to be read - self.traj_idx += self.total_cpus - - # Check if run out of traj files - if self.traj_idx >= len(self.traj_files_list): - self.logger.info("(Rank {}) Resetting env".format(self.rank)) - self.exhausted = True - observation = {} - observation = self.black_screen(observation) - return observation - - if self.n_reset == 0: - rl_traj_file = self.traj_files_list[self.traj_idx] - - # Read compressed RL Traj file - infile = bz2.BZ2File(rl_traj_file, 'r') - self.rl_traj_dict = pickle.load(infile) - infile.close() - - # Storing env info - self.n_chars = len(self.rl_traj_dict["char_names"]) - self.char_names = self.rl_traj_dict["char_names"] - self.n_actions_stack = self.rl_traj_dict["n_actions_stack"] - self.player_side = self.rl_traj_dict["player_side"] - assert self.n_actions == self.rl_traj_dict["n_actions"],\ - "Recorded episode has {} actions".format( - self.rl_traj_dict["n_actions"]) - if isinstance(self.action_space, gym.spaces.Discrete): - assert self.rl_traj_dict["action_space"] == "discrete",\ - "Recorded episode has {} action space".format( - self.rl_traj_dict["action_space"]) - else: - assert self.rl_traj_dict["action_space"] == "multi_discrete",\ - "Recorded episode has {} action space".format( - self.rl_traj_dict["action_space"]) - - if self.player_side == "P1P2": - - self.logger.debug("Two players RL trajectory") - - if self.n_reset == 0: - # First reset for this trajectory - - self.logger.debug("Loading P1 data for 2P trajectory") - - # Generate P2 Experience from P1 one - self.generate_p2_experience_from_p1() - - # For each step, isolate P1 actions from P1P2 experience - for idx in range(self.rl_traj_dict["ep_len"]): - # Actions (inverting sides) - if self.rl_traj_dict["action_space"] == "discrete": - self.rl_traj_dict["actions"][idx] = self.rl_traj_dict["actions"][idx][0] - else: - self.rl_traj_dict["actions"][idx] = [self.rl_traj_dict["actions"][idx][0], - self.rl_traj_dict["actions"][idx][1]] - - # Update reset counter - self.n_reset += 1 - - else: - # Second reset for this trajectory - - self.logger.debug("Loading P2 data for 2P trajectory") - - # OverWrite P1 RL trajectory with the one calculated for P2 - self.rl_traj_dict = self.rl_traj_dict_p2 - - # Reset reset counter - self.n_reset = 0 - - # Move traj idx to the next to be read - self.traj_idx += self.total_cpus - - else: - - self.logger.debug("One player RL trajectory") - - # Move traj idx to the next to be read - self.traj_idx += self.total_cpus - - # Observation retrieval - observation = self.obs_retrieval(reset_shift=1) - - return observation - - # Generate P2 Experience from P1 one - def generate_p2_experience_from_p1(self): - - # Copy P1 Trajectory - self.rl_traj_dict_p2 = copy.deepcopy(self.rl_traj_dict) - - # For each step, convert P1 into P2 experience - for idx in range(self.rl_traj_dict["ep_len"]): - - # Rewards (inverting sign) - self.rl_traj_dict_p2["rewards"][idx] = - \ - self.rl_traj_dict["rewards"][idx] - - # Actions (inverting sides) - if self.rl_traj_dict["action_space"] == "discrete": - self.rl_traj_dict_p2["actions"][idx] = self.rl_traj_dict["actions"][idx][1] - else: - self.rl_traj_dict_p2["actions"][idx] = [self.rl_traj_dict["actions"][idx][2], - self.rl_traj_dict["actions"][idx][3]] - - # Rendering the environment - def render(self, mode='human'): - - if mode == "human": - window_name = "Diambra Imitation Learning Environment - {}".format( - self.rank) - cv2.namedWindow(window_name, cv2.WINDOW_GUI_NORMAL) - cv2.imshow(window_name, self.last_obs) - cv2.waitKey(1) - elif mode == "rgb_array": - output = np.expand_dims(self.last_obs, axis=2) - return output - - # Print observation details to the console - def show_obs(self, observation, wait_key=1, viz=True): - - if type(observation) == dict: - for k, v in observation.items(): - if k != "frame": - if type(v) == dict: - for k2, v2 in v.items(): - if k2 == "actions": - - for k3, v3 in v2.items(): - out_value = v3 - additional_string = ": " - if type(v3) != int: - n_actions_stack = int(self.observation_space[k][k2][k3].n / (self.n_actions[0] if k3 == "move" else self.n_actions[1])) - out_value = np.reshape(v3, [n_actions_stack, -1]) - additional_string = " (reshaped for visualization):\n" - print("observation[\"{}\"][\"{}\"][\"{}\"]{}{}".format(k, k2, k3, additional_string, out_value)) - elif "ownChar" in k2 or "oppChar" in k2: - char_idx = v2 if type(v2) == int else np.where(v2 == 1)[0][0] - print("observation[\"{}\"][\"{}\"]: {} / {}".format(k, k2, v2, self.char_names[char_idx])) - else: - print("observation[\"{}\"][\"{}\"]: {}".format(k, k2, v2)) - else: - print("observation[\"{}\"]: {}".format(k, v)) - else: - frame = observation["frame"] - print("observation[\"frame\"]: shape {} - min {} - max {}".format(frame.shape, np.amin(frame), np.amax(frame))) - - if viz: - frame = observation["frame"] - else: - if viz: - frame = observation - - if viz is True and (sys.platform.startswith('linux') is False or 'DISPLAY' in os.environ): - try: - norm_factor = 255 if np.amax(frame) > 1.0 else 1.0 - for idx in range(frame.shape[2]): - cv2.imshow("[{}] Frame channel {}".format(os.getpid(), idx), frame[:, :, idx] / norm_factor) - - cv2.waitKey(wait_key) - except: - pass - -# Diambra imitation learning environment -class ImitationLearningHardcore(ImitationLearningBase): - def __init__(self, traj_files_list: List[str], rank: int=0, total_cpus: int=1): - super().__init__(traj_files_list, rank, total_cpus) - - # Observation space - obs_space_bounds = self.tmp_rl_traj_dict["obs_space_bounds"] - - # Create the observation space - self.observation_space = spaces.Box(low=obs_space_bounds[0], - high=obs_space_bounds[1], - shape=(self.frame_h, self.frame_w, - self.frame_n_channels), - dtype=np.float32) - - # Specific observation retrieval - def obs_retrieval(self, reset_shift=0): - # Observation retrieval - observation = np.zeros((self.frame_h, self.frame_w, self.frame_n_channels)) - for iframe in range(self.frame_n_channels): - observation[:, :, iframe] = self.rl_traj_dict["frames"][self.step_idx + - self.shift_counter + iframe - reset_shift] - # Storing last observation for rendering - self.last_obs = observation[:, :, self.frame_n_channels - 1] - - return observation - - # Black screen - def black_screen(self, observation): - - observation = np.zeros((self.frame_h, self.frame_w, self.frame_n_channels)) - - return observation - -# Diambra imitation learning environment -class ImitationLearning(ImitationLearningBase): - def __init__(self, traj_files_list: List[str], rank: int=0, total_cpus: int=1): - super().__init__(traj_files_list, rank, total_cpus) - - # Observation space - player_side = self.tmp_rl_traj_dict["player_side"] - self.observation_space_dict = self.tmp_rl_traj_dict["observation_space_dict"] - # Remove P2 sub space from Obs Space - if player_side == "P1P2": - self.observation_space_dict.pop("P2") - - # Create the observation space - self.observation_space = standard_dict_to_gym_obs_dict( - self.observation_space_dict) - - # Specific observation retrieval - def obs_retrieval(self, reset_shift=0): - # Observation retrieval - observation = self.rl_traj_dict["ram_states"][self.step_idx + 1 - reset_shift].copy() - - # Frame - observation["frame"] = np.zeros( - (self.frame_h, self.frame_w, self.frame_n_channels)) - for iframe in range(self.frame_n_channels): - observation["frame"][:, :, iframe] = self.rl_traj_dict["frames"][self.step_idx + - self.shift_counter + iframe - reset_shift] - # Storing last observation for rendering - self.last_obs = observation["frame"][:, :, self.frame_n_channels - 1] - - return observation - - # Black screen - def black_screen(self, observation): - - observation["frame"] = np.zeros( - (self.frame_h, self.frame_w, self.frame_n_channels)) - - return observation - - # Generate P2 Experience from P1 one - def generate_p2_experience_from_p1(self): - - super().generate_p2_experience_from_p1() - - # Process Additiona Obs for P2 (copy them in P1 position) - for ram_states in self.rl_traj_dict_p2["ram_states"]: - ram_states.pop("P1") - ram_states["P1"] = ram_states.pop("P2") - ram_states["stage"] = 0 - - # Remove P2 info from P1 Observation - for ram_states in self.rl_traj_dict["ram_states"]: - ram_states.pop("P2") - ram_states["stage"] = 0 diff --git a/diambra/arena/engine/interface.py b/diambra/arena/engine/interface.py index d312be34..f5045503 100644 --- a/diambra/arena/engine/interface.py +++ b/diambra/arena/engine/interface.py @@ -32,117 +32,8 @@ def __init__(self, env_address, grpc_timeout=60): from ..utils.splash_screen import SplashScreen SplashScreen() - # Transforming env settings dict to pb request - def env_settings_to_pb_request(self, env_settings): - - frame_shape = { - "h": env_settings.frame_shape[0], - "w": env_settings.frame_shape[1], - "c": env_settings.frame_shape[2] - } - - if env_settings.player == "P1P2": - characters = { - "p1": [env_settings.characters[0][0], env_settings.characters[0][1], env_settings.characters[0][2]], - "p2": [env_settings.characters[1][0], env_settings.characters[1][1], env_settings.characters[1][2]] - } - outfits = { - "p1": env_settings.char_outfits[0], - "p2": env_settings.char_outfits[1] - } - action_spaces = { - "p1": model.ACTION_SPACE_DISCRETE if env_settings.action_space[0] == "discrete" else model.ACTION_SPACE_MULTI_DISCRETE, - "p2": model.ACTION_SPACE_DISCRETE if env_settings.action_space[1] == "discrete" else model.ACTION_SPACE_MULTI_DISCRETE, - } - attack_buttons_combinations = { - "p1": env_settings.attack_but_combination[0], - "p2": env_settings.attack_but_combination[1] - } - super_arts = { - "p1": env_settings.super_art[0], - "p2": env_settings.super_art[1] - } - fighting_styles = { - "p1": env_settings.fighting_style[0], - "p2": env_settings.fighting_style[1] - } - ultimate_styles = { - "p1": { - "dash": env_settings.ultimate_style[0][0], - "evade": env_settings.ultimate_style[0][1], - "bar": env_settings.ultimate_style[0][2] - }, - "p2": { - "dash": env_settings.ultimate_style[1][0], - "evade": env_settings.ultimate_style[1][1], - "bar": env_settings.ultimate_style[1][2] - } - } - else: - characters = { - "p1": [env_settings.characters[0], env_settings.characters[1], env_settings.characters[2]], - "p2": [env_settings.characters[0], env_settings.characters[1], env_settings.characters[2]] - } - outfits = { - "p1": env_settings.char_outfits, - "p2": env_settings.char_outfits - } - action_spaces = { - "p1": model.ACTION_SPACE_DISCRETE if env_settings.action_space == "discrete" else model.ACTION_SPACE_MULTI_DISCRETE, - "p2": model.ACTION_SPACE_DISCRETE if env_settings.action_space == "discrete" else model.ACTION_SPACE_MULTI_DISCRETE, - } - attack_buttons_combinations = { - "p1": env_settings.attack_but_combination, - "p2": env_settings.attack_but_combination - } - super_arts = { - "p1": env_settings.super_art, - "p2": env_settings.super_art - } - fighting_styles = { - "p1": env_settings.fighting_style, - "p2": env_settings.fighting_style - } - ultimate_styles = { - "p1": { - "dash": env_settings.ultimate_style[0], - "evade": env_settings.ultimate_style[1], - "bar": env_settings.ultimate_style[2] - }, - "p2": { - "dash": env_settings.ultimate_style[0], - "evade": env_settings.ultimate_style[1], - "bar": env_settings.ultimate_style[2] - } - } - - request = model.EnvSettings( - game_id=env_settings.game_id, - continue_game=env_settings.continue_game, - show_final=env_settings.show_final, - step_ratio=env_settings.step_ratio, - player=env_settings.player, - difficulty=env_settings.difficulty, - characters=characters, - outfits=outfits, - frame_shape=frame_shape, - action_spaces=action_spaces, - attack_buttons_combinations=attack_buttons_combinations, - hardcore=env_settings.hardcore, - disable_keyboard=env_settings.disable_keyboard, - disable_joystick=env_settings.disable_joystick, - rank=env_settings.rank, - random_seed=env_settings.seed, - super_arts=super_arts, - tower=env_settings.tower, - fighting_styles=fighting_styles, - ultimate_styles=ultimate_styles - ) - - return request - # Send env settings, retrieve env info and int variables list [pb low level] - def _env_init(self, env_settings_pb): + def env_init(self, env_settings_pb): try: response = self.client.EnvInit(env_settings_pb) except: @@ -155,112 +46,17 @@ def _env_init(self, env_settings_pb): return response - # Send env settings, retrieve env info and int variables list - def env_init(self, env_settings): - env_settings_pb = self.env_settings_to_pb_request(env_settings) - response = self._env_init(env_settings_pb) - - move_dict = {} - for idx in range(0, len(response.button_mapping.moves), 2): - move_dict[int(response.button_mapping.moves[idx])] = response.button_mapping.moves[idx + 1] - att_dict = {} - for idx in range(0, len(response.button_mapping.attacks), 2): - att_dict[int(response.button_mapping.attacks[idx])] = response.button_mapping.attacks[idx + 1] - - env_info_dict = { - "n_actions": [[response.available_actions.with_buttons_combination.moves, - response.available_actions.with_buttons_combination.attacks], - [response.available_actions.without_buttons_combination.moves, - response.available_actions.without_buttons_combination.attacks]], - "frame_shape": [response.frame_shape.h, response.frame_shape.w, response.frame_shape.c], - "delta_health": response.delta_health, - "max_stage": response.max_stage, - "cumulative_reward_bounds": [response.cumulative_reward_bounds.min, response.cumulative_reward_bounds.max], - "char_list": list(response.char_list), - "actions_list": [list(response.buttons.moves), list(response.buttons.attacks)], - "actions_dict": [move_dict, att_dict], - "ram_states": response.ram_states - } - - # Set frame size - self.height = env_info_dict["frame_shape"][0] - self.width = env_info_dict["frame_shape"][1] - self.n_chan = env_info_dict["frame_shape"][2] - self.frame_dim = self.height * self.width * self.n_chan - - return env_info_dict - - # Read data - def read_data(self, response): - - # Adding boolean flags - data = {"round_done": response.game_state.round_done, - "stage_done": response.game_state.stage_done, - "game_done": response.game_state.game_done, - "ep_done": response.game_state.episode_done, - "env_done": response.game_state.env_done} - - # Adding int variables - # Actions - data["moveActionP1"] = response.actions.p1.move - data["attackActionP1"] = response.actions.p1.attack - data["moveActionP2"] = response.actions.p2.move - data["attackActionP2"] = response.actions.p2.attack - - # Ram states - for k, v in response.ram_states.items(): - data[k] = v.val - - return data - - # Read frame - def read_frame(self, frame): - # return cv2.imdecode(np.frombuffer(frame, dtype='uint8'), - # cv2.IMREAD_COLOR) - return np.frombuffer(frame, dtype='uint8').reshape(self.height, - self.width, - self.n_chan) - # Reset the environment [pb low level] - def _reset(self): - return self.client.Reset(model.Empty()) + def reset(self, episode_settings): + return self.client.Reset(episode_settings) - # Reset the environment - def reset(self): - response = self._reset() - data = self.read_data(response) - frame = self.read_frame(response.frame) - return frame, data, response.player - - # Step the environment (1P) [pb low level] - def _step_1p(self, mov_p1, att_p1): + # Step the environment [pb low level] + def step(self, action_list): actions = model.Actions() - actions.p1.move = mov_p1 - actions.p1.attack = att_p1 - return self.client.Step1P(actions) - - # Step the environment (1P) - def step_1p(self, mov_p1, att_p1): - response = self._step_1p(mov_p1, att_p1) - data = self.read_data(response) - frame = self.read_frame(response.frame) - return frame, response.reward, data - - # Step the environment (2P) [pb low level] - def _step_2p(self, mov_p1, att_p1, mov_p2, att_p2): - actions = model.Actions() - actions.p1.move = mov_p1 - actions.p1.attack = att_p1 - actions.p2.move = mov_p2 - actions.p2.attack = att_p2 - return self.client.Step2P(actions) - - # Step the environment (2P) - def step_2p(self, mov_p1, att_p1, mov_p2, att_p2): - response = self._step_2p(mov_p1, att_p1, mov_p2, att_p2) - data = self.read_data(response) - frame = self.read_frame(response.frame) - return frame, response.reward, data + for action in action_list: + action = model.Actions.Action(move=action[0], attack=action[1]) + actions.actions.append(action) + return self.client.Step(actions) # Closing DIAMBRA Arena def close(self): diff --git a/diambra/arena/env_settings.py b/diambra/arena/env_settings.py index 251829fd..73903242 100644 --- a/diambra/arena/env_settings.py +++ b/diambra/arena/env_settings.py @@ -1,7 +1,15 @@ -from dataclasses import dataclass -from typing import Union, List, Tuple +from dataclasses import dataclass, field +from typing import Union, List, Tuple, Any, Dict from diambra.arena.utils.gym_utils import available_games +from diambra.arena import SpaceTypes, Roles import numpy as np +import random +from diambra.engine import model +import time +from dacite import from_dict, Config + +def load_settings_flat_dict(target_class, settings: dict): + return from_dict(target_class, settings, config=Config(strict=True)) MAX_VAL = float("inf") MIN_VAL = float("-inf") @@ -9,192 +17,420 @@ MAX_STACK_VALUE = 48 def check_num_in_range(key, value, bounds): - assert (value >= bounds[0] and value <= bounds[1]), "\"{}\" ({}) value must be in the range {}".format(key, value, bounds) + error_message = "ERROR: \"{}\" ({}) value must be in the range {}".format(key, value, bounds) + assert (value >= bounds[0] and value <= bounds[1]), error_message + assert (type(value)==type(bounds[0])), error_message def check_val_in_list(key, value, valid_list): - assert (value in valid_list), "\"{}\" ({}) admissible values are {}".format(key, value, valid_list) + error_message = "ERROR: \"{}\" ({}) admissible values are {}".format(key, value, valid_list) + assert (value in valid_list), error_message + assert (type(value)==type(valid_list[valid_list.index(value)])), error_message +def check_type(key, value, expected_type, admit_none=True): + error_message = "ERROR: \"{}\" ({}) is of type {}, admissible types are {}".format(key, value, type(value), expected_type) + if value is not None: + assert isinstance(value, expected_type), error_message + else: + assert admit_none==True, "ERROR: \"{}\" ({}) cannot be NoneType".format(key, value) -@dataclass -class EnvironmentSettings: +def check_space_type(key, value, valid_list): + error_message = "ERROR: \"{}\" ({}) admissible values are {}".format(key, SpaceTypes.Name(value), [SpaceTypes.Name(elem) for elem in valid_list]) + assert (value in valid_list), error_message - game_id: str +def check_roles(key, value, valid_list): + error_message = "ERROR: \"{}\" ({}) admissible values are {}".format(key, Roles.Name(value), [Roles.Name(elem) for elem in valid_list]) + assert (value in valid_list), error_message + +@dataclass +class EnvironmentSettingsBase: + """Generic Environment Settings Class""" + env_info = None + games_dict = None - # System level + # Environment settings + game_id: str = "doapp" + frame_shape: Tuple[int, int, int] = (0, 0, 0) step_ratio: int = 6 - disable_keyboard:bool = True + disable_keyboard: bool = True disable_joystick: bool = True + render_mode: Union[None, str] = None rank: int = 0 - seed: int = -1 - env_address: str = "localhost:50051" + env_address: str = None grpc_timeout: int = 600 - # Game level - player: str = "Random" + # Episode settings + seed: Union[None, str] = None + difficulty: Union[None, int] = None continue_game: float = 0.0 - show_final: bool = True - difficulty: int = 3 - frame_shape: Tuple[int, int, int] = (0, 0, 0) - - tower: int = 3 # UMK3 Specific - - # Environment level - hardcore: bool = False - - def sanity_check(self): + show_final: bool = False + tower: Union[None, int] = 3 # UMK3 Specific + + # Bookkeeping variables + _last_seed: Union[None, int] = None + pb_model: model = None + + episode_settings = ["seed", "difficulty", "continue_game", "show_final", "tower", "role", + "characters", "outfits", "super_art", "fighting_style", "ultimate_style"] + + # Transforming env settings dict to pb request + def get_pb_request(self, init=False): + frame_shape = { + "h": self.frame_shape[0], + "w": self.frame_shape[1], + "c": self.frame_shape[2] + } + + if self.seed == None: + self.seed = int(time.time()) + + if self._last_seed != self.seed: + random.seed(self.seed) + np.random.seed(self.seed) + self._last_seed = self.seed + + action_spaces = self._get_action_spaces() + + if init is False: + self._process_random_values() + + player_settings = self._get_player_specific_values() + + episode_settings = model.EnvSettings.EpisodeSettings( + random_seed=self.seed, + difficulty=self.difficulty, + continue_game=self.continue_game, + show_final=self.show_final, + tower=self.tower, + player_settings=player_settings, + ) + else: + episode_settings = model.EnvSettings.EpisodeSettings() + + request = model.EnvSettings( + game_id=self.game_id, + frame_shape=frame_shape, + step_ratio=self.step_ratio, + n_players=self.n_players, + disable_keyboard=self.disable_keyboard, + disable_joystick=self.disable_joystick, + rank=self.rank, + action_spaces=action_spaces, + episode_settings=episode_settings, + ) + + self.pb_model = request + + return request + + def finalize_init(self, env_info): + self.env_info = env_info self.games_dict = available_games(False) + # Create list of valid characters + self.valid_characters = [character for character in self.env_info.characters_info.char_list \ + if character not in self.env_info.characters_info.char_forbidden_list] + + def update_episode_settings(self, options: Dict[str, Any] = None): + for k, v in options.items(): + if k in self.episode_settings: + setattr(self, k, v) + + self._sanity_check() + + # Storing original attributes before sampling random ones + original_settings_values = {} + for k in self.episode_settings: + original_settings_values[k] = getattr(self, k) + + request = self.get_pb_request() + + # Restoring original attributes after random sampling + for k, v in original_settings_values.items(): + setattr(self, k, v) + + return request + + def _sample_characters(self, n_characters=3): + random.shuffle(self.valid_characters) + sampled_characters = [] + for _ in range(n_characters): + for character in self.valid_characters: + valid = True + for sampled_character in sampled_characters: + if sampled_character == character: + valid = False + elif character in self.env_info.characters_info.char_homonymy_map.keys() and \ + sampled_character in self.env_info.characters_info.char_homonymy_map.keys(): + if self.env_info.characters_info.char_homonymy_map[character] == sampled_character: + valid = False + if valid is True: + sampled_characters.append(character) + break + + return sampled_characters + + def _sanity_check(self): + if self.env_info is None or self.games_dict is None: + raise Exception("EnvironmentSettings class not correctly initialized") + # TODO: consider using typing.Literal to specify lists of admissible values: NOTE: It requires Python 3.8+ + check_val_in_list("game_id", self.game_id, list(self.games_dict.keys())) check_num_in_range("step_ratio", self.step_ratio, [1, 6]) - check_num_in_range("rank", self.rank, [0, MAX_VAL]) - check_num_in_range("seed", self.seed, [-1, MAX_VAL]) - check_num_in_range("grpc_timeout", self.grpc_timeout, [0, 3600]) - - check_val_in_list("game_id", self.game_id, self.games_dict.keys()) - check_val_in_list("player", self.player, ["P1", "P2", "Random", "P1P2"]) - check_num_in_range("continue_game", self.continue_game, [MIN_VAL, 1.0]) - check_num_in_range("difficulty", self.difficulty, self.games_dict[self.game_id]["difficulty"][:2]) - check_num_in_range("frame_shape[0]", self.frame_shape[0], [0, MAX_FRAME_RES]) check_num_in_range("frame_shape[1]", self.frame_shape[1], [0, MAX_FRAME_RES]) if (min(self.frame_shape[0], self.frame_shape[1]) == 0 and max(self.frame_shape[0], self.frame_shape[1]) != 0): raise Exception("\"frame_shape[0] and frame_shape[1]\" must be either both different from or equal to 0") - check_val_in_list("frame_shape[2]", self.frame_shape[2], [0, 1]) + if self.render_mode is not None: + check_val_in_list("render_mode", self.render_mode, ["human", "rgb_array"]) + check_num_in_range("rank", self.rank, [0, MAX_VAL]) + check_type("env_address", self.env_address, str) + check_num_in_range("grpc_timeout", self.grpc_timeout, [0, 3600]) - check_num_in_range("tower", self.tower, [1, 4]) - -@dataclass -class EnvironmentSettings1P(EnvironmentSettings): - - # Player level - characters: Union[str, Tuple[str], Tuple[str, str], Tuple[str, str, str]] = ("Random", "Random", "Random") - char_outfits: int = 1 - action_space: str = "multi_discrete" - attack_but_combination: bool = True - - super_art: int = 0 # SFIII Specific - - fighting_style: int = 0 # KOF Specific - ultimate_style: Tuple[int, int, int] = (0, 0, 0) # KOF Specific + if self.seed is not None: + check_num_in_range("seed", self.seed, [-1, MAX_VAL]) + difficulty_admissible_values = list(range(self.env_info.difficulty_bounds.min, self.env_info.difficulty_bounds.max + 1)) + difficulty_admissible_values.append(None) + check_val_in_list("difficulty", self.difficulty, difficulty_admissible_values) + check_num_in_range("continue_game", self.continue_game, [MIN_VAL, 1.0]) + check_type("show_final", self.show_final, bool) + check_val_in_list("tower", self.tower, [None, 1, 2, 3, 4]) - def sanity_check(self): - super().sanity_check() + def _process_random_values(self): + if self.difficulty is None: + self.difficulty = random.choice(list(range(self.env_info.difficulty_bounds.min, self.env_info.difficulty_bounds.max + 1))) + if self.tower is None: + self.tower = random.choice(list(range(1, 5))) - # Check for characters - if isinstance(self.characters, str): - self.characters = (self.characters, "Random", "Random") +@dataclass +class EnvironmentSettings(EnvironmentSettingsBase): + """Single Agent Environment Settings Class""" + # Env settings + n_players: int = 1 + action_space: int = SpaceTypes.MULTI_DISCRETE + + # Episode settings + role: Union[None, int] = None + characters: Union[None, str, Tuple[str], Tuple[str, str], Tuple[str, str, str]] = None + outfits: int = 1 + super_art: Union[None, int] = None # SFIII Specific + fighting_style: Union[None, int] = None # KOF Specific + ultimate_style: Union[None, Tuple[int, int, int]] = None # KOF Specific + + def _sanity_check(self): + super()._sanity_check() + + # Env settings + check_num_in_range("n_players", self.n_players, [1, 1]) + check_space_type("action_space", self.action_space, [SpaceTypes.DISCRETE, SpaceTypes.MULTI_DISCRETE]) + + # Episode settings + if self.role is not None: + check_roles("role", self.role, [Roles.P1, Roles.P2]) + if isinstance(self.characters, str) or self.characters is None: + self.characters = (self.characters, None, None) else: - for idx in range(len(self.characters), 3): - self.characters += ("Random", ) - - check_num_in_range("char_outfits", self.char_outfits, self.games_dict[self.game_id]["outfits"]) + for _ in range(len(self.characters), 3): + self.characters += (None, ) + char_list = list(self.env_info.characters_info.char_list) + char_list.append(None) for idx in range(3): - check_val_in_list("characters[{}]".format(idx), self.characters[idx], - np.append(self.games_dict[self.game_id]["char_list"], "Random")) - check_val_in_list("action_space", self.action_space, ["discrete", "multi_discrete"]) + check_val_in_list("characters[{}]".format(idx), self.characters[idx], char_list) + check_num_in_range("outfits", self.outfits, self.games_dict[self.game_id]["outfits"]) + check_val_in_list("super_art", self.super_art, [None, 1, 2, 3]) + check_val_in_list("fighting_style", self.fighting_style, [None, 1, 2, 3]) + if self.ultimate_style is not None: + for idx in range(3): + check_val_in_list("ultimate_style[{}]".format(idx), self.ultimate_style[idx], [1, 2]) - check_num_in_range("super_art", self.super_art, [0, 3]) + def _get_action_spaces(self): + return [self.action_space] - check_num_in_range("fighting_style", self.fighting_style, [0, 3]) + def _process_random_values(self): + super()._process_random_values() + + sampled_characters = self._sample_characters() + characters_tmp = [] for idx in range(3): - check_num_in_range("ultimate_style[{}]".format(idx), self.ultimate_style[idx], [0, 2]) + if self.characters[idx] is None: + characters_tmp.append(sampled_characters[idx]) + else: + characters_tmp.append(self.characters[idx]) + self.characters = tuple(characters_tmp) + + if self.role is None: + self.role = random.choice([Roles.P1, Roles.P2]) + if self.super_art is None: + self.super_art = random.choice(list(range(1, 4))) + if self.fighting_style is None: + self.fighting_style = random.choice(list(range(1, 4))) + if self.ultimate_style is None: + self.ultimate_style = tuple([random.choice(list(range(1, 3))) for _ in range(3)]) + + def _get_player_specific_values(self): + player_settings = model.EnvSettings.EpisodeSettings.PlayerSettings( + role=self.role, + characters=[self.characters[idx] for idx in range(self.env_info.characters_info.chars_to_select)], + outfits=self.outfits, + super_art=self.super_art, + fighting_style=self.fighting_style, + ultimate_style={"dash": self.ultimate_style[0], "evade": self.ultimate_style[1], "bar": self.ultimate_style[2]} + ) + + return [player_settings] @dataclass -class EnvironmentSettings2P(EnvironmentSettings): - - # Player level - characters: Union[Tuple[str, str], Tuple[Tuple[str], Tuple[str]], - Tuple[Tuple[str, str], Tuple[str, str]], - Tuple[Tuple[str, str, str], Tuple[str, str, str]]] =\ - (("Random", "Random", "Random"), ("Random", "Random", "Random")) - char_outfits: Tuple[int, int] = (1, 1) - action_space: Tuple[str, str] = ("multi_discrete", "multi_discrete") - attack_but_combination: Tuple[bool, bool] = (True, True) - - super_art: Tuple[int, int] = (0, 0) # SFIII Specific - - fighting_style: Tuple[int, int] = (0, 0) # KOF Specific - ultimate_style: Tuple[Tuple[int, int, int], Tuple[int, int, int]] = ((0, 0, 0), (0, 0, 0)) # KOF Specific - - def sanity_check(self): - super().sanity_check() - - # Check for characters - if isinstance(self.characters[0], str): - self.characters = ((self.characters[0], "Random", "Random"), - (self.characters[1], "Random", "Random")) +class EnvironmentSettingsMultiAgent(EnvironmentSettingsBase): + """Multi Agent Environment Settings Class""" + # Env Settings + n_players: int = 2 + action_space: Tuple[int, int] = (SpaceTypes.MULTI_DISCRETE, SpaceTypes.MULTI_DISCRETE) + + # Episode Settings + role: Union[Tuple[None, None], Tuple[int, int]] = (None, None) + characters: Union[Tuple[None, None], Tuple[str, None], Tuple[None, str], Tuple[str, str], + Tuple[Tuple[str], Tuple[str]], Tuple[Tuple[str, str], Tuple[str, str]], + Tuple[Tuple[str, str, str], Tuple[str, str, str]]] = (None, None) + outfits: Tuple[int, int] = (1, 1) + super_art: Union[Tuple[None, None], Tuple[int, int]] = (None, None) # SFIII Specific + fighting_style: Union[Tuple[None, None], Tuple[int, int]] = (None, None) # KOF Specific + ultimate_style: Union[Tuple[None, None], Tuple[Tuple[int, int, int], Tuple[int, int, int]]] = (None, None) # KOF Specific + + def _sanity_check(self): + super()._sanity_check() + + # Env Settings + check_num_in_range("n_players", self.n_players, [2, 2]) + for idx in range(2): + check_space_type("action_space[{}]".format(idx), self.action_space[idx], [SpaceTypes.DISCRETE, SpaceTypes.MULTI_DISCRETE]) + + # Episode Settings + if isinstance(self.characters[0], str) or self.characters[0] is None: + self.characters = ((self.characters[0], None, None), (self.characters[1], None, None)) else: tmp_chars = [self.characters[0], self.characters[1]] - for idx in range(len(self.characters[0]), 3): + for _ in range(len(self.characters[0]), 3): for jdx in range(2): - tmp_chars[jdx] += ("Random", ) + tmp_chars[jdx] += (None, ) self.characters = tuple(tmp_chars) + char_list = list(self.env_info.characters_info.char_list) + char_list.append(None) + for idx in range(2): + if self.role[idx] is not None: + check_roles("role[{}]".format(idx), self.role[idx], [Roles.P1, Roles.P2]) + for jdx in range(3): + check_val_in_list("characters[{}][{}]".format(idx, jdx), self.characters[idx][jdx], char_list) + check_num_in_range("outfits[{}]".format(idx), self.outfits[idx], self.games_dict[self.game_id]["outfits"]) + check_val_in_list("super_art[{}]".format(idx), self.super_art[idx], [None, 1, 2, 3]) + check_val_in_list("fighting_style[{}]".format(idx), self.fighting_style[idx], [None, 1, 2, 3]) + if self.ultimate_style[idx] is not None: + for jdx in range(3): + check_val_in_list("ultimate_style[{}][{}]".format(idx, jdx), self.ultimate_style[idx][jdx], [1, 2]) + + def _process_random_values(self): + super()._process_random_values() + + characters_tmp = [[],[]] + + for idx, characters in enumerate(self.characters): + sampled_characters = self._sample_characters() + for jdx in range(3): + if characters[jdx] is None: + characters_tmp[idx].append(sampled_characters[jdx]) + else: + characters_tmp[idx].append(characters[jdx]) + + self.characters = (tuple(characters_tmp[0]), tuple(characters_tmp[1])) + + if self.role[0] is None: + if self.role[1] is None: + coin = random.choice([True, False]) + self.role = (Roles.P1, Roles.P2) if coin is True else (Roles.P2, Roles.P1) + else: + self.role = (Roles.P1 if self.role[1] == Roles.P2 else Roles.P2, self.role[1]) + else: + if self.role[1] is None: + self.role = (self.role[0], Roles.P1 if self.role[0] == Roles.P2 else Roles.P2) - for jdx in range(2): - check_num_in_range("char_outfits[{}]".format(jdx), self.char_outfits[jdx], - self.games_dict[self.game_id]["outfits"]) - for idx in range(3): - check_val_in_list("characters[{}][{}]".format(jdx, idx), self.characters[jdx][idx], - np.append(self.games_dict[self.game_id]["char_list"], "Random")) - check_val_in_list("action_space[{}]".format(jdx), self.action_space[jdx], ["discrete", "multi_discrete"]) + self.super_art = tuple([random.choice(list(range(1, 4))) if self.super_art[idx] is None else self.super_art[idx] for idx in range(2)]) + self.fighting_style = tuple([random.choice(list(range(1, 4))) if self.fighting_style[idx] is None else self.fighting_style[idx] for idx in range(2)]) + self.ultimate_style = tuple([[random.choice(list(range(1, 3))) for _ in range(3)] if self.ultimate_style[idx] is None else self.ultimate_style[idx] for idx in range(2)]) - check_num_in_range("super_art[{}]".format(jdx), self.super_art[jdx], [0, 3]) + def _get_action_spaces(self): + return [action_space for action_space in self.action_space] - check_num_in_range("fighting_style[{}]".format(jdx), self.fighting_style[jdx], [0, 3]) - for idx in range(3): - check_num_in_range("ultimate_style[{}][{}]".format(jdx, idx), self.ultimate_style[jdx][idx], [0, 2]) + def _get_player_specific_values(self): + players_env_settings = [] + + for idx in range(2): + player_settings = model.EnvSettings.EpisodeSettings.PlayerSettings( + role=self.role[idx], + characters=[self.characters[idx][jdx] for jdx in range(self.env_info.characters_info.chars_to_select)], + outfits=self.outfits[idx], + super_art=self.super_art[idx], + fighting_style=self.fighting_style[idx], + ultimate_style={"dash": self.ultimate_style[idx][0], "evade": self.ultimate_style[idx][1], "bar": self.ultimate_style[idx][2]} + ) + + players_env_settings.append(player_settings) + + return players_env_settings @dataclass class WrappersSettings: - no_op_max: int = 0 - sticky_actions: int = 1 - clip_rewards: bool = False - reward_normalization: bool = False - reward_normalization_factor: float = 0.5 - frame_stack: int = 1 - actions_stack: int = 1 + repeat_action: int = 1 + normalize_reward: bool = False + normalization_factor: float = 0.5 + clip_reward: bool = False + no_attack_buttons_combinations: bool = False + frame_shape: Tuple[int, int, int] = (0, 0, 0) + stack_frames: int = 1 + dilation: int = 1 + add_last_action: bool = False + stack_actions: int = 1 scale: bool = False exclude_image_scaling: bool = False process_discrete_binary: bool = False - scale_mod: int = 0 - hwc_obs_resize: Tuple[int, int, int] = (84, 84, 0) - dilation: int = 1 + role_relative: bool = False flatten: bool = False - filter_keys: List[str] = None + filter_keys: List[str] = field(default_factory=list) + wrappers: List[List[Any]] = field(default_factory=list) def sanity_check(self): - - no_op_max: int = 0 - sticky_actions: int = 1 - reward_normalization_factor: float = 0.5 - frame_stack: int = 1 - actions_stack: int = 1 - scale_mod: int = 0 - hwc_obs_resize: Tuple[int, int, int] = (84, 84, 0) - dilation: int = 1 - check_num_in_range("no_op_max", self.no_op_max, [0, 12]) - check_num_in_range("sticky_actions", self.sticky_actions, [1, 12]) - check_num_in_range("frame_stack", self.frame_stack, [1, MAX_STACK_VALUE]) - check_num_in_range("actions_stack", self.actions_stack, [1, MAX_STACK_VALUE]) + check_num_in_range("repeat_action", self.repeat_action, [1, 12]) + check_type("normalize_reward", self.normalize_reward, bool, admit_none=False) + check_num_in_range("normalization_factor", self.normalization_factor, [0.0, 1000000]) + check_type("clip_reward", self.clip_reward, bool, admit_none=False) + check_type("no_attack_buttons_combinations", self.no_attack_buttons_combinations, bool, admit_none=False) + check_val_in_list("frame_shape[2]", self.frame_shape[2], [0, 1, 3]) + check_num_in_range("frame_shape[0]", self.frame_shape[0], [0, MAX_FRAME_RES]) + check_num_in_range("frame_shape[1]", self.frame_shape[1], [0, MAX_FRAME_RES]) + if (min(self.frame_shape[0], self.frame_shape[1]) == 0 and + max(self.frame_shape[0], self.frame_shape[1]) != 0): + raise Exception("\"frame_shape[0] and frame_shape[1]\" must be both different from 0") + check_num_in_range("stack_frames", self.stack_frames, [1, MAX_STACK_VALUE]) check_num_in_range("dilation", self.dilation, [1, MAX_STACK_VALUE]) - check_num_in_range("scale_mod", self.scale_mod, [0, 0]) - - check_val_in_list("hwc_obs_resize[2]", self.hwc_obs_resize[2], [0, 1, 3]) - if self.hwc_obs_resize[2] != 0: - check_num_in_range("hwc_obs_resize[0]", self.hwc_obs_resize[0], [1, MAX_FRAME_RES]) - check_num_in_range("hwc_obs_resize[1]", self.hwc_obs_resize[1], [1, MAX_FRAME_RES]) - if (min(self.hwc_obs_resize[0], self.hwc_obs_resize[1]) == 0 and - max(self.hwc_obs_resize[0], self.hwc_obs_resize[1]) != 0): - raise Exception("\"hwc_obs_resize[0] and hwc_obs_resize[1]\" must be both different from 0") - + check_type("add_last_action", self.add_last_action, bool, admit_none=False) + stack_actions_bounds = [1, 1] + if self.add_last_action is True: + stack_actions_bounds = [1, MAX_STACK_VALUE] + check_num_in_range("stack_actions", self.stack_actions, stack_actions_bounds) + check_type("scale", self.scale, bool, admit_none=False) + check_type("exclude_image_scaling", self.exclude_image_scaling, bool, admit_none=False) + check_type("process_discrete_binary", self.process_discrete_binary, bool, admit_none=False) + check_type("role_relative", self.role_relative, bool, admit_none=False) + check_type("flatten", self.flatten, bool, admit_none=False) + check_type("filter_keys", self.filter_keys, list, admit_none=False) + check_type("wrappers", self.wrappers, list, admit_none=False) @dataclass class RecordingSettings: + dataset_path: Union[None, str] = None + username: Union[None, str] = None - file_path: str - username: str = "username" - ignore_p2: bool = False + def sanity_check(self): + check_type("dataset_path", self.dataset_path, str) + check_type("username", self.username, str) diff --git a/diambra/arena/make_env.py b/diambra/arena/make_env.py index 0e8f6e9d..d0b0778e 100644 --- a/diambra/arena/make_env.py +++ b/diambra/arena/make_env.py @@ -1,12 +1,14 @@ import os import logging -from dacite import from_dict -from .arena_gym import DiambraGymHardcore1P, DiambraGym1P, DiambraGymHardcore2P, DiambraGym2P -from .wrappers.arena_wrappers import env_wrapping -from .env_settings import EnvironmentSettings1P, EnvironmentSettings2P, WrappersSettings, RecordingSettings - -def make(game_id, env_settings={}, wrappers_settings={}, - traj_rec_settings={}, seed=None, rank=0, log_level=logging.INFO): +from diambra.arena.arena_gym import DiambraGym1P, DiambraGym2P +from diambra.arena.wrappers.arena_wrappers import env_wrapping +from diambra.arena import EnvironmentSettings, EnvironmentSettingsMultiAgent, WrappersSettings, RecordingSettings +from diambra.arena.wrappers.episode_recording import EpisodeRecorder +from typing import Union + +def make(game_id, env_settings: Union[EnvironmentSettings, EnvironmentSettingsMultiAgent]=EnvironmentSettings(), + wrappers_settings: WrappersSettings=WrappersSettings(), episode_recording_settings: RecordingSettings=RecordingSettings(), + render_mode: str=None, rank: int=0, log_level=logging.INFO): """ Create a wrapped environment. :param seed: (int) the initial seed for RNG @@ -17,8 +19,9 @@ def make(game_id, env_settings={}, wrappers_settings={}, logging.basicConfig(level=log_level) logger = logging.getLogger(__name__) - # Include game_id in env_settings - env_settings["game_id"] = game_id + # Include game_id and render_mode in env_settings + env_settings.game_id = game_id + env_settings.render_mode = render_mode # Check if DIAMBRA_ENVS var present env_addresses = os.getenv("DIAMBRA_ENVS", "").split() @@ -29,48 +32,27 @@ def make(game_id, env_settings={}, wrappers_settings={}, "# of env servers: {}".format(len(env_addresses)), "# rank of client: {} (0-based index)".format(rank)) else: # If not present, set default value - if "env_address" not in env_settings: + if env_settings.env_address is None: env_addresses = ["localhost:50051"] else: - env_addresses = [env_settings["env_address"]] - - env_settings["env_address"] = env_addresses[rank] - env_settings["rank"] = rank - if seed is not None: - env_settings["seed"] = seed + env_addresses = [env_settings.env_address] - # Checking settings and setting up default ones - if "player" in env_settings.keys() and env_settings["player"] == "P1P2": - env_settings = from_dict(EnvironmentSettings2P, env_settings) - else: - env_settings = from_dict(EnvironmentSettings1P, env_settings) - env_settings.sanity_check() + env_settings.env_address = env_addresses[rank] + env_settings.rank = rank # Make environment - if env_settings.player != "P1P2": # 1P Mode - if env_settings.hardcore is True: - env = DiambraGymHardcore1P(env_settings) - else: - env = DiambraGym1P(env_settings) + if env_settings.n_players == 1: # 1P Mode + env = DiambraGym1P(env_settings) else: # 2P Mode - if env_settings.hardcore is True: - env = DiambraGymHardcore2P(env_settings) - else: - env = DiambraGym2P(env_settings) + env = DiambraGym2P(env_settings) + + # Apply episode recorder wrapper + if episode_recording_settings.dataset_path is not None: + episode_recording_settings.sanity_check() + env = EpisodeRecorder(env, episode_recording_settings) # Apply environment wrappers - wrappers_settings = from_dict(WrappersSettings, wrappers_settings) wrappers_settings.sanity_check() - env = env_wrapping(env, wrappers_settings, hardcore=env_settings.hardcore) - - # Apply trajectories recorder wrappers - if len(traj_rec_settings) != 0: - traj_rec_settings = from_dict(RecordingSettings, traj_rec_settings) - if env_settings.hardcore is True: - from diambra.arena.wrappers.traj_rec_wrapper_hardcore import TrajectoryRecorder - else: - from diambra.arena.wrappers.traj_rec_wrapper import TrajectoryRecorder - - env = TrajectoryRecorder(env, traj_rec_settings) + env = env_wrapping(env, wrappers_settings) return env diff --git a/diambra/arena/ray_rllib/make_ray_env.py b/diambra/arena/ray_rllib/make_ray_env.py index cba4ada1..f77cfd3e 100644 --- a/diambra/arena/ray_rllib/make_ray_env.py +++ b/diambra/arena/ray_rllib/make_ray_env.py @@ -1,15 +1,14 @@ import os import diambra.arena +from diambra.arena import EnvironmentSettings, WrappersSettings import logging -import gym +import gymnasium as gym from ray.rllib.env.env_context import EnvContext from copy import deepcopy import pickle class DiambraArena(gym.Env): - def __init__(self, config: EnvContext): - self.logger = logging.getLogger(__name__) # If to load environment spaces from a file @@ -31,16 +30,15 @@ def __init__(self, config: EnvContext): self.env_spaces_file_name = config["env_spaces_file_name"] if self.load_spaces_from_file is False: - if "is_rollout" not in config.keys(): message = "Environment initialized without a preprocessed config file." message += " Make sure to call \"preprocess_ray_config\" before initializing Ray RL Algorithms." raise Exception(message) self.game_id = config["game_id"] - self.settings = config["settings"] if "settings" in config.keys() else {} - self.wrappers_settings = config["wrappers_settings"] if "wrappers_settings" in config.keys() else {} - self.seed = config["seed"] if "seed" in config.keys() else 0 + self.settings = config["settings"] if "settings" in config.keys() else EnvironmentSettings() + self.wrappers_settings = config["wrappers_settings"] if "wrappers_settings" in config.keys() else WrappersSettings() + self.render_mode = config["render_mode"] if "render_mode" in config.keys() else None num_rollout_workers = config["num_workers"] num_eval_workers = config["evaluation_num_workers"] @@ -68,8 +66,7 @@ def __init__(self, config: EnvContext): self.logger.debug("Rank: {}".format(self.rank)) - self.env = diambra.arena.make(self.game_id, self.settings, self.wrappers_settings, - seed=self.seed + self.rank, rank=self.rank) + self.env = diambra.arena.make(self.game_id, self.settings, self.wrappers_settings, render_mode=self.render_mode, rank=self.rank) env_spaces_dict = {} env_spaces_dict["action_space"] = self.env.action_space @@ -81,7 +78,6 @@ def __init__(self, config: EnvContext): env_spaces_file = open(self.env_spaces_file_name, "wb") pickle.dump(env_spaces_dict, env_spaces_file) env_spaces_file.close() - else: print("Loading environment spaces from: {}".format(self.env_spaces_file_name)) self.logger.info("Loading environment spaces from: {}".format(self.env_spaces_file_name)) @@ -93,8 +89,11 @@ def __init__(self, config: EnvContext): self.action_space = env_spaces_dict["action_space"] self.observation_space = env_spaces_dict["observation_space"] - def reset(self): - return self.env.reset() + def reset(self, seed=None, options=None): + if self.load_spaces_from_file is True: + return self.observation_space.sample(), {} + else: + return self.env.reset(seed=seed, options=options) def step(self, action): return self.env.step(action) @@ -103,7 +102,6 @@ def render(self): return self.env.render() def preprocess_ray_config(config): - logger = logging.getLogger(__name__) num_envs_required = 0 diff --git a/diambra/arena/stable_baselines/make_sb_env.py b/diambra/arena/stable_baselines/make_sb_env.py index 87e84ec2..9b26b847 100644 --- a/diambra/arena/stable_baselines/make_sb_env.py +++ b/diambra/arena/stable_baselines/make_sb_env.py @@ -1,86 +1,93 @@ import os +import time import diambra.arena -from .wrappers.add_obs_wrap import AdditionalObsToChannel -from .wrappers.p2_wrap import SelfPlayVsRL, VsHum, IntegratedSelfPlay +from diambra.arena import SpaceTypes, EnvironmentSettings, WrappersSettings, RecordingSettings +import gym from stable_baselines import logger from stable_baselines.bench import Monitor -from stable_baselines.common.misc_util import set_global_seeds from stable_baselines.common.vec_env import DummyVecEnv, SubprocVecEnv +from stable_baselines.common import set_global_seeds - -def make_sb_env(seed: int, env_settings: dict, wrappers_settings: dict={}, - traj_rec_settings: dict={}, custom_wrappers: list=None, - key_to_add: list=None, p2_mode: str=None, p2_policy=None, - start_index: int=0, allow_early_resets: bool=True, - start_method: str=None, no_vec: bool=False, - use_subprocess: bool=False): +# Make Stable Baselines Env function +def make_sb_env(game_id: str, env_settings: EnvironmentSettings=EnvironmentSettings(), + wrappers_settings: WrappersSettings=WrappersSettings(), + episode_recording_settings: RecordingSettings=RecordingSettings(), + render_mode: str="rgb_array", seed: int=None, start_index: int=0, + allow_early_resets: bool=True, start_method: str=None, + no_vec: bool=False, use_subprocess: bool=False): """ Create a wrapped, monitored VecEnv. - :param seed: (int) initial seed for RNG - :param env_settings: (dict) parameters for DIAMBRA environment - :param wrappers_settings: (dict) parameters for environment - wraping function - :param traj_rec_settings: (dict) parameters for environment recording - wraping function - :param key_to_add: (list) ordered parameters for environment stable - baselines converter wraping function + :param game_id: (str) the game environment ID + :param env_settings: (dict) parameters for DIAMBRA Arena environment + :param wrappers_settings: (dict) parameters for environment wrapping function + :param episode_recording_settings: (dict) parameters for environment recording wrapping function :param start_index: (int) start rank index :param allow_early_resets: (bool) allows early reset of the environment - :param start_method: (str) method used to start the subprocesses. - See SubprocVecEnv doc for more information - :param use_subprocess: (bool) Whether to use `SubprocVecEnv` or - `DummyVecEnv` when - :param no_vec: (bool) Whether to avoid usage of Vectorized Env or not. - Default: False + :param start_method: (str) method used to start the subprocesses. See SubprocVecEnv doc for more information + :param use_subprocess: (bool) Whether to use `SubprocVecEnv` or `DummyVecEnv` + :param no_vec: (bool) Whether to avoid usage of Vectorized Env or not. Default: False :return: (VecEnv) The diambra environment """ env_addresses = os.getenv("DIAMBRA_ENVS", "").split() if len(env_addresses) == 0: - print("WARNING: running script without diambra CLI, this is a development option only.") - env_addresses = ["0.0.0.0:50051"] + raise Exception("ERROR: Running script without DIAMBRA CLI.") num_envs = len(env_addresses) - hardcore = False - if "hardcore" in env_settings: - hardcore = env_settings["hardcore"] - - def _make_sb_env(rank): - def _thunk(): - env = diambra.arena.make(env_settings["game_id"], env_settings, - wrappers_settings, traj_rec_settings, - seed=seed + rank, rank=rank) - if not hardcore: + # Seed management + if seed is None: + seed = int(time.time()) + env_settings.seed = seed - # Applying custom wrappers - if custom_wrappers is not None: - for wrap in custom_wrappers: - env = wrap(env) + # Add the conversion from gymnasium to gym + old_gym_wrapper = [OldGymWrapper, {}] + wrappers_settings.wrappers.insert(0, old_gym_wrapper) - env = AdditionalObsToChannel(env, key_to_add) - if p2_mode is not None: - if p2_mode == "integratedSelfPlay": - env = IntegratedSelfPlay(env) - elif p2_mode == "selfPlayVsRL": - env = SelfPlayVsRL(env, p2_policy) - elif p2_mode == "vsHum": - env = VsHum(env, p2_policy) + def _make_sb_env(rank): + def _init(): + env = diambra.arena.make(game_id, env_settings, wrappers_settings, + episode_recording_settings, render_mode, rank=rank) env = Monitor(env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)), allow_early_resets=allow_early_resets) return env - return _thunk - set_global_seeds(seed) + set_global_seeds(seed) + return _init # If not wanting vectorized envs if no_vec and num_envs == 1: - return _make_sb_env(0)(), num_envs + env = _make_sb_env(0)() + else: + # When using one environment, no need to start subprocesses + if num_envs == 1 or not use_subprocess: + env = DummyVecEnv([_make_sb_env(i + start_index) for i in range(num_envs)]) + else: + env = SubprocVecEnv([_make_sb_env(i + start_index) for i in range(num_envs)], + start_method=start_method) + + return env, num_envs + +class OldGymWrapper(gym.Wrapper): + def __init__(self, env): + """ + Convert gymnasium to gym<=0.21 environment + :param env: (Gymnasium Environment) the environment to wrap + :param env: (Gym<=0.21 Environment) the resulting environment + """ + gym.Wrapper.__init__(self, env) + if self.env_settings.action_space == SpaceTypes.MULTI_DISCRETE: + self.action_space = gym.spaces.MultiDiscrete(self.n_actions) + elif self.env_settings.action_space == SpaceTypes.DISCRETE: + self.action_space = gym.spaces.Discrete(self.n_actions[0] + self.n_actions[1] - 1) + self.logger.debug("Using {} action space".format(SpaceTypes.Name(self.env_settings.action_space))) + - # When using one environment, no need to start subprocesses - if num_envs == 1 or not use_subprocess: - return DummyVecEnv([_make_sb_env(i + start_index) for i in range(num_envs)]), num_envs + def reset(self, **kwargs): + obs, _ = self.env.reset(**kwargs) + return obs - return SubprocVecEnv([_make_sb_env(i + start_index) for i in range(num_envs)], - start_method=start_method), num_envs + def step(self, action): + obs, reward, terminated, truncated, info = self.env.step(action) + return obs, reward, terminated or truncated, info diff --git a/diambra/arena/stable_baselines/sb_utils.py b/diambra/arena/stable_baselines/sb_utils.py index bf231831..cb10467c 100644 --- a/diambra/arena/stable_baselines/sb_utils.py +++ b/diambra/arena/stable_baselines/sb_utils.py @@ -1,72 +1,48 @@ from stable_baselines.common.callbacks import BaseCallback import cv2 -import os -import time -import json import numpy as np from pathlib import Path # Visualize Obs content +def show_obs(observation, ram_states_list, n_actions, n_actions_stack, char_list, viz=False, wait_key=1): - -def show_obs(observation, key_to_add, key_to_add_count, actions_stack, - n_actions, wait_key, viz, char_list, hardcore, idx_list): - - if not hardcore: - shp = observation.shape - for idx in idx_list: - add_par = observation[:, :, shp[2] - 1] - add_par = np.reshape(add_par, (-1)) - - counter = 0 + idx * int((shp[0] * shp[1]) / 2) - - print("Additional Par P{} =".format(idx + 1), add_par[counter]) - + shp = observation.shape + ram_states = observation[:, :, shp[2] - 1] + ram_states = np.reshape(ram_states, (-1)) + + counter = 0 + + print("RAM states =", ram_states[counter]) + + counter += 1 + if "action_move" in ram_states_list and "action_attack" in ram_states_list: + n_values = n_actions_stack * (n_actions[0] + n_actions[1]) + var = ram_states[counter:counter + n_values] + counter += n_values + move_actions = var[0:n_actions_stack * n_actions[0]] + attack_actions = var[n_actions_stack * n_actions[0]:n_values] + move_actions = np.reshape(move_actions, (n_actions_stack, -1)) + attack_actions = np.reshape(attack_actions, (n_actions_stack, -1)) + print("Move actions =\n", move_actions) + print("Attack actions =\n ", attack_actions) + ram_states_list = [element for element in ram_states_list if element != "action_move" and element != "action_attack"] + + for ram_state_key in ram_states_list: + if "own_char" in ram_state_key or "opp_char" in ram_state_key: + var = ram_states[counter:counter + len(char_list)] + counter += len(char_list) + print("{} = {}".format(ram_state_key, char_list[list(var).index(1.0)])) + else: + var = ram_states[counter:counter + 1] counter += 1 - - for idk in range(len(key_to_add)): - - var = add_par[counter:counter + key_to_add_count[idk][idx]]\ - if key_to_add_count[idk][idx] > 1 else add_par[counter] - counter += key_to_add_count[idk][idx] - - if "actions" in key_to_add[idk]: - move_actions = var[0:actions_stack * n_actions[idx][0]] - attack_actions = var[actions_stack * n_actions[idx][0]:actions_stack * (n_actions[idx][0] + n_actions[idx][1])] - move_actions = np.reshape(move_actions, (actions_stack, -1)) - attack_actions = np.reshape(attack_actions, (actions_stack, -1)) - print("Move actions P{} =\n".format(idx + 1), move_actions) - print("Attack actions P{} =\n ".format(idx + 1), attack_actions) - elif "ownChar" in key_to_add[idk] or "oppChar" in key_to_add[idk]: - print("{}P{} =".format(key_to_add[idk], idx + 1), char_list[list(var).index(1.0)]) - else: - print("{}P{} =".format(key_to_add[idk], idx + 1), var) - - if viz: - obs = np.array(observation[:, :, 0:shp[2] - 1]).astype(np.float32) - else: - if viz: - obs = np.array(observation).astype(np.float32) + print("{} = {}".format(ram_state_key, var)) if viz: + obs = np.array(observation[:, :, 0:shp[2] - 1]).astype(np.float32) for idx in range(obs.shape[2]): cv2.imshow("image" + str(idx), obs[:, :, idx]) - cv2.wait_key(wait_key) - -# Util to copy P2 additional OBS into P1 position -# on last (add info dedicated) channel - - -def p2_to_p1_add_obs_move(observation): - shp = observation.shape - start_idx = int((shp[0] * shp[1]) / 2) - observation = np.reshape(observation, (-1)) - num_add_par_p2 = int(observation[start_idx]) - add_par_p2 = observation[start_idx:start_idx + num_add_par_p2 + 1] - observation[0:num_add_par_p2 + 1] = add_par_p2 - observation = np.reshape(observation, (shp[0], -1)) - return observation + cv2.waitKey(wait_key) # Linear scheduler for RL agent parameters def linear_schedule(initial_value, final_value=0.0): @@ -115,98 +91,3 @@ def _on_step(self) -> bool: self.model.save(self.save_path_base / (self.filename + str(self.n_calls * self.num_envs))) return True - -# Update p2Brain model Callback - - -class UpdateRLPolicyWeights(BaseCallback): - def __init__(self, check_freq: int, num_envs: int, save_path: str, - prev_agents_sampling={"probability": 0.0, "list": []}, verbose=1): - super(UpdateRLPolicyWeights, self).__init__(verbose) - self.check_freq = int(check_freq / num_envs) - self.num_envs = num_envs - self.save_path = os.path.join(save_path, 'lastModel') - self.sampling_probability = prev_agents_sampling["probability"] - self.prev_agents_list = prev_agents_sampling["list"] - time_dep_seed = int((time.time() - int(time.time() - 0.5)) * 1000) - np.random.seed(time_dep_seed) - - def _on_step(self) -> bool: - if self.n_calls % self.check_freq == 0: - # Selects if using previous agent or the last saved one - if np.random.rand() < self.sampling_probability: - # Sample an old model from the list - if self.verbose > 0: - print("Using an older model") - - # Sample one of the older models - idx = int(np.random.rand() * len(self.prev_agents_list)) - weights_paths_sampled = self.prev_agents_list[idx] - - # Load new weights - self.training_env.env_method("update_p2_policy_weights", - weights_path=weights_paths_sampled) - else: - # Use the last saved model - if self.verbose > 0: - print("Using last saved model") - - if self.verbose > 0: - print("Saving latest model to {}".format(self.save_path)) - - # Save the agent - self.model.save(self.save_path) - - # Load new weights - self.training_env.env_method("update_p2_policy_weights", - weights_path=self.save_path) - - return True - -# Model CFG save - - -def model_cfg_save(model_path, name, n_actions, char_list, - settings, wrappers_settings, key_to_add, params): - data = {} - _, model_name = os.path.split(model_path) - data["agentModel"] = model_name + ".zip" - data["name"] = name - data["n_actions"] = n_actions - data["char_list"] = char_list - data["settings"] = settings - data["wrappers_settings"] = wrappers_settings - data["key_to_add"] = key_to_add - data["params"] = params - - with open(model_path + ".json", 'w') as outfile: - json.dump(data, outfile, indent=4) - - -def key_to_add_count_calc(key_to_add, n_actions, n_actions_stack, char_list): - - key_to_add_count = [] - - for key in key_to_add: - if "actions" in key: - key_to_add_count.append([n_actions_stack * (n_actions[0] + n_actions[1])]) - elif "Char" in key: - key_to_add_count.append([len(char_list)]) - else: - key_to_add_count.append([1]) - - return key_to_add_count - -# Abort training when run out of recorded trajectories for imitation learning - - -class ImitationLearningExhaustedExamples(BaseCallback): - """ - Callback for aborting training when run out of Imitation Learning examples - """ - def __init__(self): - super(ImitationLearningExhaustedExamples, self).__init__() - - def _on_step(self) -> bool: - - return np.any(self.env.get_attr("exhausted")) diff --git a/diambra/arena/stable_baselines/wrappers/__init__.py b/diambra/arena/stable_baselines/wrappers/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/diambra/arena/stable_baselines/wrappers/add_obs_wrap.py b/diambra/arena/stable_baselines/wrappers/add_obs_wrap.py deleted file mode 100644 index 14c23733..00000000 --- a/diambra/arena/stable_baselines/wrappers/add_obs_wrap.py +++ /dev/null @@ -1,153 +0,0 @@ -import gym -from gym import spaces -import numpy as np - -# KeysToDict from KeysToAdd - - -def keys_to_dict_calc(key_to_add, observation_space, player_to_skip="P2"): - keys_to_dict = {} - for key in key_to_add: - elem_to_add = [] - # Loop among all spaces - for k in observation_space.spaces: - # Skip frame and consider only a single player - if k == "frame" or k == player_to_skip: - continue - if isinstance(observation_space[k], gym.spaces.dict.Dict): - for l in observation_space.spaces[k].spaces: - if isinstance(observation_space[k][l], gym.spaces.dict.Dict): - if key == l: - elem_to_add.append("Px") - elem_to_add.append(l) - keys_to_dict[key] = elem_to_add - else: - if key == l: - elem_to_add.append("Px") - elem_to_add.append(l) - keys_to_dict[key] = elem_to_add - else: - if key == k: - elem_to_add.append(k) - keys_to_dict[key] = elem_to_add - - return keys_to_dict - -# Positioning element on last frame channel - - -def add_keys(counter, key_to_add, keys_to_dict, obs, new_data, player_id): - - data_pos = counter - - for key in key_to_add: - tmp_list = keys_to_dict[key] - if tmp_list[0] == "Px": - val = obs["P{}".format(player_id+1)] - - for idx in range(len(tmp_list)-1): - - if tmp_list[idx+1] == "actions": - val = np.concatenate((val["actions"]["move"], val["actions"]["attack"])) - else: - val = val[tmp_list[idx+1]] - - if isinstance(val, (float, int)) or val.size == 1: - val = [val] - else: - val = [obs[tmp_list[0]]] - - for elem in val: - counter = counter + 1 - new_data[counter] = elem - - new_data[data_pos] = counter - data_pos - - return counter - -# Observation modification (adding one channel to store additional info) - - -def process_obs(obs, dtype, box_high_bound, player_side, key_to_add, - keys_to_dict, imitation_learning=False): - - # Adding a channel to the standard image, it will be in last position and - # it will store additional obs - shp = obs["frame"].shape - obs_new = np.zeros((shp[0], shp[1], shp[2]+1), dtype=dtype) - - # Storing standard image in the first channel leaving the last one for - # additional obs - obs_new[:, :, 0:shp[2]] = obs["frame"] - - # Adding new info to the additional channel, on a very - # long line and then reshaping into the obs dim - new_data = np.zeros((shp[0] * shp[1])) - - # Adding new info for 1P - counter = 0 - add_keys(counter, key_to_add, keys_to_dict, obs, new_data, player_id=0) - - # Adding new info for P2 in 2P games - if player_side == "P1P2" and not imitation_learning: - counter = int((shp[0] * shp[1]) / 2) - add_keys(counter, key_to_add, keys_to_dict, obs, new_data, player_id=1) - - new_data = np.reshape(new_data, (shp[0], -1)) - - new_data = new_data * box_high_bound - - obs_new[:, :, shp[2]] = new_data - - return obs_new - -# Convert additional obs to fifth observation channel for stable baselines - - -class AdditionalObsToChannel(gym.ObservationWrapper): - def __init__(self, env, key_to_add, imitation_learning=False): - """ - Add to observations additional info - :param env: (Gym Environment) the environment to wrap - :param key_to_add: (list) ordered parameters for additional Obs - """ - gym.ObservationWrapper.__init__(self, env) - shp = self.env.observation_space["frame"].shape - self.key_to_add = key_to_add - self.imitation_learning = imitation_learning - - self.box_high_bound = self.env.observation_space["frame"].high.max() - self.box_low_bound = self.env.observation_space["frame"].low.min() - assert (self.box_high_bound == 1.0 or self.box_high_bound == 255),\ - "Observation space max bound must be either 1.0 or 255 to use Additional Obs" - assert (self.box_low_bound == 0.0 or self.box_low_bound == -1.0),\ - "Observation space min bound must be either 0.0 or -1.0 to use Additional Obs" - - # Build key_to_add - Observation Space dict connectivity - self.keys_to_dict = keys_to_dict_calc(self.key_to_add, self.env.observation_space) - - self.old_obs_space = self.observation_space - self.observation_space = spaces.Box(low=self.box_low_bound, high=self.box_high_bound, - shape=(shp[0], shp[1], shp[2] + 1), - dtype=np.float32) - self.shp = self.observation_space.shape - - # Return key_to_add count - self.key_to_add_count = [] - for key in self.key_to_add: - p1Val = add_keys(0, [key], self.keys_to_dict, self.old_obs_space.sample(), - np.zeros((shp[0] * shp[1])), 0) - if self.env.player_side == "P1P2": - p2Val = add_keys(0, [key], self.keys_to_dict, self.old_obs_space.sample(), - np.zeros((shp[0] * shp[1])), 1) - self.key_to_add_count.append([p1Val, p2Val]) - else: - self.key_to_add_count.append([p1Val]) - - # Process observation - def observation(self, obs): - - return process_obs(obs, self.observation_space.dtype, - self.box_high_bound, self.env.player_side, - self.key_to_add, self.keys_to_dict, - self.imitation_learning) diff --git a/diambra/arena/stable_baselines/wrappers/p2_wrap.py b/diambra/arena/stable_baselines/wrappers/p2_wrap.py deleted file mode 100644 index 961cf1bb..00000000 --- a/diambra/arena/stable_baselines/wrappers/p2_wrap.py +++ /dev/null @@ -1,89 +0,0 @@ -from ..sb_utils import p2_to_p1_add_obs_move -import gym -import numpy as np - -# Gym Env wrapper for two players mode to be used in integrated Self Play - - -class IntegratedSelfPlay(gym.Wrapper): - def __init__(self, env): - - gym.Wrapper.__init__(self, env) - - # Modify action space - assert self.action_space["P1"] == self.action_space["P2"],\ - "P1 and P2 action spaces are supposed to be identical: {} {}"\ - .format(self.action_space["P1"], self.action_space["P2"]) - self.action_space = self.action_space["P1"] - -# Gym Env wrapper for two players mode with RL algo on P2 - - -class SelfPlayVsRL(gym.Wrapper): - def __init__(self, env, p2_policy): - - gym.Wrapper.__init__(self, env) - - # Modify action space - self.action_space = self.action_space["P1"] - - # P2 action logic - self.p2_policy = p2_policy - - # Save last Observation - def update_last_obs(self, obs): - self.lastObs = obs - - # Update p2_policy RL policy weights - def update_p2_policy_weights(self, weights_path): - self.p2_policy.update_weights(weights_path) - - # Step the environment - def step(self, action): - - # Observation modification and P2 actions selected by the model - self.lastObs[:, :, -1] = p2_to_p1_add_obs_move(self.lastObs[:, :, -1]) - p2_policy_actions, _ = self.p2_policy.act(self.lastObs) - - obs, reward, done, info = self.env.step(np.hstack((action, p2_policy_actions))) - self.update_last_obs(obs) - - return obs, reward, done, info - - # Reset the environment - def reset(self): - - obs = self.env.reset() - self.update_last_obs(obs) - - return obs - -# Gym Env wrapper for two players mode with HUM+Gamepad on P2 - - -class VsHum(gym.Wrapper): - def __init__(self, env, p2_policy): - - gym.Wrapper.__init__(self, env) - - # Modify action space - self.action_space = self.action_space["P1"] - - # P2 action logic - self.p2_policy = p2_policy - - # If p2 action logic is gamepad, add it to self.gamepads (for char selection) - # Check action space is prescribed as "multi_discrete" - self.p2_policy.initialize(self.env.actionList()) - if self.actionsSpace[1] != "multi_discrete": - raise Exception("Action Space for P2 must be \"multi_discrete\" when using gamePad") - if not self.attackButCombination[1]: - raise Exception("Use attack buttons combinations for P2 must be \"True\" when using gamePad") - - # Step the environment - def step(self, action): - - # P2 actions selected by the Gamepad - p2_policy_actions, _ = self.p2_policy.act() - - return self.env.step(np.hstack((action, p2_policy_actions))) diff --git a/diambra/arena/stable_baselines/wrappers/tektag_rew_wrap.py b/diambra/arena/stable_baselines/wrappers/tektag_rew_wrap.py deleted file mode 100644 index 58243dea..00000000 --- a/diambra/arena/stable_baselines/wrappers/tektag_rew_wrap.py +++ /dev/null @@ -1,97 +0,0 @@ -import gym - -# Gym Env wrapper to penalize for char 2 health at round end - - -class TektagRoundEndChar2Penalty(gym.Wrapper): - def __init__(self, env): - - gym.Wrapper.__init__(self, env) - - # Check ownHealth2 is available - assert (("Health1P1" in self.add_obs.keys()) and - ("Health1P2" in self.add_obs.keys()) and - ("Health2P1" in self.add_obs.keys()) and - ("Health2P2" in self.add_obs.keys())),\ - "Both first and second char healths, for both P1 and P2, must be present in add_obs" +\ - " to use tektagRoundEndChar2Penalty wrapper {}".format(self.add_obs.keys()) - - # Check single player mode is on - assert (isinstance(self.action_space, gym.spaces.MultiDiscrete) or - isinstance(self.action_space, gym.spaces.Discrete)),\ - "Only single player environment are supported by" +\ - " tektagRoundEndChar2Penalty wrapper, {}".format(type(self.action_space)) - - print("Applying Background Char Health Penalty at Round End Wrapper") - - # Step the environment - def step(self, action): - - obs, reward, done, info = self.env.step(action) - - # When round ends - if info["round_done"] is True: - # If round lost - if reward < 0.0: - # Add penalty for background character health bar - # print("Applying end round penalty: original reward = {},"\ - # " reward with penalty = {}".format(round(reward, 2), round(- 2.0*self.oldHealths, 2))) - reward = - 2.0*self.oldHealths - - self.oldHealths = obs["P1"]["ownHealth1"] + obs["P1"]["ownHealth2"] - - return obs, reward, done, info - - # Reset the environment - def reset(self): - - obs = self.env.reset() - - # Variable to store previous step healths - self.oldHealths = obs["P1"]["ownHealth1"] + obs["P1"]["ownHealth2"] - - return obs - -# Gym Env wrapper to penalize when background char has health bar a lot higher -# than foreground char - - -class TektagHealthBarUnbalancePenalty(gym.Wrapper): - def __init__(self, env, unbalance_thresh=0.75): - - gym.Wrapper.__init__(self, env) - - # Check ownHealth2 is available - assert (("Health1P1" in self.add_obs.keys()) and - ("Health1P2" in self.add_obs.keys()) and - ("Health2P1" in self.add_obs.keys()) and - ("Health2P2" in self.add_obs.keys())),\ - "Both first and second char healths, for both P1 and P2, must be present in add_obs" +\ - " to use tektagRoundEndChar2Penalty wrapper {}".format(self.add_obs.keys()) - - # Check single player mode is on - assert (isinstance(self.action_space, gym.spaces.MultiDiscrete) or - isinstance(self.action_space, gym.spaces.Discrete)),\ - "Only single player environment are supported by" +\ - " tektagRoundEndChar2Penalty wrapper, {}".format(type(self.action_space)) - - print("Applying Char Health Unbalance Penalty Wrapper") - - self.unbalance_thresh = unbalance_thresh - self.penalty = 0.1*self.unbalance_thresh - self.charManagement = [["ownHealth1", "ownHealth2"], ["ownHealth2", "ownHealth1"]] - - # Step the environment - def step(self, action): - - obs, reward, done, info = self.env.step(action) - - # If background char health minus foreground one is - # higher than threshold - keys = self.charManagement[obs["P1"]["ownActiveChar"]] - if ((obs["P1"][keys[1]] - obs["P1"][keys[0]]) > (self.unbalance_thresh / 2.0)): - # Add penalty for background character health bar - # print("Applying Health unbalance penalty: penalty = {}".format(-round(self.penalty, 2))) - reward = -self.penalty - - return obs, reward, done, info diff --git a/diambra/arena/stable_baselines3/make_sb3_env.py b/diambra/arena/stable_baselines3/make_sb3_env.py index de07791d..b554cd40 100644 --- a/diambra/arena/stable_baselines3/make_sb3_env.py +++ b/diambra/arena/stable_baselines3/make_sb3_env.py @@ -1,53 +1,57 @@ import os -import sys +import time import diambra.arena +from diambra.arena import EnvironmentSettings, WrappersSettings, RecordingSettings +from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv from stable_baselines3.common.utils import set_random_seed -from stable_baselines3.common.monitor import Monitor -# Make Stable Baselines Env function -def make_sb3_env(game_id: str, env_settings: dict={}, wrappers_settings: dict={}, - use_subprocess: bool=True, seed: int=0, log_dir_base: str="/tmp/DIAMBRALog/", - start_index: int=0, allow_early_resets: bool=True, - start_method: str=None, no_vec: bool=False): +# Make Stable Baselines3 Env function +def make_sb3_env(game_id: str, env_settings: EnvironmentSettings=EnvironmentSettings(), + wrappers_settings: WrappersSettings=WrappersSettings(), + episode_recording_settings: RecordingSettings=RecordingSettings(), + render_mode: str="rgb_array", seed: int=None, start_index: int=0, + allow_early_resets: bool=True, start_method: str=None, no_vec: bool=False, + use_subprocess: bool=True, log_dir_base: str="/tmp/DIAMBRALog/"): """ Create a wrapped, monitored VecEnv. :param game_id: (str) the game environment ID - :param env_settings: (dict) parameters for DIAMBRA Arena environment - :param wrappers_settings: (dict) parameters for environment - wraping function + :param env_settings: (EnvironmentSettings) parameters for DIAMBRA Arena environment + :param wrappers_settings: (WrappersSettings) parameters for environment wrapping function + :param episode_recording_settings: (RecordingSettings) parameters for environment recording wrapping function :param start_index: (int) start rank index :param allow_early_resets: (bool) allows early reset of the environment - :param start_method: (str) method used to start the subprocesses. - See SubprocVecEnv doc for more information - :param use_subprocess: (bool) Whether to use `SubprocVecEnv` or - `DummyVecEnv` when - :param no_vec: (bool) Whether to avoid usage of Vectorized Env or not. - Default: False - :param seed: (int) initial seed for RNG + :param start_method: (str) method used to start the subprocesses. See SubprocVecEnv doc for more information + :param use_subprocess: (bool) Whether to use `SubprocVecEnv` or `DummyVecEnv` + :param no_vec: (bool) Whether to avoid usage of Vectorized Env or not. Default: False :return: (VecEnv) The diambra environment """ env_addresses = os.getenv("DIAMBRA_ENVS", "").split() if len(env_addresses) == 0: raise Exception("ERROR: Running script without DIAMBRA CLI.") - sys.exit(1) num_envs = len(env_addresses) + # Seed management + if seed is None: + seed = int(time.time()) + env_settings.seed = seed + def _make_sb3_env(rank): def _init(): env = diambra.arena.make(game_id, env_settings, wrappers_settings, - seed=seed + rank, rank=rank) + episode_recording_settings, render_mode, rank=rank) + env.reset(seed=seed + rank) # Create log dir log_dir = os.path.join(log_dir_base, str(rank)) os.makedirs(log_dir, exist_ok=True) env = Monitor(env, log_dir, allow_early_resets=allow_early_resets) return env + set_random_seed(seed) return _init - set_random_seed(seed) # If not wanting vectorized envs if no_vec and num_envs == 1: diff --git a/diambra/arena/stable_baselines3/sb3_utils.py b/diambra/arena/stable_baselines3/sb3_utils.py index 931b5fe7..fc1a0d2b 100644 --- a/diambra/arena/stable_baselines3/sb3_utils.py +++ b/diambra/arena/stable_baselines3/sb3_utils.py @@ -1,9 +1,4 @@ from stable_baselines3.common.callbacks import BaseCallback -import cv2 -import os -import time -import json -import numpy as np from pathlib import Path # Linear scheduler for RL agent parameters diff --git a/diambra/arena/utils/controller.py b/diambra/arena/utils/controller.py index 5e4da4bd..d0d63a65 100644 --- a/diambra/arena/utils/controller.py +++ b/diambra/arena/utils/controller.py @@ -5,7 +5,10 @@ from inputs import devices import pickle from os.path import expanduser +import logging + home_dir = expanduser("~") +CONFIG_FILE_PATH = os.path.join(home_dir, '.diambra/config/deviceConfig.cfg') # Create devices list def create_devices_list(): @@ -30,11 +33,8 @@ def create_devices_list(): # Function to retrieve available devices and select one def available_devices(): - devices_list = create_devices_list() - print("Available devices:\n") - for idx, device_dict in devices_list.items(): print("{} - {} ({}) [{}]".format(idx, device_dict["device"].name, device_dict["type"], device_dict["id"])) @@ -57,7 +57,6 @@ def get_diambra_controller(action_list, cfg=["But1", "But2", "But3", "But4", "Bu return DiambraKeyboard(device_dict["device"], action_list, cfg, force_configure, skip_configure) else: return DiambraGamepad(device_dict["device"], action_list, cfg, force_configure, skip_configure) - except: raise Exception("Unable to initialize device, have you unplugged it during execution?") else: @@ -68,9 +67,11 @@ class DiambraDevice(Thread): # def class type thread def __init__(self, device, action_list=(("NoMove", "Left", "UpLeft", "Up", "UpRight", "Right", "DownRight", "Down", "DownLeft"), ("But0", "But1", "But2", "But3", "But4", "But5", "But6", "But7", "But8")), cfg=["But1", "But2", "But3", "But4", "But5", "But6", "But7", "But8"], - force_configure=False, skip_configure=False): + force_configure=False, skip_configure=False, logging_level=logging.INFO): # thread init class (don't forget this) Thread.__init__(self, daemon=True) + self.logger = logging.getLogger(__name__) + self.logger.basicConfig(logging_level) self.stop_event = Event() @@ -80,7 +81,7 @@ def __init__(self, device, action_list=(("NoMove", "Left", "UpLeft", "Up", "UpRi self.select_but = 0 self.event_hash_move = np.zeros(4) self.event_hash_attack = np.zeros(8) - self.device_config_file_name = os.path.join(home_dir, '.diambra/config/deviceConfig.cfg') + self.device_config_file_path = CONFIG_FILE_PATH self.all_actions_list = (("NoMove", "Left", "UpLeft", "Up", "UpRight", "Right", "DownRight", "Down", "DownLeft"), @@ -119,7 +120,7 @@ def __init__(self, device, action_list=(("NoMove", "Left", "UpLeft", "Up", "UpRi ans = input("Want to reconfigure the device? (y/n): ") if ans == "y": - print("Restarting device configuration") + self.logger.info("Restarting device configuration") # Run device configuration self.configure() else: @@ -129,12 +130,11 @@ def __init__(self, device, action_list=(("NoMove", "Left", "UpLeft", "Up", "UpRi self.save_device_configuration() - print("Diambra device initialized on device {} [{}]".format(self.device.name, self.device_id)) + self.logger.info("Diambra device initialized on device {} [{}]".format(self.device.name, self.device_id)) # Show device events def show_device_events(self, event_codes_to_skip=[], event_code_to_show=None): - - print("Use device to see events") + self.logger.info("Use device to see events") while True: for event in self.device.read(): if event.ev_type != "Sync" and event.ev_type != "Misc": @@ -142,7 +142,7 @@ def show_device_events(self, event_codes_to_skip=[], event_code_to_show=None): if event_code_to_show is not None: if event.code != event_code_to_show: continue - print("Event type: {}, event code: {}, event state: {}".format(event.ev_type, event.code, event.state)) + self.logger.info("Event type: {}, event code: {}, event state: {}".format(event.ev_type, event.code, event.state)) # Prepare device config dict to be saved def process_device_dict_for_save(self): @@ -151,7 +151,7 @@ def process_device_dict_for_save(self): # Save device configuration def save_device_configuration(self): - print("Saving configuration in {}".format(self.device_config_file_name)) + self.logger.info("Saving configuration in {}".format(self.device_config_file_path)) # Convert device config dictionary cfg_dict_to_save = self.process_device_dict_for_save() @@ -169,27 +169,18 @@ def save_device_configuration(self): cfg_file_dict_list.append(cfg_dict_to_save) # Open file (new or overwrite previous one) - cfg_file = open(self.device_config_file_name, "wb") + cfg_file = open(self.device_config_file_path, "wb") pickle.dump(cfg_file_dict_list, cfg_file) cfg_file.close() # Load all devices configuration def load_all_device_configurations(self): - - try: - cfg_file = open(self.device_config_file_name, "rb") - - # Load Pickle Dict - cfg_file_dict_list = pickle.load(cfg_file) - - cfg_file.close() - - except OSError: - print("No device configuration file found in: {}".format(os.path.join(home_dir, '.diambra/config/'))) - config_file_folder = os.path.join(home_dir, '.diambra/') - os.makedirs(config_file_folder, exist_ok=True) - config_file_folder = os.path.join(home_dir, '.diambra/config/') - os.makedirs(config_file_folder, exist_ok=True) + if os.path.exists(self.device_config_file_path): + with open(self.device_config_file_path, "rb") as cfg_file: + cfg_file_dict_list = pickle.load(cfg_file) + else: + self.logger.info("No device configuration file found in: {}".format(os.path.dirname(self.device_config_file_path))) + os.makedirs(os.path.dirname(self.device_config_file_path), exist_ok=True) cfg_file_dict_list = [] self.cfg_file_dict_list = cfg_file_dict_list @@ -204,8 +195,7 @@ def configure(self): # Configuration Test def config_test(self): - - print("Press Start to end configuration test") + self.logger.info("Press Start to end configuration test") # Execute run function in thread mode thread = Thread(target=self.run, args=()) @@ -213,19 +203,17 @@ def config_test(self): thread.start() while True: - actions = self.get_all_actions() if actions[0] != 0: - print("Move action = {}. (Press START to end configuration test).".format(self.all_actions_list[0][actions[0]])) + self.logger.info("Move action = {}. (Press START to end configuration test).".format(self.all_actions_list[0][actions[0]])) if actions[1] != 0: - print("Attack action = {}. (Press START to end configuration test).".format(self.all_actions_list[1][actions[1]])) + self.logger.info("Attack action = {}. (Press START to end configuration test).".format(self.all_actions_list[1][actions[1]])) if actions[2] != 0: break # Creating hash dictionary def compose_hash_dict(self, dictionary, hash_elem): - key_val_list = [] key_val = "" @@ -391,40 +379,40 @@ def load_device_configuration(self): self.device_dict["Arrow"][item[0]] = item[1] config_found = True - print("Device configuration file found in: {}".format(os.path.join(home_dir, '.diambra/config/'))) - print("Device configuration file loaded.") + self.logger.info("Device configuration file found in: {}".format(os.path.dirname(self.device_config_file_path))) + self.logger.info("Device configuration file loaded.") except: - print("Invalid device configuration file found in: {}".format(os.path.join(home_dir, '.diambra/config/'))) + self.logger.info("Invalid device configuration file found in: {}".format(os.path.dirname(self.device_config_file_path))) if not config_found: - print("Configuration for this device not present in device configuration file") + self.logger.info("Configuration for this device not present in device configuration file") return config_found # Configure device buttons def configure(self): - print("") - print("") - print("Configuring device {}".format(self.device)) - print("") - print("# Buttons CFG file") - print(" _______ ") - print(" B7 __|digital|__ B8 ") - print(" B5 |buttons| B6 ") - print(" / \ ") - print(" SELECT START ") - print(" UP B1 ") - print(" | ") - print(" LEFT-- --RIGHT B4 B2") - print(" | ") - print(" DOWN B3 ") - print(" __/____ ") - print(" |digital| ") - print(" | move | ") - print(" ------- ") - print("") - print("") + self.logger.info("") + self.logger.info("") + self.logger.info("Configuring device {}".format(self.device)) + self.logger.info("") + self.logger.info("# Buttons CFG file") + self.logger.info(" _______ ") + self.logger.info(" B7 __|digital|__ B8 ") + self.logger.info(" B5 |buttons| B6 ") + self.logger.info(" / \ ") + self.logger.info(" SELECT START ") + self.logger.info(" UP B1 ") + self.logger.info(" | ") + self.logger.info(" LEFT-- --RIGHT B4 B2") + self.logger.info(" | ") + self.logger.info(" DOWN B3 ") + self.logger.info(" __/____ ") + self.logger.info(" |digital| ") + self.logger.info(" | move | ") + self.logger.info(" ------- ") + self.logger.info("") + self.logger.info("") self.code_to_group_map = defaultdict(lambda: "") self.device_dict = {} @@ -433,8 +421,8 @@ def configure(self): # Buttons configuration # Start and Select - print("Return/Enter key is not-allowed and would cause the program to stop.") - print("Press START button") + self.logger.info("Return/Enter key is not-allowed and would cause the program to stop.") + self.logger.info("Press START button") but_not_set = True start_set = False while but_not_set: @@ -445,13 +433,13 @@ def configure(self): raise Exception("Return/Enter not-allowed, aborting.") self.start_code = event.code start_set = True - print("Start associated with {}".format(event.code)) + self.logger.info("Start associated with {}".format(event.code)) else: if start_set == True: but_not_set = False break - print("Press SELECT button (Start to skip)") + self.logger.info("Press SELECT button (Start to skip)") but_not_set = True while but_not_set: for event in self.device.read(): @@ -460,11 +448,11 @@ def configure(self): if (event.code == "KEY_ENTER"): raise Exception("Return/Enter not-allowed, aborting.") self.select_code = event.code - print("Select associated with {}".format(event.code)) + self.logger.info("Select associated with {}".format(event.code)) else: but_not_set = False if event.code == self.start_code: - print("Select association skipped") + self.logger.info("Select association skipped") break # Attack buttons @@ -474,7 +462,7 @@ def configure(self): if end_flag: break - print("Press B{} button (SELECT / START to end configuration)".format(idx+1)) + self.logger.info("Press B{} button (SELECT / START to end configuration)".format(idx+1)) but_not_set = True @@ -487,20 +475,20 @@ def configure(self): if event.ev_type == "Key": if (event.code != self.start_code and event.code != self.select_code): if event.state > 0: - print("Button B{}, event code = {}".format(idx+1, event.code)) + self.logger.info("Button B{}, event code = {}".format(idx+1, event.code)) self.device_dict["Key"][event.code] = idx self.code_to_group_map[event.code] = "Key" elif event.state == 0: but_not_set = False else: if event.state == 0: - print("Remaining buttons configuration skipped") + self.logger.info("Remaining buttons configuration skipped") end_flag = True break # Moves end_flag = False - print("Configuring moves") + self.logger.info("Configuring moves") moves_list = ["UP", "RIGHT", "DOWN", "LEFT"] for idx, move in enumerate(moves_list): @@ -508,7 +496,7 @@ def configure(self): if end_flag: break - print("Press {} arrow (SELECT / START to skip)".format(move)) + self.logger.info("Press {} arrow (SELECT / START to skip)".format(move)) but_not_set = True @@ -521,20 +509,20 @@ def configure(self): if event.ev_type == "Key": if (event.code != self.start_code and event.code != self.select_code): if event.state > 0: - print("Move {}, event code = {}".format(move, event.code)) + self.logger.info("Move {}, event code = {}".format(move, event.code)) self.device_dict["Arrow"][event.code] = idx self.code_to_group_map[event.code] = "Arrow" elif event.state == 0: but_not_set = False else: if event.state == 0: - print("Remaining buttons configuration skipped") + self.logger.info("Remaining buttons configuration skipped") end_flag = True break - print("Device dict : ") - print("Buttons (Keys) dict : ", self.device_dict["Key"]) - print("Arrows (Keys) dict : ", self.device_dict["Arrow"]) + self.logger.info("Device dict : ") + self.logger.info("Buttons (Keys) dict : ", self.device_dict["Key"]) + self.logger.info("Arrows (Keys) dict : ", self.device_dict["Arrow"]) input("Configuration completed, press Enter to continue.") @@ -661,40 +649,40 @@ def load_device_configuration(self): item[1], item[2], item[3], item[4]] config_found = True - print("Device configuration file found in: {}".format(os.path.join(home_dir, '.diambra/config/'))) - print("Device configuration file loaded.") + self.logger.info("Device configuration file found in: {}".format(os.path.dirname(self.device_config_file_path))) + self.logger.info("Device configuration file loaded.") except: - print("Invalid device configuration file found in: {}".format(os.path.join(home_dir, '.diambra/config/'))) + self.logger.info("Invalid device configuration file found in: {}".format(os.path.dirname(self.device_config_file_path))) if not config_found: - print("Configuration for this device not present in device configuration file") + self.logger.info("Configuration for this device not present in device configuration file") return config_found # Configure device buttons def configure(self): - print("") - print("") - print("Configuring device {}".format(self.device)) - print("") - print("# Buttons CFG file") - print(" _______ ") - print(" B7 __|digital|__ B8 ") - print(" B5 |buttons| B6 ") - print(" / \ ") - print(" B1 ") - print(" | SELECT START ") - print(" -- -- B4 B2") - print(" | - ") - print(" __/____ ( + ) B3 ") - print("|digital| - ") - print("| move | \______ ") - print(" ------- |analog| ") - print(" | move | ") - print(" ------ ") - print("") - print("NB: Be sure to have your analog switch on before starting.") - print("") + self.logger.info("") + self.logger.info("") + self.logger.info("Configuring device {}".format(self.device)) + self.logger.info("") + self.logger.info("# Buttons CFG file") + self.logger.info(" _______ ") + self.logger.info(" B7 __|digital|__ B8 ") + self.logger.info(" B5 |buttons| B6 ") + self.logger.info(" / \ ") + self.logger.info(" B1 ") + self.logger.info(" | SELECT START ") + self.logger.info(" -- -- B4 B2") + self.logger.info(" | - ") + self.logger.info(" __/____ ( + ) B3 ") + self.logger.info("|digital| - ") + self.logger.info("| move | \______ ") + self.logger.info(" ------- |analog| ") + self.logger.info(" | move | ") + self.logger.info(" ------ ") + self.logger.info("") + self.logger.info("NB: Be sure to have your analog switch on before starting.") + self.logger.info("") self.device_dict = {} self.device_dict["Key"] = defaultdict(lambda: 7) @@ -702,30 +690,30 @@ def configure(self): # Buttons configuration # Start and Select - print("Press START button") + self.logger.info("Press START button") but_not_set = True while but_not_set: for event in self.device.read(): if event.ev_type == "Key": if event.state == 1: self.start_code = event.code - print("Start associated with {}".format(event.code)) + self.logger.info("Start associated with {}".format(event.code)) else: but_not_set = False break - print("Press SELECT button (Start to skip)") + self.logger.info("Press SELECT button (Start to skip)") but_not_set = True while but_not_set: for event in self.device.read(): if event.ev_type == "Key": if event.code != self.start_code and event.state == 1: self.select_code = event.code - print("Select associated with {}".format(event.code)) + self.logger.info("Select associated with {}".format(event.code)) else: but_not_set = False if event.code == self.start_code: - print("Select association skipped") + self.logger.info("Select association skipped") break # Attack buttons @@ -735,7 +723,7 @@ def configure(self): if end_flag: break - print("Press B{} button (SELECT / START to end configuration)".format(idx+1)) + self.logger.info("Press B{} button (SELECT / START to end configuration)".format(idx+1)) but_not_set = True @@ -748,20 +736,20 @@ def configure(self): if event.ev_type == "Key": if (event.code != self.start_code and event.code != self.select_code): if event.state == 1: - print("Button B{}, event code = {}".format(idx+1, event.code)) + self.logger.info("Button B{}, event code = {}".format(idx+1, event.code)) self.device_dict["Key"][event.code] = idx elif event.state == 0: but_not_set = False else: if event.state == 0: - print("Remaining buttons configuration skipped") + self.logger.info("Remaining buttons configuration skipped") end_flag = True break # Move sticks # Digital end_flag = False - print("Configuring digital move") + self.logger.info("Configuring digital move") moves_list = ["UP", "RIGHT", "DOWN", "LEFT"] event_codes_list = ["Y", "X", "Y", "X"] self.device_dict["Absolute"]["ABS_HAT0Y"] = defaultdict(lambda: [ @@ -776,7 +764,7 @@ def configure(self): if end_flag: break - print("Press {} arrow (SELECT / START to skip)".format(move)) + self.logger.info("Press {} arrow (SELECT / START to skip)".format(move)) but_not_set = True @@ -790,16 +778,16 @@ def configure(self): if event.ev_type == "Absolute": if event.code == "ABS_HAT0" + event_codes_list[idx]: if abs(event.state) == 1: - print("{} move event code = {}, event state = {}".format( + self.logger.info("{} move event code = {}, event state = {}".format( move, event.code, event.state)) self.device_dict["Absolute"][event.code][event.state] = [ idx, abs(event.state)] elif event.state == 0: but_not_set = False else: - print("Digital Move Stick assumes not admissible values: {}".format( + self.logger.info("Digital Move Stick assumes not admissible values: {}".format( event.state)) - print( + self.logger.info( "Digital Move Stick not supported, configuration skipped") end_flag = True break @@ -807,13 +795,13 @@ def configure(self): if (event.code == self.start_code or event.code == self.select_code): if event.state == 0: - print("Digital Move Stick configuration skipped") + self.logger.info("Digital Move Stick configuration skipped") end_flag = True break # Move sticks # Analog - print("Configuring analog move") + self.logger.info("Configuring analog move") moves_list = ["UP", "RIGHT", "DOWN", "LEFT"] event_codes_list = ["Y", "X", "Y", "X"] self.max_analog_val = {} @@ -821,8 +809,7 @@ def configure(self): for idx, move in enumerate(moves_list): - print( - "Move left analog in {} position, keep it there and press Start".format(move)) + self.logger.info("Move left analog in {} position, keep it there and press Start".format(move)) but_not_set = True @@ -856,17 +843,17 @@ def configure(self): self.origin_analog_val[moves_list[idx]]) * thresh_perc self.delta_perc[idx+2] = -self.delta_perc[idx] - print("Delta perc = ", self.delta_perc) + self.logger.info("Delta perc = ", self.delta_perc) # Addressing Y-X axis for idx in range(2): - print("{} move event code = ABS_{}, ".format(moves_list[idx], + self.logger.info("{} move event code = ABS_{}, ".format(moves_list[idx], event_codes_list[idx]) + "event state = {}".format(self.max_analog_val[moves_list[idx]])) - print("{} move event code = ABS_{}, ".format(moves_list[idx+2], + self.logger.info("{} move event code = ABS_{}, ".format(moves_list[idx+2], event_codes_list[idx+2]) + " event state = {}".format(self.max_analog_val[moves_list[idx+2]])) - print("NO {}-{} move event code =".format(moves_list[idx], + self.logger.info("NO {}-{} move event code =".format(moves_list[idx], moves_list[idx+2]) + " ABS_{}, ".format(event_codes_list[idx]) + "event state = {}".format(self.origin_analog_val[moves_list[idx]])) @@ -885,14 +872,13 @@ def configure(self): [[idx], 1], [[idx, idx+2], 0], [[idx+2], 1]] else: - print( - "Not admissible values found in analog stick configuration, skipping") + self.logger.info("Not admissible values found in analog stick configuration, skipping") - print("device dict : ") - print("Buttons (Keys) dict : ", self.device_dict["Key"]) - print("Moves (Absolute) dict : ", self.device_dict["Absolute"]) + self.logger.info("device dict : ") + self.logger.info("Buttons (Keys) dict : ", self.device_dict["Key"]) + self.logger.info("Moves (Absolute) dict : ", self.device_dict["Absolute"]) - print("Configuration completed.") + self.logger.info("Configuration completed.") return @@ -941,7 +927,6 @@ def run(self): # run is a default Thread function if __name__ == "__main__": - print("\nWhat do you want to do:") print(" 1 - Show device events") print(" 2 - Configure device") diff --git a/diambra/arena/utils/diambra_data_loader.py b/diambra/arena/utils/diambra_data_loader.py new file mode 100644 index 00000000..07e3308d --- /dev/null +++ b/diambra/arena/utils/diambra_data_loader.py @@ -0,0 +1,83 @@ +import pickle +import bz2 +import cv2 +import os +import logging +import numpy as np +import sys + +# Diambra dataloader +class DiambraDataLoader: + def __init__(self, dataset_path: str, log_level=logging.INFO): + logging.basicConfig(level=log_level) + self.logger = logging.getLogger(__name__) + + # List of RL trajectories files + self.dataset_path = dataset_path + self.episode_files = [] + + if not os.path.exists(self.dataset_path): + raise FileNotFoundError(f"The path '{self.dataset_path}' does not exist.") + + if not os.path.isdir(self.dataset_path): + raise NotADirectoryError(f"'{self.dataset_path}' is not a directory.") + + episode_files = [filename for filename in os.listdir(self.dataset_path) if filename.endswith(".diambra")] + + if not episode_files: + raise Exception("No '.diambra' files found in the specified directory.") + + self.episode_files = episode_files + + # Idx of trajectory file to read + self.file_idx = 0 + self.n_loops = 0 + self.frame = np.zeros((128, 128, 1), dtype=np.uint8) + + # Step the environment + def step(self): + + step_data = self.episode_data[self.step_idx] + self.frame = cv2.imdecode(np.frombuffer(step_data["obs"]["frame"], dtype=np.uint8), cv2.IMREAD_UNCHANGED) + step_data["obs"]["frame"] = self.frame + self.step_idx += 1 + + return step_data["obs"], step_data["action"], step_data["reward"], step_data["terminated"], step_data["truncated"], step_data["info"], + + # Resetting the environment + def reset(self): + + self.n_loops += int(self.file_idx / len(self.episode_files)) + self.file_idx = self.file_idx % len(self.episode_files) + + # Open the next episode file + episode_file = self.episode_files[self.file_idx] + # Idx of trajectory file to read + self.file_idx += 1 + self.frame = np.zeros((128, 128, 1), dtype=np.uint8) + + # Read compressed RL Traj file + in_file = bz2.BZ2File(os.path.join(self.dataset_path, episode_file), 'r') + self.episode = pickle.load(in_file) + in_file.close() + + self.logger.info("Episode summary = {}".format(self.episode["episode_summary"])) + self.episode_data = self.episode["data"] + + # Reset run step + self.step_idx = 0 + + return self.n_loops + + # Rendering the environment + def render(self, waitKey=1): + if (sys.platform.startswith('linux') is False or 'DISPLAY' in os.environ): + try: + window_name = "Diambra Data Loader" + cv2.namedWindow(window_name, cv2.WINDOW_GUI_NORMAL) + cv2.imshow(window_name, self.frame) + cv2.waitKey(waitKey) + return True + except: + return False + diff --git a/diambra/arena/utils/engine_mock.py b/diambra/arena/utils/engine_mock.py index 33cd74f6..d20b2900 100644 --- a/diambra/arena/utils/engine_mock.py +++ b/diambra/arena/utils/engine_mock.py @@ -2,65 +2,35 @@ import random import numpy as np import diambra.arena -from diambra.engine import Client, model +from copy import deepcopy +from diambra.engine import model +from diambra.arena import Roles class DiambraEngineMock: - - def __init__(self, steps_per_round=20, fps=1000): + def __init__(self, fps=1000, override_perfect_probability=None): # Game features self.game_data = None - self.steps_per_round = steps_per_round self.fps = fps - # Random seed - time_dep_seed = int((time.time() - int(time.time() - 0.5)) * 1000) - random.seed(time_dep_seed) - # Class state variables initialization - self.n_steps = 0 - self.n_rounds_won = 0 - self.n_rounds_lost = 0 - self.n_stages = 0 + self.timer = 0 + self.current_stage_number = 1 self.n_continue = 0 - self.side_p1 = 0 - self.side_p2 = 1 - self.char_p1 = 0 - self.char_p2 = 0 - self.health_p1 = 0 - self.health_p2 = 0 + self.side = {Roles.P1: 0, Roles.P2: 1} + self.char = {Roles.P1: 0, Roles.P2: 0} + self.health = {Roles.P1: 0, Roles.P2: 0} + self.wins = {Roles.P1: 0, Roles.P2: 0} self.player = "" self.perfect = False + self.override_perfect_probability = override_perfect_probability - def _mock__init__(self, env_address, grpc_timeout=60): + def mock__init__(self, env_address, grpc_timeout=60): print("Trying to connect to DIAMBRA Engine server (timeout={}s)...".format(grpc_timeout)) print("... done (MOCKED!).") - def generate_ram_states(self): - - for k, v in self.ram_states.items(): - self.ram_states[k][3] = random.choices(range(v[1], v[2] + 1))[0] - - # Setting meaningful values to ram states - self.ram_states["stage"][3] = self.n_stages + 1 - self.ram_states["SideP1"][3] = self.side_p1 - self.ram_states["SideP2"][3] = self.side_p2 - self.ram_states["WinsP1"][3] = self.n_rounds_won - self.ram_states["WinsP2"][3] = self.n_rounds_lost - - self.ram_states["CharP1"][3] = self.char_p1 - self.ram_states["CharP2"][3] = self.char_p2 - - if self.game_data["number_of_chars_per_round"] == 1: - self.ram_states["HealthP1"][3] = self.health_p1 - self.ram_states["HealthP2"][3] = self.health_p2 - else: - for idx in range(self.game_data["number_of_chars_per_round"]): - self.ram_states["Health{}P1".format(idx+1)][3] = self.health_p1 - self.ram_states["Health{}P2".format(idx+1)][3] = self.health_p2 - # Send env settings, retrieve env info and int variables list [pb low level] - def _mock_env_init(self, env_settings_pb): + def mock_env_init(self, env_settings_pb): self.settings = env_settings_pb # Print settings @@ -77,10 +47,15 @@ def _mock_env_init(self, env_settings_pb): "Hard": [0.25, 0.1], } - difficulty_level = self.game_data["difficulty_to_cluster_map"][str(self.settings.difficulty)] + difficulty = self.settings.episode_settings.difficulty + if difficulty == 0: + difficulty = random.choice(range(self.game_data["difficulty"][0], self.game_data["difficulty"][1] + 1)) + difficulty_level = self.game_data["difficulty_to_cluster_map"][str(difficulty)] self.base_round_winning_probability = probability_maps[difficulty_level][0] ** (1.0/self.game_data["stages_per_game"]) self.perfect_probability = probability_maps[difficulty_level][1] + if self.override_perfect_probability is not None: + self.perfect_probability = self.override_perfect_probability self.frame_shape = self.game_data["frame_shape"] if (self.settings.frame_shape.h > 0 and self.settings.frame_shape.w > 0): @@ -89,84 +64,145 @@ def _mock_env_init(self, env_settings_pb): if (self.settings.frame_shape.c == 1): self.frame_shape[2] = self.settings.frame_shape.c - self.continue_per_episode = - int(self.settings.continue_game) if self.settings.continue_game < 0.0 else int(self.settings.continue_game*10) + continue_game_setting = self.settings.episode_settings.continue_game + self.continue_per_episode = - int(continue_game_setting) if continue_game_setting < 0.0 else int(continue_game_setting*10) self.delta_health = self.game_data["health"][1] - self.game_data["health"][0] - self.base_hit = int(self.delta_health * (self.game_data["n_actions"][0] + self.game_data["n_actions"][0]) / (self.game_data["n_actions"][1] * (self.steps_per_round - 1))) + self.base_hit = int(self.delta_health * self.game_data["n_actions"][1] / + ((self.game_data["n_actions"][0] + self.game_data["n_actions"][1]) * + (self.game_data["ram_states"]["common"]["timer"][2] / self.settings.step_ratio))) # Generate the ram states map - self.ram_states = self.game_data["ram_states"] + self.ram_states = {} + self.ram_states[model.RamStatesCategories.common] = self.game_data["ram_states"]["common"] + self.ram_states[model.RamStatesCategories.P1] = deepcopy(self.game_data["ram_states"]["Px"]) + self.ram_states[model.RamStatesCategories.P2] = deepcopy(self.game_data["ram_states"]["Px"]) for k, v in self.ram_states.items(): - self.ram_states[k].append(0) + for k2, v2 in v.items(): + self.ram_states[k][k2].append(0) # Build the response response = model.EnvInitResponse() + # Frame response.frame_shape.h = self.frame_shape[0] response.frame_shape.w = self.frame_shape[1] response.frame_shape.c = self.frame_shape[2] - response.available_actions.with_buttons_combination.moves = self.game_data["n_actions"][0] - response.available_actions.with_buttons_combination.attacks = self.game_data["n_actions"][2] - response.available_actions.without_buttons_combination.moves = self.game_data["n_actions"][0] - response.available_actions.without_buttons_combination.attacks = self.game_data["n_actions"][1] - - response.delta_health = self.delta_health - response.max_stage = self.game_data["stages_per_game"] - response.cumulative_reward_bounds.min = -((self.game_data["rounds_per_stage"] - 1) * (response.max_stage - 1) + self.game_data["rounds_per_stage"]) * response.delta_health - response.cumulative_reward_bounds.max = self.game_data["rounds_per_stage"] * response.max_stage * response.delta_health - response.char_list.extend(self.game_data["char_list"]) - - response.buttons.moves.extend(["NoMove", "Left", "UpLeft", "Up", "UpRight", "Right", "DownRight", "Down", "DownLeft"]) - response.buttons.attacks.extend(["But{}".format(i) for i in range(self.game_data["n_actions"][1])] +\ - ["But{}But{}".format(i - self.game_data["n_actions"][1] + 1, i - self.game_data["n_actions"][1] + 2)\ - for i in range(self.game_data["n_actions"][1], self.game_data["n_actions"][2])]) - response.button_mapping.moves.extend(["0", " ", "1", "\u2190", "2", "\u2196", "3", "\u2191", - "4", "\u2197", "5", "\u2192", "6", "\u2198", "7", "\u2193", "8", "\u2199"]) - attack_mapping = ["0", " "] - for i in range(1, self.game_data["n_actions"][2]): - attack_mapping += [str(i), "Attack{}".format(i)] - response.button_mapping.attacks.extend(attack_mapping) - - self.generate_ram_states() + # Available actions + response.available_actions.n_moves = self.game_data["n_actions"][0] + response.available_actions.n_attacks = self.game_data["n_actions"][1] + response.available_actions.n_attacks_no_comb = self.game_data["n_actions"][2] + + move_keys = ["NoMove", "Left", "UpLeft", "Up", "UpRight", "Right", "DownRight", "Down", "DownLeft"] + move_labels = [" ", "\u2190", "\u2196", "\u2191", "\u2197", "\u2192", "\u2198", "\u2193", "\u2199"] + for idx in range(self.game_data["n_actions"][0]): + button = model.EnvInitResponse.AvailableActions.Button() + button.key = move_keys[idx] + button.label = move_labels[idx] + response.available_actions.moves.append(button) + + attack_keys = ["But{}".format(i) for i in range(self.game_data["n_actions"][2])] +\ + ["But{}But{}".format(i - self.game_data["n_actions"][2] + 1, i - self.game_data["n_actions"][2] + 2)\ + for i in range(self.game_data["n_actions"][2], self.game_data["n_actions"][1])] + attack_labels = [" "] + for i in range(1, self.game_data["n_actions"][1]): + attack_labels += ["Attack{}".format(i)] + for idx in range(self.game_data["n_actions"][1]): + button = model.EnvInitResponse.AvailableActions.Button() + button.key = attack_keys[idx] + button.label = attack_labels[idx] + response.available_actions.attacks.append(button) + + # Cumulative reward bounds + response.cumulative_reward_bounds.min = -((self.game_data["rounds_per_stage"] - 1) * (self.game_data["stages_per_game"] - 1) + self.game_data["rounds_per_stage"]) * self.delta_health + response.cumulative_reward_bounds.max = self.game_data["rounds_per_stage"] * self.game_data["stages_per_game"] * self.delta_health + + # Characters info + response.characters_info.char_list.extend(self.game_data["char_list"]) + response.characters_info.char_forbidden_list.extend(self.game_data["char_forbidden_list"]) + for key, value in self.game_data["char_homonymy_map"].items(): + response.characters_info.char_homonymy_map[key] = value + response.characters_info.chars_per_round = self.game_data["number_of_chars_per_round"] + response.characters_info.chars_to_select = self.game_data["number_of_chars_to_select"] + + # Difficulty bounds + response.difficulty_bounds.min = self.game_data["difficulty"][0] + response.difficulty_bounds.max = self.game_data["difficulty"][1] + + # RAM states + self._generate_ram_states() for k, v in self.ram_states.items(): - response.ram_states[k].type = v[0] - response.ram_states[k].min = v[1] - response.ram_states[k].max = v[2] - response.ram_states[k].val = v[3] + for k2, v2 in v.items(): + k2_enum =model.RamStates.Value(k2) + response.ram_states_categories[k].ram_states[k2_enum].type = model.SpaceTypes.Value(v2[0]) + response.ram_states_categories[k].ram_states[k2_enum].min = v2[1] + response.ram_states_categories[k].ram_states[k2_enum].max = v2[2] return response - def generate_frame(self): - frame = np.ones((self.frame_shape), dtype=np.int8) * ((self.n_stages * self.game_data["rounds_per_stage"] + self.n_steps) % 255) + # Reset the environment [pb low level] + def mock_reset(self, episode_settings): + # Update variable env settings + self.settings.episode_settings.CopyFrom(episode_settings) + + # Random seed + random.seed(self.settings.episode_settings.random_seed) + np.random.seed(self.settings.episode_settings.random_seed) + + self._reset_state() + + return self._update_step_reset_response() + + # Step the environment [pb low level] + def mock_step(self, actions): + # Update class state + self._new_game_state(actions) + + return self._update_step_reset_response() + + # Closing DIAMBRA Arena + def mock_close(self): + pass + + def _generate_ram_states(self): + for k, v in self.ram_states.items(): + for k2, v2 in v.items(): + self.ram_states[k][k2][3] = random.choice(range(v2[1], v2[2] + 1)) + + # Setting meaningful values to ram states + values = [self.char, self.health, self.wins, self.side] + + for idx, state in enumerate(["character", "health", "wins", "side"]): + for text in ["", "_1", "_2", "_3"]: + for player in [Roles.P1, Roles.P2]: + key = "{}{}".format(state, text) + if (key in self.ram_states[player]): + self.ram_states[player][key][3] = values[idx][player] + + self.ram_states[model.RamStatesCategories.common]["stage"][3] = int(self.current_stage_number) + self.ram_states[model.RamStatesCategories.common]["timer"][3] = int(self.timer) + + def _generate_frame(self): + frame = np.ones((self.frame_shape), dtype=np.int8) * ((self.current_stage_number * self.game_data["rounds_per_stage"] + int(self.timer)) % 255) return frame.tobytes() # Set delta health - def set_perfect_chance(self): + def _set_perfect_chance(self): self.perfect = random.choices([True, False], [self.perfect_probability, 1.0 - self.perfect_probability])[0] + # Force perfect to true in case of 2P games to avoid double update for own role (see health evolution in new_game_state) + self.perfect = self.perfect or self.settings.n_players == 2 # Reset game state - def reset_state(self): + def _reset_state(self): # Reset class state - self.n_steps = 0 - self.n_rounds_won = 0 - self.n_rounds_lost = 0 - self.n_stages = 0 + self.current_stage_number = 1 self.n_continue = 0 # Actions - self.mov_p1 = 0 - self.att_p1 = 0 - self.mov_p2 = 0 - self.att_p2 = 0 - - # Player - if self.settings.player != "Random": - self.player = self.settings.player - else: - self.player = random.choices(["P1", "P2"])[0] + self.player_actions = [[0, 0], [0, 0]] - # Set delta healths - self.set_perfect_chance() + # Set perfect chance + self._set_perfect_chance() # Done flags self.round_done_ = False @@ -175,50 +211,29 @@ def reset_state(self): self.episode_done_ = False self.env_done_ = False - self.side_p1 = 0 - self.side_p2 = 1 - self.health_p1 = self.game_data["health"][1] - self.health_p2 = self.game_data["health"][1] + self.side[Roles.P1] = 0 + self.side[Roles.P2] = 1 + self.health[Roles.P1] = self.game_data["health"][1] + self.health[Roles.P2] = self.game_data["health"][1] + self.wins[Roles.P1] = 0 + self.wins[Roles.P2] = 0 + self.timer = self.game_data["ram_states"]["common"]["timer"][2] self.reward = 0 # Characters - if self.player == "P1P2": - if (self.settings.characters.p1[0] == "Random"): - self.char_p1 = random.choices(range(len(self.game_data["char_list"])))[0] - else: - self.char_p1 = self.game_data["char_list"].index(self.settings.characters.p1[0]) - - if (self.settings.characters.p2[0] == "Random"): - self.char_p2 = random.choices(range(len(self.game_data["char_list"])))[0] - else: - self.char_p2 = self.game_data["char_list"].index(self.settings.characters.p2[0]) - - elif self.player == "P1": - self.char_p2 = self.n_stages - if (self.settings.characters.p1[0] == "Random"): - self.char_p1 = random.choices(range(len(self.game_data["char_list"])))[0] - else: - self.char_p1 = self.game_data["char_list"].index(self.settings.characters.p1[0]) - - else: - self.char_p1 = self.n_stages - if (self.settings.characters.p2[0] == "Random"): - self.char_p2 = random.choices(range(len(self.game_data["char_list"])))[0] - else: - self.char_p2 = self.game_data["char_list"].index(self.settings.characters.p2[0]) + for idx in range(self.settings.n_players): + self.char[self.settings.episode_settings.player_settings[idx].role] =\ + self.game_data["char_list"].index(self.settings.episode_settings.player_settings[idx].characters[0]) # Update game state - def new_game_state(self, mov_p1=0, att_p1=0, mov_p2=0, att_p2=0): - + def _new_game_state(self, actions): # Sleep to simulate computer time elapsed - time.sleep(1.0/self.fps) + time.sleep(1.0/(self.settings.step_ratio * self.fps)) # Actions - self.mov_p1 = mov_p1 - self.att_p1 = att_p1 - self.mov_p2 = mov_p2 - self.att_p2 = att_p2 + for idx, action in enumerate(actions): + self.player_actions[idx] = [action[0], action[1]] # Done flags self.round_done_ = False @@ -227,165 +242,129 @@ def new_game_state(self, mov_p1=0, att_p1=0, mov_p2=0, att_p2=0): self.episode_done_ = False self.env_done_ = False - self.n_steps += 1 + self.timer -= (1.0 * self.settings.step_ratio) / 60.0 - starting_health_p1 = self.health_p1 - starting_health_p2 = self.health_p2 + starting_health = { + Roles.P1: self.health[Roles.P1], + Roles.P2: self.health[Roles.P2] + } # Health evolution - hit_prob = self.base_round_winning_probability ** self.n_stages + hit_prob = self.base_round_winning_probability ** self.current_stage_number - if self.player == "P2": + for idx in range(self.settings.n_players): + role = self.settings.episode_settings.player_settings[idx].role + opponent_role = Roles.P2 if role == Roles.P1 else Roles.P1 + if self.player_actions[idx][1] != 0: + self.health[opponent_role] -= random.choices([self.base_hit, 0], [hit_prob, 1.0 - hit_prob])[0] if not self.perfect: - self.health_p2 -= random.choices([self.base_hit, 0], [1.0 - hit_prob, hit_prob])[0] - if att_p1 != 0: - self.health_p1 -= random.choices([self.base_hit, 0], [hit_prob, 1.0 - hit_prob])[0] - else: - self.health_p1 -= random.choices([self.base_hit, 0], [1.0 - hit_prob, hit_prob])[0] - if att_p1 != 0: - self.health_p2 -= random.choices([self.base_hit, 0], [hit_prob, 1.0 - hit_prob])[0] - if (self.player == "P1P2" and att_p2 == 0) or self.perfect: - self.health_p1 = starting_health_p1 + self.health[role] -= random.choices([self.base_hit, 0], [1.0 - hit_prob, hit_prob])[0] - self.health_p1 = max(self.health_p1, self.game_data["health"][0]) - self.health_p2 = max(self.health_p2, self.game_data["health"][0]) + for role in [Roles.P1, Roles.P2]: + self.health[role] = max(self.health[role], self.game_data["health"][0]) - if (min(self.health_p1, self.health_p2) == self.game_data["health"][0]) or ((self.n_steps % self.steps_per_round) == 0): + role_0 = self.settings.episode_settings.player_settings[0].role + opponent_role_0 = Roles.P2 if role_0 == Roles.P1 else Roles.P1 + + if (min(self.health[Roles.P1], self.health[Roles.P2]) == self.game_data["health"][0]) or (self.timer <= 0): self.round_done_ = True - if self.health_p1 > self.health_p2: - self.health_p2 = self.game_data["health"][0] - if self.player == "P2": - print("Round lost") - self.n_rounds_lost += 1 - else: - print("Round won") - self.n_rounds_won += 1 - - elif self.health_p2 > self.health_p1: - self.health_p1 = self.game_data["health"][0] - if self.player == "P2": - print("Round won") - self.n_rounds_won += 1 - else: - print("Round lost") - self.n_rounds_lost += 1 + if self.health[role_0] > self.health[opponent_role_0]: + self.health[opponent_role_0] = self.game_data["health"][0] + print("Round won") + self.wins[role_0] += 1 + elif self.health[role_0] < self.health[opponent_role_0]: + self.health[role_0] = self.game_data["health"][0] + print("Round lost") + self.wins[opponent_role_0] += 1 else: print("Draw, forcing lost") - self.n_rounds_lost += 1 - if self.player == "P2": - self.health_p2 = self.game_data["health"][0] - else: - self.health_p1 = self.game_data["health"][0] + self.wins[opponent_role_0] += 1 + self.health[role_0] = self.game_data["health"][0] - if self.n_rounds_won == self.game_data["rounds_per_stage"]: + if self.wins[role_0] == self.game_data["rounds_per_stage"]: self.stage_done_ = True - self.n_stages += 1 - self.n_rounds_won = 0 - self.n_rounds_lost = 0 - if self.player == "P1P2": + self.current_stage_number += 1 + self.wins[role_0] = 0 + self.wins[opponent_role_0] = 0 + if self.settings.n_players == 2: self.game_done_ = True self.episode_done_ = True - elif self.player == "P1": - self.char_p2 = self.n_stages else: - self.char_p1 = self.n_stages + self.char[opponent_role_0] = self.current_stage_number + print("Stage done") + print("Moving to stage {} of {}".format(self.current_stage_number, self.game_data["stages_per_game"])) - if self.n_rounds_lost == self.game_data["rounds_per_stage"]: + if self.wins[opponent_role_0] == self.game_data["rounds_per_stage"]: + print("Game done") self.game_done_ = True if self.n_continue >= self.continue_per_episode: + print("Episode done") self.episode_done_ = True else: + print("Continuing game") self.n_continue += 1 - self.n_rounds_won = 0 - self.n_rounds_lost = 0 + self.wins[role_0] = 0 + self.wins[opponent_role_0] = 0 - if self.n_stages == self.game_data["stages_per_game"]: + if self.current_stage_number == self.game_data["stages_per_game"]: + print("Episode done") + print("Game completed!") self.game_done_ = True self.episode_done_ = True self.env_done_ = self.episode_done_ - delta_p1 = starting_health_p1 - self.health_p1 - delta_p2 = starting_health_p2 - self.health_p2 - self.reward = delta_p1 - delta_p2 if self.player == "P2" else delta_p2 - delta_p1 + delta = { + Roles.P1: starting_health[Roles.P1] - self.health[Roles.P1], + Roles.P2: starting_health[Roles.P2] - self.health[Roles.P2] + } + self.reward = delta[opponent_role_0] - delta[role_0] if np.any([self.round_done_, self.stage_done_, self.game_done_]): - - self.n_steps = 0 - - self.side_p1 = 0 - self.side_p2 = 1 - self.health_p1 = self.game_data["health"][1] - self.health_p2 = self.game_data["health"][1] - - # Set delta healths - self.set_perfect_chance() + self.side[Roles.P1] = 0 + self.side[Roles.P2] = 1 + self.health[Roles.P1] = self.game_data["health"][1] + self.health[Roles.P2] = self.game_data["health"][1] + self.timer = self.game_data["ram_states"]["common"]["timer"][2] + + # Set perfect chance + self._set_perfect_chance() else: - self.side_p1 = random.choices([0, 1], [0.3, 0.7])[0] - self.side_p2 = random.choices([(self.side_p1 + 1) % 2, self.side_p1], [0.97, 0.03])[0] + self.side[Roles.P1] = random.choices([0, 1], [0.3, 0.7])[0] + self.side[Roles.P2] = random.choices([(self.side[Roles.P1] + 1) % 2, self.side[Roles.P2]], [0.97, 0.03])[0] - def update_observation(self): + def _update_step_reset_response(self): # Response - observation = model.Observation() - - # Actions - observation.actions.p1.move = self.mov_p1 - observation.actions.p1.attack = self.att_p1 - observation.actions.p2.move = self.mov_p2 - observation.actions.p2.attack = self.att_p2 + response = model.StepResetResponse() # Ram states - self.generate_ram_states() + self._generate_ram_states() for k, v in self.ram_states.items(): - observation.ram_states[k].type = v[0] - observation.ram_states[k].min = v[1] - observation.ram_states[k].max = v[2] - observation.ram_states[k].val = v[3] + for k2, v2 in v.items(): + response.observation.ram_states_categories[k].ram_states[model.RamStates.Value(k2)] = v2[3] # Game state - observation.game_state.round_done = self.round_done_ - observation.game_state.stage_done = self.stage_done_ - observation.game_state.game_done = self.game_done_ - observation.game_state.episode_done = self.episode_done_ - observation.game_state.env_done = self.env_done_ - - # Player - observation.player = self.player + response.info.game_states[model.GameStates.round_done] = self.round_done_ + response.info.game_states[model.GameStates.stage_done] = self.stage_done_ + response.info.game_states[model.GameStates.game_done] = self.game_done_ + response.info.game_states[model.GameStates.episode_done] = self.episode_done_ + response.info.game_states[model.GameStates.env_done] = self.env_done_ # Frame - observation.frame = self.generate_frame() + response.observation.frame = self._generate_frame() # Reward - observation.reward = self.reward + response.reward = self.reward - return observation - - - # Reset the environment [pb low level] - def _mock_reset(self): - - self.reset_state() - - return self.update_observation() - - # Step the environment (1P) [pb low level] - def _mock_step_1p(self, mov_p1, att_p1): - - # Update class state - self.new_game_state(mov_p1, att_p1) - - return self.update_observation() - - # Step the environment (2P) [pb low level] - def _mock_step_2p(self, mov_p1, att_p1, mov_p2, att_p2): - - # Update class state - self.new_game_state(mov_p1, att_p1, mov_p2, att_p2) + return response - return self.update_observation() +def load_mocker(mocker, **kwargs): + diambra_engine_mock = DiambraEngineMock(**kwargs) - # Closing DIAMBRA Arena - def _mock_close(self): - pass + mocker.patch("diambra.arena.engine.interface.DiambraEngine.__init__", diambra_engine_mock.mock__init__) + mocker.patch("diambra.arena.engine.interface.DiambraEngine.env_init", diambra_engine_mock.mock_env_init) + mocker.patch("diambra.arena.engine.interface.DiambraEngine.reset", diambra_engine_mock.mock_reset) + mocker.patch("diambra.arena.engine.interface.DiambraEngine.step", diambra_engine_mock.mock_step) + mocker.patch("diambra.arena.engine.interface.DiambraEngine.close", diambra_engine_mock.mock_close) \ No newline at end of file diff --git a/diambra/arena/utils/gym_utils.py b/diambra/arena/utils/gym_utils.py index 5301ea1b..2a5b2954 100644 --- a/diambra/arena/utils/gym_utils.py +++ b/diambra/arena/utils/gym_utils.py @@ -1,17 +1,13 @@ -import gym +import gymnasium as gym import os -import sys -from gym import spaces -import numpy as np import pickle import bz2 -import cv2 import json from threading import Thread import hashlib # Save compressed pickle files in parallel -class ParallelPickleWriter(Thread): # def class typr thread +class ParallelPickleWriter(Thread): # def class type thread def __init__(self, save_path, to_save): Thread.__init__(self) # thread init class (don't forget this) @@ -128,82 +124,6 @@ def discrete_to_multi_discrete_action(action, n_move_actions): return mov_act, att_act -# Visualize Gym Obs content -def show_gym_obs(observation, char_list, wait_key=1, viz=True): - - print("WARNING: Deprecated, use env.show_obs() instead, will be removed in future versions.") - - if type(observation) == dict: - for k, v in observation.items(): - if k != "frame": - if type(v) == dict: - for k2, v2 in v.items(): - if "ownChar" in k2 or "oppChar" in k2: - print("observation[\"{}\"][\"{}\"]: {}".format(k, k2, char_list[v2])) - else: - print( - "observation[\"{}\"][\"{}\"]: {}".format(k, k2, v2)) - else: - print("observation[\"{}\"]: {}".format(k, v)) - else: - print("observation[\"frame\"].shape:", - observation["frame"].shape) - - if viz: - obs = observation["frame"] / 255.0 - else: - if viz: - obs = observation / 255.0 - - if viz is True and (sys.platform.startswith('linux') is False or 'DISPLAY' in os.environ): - try: - cv2.imshow("[{}] Frame".format(os.getpid()), obs[:, :, ::-1]) # rgb2bgr - cv2.waitKey(wait_key) - except: - pass - -# Visualize Obs content -def show_wrap_obs(observation, n_actions_stack, char_list, wait_key=1, viz=True): - - print("WARNING: Deprecated, use env.show_obs() instead, will be removed in future versions.") - - if type(observation) == dict: - for k, v in observation.items(): - if k != "frame": - if type(v) == dict: - for k2, v2 in v.items(): - if type(v2) == dict: - for k3, v3 in v2.items(): - print("observation[\"{}\"][\"{}\"][\"{}\"] (reshaped for viz):\n{}" - .format(k, k2, k3, np.reshape(v3, [n_actions_stack, -1]))) - elif "ownChar" in k2 or "oppChar" in k2: - print("observation[\"{}\"][\"{}\"]: {} / {}".format(k, k2, v2, - char_list[np.where(v2 == 1)[0][0]])) - else: - print( - "observation[\"{}\"][\"{}\"]: {}".format(k, k2, v2)) - else: - print("observation[\"{}\"]: {}".format(k, v)) - else: - obs = observation["frame"] - print("observation[\"frame\"]: shape {} - min {} - max {}".format(obs.shape, np.amin(obs), np.amax(obs))) - - if viz: - obs = observation["frame"] - else: - if viz: - obs = observation - - if viz is True and (sys.platform.startswith('linux') is False or 'DISPLAY' in os.environ): - try: - norm_factor = 255 if np.amax(obs) > 1.0 else 1.0 - for idx in range(obs.shape[2]): - cv2.imshow("[{}] Frame-{}".format(os.getpid(), idx), obs[:, :, idx] / norm_factor) - - cv2.waitKey(wait_key) - except: - pass - # List all available games def available_games(print_out=True, details=False): base_path = os.path.dirname(os.path.abspath(__file__)) diff --git a/diambra/arena/utils/integratedGames.json b/diambra/arena/utils/integratedGames.json index eeff1f30..689e1d6a 100644 --- a/diambra/arena/utils/integratedGames.json +++ b/diambra/arena/utils/integratedGames.json @@ -11,6 +11,8 @@ "char_list": ["Kasumi", "Zack", "Hayabusa", "Bayman", "Lei-Fang", "Raidou", "Gen-Fu", "Tina", "Bass", "Jann-Lee", "Ayane"], + "char_forbidden_list": [], + "char_homonymy_map": {}, "outfits": [1, 4], "difficulty": [1, 4, 3], "difficulty_to_cluster_map": { @@ -27,19 +29,21 @@ "rounds_per_stage": 2, "stages_per_game": 8, "number_of_chars_per_round": 1, - "n_actions": [9, 4, 8], + "number_of_chars_to_select": 1, + "n_actions": [9, 8, 4], "health": [0, 208], "frame_shape": [480, 512, 3], "ram_states": { - "SideP1": [0, 0, 1], - "SideP2": [0, 0, 1], - "WinsP1": [1, 0, 2], - "WinsP2": [1, 0, 2], - "stage": [1, 1, 8], - "CharP1": [2, 0, 10], - "CharP2": [2, 0, 10], - "HealthP1": [1, 0, 208], - "HealthP2": [1, 0, 208] + "common": { + "stage": ["BOX", 1, 8], + "timer": ["BOX", 0, 40] + }, + "Px": { + "side": ["BINARY", 0, 1], + "wins": ["BOX", 0, 2], + "character": ["DISCRETE", 0, 10], + "health": ["BOX", 0, 208] + } }, "cfg": {"H": "But6", "P": "But1", "K": "But2"} }, @@ -55,6 +59,8 @@ "char_list": ["Alex", "Twelve", "Hugo", "Sean", "Makoto", "Elena", "Ibuki", "Chun-Li", "Dudley", "Necro", "Q", "Oro", "Urien", "Remy", "Ryu", "Gouki", "Yun", "Yang", "Ken", "Gill"], + "char_forbidden_list": ["Gill"], + "char_homonymy_map": {}, "outfits": [1, 7], "difficulty": [1, 8, 6], "difficulty_to_cluster_map": { @@ -75,31 +81,27 @@ "rounds_per_stage": 2, "stages_per_game": 10, "number_of_chars_per_round": 1, - "n_actions": [9, 7, 10], + "number_of_chars_to_select": 1, + "n_actions": [9, 10, 7], "health": [-1, 160], "frame_shape": [224, 384, 3], "ram_states": { - "SideP1": [0, 0, 1], - "SideP2": [0, 0, 1], - "WinsP1": [1, 0, 2], - "WinsP2": [1, 0, 2], - "stage": [1, 1, 10], - "CharP1": [2, 0, 19], - "CharP2": [2, 0, 19], - "StunBarP1": [1, 0, 72], - "StunBarP2": [1, 0, 72], - "StunnedP1": [0, 0, 1], - "StunnedP2": [0, 0, 1], - "SuperBarP1": [1, 0, 128], - "SuperBarP2": [1, 0, 128], - "SuperTypeP1": [2, 0, 2], - "SuperTypeP2": [2, 0, 2], - "SuperCountP1": [1, 0, 3], - "SuperCountP2": [1, 0, 3], - "SuperMaxCountP1": [1, 1, 3], - "SuperMaxCountP2": [1, 1, 3], - "HealthP1": [1, -1, 160], - "HealthP2": [1, -1, 160] + "common": { + "stage": ["BOX", 1, 10], + "timer": ["BOX", 0, 99] + }, + "Px": { + "side": ["BINARY", 0, 1], + "wins": ["BOX", 0, 2], + "character": ["DISCRETE", 0, 19], + "stun_bar": ["BOX", 0, 72], + "stunned": ["BINARY", 0, 1], + "super_bar": ["BOX", 0, 128], + "super_type": ["DISCRETE", 0, 2], + "super_count": ["BOX", 0, 3], + "super_max_count": ["BOX", 1, 3], + "health": ["BOX", -1, 160] + } }, "cfg": {"LP": "But4", "MP": "But1", "HP": "But5", "LK": "But3", "MK": "But2", "HK": "But6"} }, @@ -119,7 +121,20 @@ "Lee", "Wang", "P.Jack", "Devil", "True Ogre", "Ogre", "Roger", "Tetsujin", "Panda", "Tiger", "Angel", "Alex", "Mokujin", "Unknown"], - "outfits": [1, 2], + "char_forbidden_list": ["Unknown"], + "char_homonymy_map": { + "Tetsujin": "Mokujin", + "Mokujin": "Tetsujin", + "Tiger": "Eddy", + "Eddy": "Tiger", + "Angel": "Devil", + "Devil": "Angel", + "Panda": "Kuma", + "Kuma": "Panda", + "Alex": "Roger", + "Roger": "Alex" + }, + "outfits": [1, 5], "difficulty": [1, 9, 7], "difficulty_to_cluster_map": { "1": "Easy", @@ -140,29 +155,26 @@ "rounds_per_stage": 2, "stages_per_game": 8, "number_of_chars_per_round": 2, - "n_actions": [9, 6, 13], + "number_of_chars_to_select": 2, + "n_actions": [9, 13, 6], "health": [0, 182], "frame_shape": [240, 512, 3], "ram_states": { - "SideP1": [0, 0, 1], - "SideP2": [0, 0, 1], - "WinsP1": [1, 0, 2], - "WinsP2": [1, 0, 2], - "stage": [1, 1, 8], - "CharP1": [2, 0, 38], - "CharP2": [2, 0, 38], - "Char1P1": [2, 0, 38], - "Char1P2": [2, 0, 38], - "Char2P1": [2, 0, 38], - "Char2P2": [2, 0, 38], - "Health1P1": [1, 0, 182], - "Health1P2": [1, 0, 182], - "Health2P1": [1, 0, 182], - "Health2P2": [1, 0, 182], - "ActiveCharP1": [0, 0, 1], - "ActiveCharP2": [0, 0, 1], - "BarStatusP1": [2, 0, 4], - "BarStatusP2": [2, 0, 4] + "common": { + "stage": ["BOX", 1, 8], + "timer": ["BOX", 0, 60] + }, + "Px": { + "side": ["BINARY", 0, 1], + "wins": ["BOX", 0, 2], + "character": ["DISCRETE", 0, 38], + "character_1": ["DISCRETE", 0, 38], + "character_2": ["DISCRETE", 0, 38], + "health_1": ["BOX", 0, 227], + "health_2": ["BOX", 0, 227], + "active_character": ["BINARY", 0, 1], + "bar_status": ["DISCRETE", 0, 4] + } }, "cfg": {"LP": "But4", "RP": "But1", "LK": "But3", "RK": "But2", "TAG": "But6"} }, @@ -180,6 +192,8 @@ "Shang Tsung", "Nightwolf", "Sub-Zero-2", "Cyrax", "Liu Kang", "Jade", "Sub-Zero", "Kung Lao", "Smoke", "Skorpion", "Human Smoke", "Noob Saibot", "Motaro", "Shao Kahn"], + "char_forbidden_list": ["Noob Saibot", "Human Smoke", "Motaro", "Shao Kahn"], + "char_homonymy_map": {}, "outfits": [1, 1], "difficulty": [1, 5, 4], "difficulty_to_cluster_map": { @@ -197,21 +211,23 @@ "rounds_per_stage": 2, "stages_per_game": 11, "number_of_chars_per_round": 1, + "number_of_chars_to_select": 1, "n_actions": [9, 7, 7], "health": [0, 166], "frame_shape": [254, 500, 3], "ram_states": { - "SideP1": [0, 0, 1], - "SideP2": [0, 0, 1], - "WinsP1": [1, 0, 2], - "WinsP2": [1, 0, 2], - "stage": [1, 1, 11], - "CharP1": [2, 0, 25], - "CharP2": [2, 0, 25], - "AggressorBarP1": [1, 0, 48], - "AggressorBarP2": [1, 0, 48], - "HealthP1": [1, 0, 166], - "HealthP2": [1, 0, 166] + "common": { + "stage": ["BOX", 1, 11], + "timer": ["BOX", 0, 100] + }, + "Px": { + "side": ["BINARY", 0, 1], + "wins": ["BOX", 0, 2], + "character": ["DISCRETE", 0, 25], + "aggressor_bar": ["BOX", 0, 48], + "health": ["BOX", 0, 166] + } + }, "cfg": {"HP": "But1", "HK": "But2", "LK": "But3", "LP": "But4", "RUN": "But5", "BLK": "But6"} }, @@ -229,6 +245,8 @@ "Ukyo", "Yoshitora", "Gaoh", "Haohmaru", "Genjuro", "Shizumaru", "Kazuki", "Tamtam", "Rasetsumaru", "Rimururu", "Mina", "Zankuro", "Nakoruru", "Rera", "Yunfei", "Basara", "Mizuki"], + "char_forbidden_list": [], + "char_homonymy_map": {}, "outfits": [1, 4], "difficulty": [1, 8, 6], "difficulty_to_cluster_map": { @@ -249,33 +267,28 @@ "rounds_per_stage": 2, "stages_per_game": 7, "number_of_chars_per_round": 1, - "n_actions": [9, 5, 11], + "number_of_chars_to_select": 1, + "n_actions": [9, 11, 5], "health": [0, 125], "frame_shape": [224, 320, 3], "ram_states": { - "SideP1": [0, 0, 1], - "SideP2": [0, 0, 1], - "WinsP1": [1, 0, 3], - "WinsP2": [1, 0, 3], - "stage": [1, 1, 7], - "CharP1": [2, 0, 27], - "CharP2": [2, 0, 27], - "RageOnP1": [0, 0, 1], - "RageOnP2": [0, 0, 1], - "WeaponLostP1": [0, 0, 1], - "WeaponLostP2": [0, 0, 1], - "WeaponFightP1": [0, 0, 1], - "WeaponFightP2": [0, 0, 1], - "RageUsedP1": [0, 0, 1], - "RageUsedP2": [0, 0, 1], - "RageBarP1": [1, 0, 164096], - "RageBarP2": [1, 0, 164096], - "WeaponBarP1": [1, 0, 120], - "WeaponBarP2": [1, 0, 120], - "PowerBarP1": [1, 0, 64], - "PowerBarP2": [1, 0, 64], - "HealthP1": [1, 0, 125], - "HealthP2": [1, 0, 125] + "common": { + "stage": ["BOX", 1, 7], + "timer": ["BOX", 0, 60] + }, + "Px": { + "side": ["BINARY", 0, 1], + "wins": ["BOX", 0, 3], + "character": ["DISCRETE", 0, 27], + "rage_on": ["BINARY", 0, 1], + "weapon_lost": ["BINARY", 0, 1], + "weapon_fight": ["BINARY", 0, 1], + "rage_used": ["BINARY", 0, 1], + "rage_bar": ["BOX", 0, 100], + "weapon_bar": ["BOX", 0, 120], + "power_bar": ["BOX", 0, 64], + "health": ["BOX", 0, 125] + } }, "cfg": {"WS": "But1", "MS": "But2", "K": "But3", "M": "But4"} }, @@ -293,6 +306,8 @@ "Kim","Chang","Choi","Yashiro","Shermie","Chris","Yamazaki","Blue","Billy", "Iori","Mature","Vice","Heidern","Takuma","Saisyu","Heavy-D!","Lucky","Brian", "Eiji","Kasumi","Shingo","Rugal","Geese","Krauser","Mr.Big","Goenitz","Orochi"], + "char_forbidden_list": ["Goenitz","Orochi"], + "char_homonymy_map": {}, "outfits": [1, 4], "difficulty": [1, 8, 6], "difficulty_to_cluster_map": { @@ -313,31 +328,27 @@ "rounds_per_stage": 3, "stages_per_game": 7, "number_of_chars_per_round": 1, - "n_actions": [9, 5, 9], + "number_of_chars_to_select": 3, + "n_actions": [9, 9, 5], "health": [-1, 119], "frame_shape": [240, 320, 3], "ram_states": { - "SideP1": [0, 0, 1], - "SideP2": [0, 0, 1], - "stage": [1, 1, 7], - "CharP1": [2, 0, 44], - "CharP2": [2, 0, 44], - "Char1P1": [2, 0, 44], - "Char1P2": [2, 0, 44], - "Char2P1": [2, 0, 44], - "Char2P2": [2, 0, 44], - "Char3P1": [2, 0, 44], - "Char3P2": [2, 0, 44], - "HealthP1": [1, -1, 119], - "HealthP2": [1, -1, 119], - "PowerBarP1": [1, 0, 100], - "PowerBarP2": [1, 0, 100], - "SpecialAttacksP1": [1, 0, 5], - "SpecialAttacksP2": [1, 0, 5], - "WinsP1": [1, 0, 3], - "WinsP2": [1, 0, 3], - "BarTypeP1": [2, 0, 7], - "BarTypeP2": [2, 0, 7] + "common": { + "stage": ["BOX", 1, 7], + "timer": ["BOX", 0, 60] + }, + "Px": { + "side": ["BINARY", 0, 1], + "character": ["DISCRETE", 0, 44], + "character_1": ["DISCRETE", 0, 44], + "character_2": ["DISCRETE", 0, 44], + "character_3": ["DISCRETE", 0, 44], + "health": ["BOX", -1, 119], + "power_bar": ["BOX", 0, 100], + "special_attacks": ["BOX", 0, 5], + "wins": ["BOX", 0, 3], + "bar_type": ["DISCRETE", 0, 7] + } }, "cfg": {"WP": "But1", "WK": "But2", "SP": "But3", "SK": "But4"} } diff --git a/diambra/arena/utils/splash_screen.py b/diambra/arena/utils/splash_screen.py index 68d74a69..a72492d2 100644 --- a/diambra/arena/utils/splash_screen.py +++ b/diambra/arena/utils/splash_screen.py @@ -4,7 +4,6 @@ gif_file_path = os.path.join(os.path.dirname(__file__), ".splash_screen.gif") - def get_monitor_from_coord(x, y): monitors = screeninfo.get_monitors() @@ -13,12 +12,8 @@ def get_monitor_from_coord(x, y): return m return monitors[0] -# Class to manage gampad - - class SplashScreen(): def __init__(self, time_interval=100, timeout=5000): - self.timeout = timeout self.time_interval = time_interval # self.labels = (t*"\u25AE" for t in range(int((timeout-750)/time_interval))) @@ -30,14 +25,13 @@ def __init__(self, time_interval=100, timeout=5000): self.window.wm_attributes("-topmost", True) # Get the screen which contains top - current_screen = get_monitor_from_coord( - self.window.winfo_x(), self.window.winfo_y()) + current_screen = get_monitor_from_coord(self.window.winfo_x(), self.window.winfo_y()) # Get the monitor's size width = current_screen.width height = current_screen.height - image = tk.PhotoImage(file=gif_file_path) + image = tk.PhotoImage(master=self.window, file=gif_file_path) hw_dim = [image.height(), image.width()] self.window.geometry( '%dx%d+%d+%d' % (hw_dim[1], hw_dim[0], diff --git a/diambra/arena/wrappers/arena_wrappers.py b/diambra/arena/wrappers/arena_wrappers.py index dfe6f311..ee6bc264 100644 --- a/diambra/arena/wrappers/arena_wrappers.py +++ b/diambra/arena/wrappers/arena_wrappers.py @@ -1,10 +1,48 @@ import random import numpy as np -import gym +import gymnasium as gym import logging -from diambra.arena.env_settings import WrappersSettings +from diambra.arena import SpaceTypes, WrappersSettings +from diambra.arena.wrappers.observation import WarpFrame, GrayscaleFrame, FrameStack, ActionsStack, \ + NormalizeObservation, FlattenFilterDictObs, \ + AddLastActionToObservation, RoleRelativeObservation -class NoopResetEnv(gym.Wrapper): +# Remove attack buttons combinations +class NoAttackButtonsCombinations(gym.Wrapper): + def __init__(self, env): + """ + Limit attack actions to single buttons removing attack buttons combinations + :param env: (Gym Environment) the environment to wrap + """ + gym.Wrapper.__init__(self, env) + # N actions + self.n_actions = [self.unwrapped.env_info.available_actions.n_moves, self.unwrapped.env_info.available_actions.n_attacks_no_comb] + if self.unwrapped.env_settings.n_players == 1: + if self.unwrapped.env_settings.action_space == SpaceTypes.MULTI_DISCRETE: + self.action_space = gym.spaces.MultiDiscrete(self.n_actions) + elif self.unwrapped.env_settings.action_space == SpaceTypes.DISCRETE: + self.action_space = gym.spaces.Discrete(self.n_actions[0] + self.n_actions[1] - 1) + else: + raise Exception("Action space not recognized in \"NoAttackButtonsCombinations\" wrapper") + self.unwrapped.logger.debug("Using {} action space without attack buttons combinations".format(SpaceTypes.Name(self.unwrapped.env_settings.action_space))) + else: + self.unwrapped.logger.warning("Warning: \"NoAttackButtonsCombinations\" is by default applied on all agents actions space") + for idx in range(self.unwrapped.env_settings.n_players): + if self.unwrapped.env_settings.action_space[idx] == SpaceTypes.MULTI_DISCRETE: + self.action_space["agent_{}".format(idx)] = gym.spaces.MultiDiscrete(self.n_actions) + elif self.unwrapped.env_settings.action_space[idx] == SpaceTypes.DISCRETE: + self.action_space["agent_{}".format(idx)] = gym.spaces.Discrete(self.n_actions[0] + self.n_actions[1] - 1) + else: + raise Exception("Action space not recognized in \"NoAttackButtonsCombinations\" wrapper") + self.unwrapped.logger.debug("Using {} action space for agent_{} without attack buttons combinations".format(SpaceTypes.Name(self.unwrapped.env_settings.action_space[idx]), idx)) + + def reset(self, **kwargs): + return self.env.reset(**kwargs) + + def step(self, action): + return self.env.step(action) + +class NoopReset(gym.Wrapper): def __init__(self, env, no_op_max=6): """ Sample initial states by taking random number of no-ops on reset. @@ -24,11 +62,8 @@ def reset(self, **kwargs): no_ops = random.randint(1, self.no_op_max + 1) assert no_ops > 0 obs = None - no_op_action = [0, 0, 0, 0] - if isinstance(self.action_space, gym.spaces.Discrete): - no_op_action = 0 for _ in range(no_ops): - obs, _, done, _ = self.env.step(no_op_action) + obs, _, done, _ = self.env.step(self.unwrapped.get_no_op_action()) if done: obs = self.env.reset(**kwargs) return obs @@ -36,8 +71,7 @@ def reset(self, **kwargs): def step(self, action): return self.env.step(action) - -class StickyActionsEnv(gym.Wrapper): +class StickyActions(gym.Wrapper): def __init__(self, env, sticky_actions): """ Apply sticky actions @@ -47,18 +81,12 @@ def __init__(self, env, sticky_actions): """ gym.Wrapper.__init__(self, env) self.sticky_actions = sticky_actions - - assert self.env.env_settings.step_ratio == 1, "sticky_actions can "\ - "be activated only "\ - "when stepRatio is "\ - "set equal to 1" + assert self.unwrapped.env_settings.step_ratio == 1, "StickyActions wrapper can be activated only "\ + "when step_ratio is set equal to 1" def step(self, action): - rew = 0.0 - for _ in range(self.sticky_actions): - obs, rew_step, done, info = self.env.step(action) rew += rew_step if info["round_done"] is True: @@ -66,8 +94,7 @@ def step(self, action): return obs, rew, done, info - -class ClipRewardEnv(gym.RewardWrapper): +class ClipReward(gym.RewardWrapper): def __init__(self, env): """ clips the reward to {+1, 0, -1} by its sign. @@ -82,114 +109,101 @@ def reward(self, reward): """ return np.sign(reward) - -class NormalizeRewardEnv(gym.RewardWrapper): +class NormalizeReward(gym.RewardWrapper): def __init__(self, env, reward_normalization_factor): """ Normalize the reward dividing it by the product of rewardNormalizationFactor multiplied by - the maximum character health variadtion (max - min). + the maximum character health variation (max - min). :param env: (Gym Environment) the environment :param rewardNormalizationFactor: multiplication factor """ gym.RewardWrapper.__init__(self, env) - self.env.reward_normalization_value = reward_normalization_factor * self.env.max_delta_health + self.unwrapped.reward_normalization_value = reward_normalization_factor * self.unwrapped.max_delta_health def reward(self, reward): """ - Nomralize reward dividing by reward normalization factor*max_delta_health + Normalize reward dividing by reward normalization factor*max_delta_health :param reward: (float) """ - return float(reward) / float(self.env.reward_normalization_value) + return float(reward) / float(self.unwrapped.reward_normalization_value) # Environment Wrapping (rewards normalization, resizing, grayscaling, etc) -def env_wrapping(env, wrappers_settings: WrappersSettings, hardcore: bool=False): +def env_wrapping(env, wrappers_settings: WrappersSettings): """ Typical standard environment wrappers :param env: (Gym Environment) the diambra environment - :param no_op_max: (int) wrap the environment to perform - no_op_max no action steps at reset - :param clipRewards: (bool) wrap the reward clipping wrapper - :param rewardNormalization: (bool) if to activate reward noramlization - :param rewardNormalizationFactor: (double) noramlization factor - for reward normalization wrapper - :param frameStack: (int) wrap the frame stacking wrapper - using #frameStack frames - :param dilation (frame stacking): (int) stack one frame every - #dilation frames, useful to assure - action every step considering - a dilated subset of previous frames - :param actionsStack: (int) wrap the frame stacking wrapper - using #frameStack frames - :param scale: (bool) wrap the scaling observation wrapper - :param scaleMod: (int) them scaling method: 0->[0,1] 1->[-1,1] + :param wrappers_settings: (WrappersSettings) settings for the wrappers :return: (Gym Environment) the wrapped diambra environment """ logger = logging.getLogger(__name__) + ### Generic wrappers(s) if wrappers_settings.no_op_max > 0: - env = NoopResetEnv(env, no_op_max=wrappers_settings.no_op_max) - - if wrappers_settings.sticky_actions > 1: - env = StickyActionsEnv(env, sticky_actions=wrappers_settings.sticky_actions) - - if hardcore is True: - from diambra.arena.wrappers.obs_wrapper_hardcore import WarpFrame,\ - WarpFrame3C, FrameStack, FrameStackDilated,\ - ScaledFloatObsNeg, ScaledFloatObs - else: - from diambra.arena.wrappers.obs_wrapper import WarpFrame, \ - WarpFrame3C, FrameStack, FrameStackDilated,\ - ActionsStack, ScaledFloatObsNeg, ScaledFloatObs, FlattenFilterDictObs - - if wrappers_settings.hwc_obs_resize[2] == 1: - # Resizing observation from H x W x 3 to - # hwObsResize[0] x hwObsResize[1] x 1 - env = WarpFrame(env, wrappers_settings.hwc_obs_resize) - elif wrappers_settings.hwc_obs_resize[2] == 3: - # Resizing observation from H x W x 3 to - # hwObsResize[0] x hwObsResize[1] x hwObsResize[2] - env = WarpFrame3C(env, wrappers_settings.hwc_obs_resize) - - # Normalize rewards - if wrappers_settings.reward_normalization is True: - env = NormalizeRewardEnv(env, wrappers_settings.reward_normalization_factor) - - # Clip rewards using sign function - if wrappers_settings.clip_rewards is True: - env = ClipRewardEnv(env) + env = NoopReset(env, no_op_max=wrappers_settings.no_op_max) - # Stack #frameStack frames together - if wrappers_settings.frame_stack > 1: - if wrappers_settings.dilation == 1: - env = FrameStack(env, wrappers_settings.frame_stack) + if wrappers_settings.repeat_action > 1: + env = StickyActions(env, sticky_actions=wrappers_settings.repeat_action) + + ### Reward wrappers(s) + if wrappers_settings.normalize_reward is True: + env = NormalizeReward(env, wrappers_settings.normalization_factor) + + if wrappers_settings.clip_reward is True: + env = ClipReward(env) + + ### Action space wrapper(s) + if wrappers_settings.no_attack_buttons_combinations is True: + env = NoAttackButtonsCombinations(env) + + ### Observation space wrappers(s) + if wrappers_settings.frame_shape[2] == 1: + if env.observation_space["frame"].shape[2] == 1: + env.logger.warning("Warning: skipping grayscaling as the frame is already single channel.") else: - logger.debug("Using frame stacking with dilation = {}".format(wrappers_settings.dilation)) - env = FrameStackDilated(env, wrappers_settings.frame_stack, wrappers_settings.dilation) + # Greyscaling frame to h x w x 1 + env = GrayscaleFrame(env) + + if wrappers_settings.frame_shape[0] != 0 and wrappers_settings.frame_shape[1] != 0: + # Resizing observation from H x W x C to + # frame_shape[0] x frame_shape[1] x C + # Check if frame shape is bigger than native shape + native_frame_size = env.observation_space["frame"].shape + if wrappers_settings.frame_shape[0] > native_frame_size[0] or wrappers_settings.frame_shape[1] > native_frame_size[1]: + warning_message = "Warning: \"frame_shape\" greater than game native frame shape.\n" + warning_message += " \"native frame shape\" = [" + str(native_frame_size[0]) + warning_message += " X " + str(native_frame_size[1]) + "]\n" + warning_message += " \"frame_shape\" = [" + str(wrappers_settings.frame_shape[0]) + warning_message += " X " + str(wrappers_settings.frame_shape[1]) + "]" + env.logger.warning(warning_message) + + env = WarpFrame(env, wrappers_settings.frame_shape[:2]) - # Stack #actionsStack actions together - if wrappers_settings.actions_stack > 1 and not hardcore: - env = ActionsStack(env, wrappers_settings.actions_stack) + # Stack #frameStack frames together + if wrappers_settings.stack_frames > 1: + env = FrameStack(env, wrappers_settings.stack_frames, wrappers_settings.dilation) + + # Add last action to observation + if wrappers_settings.add_last_action: + env = AddLastActionToObservation(env) - # Scales observations normalizing them + # Stack #actionsStack actions together + if wrappers_settings.stack_actions > 1: + env = ActionsStack(env, wrappers_settings.stack_actions) + + # Scales observations normalizing them between 0.0 and 1.0 if wrappers_settings.scale is True: - if wrappers_settings.scale_mod == 0: - # Between 0.0 and 1.0 - if hardcore is False: - env = ScaledFloatObs(env, wrappers_settings.exclude_image_scaling, wrappers_settings.process_discrete_binary) - else: - env = ScaledFloatObs(env) - elif wrappers_settings.scale_mod == -1: - # Between -1.0 and 1.0 - raise RuntimeError("Scaling between -1.0 and 1.0 currently not implemented") - env = ScaledFloatObsNeg(env) - else: - raise ValueError("Scale mod must be either 0 or -1") + env = NormalizeObservation(env, wrappers_settings.exclude_image_scaling, wrappers_settings.process_discrete_binary) + + # Convert base observation to role-relative observation + if wrappers_settings.role_relative is True: + env = RoleRelativeObservation(env) if wrappers_settings.flatten is True: - if hardcore is True: - logger.warning("Dictionary observation flattening is valid only for not hardcore mode, skipping it.") - else: - env = FlattenFilterDictObs(env, wrappers_settings.filter_keys) + env = FlattenFilterDictObs(env, wrappers_settings.filter_keys) + + # Apply all additional wrappers in sequence: + for wrapper in wrappers_settings.wrappers: + env = wrapper[0](env, **wrapper[1]) return env diff --git a/diambra/arena/wrappers/episode_recording.py b/diambra/arena/wrappers/episode_recording.py new file mode 100644 index 00000000..ed97b3e5 --- /dev/null +++ b/diambra/arena/wrappers/episode_recording.py @@ -0,0 +1,75 @@ +import os +import numpy as np +import datetime +import gymnasium as gym +from diambra.arena.utils.gym_utils import ParallelPickleWriter +from diambra.arena.env_settings import RecordingSettings +import copy +import cv2 + + +# Trajectory recorder wrapper +class EpisodeRecorder(gym.Wrapper): + def __init__(self, env, recording_settings: RecordingSettings): + """ + Record trajectories to use them for imitation learning + :param env: (Gym Environment) the environment to wrap + :param file_path: (str) file path specifying where to + store the trajectory file + """ + gym.Wrapper.__init__(self, env) + self.dataset_path = recording_settings.dataset_path + self.username = recording_settings.username + + self.compression_parameters = [int(cv2.IMWRITE_JPEG_QUALITY), 80] + + self.env.logger.info("Recording trajectories in \"{}\"".format(self.dataset_path)) + os.makedirs(self.dataset_path, exist_ok=True) + + def reset(self, **kwargs): + """ + Reset the environment and add requested info to the observation + :return: observation + """ + self.episode_data = [] + + obs, info = self.env.reset(**kwargs) + self._last_obs = copy.deepcopy(obs) + _, self._last_obs["frame"] = cv2.imencode('.jpg', obs["frame"], self.compression_parameters) + + return obs, info + + def step(self, action): + """ + Step the environment with the given action + and add requested info to the observation + :param action: ([int] or [float]) the action + :return: new observation, reward, done, information + """ + obs, reward, terminated, truncated, info = self.env.step(action) + + self.episode_data.append({ + "obs": self._last_obs, + "action": action, + "reward": reward, + "terminated": terminated, + "truncated": truncated, + "info": info}) + self._last_obs = copy.deepcopy(obs) + _, self._last_obs["frame"] = cv2.imencode('.jpg', obs["frame"], self.compression_parameters) + + if terminated or truncated: + to_save = {} + to_save["episode_summary"] = { + "steps": len(self.episode_data), + "username": self.username, + "env_settings": self.env.env_settings.pb_model, + } + to_save["data"] = self.episode_data + + # Save recording file + save_path = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + ".diambra" + pickle_writer = ParallelPickleWriter(os.path.join(self.dataset_path, save_path), to_save) + pickle_writer.start() + + return obs, reward, terminated, truncated, info diff --git a/diambra/arena/wrappers/obs_wrapper.py b/diambra/arena/wrappers/obs_wrapper.py deleted file mode 100644 index 03022574..00000000 --- a/diambra/arena/wrappers/obs_wrapper.py +++ /dev/null @@ -1,363 +0,0 @@ -from gym import spaces -import gym -from copy import deepcopy -import numpy as np -from collections import deque -from collections.abc import Mapping -import cv2 # pytype:disable=import-error -cv2.ocl.setUseOpenCL(False) - -# Env Wrappers classes -class WarpFrame(gym.ObservationWrapper): - def __init__(self, env, hw_obs_resize=[84, 84]): - """ - Warp frames to 84x84 as done in the Nature paper and later work. - :param env: (Gym Environment) the environment - """ - env.logger.warning("Warning: for speedup, avoid frame warping wrappers, use environment's " + - "native frame wrapping through \"frame_shape\" setting (see documentation for details)") - gym.ObservationWrapper.__init__(self, env) - self.width = hw_obs_resize[1] - self.height = hw_obs_resize[0] - self.observation_space.spaces["frame"] = spaces.Box(low=0, high=255, - shape=(self.height, self.width, 1), - dtype=self.observation_space["frame"].dtype) - - def observation(self, obs): - """ - returns the current observation from a obs - :param obs: environment obs - :return: the observation - """ - obs["frame"] = cv2.cvtColor(obs["frame"], cv2.COLOR_RGB2GRAY) - obs["frame"] = cv2.resize(obs["frame"], (self.width, self.height), - interpolation=cv2.INTER_LINEAR)[:, :, None] - return obs - -class WarpFrame3C(gym.ObservationWrapper): - def __init__(self, env, hw_obs_resize=[224, 224]): - """ - Warp frames to 84x84 as done in the Nature paper and later work. - :param env: (Gym Environment) the environment - """ - env.logger.warning("Warning: for speedup, avoid frame warping wrappers, use environment's " + - "native frame wrapping through \"frame_shape\" setting (see documentation for details)") - gym.ObservationWrapper.__init__(self, env) - self.width = hw_obs_resize[1] - self.height = hw_obs_resize[0] - self.observation_space.spaces["frame"] = spaces.Box(low=0, high=255, - shape=(self.height, self.width, 3), - dtype=self.observation_space["frame"].dtype) - - def observation(self, obs): - """ - returns the current observation from a obs - :param obs: environment obs - :return: the observation - """ - obs["frame"] = cv2.resize(obs["frame"], (self.width, self.height), - interpolation=cv2.INTER_LINEAR)[:, :, None] - return obs - - -class FrameStack(gym.Wrapper): - def __init__(self, env, n_frames): - """Stack n_frames last frames. - - :param env: (Gym Environment) the environment - :param n_frames: (int) the number of frames to stack - """ - gym.Wrapper.__init__(self, env) - self.n_frames = n_frames - self.frames = deque([], maxlen=n_frames) - shp = self.observation_space["frame"].shape - self.observation_space.spaces["frame"] = spaces.Box(low=0, high=255, - shape=(shp[0], shp[1], shp[2] * n_frames), - dtype=self.observation_space["frame"].dtype) - - def reset(self, **kwargs): - obs = self.env.reset(**kwargs) - # Fill the stack upon reset to avoid black frames - for _ in range(self.n_frames): - self.frames.append(obs["frame"]) - - obs["frame"] = self.get_ob() - return obs - - def step(self, action): - obs, reward, done, info = self.env.step(action) - self.frames.append(obs["frame"]) - - # Add last obs n_frames - 1 times in case of - # new round / stage / continueGame - if ((info["round_done"] or info["stage_done"] or info["game_done"]) and not done): - for _ in range(self.n_frames - 1): - self.frames.append(obs["frame"]) - - obs["frame"] = self.get_ob() - return obs, reward, done, info - - def get_ob(self): - assert len(self.frames) == self.n_frames - return np.concatenate(self.frames, axis=2) - - -class FrameStackDilated(gym.Wrapper): - def __init__(self, env, n_frames, dilation): - """Stack n_frames last frames with dilation factor. - :param env: (Gym Environment) the environment - :param n_frames: (int) the number of frames to stack - :param dilation: (int) the dilation factor - """ - gym.Wrapper.__init__(self, env) - self.n_frames = n_frames - self.dilation = dilation - # Keeping all n_frames*dilation in memory, - # then extract the subset given by the dilation factor - self.frames = deque([], maxlen=n_frames * dilation) - shp = self.observation_space["frame"].shape - self.observation_space.spaces["frame"] = spaces.Box(low=0, high=255, - shape=(shp[0], shp[1], shp[2] * n_frames), - dtype=self.observation_space["frame"].dtype) - - def reset(self, **kwargs): - obs = self.env.reset(**kwargs) - for _ in range(self.n_frames * self.dilation): - self.frames.append(obs["frame"]) - obs["frame"] = self.get_ob() - return obs - - def step(self, action): - obs, reward, done, info = self.env.step(action) - self.frames.append(obs["frame"]) - - # Add last obs n_frames - 1 times in case of - # new round / stage / continueGame - if ((info["round_done"] or info["stage_done"] or info["game_done"]) and not done): - for _ in range(self.n_frames * self.dilation - 1): - self.frames.append(obs["frame"]) - - obs["frame"] = self.get_ob() - return obs, reward, done, info - - def get_ob(self): - frames_subset = list(self.frames)[self.dilation - 1::self.dilation] - assert len(frames_subset) == self.n_frames - return np.concatenate(frames_subset, axis=2) - - -class ActionsStack(gym.Wrapper): - def __init__(self, env, n_actions_stack): - """Stack n_actions_stack last actions. - :param env: (Gym Environment) the environment - :param n_actions_stack: (int) the number of actions to stack - """ - gym.Wrapper.__init__(self, env) - self.n_actions_stack = n_actions_stack - self.n_players = 1 if self.env.env_settings.player != "P1P2" else 2 - self.move_action_stack = [] - self.attack_action_stack = [] - for iplayer in range(self.n_players): - self.move_action_stack.append(deque([0 for i in range(n_actions_stack)], maxlen=n_actions_stack)) - self.attack_action_stack.append(deque([0 for i in range(n_actions_stack)], maxlen=n_actions_stack)) - - if self.n_players == 1: - self.observation_space["P1"]["actions"]["move"] = spaces.MultiDiscrete([self.n_actions[0]] * n_actions_stack) - self.observation_space["P1"]["actions"]["attack"] = spaces.MultiDiscrete([self.n_actions[1]] * n_actions_stack) - else: - for iplayer in range(self.n_players): - self.observation_space["P{}".format(iplayer + 1)]["actions"]["move"] =\ - spaces.MultiDiscrete([self.n_actions[iplayer][0]] * n_actions_stack) - self.observation_space["P{}".format(iplayer + 1)]["actions"]["attack"] =\ - spaces.MultiDiscrete([self.n_actions[iplayer][1]] * n_actions_stack) - - def fill_stack(self, value=0): - # Fill the actions stack with no action after reset - for _ in range(self.n_actions_stack): - for iplayer in range(self.n_players): - self.move_action_stack[iplayer].append(value) - self.attack_action_stack[iplayer].append(value) - - def reset(self, **kwargs): - obs = self.env.reset(**kwargs) - self.fill_stack() - - for iplayer in range(self.n_players): - obs["P{}".format( - iplayer + 1)]["actions"]["move"] = np.array(self.move_action_stack[iplayer]) - obs["P{}".format( - iplayer + 1)]["actions"]["attack"] = np.array(self.attack_action_stack[iplayer]) - return obs - - def step(self, action): - obs, reward, done, info = self.env.step(action) - for iplayer in range(self.n_players): - self.move_action_stack[iplayer].append( - obs["P{}".format(iplayer + 1)]["actions"]["move"]) - self.attack_action_stack[iplayer].append( - obs["P{}".format(iplayer + 1)]["actions"]["attack"]) - - # Add noAction for n_actions_stack - 1 times - # in case of new round / stage / continueGame - if ((info["round_done"] or info["stage_done"] or info["game_done"]) and not done): - self.fill_stack() - - for iplayer in range(self.n_players): - obs["P{}".format( - iplayer + 1)]["actions"]["move"] = np.array(self.move_action_stack[iplayer]) - obs["P{}".format( - iplayer + 1)]["actions"]["attack"] = np.array(self.attack_action_stack[iplayer]) - return obs, reward, done, info - -class ScaledFloatObsNeg(gym.ObservationWrapper): - def __init__(self, env): - gym.ObservationWrapper.__init__(self, env) - self.observation_space.spaces["frame"] = spaces.Box(low=-1.0, high=1.0, - shape=self.observation_space["frame"].shape, - dtype=np.float32) - - def observation(self, observation): - # careful! This undoes the memory optimization, use - # with smaller replay buffers only. - observation["frame"] = observation["frame"] / 127.5 - 1.0 - return observation - - -class ScaledFloatObs(gym.ObservationWrapper): - def __init__(self, env, exclude_image_scaling=False, process_discrete_binary=False): - gym.ObservationWrapper.__init__(self, env) - - self.exclude_image_scaling = exclude_image_scaling - self.process_discrete_binary = process_discrete_binary - - self.original_observation_space = deepcopy(self.observation_space) - self.scaled_float_obs_space_func(self.observation_space) - - # Recursive function to modify obs space dict - # FIXME: this can probably be dropped with gym >= 0.21 and only use the next one, here for SB2 compatibility - def scaled_float_obs_space_func(self, obs_dict): - # Updating observation space dict - for k, v in obs_dict.spaces.items(): - - if isinstance(v, spaces.dict.Dict): - self.scaled_float_obs_space_func(v) - else: - if isinstance(v, spaces.MultiDiscrete): - # One hot encoding x nStack - n_val = v.nvec.shape[0] - max_val = v.nvec[0] - obs_dict.spaces[k] = spaces.MultiBinary(n_val * max_val) - elif isinstance(v, spaces.Discrete) and (v.n > 2 or self.process_discrete_binary is True): - # One hot encoding - obs_dict.spaces[k] = spaces.MultiBinary(v.n) - elif isinstance(v, spaces.Box) and (self.exclude_image_scaling is False or len(v.shape) < 3): - obs_dict.spaces[k] = spaces.Box(low=0.0, high=1.0, shape=v.shape, dtype=np.float32) - - # Recursive function to modify obs dict - def scaled_float_obs_func(self, observation, observation_space): - - # Process all observations - for k, v in observation.items(): - - if isinstance(v, dict): - self.scaled_float_obs_func(v, observation_space.spaces[k]) - else: - v_space = observation_space.spaces[k] - if isinstance(v_space, spaces.MultiDiscrete): - n_act = observation_space.spaces[k].nvec[0] - buf_len = observation_space.spaces[k].nvec.shape[0] - actions_vector = np.zeros((buf_len * n_act), dtype=np.uint8) - for iact in range(buf_len): - actions_vector[iact * n_act + observation[k][iact]] = 1 - observation[k] = actions_vector - elif isinstance(v_space, spaces.Discrete) and (v_space.n > 2 or self.process_discrete_binary is True): - var_vector = np.zeros((observation_space.spaces[k].n), dtype=np.uint8) - var_vector[observation[k]] = 1 - observation[k] = var_vector - elif isinstance(v_space, spaces.Box) and (self.exclude_image_scaling is False or len(v_space.shape) < 3): - high_val = np.max(v_space.high) - low_val = np.min(v_space.low) - observation[k] = np.array((observation[k] - low_val) / (high_val - low_val), dtype=np.float32) - - return observation - - def observation(self, observation): - - return self.scaled_float_obs_func(observation, self.original_observation_space) - -def flatten_filter_obs_space_func(input_dictionary, filter_keys): - _FLAG_FIRST = object() - flattened_dict = {} - - def dummy_check(new_key): - return True - - def check_filter(new_key): - return new_key in filter_keys - - def visit(subdict, flattened_dict, partial_key, check_method): - for k, v in subdict.spaces.items(): - new_key = k if partial_key == _FLAG_FIRST else partial_key + "_" + k - if isinstance(v, Mapping) or isinstance(v, spaces.Dict): - visit(v, flattened_dict, new_key, check_method) - else: - if check_method(new_key): - flattened_dict[new_key] = v - - if filter_keys is not None: - visit(input_dictionary, flattened_dict, _FLAG_FIRST, check_filter) - else: - visit(input_dictionary, flattened_dict, _FLAG_FIRST, dummy_check) - - return flattened_dict - -def flatten_filter_obs_func(input_dictionary, filter_keys): - _FLAG_FIRST = object() - flattened_dict = {} - - def dummy_check(new_key): - return True - - def check_filter(new_key): - return new_key in filter_keys - - def visit(subdict, flattened_dict, partial_key, check_method): - for k, v in subdict.items(): - new_key = k if partial_key == _FLAG_FIRST else partial_key + "_" + k - if isinstance(v, Mapping) or isinstance(v, spaces.Dict): - visit(v, flattened_dict, new_key, check_method) - else: - if check_method(new_key): - flattened_dict[new_key] = v - - if filter_keys is not None: - visit(input_dictionary, flattened_dict, _FLAG_FIRST, check_filter) - else: - visit(input_dictionary, flattened_dict, _FLAG_FIRST, dummy_check) - - return flattened_dict - -class FlattenFilterDictObs(gym.ObservationWrapper): - def __init__(self, env, filter_keys): - gym.ObservationWrapper.__init__(self, env) - - self.filter_keys = filter_keys - if (filter_keys is not None): - self.filter_keys = list(set(filter_keys)) - if "frame" not in filter_keys: - self.filter_keys += ["frame"] - - original_obs_space_keys = (flatten_filter_obs_space_func(self.observation_space, None)).keys() - self.observation_space = spaces.Dict(flatten_filter_obs_space_func(self.observation_space, self.filter_keys)) - - if filter_keys is not None: - if (sorted(self.observation_space.spaces.keys()) != sorted(self.filter_keys)): - raise Exception("Specified observation key(s) not found:", - " Available key(s):", sorted(original_obs_space_keys), - " Specified key(s):", sorted(self.filter_keys), - " Key(s) not found:", sorted([key for key in self.filter_keys if key not in original_obs_space_keys]), - ) - - def observation(self, observation): - - return flatten_filter_obs_func(observation, self.filter_keys) diff --git a/diambra/arena/wrappers/obs_wrapper_hardcore.py b/diambra/arena/wrappers/obs_wrapper_hardcore.py deleted file mode 100644 index ed786c50..00000000 --- a/diambra/arena/wrappers/obs_wrapper_hardcore.py +++ /dev/null @@ -1,166 +0,0 @@ -from gym import spaces -import gym -import numpy as np -from collections import deque -import cv2 # pytype:disable=import-error -cv2.ocl.setUseOpenCL(False) - -# Env Wrappers classes -class WarpFrame(gym.ObservationWrapper): - def __init__(self, env, hw_obs_resize=[84, 84]): - """ - Warp frames to 84x84 as done in the Nature paper and later work. - :param env: (Gym Environment) the environment - """ - env.logger.warning("Warning: for speedup, avoid frame warping wrappers, use environment's "\ - "native frame wrapping through \"frame_shape\" setting (see documentation for details)") - gym.ObservationWrapper.__init__(self, env) - self.width = hw_obs_resize[1] - self.height = hw_obs_resize[0] - self.observation_space = spaces.Box(low=0, high=255, - shape=(self.height, self.width, 1), - dtype=self.observation_space.dtype) - - def observation(self, frame): - """ - returns the current observation from a frame - :param frame: ([int] or [float]) environment frame - :return: ([int] or [float]) the observation - """ - frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) - frame = cv2.resize(frame, (self.width, self.height), - interpolation=cv2.INTER_LINEAR)[:, :, None] - return frame - - -class WarpFrame3C(gym.ObservationWrapper): - def __init__(self, env, hw_obs_resize=[224, 224]): - """ - Warp frames to 84x84 as done in the Nature paper and later work. - :param env: (Gym Environment) the environment - """ - env.logger.warning("Warning: for speedup, avoid frame warping wrappers, use environment's "\ - "native frame wrapping through \"frame_shape\" setting (see documentation for details)") - gym.ObservationWrapper.__init__(self, env) - self.width = hw_obs_resize[1] - self.height = hw_obs_resize[0] - self.observation_space = spaces.Box(low=0, high=255, - shape=(self.height, self.width, 3), - dtype=self.observation_space.dtype) - - def observation(self, frame): - """ - returns the current observation from a frame - :param frame: ([int] or [float]) environment frame - :return: ([int] or [float]) the observation - """ - frame = cv2.resize(frame, (self.width, self.height), - interpolation=cv2.INTER_LINEAR)[:, :, None] - return frame - - -class FrameStack(gym.Wrapper): - def __init__(self, env, n_frames): - """Stack n_frames last frames. - - :param env: (Gym Environment) the environment - :param n_frames: (int) the number of frames to stack - """ - gym.Wrapper.__init__(self, env) - self.n_frames = n_frames - self.frames = deque([], maxlen=n_frames) - shp = self.observation_space.shape - self.observation_space = spaces.Box(low=0, high=255, - shape=(shp[0], shp[1], shp[2] * n_frames), - dtype=self.observation_space.dtype) - - def reset(self, **kwargs): - obs = self.env.reset(**kwargs) - # Fill the stack upon reset to avoid black frames - for _ in range(self.n_frames): - self.frames.append(obs) - - return self.get_ob() - - def step(self, action): - obs, reward, done, info = self.env.step(action) - self.frames.append(obs) - - # Add last obs n_frames - 1 times in case of - # new round / stage / continueGame - if ((info["round_done"] or info["stage_done"] or info["game_done"]) and not done): - for _ in range(self.n_frames - 1): - self.frames.append(obs) - - return self.get_ob(), reward, done, info - - def get_ob(self): - assert len(self.frames) == self.n_frames - return np.concatenate(self.frames, axis=2) - - -class FrameStackDilated(gym.Wrapper): - def __init__(self, env, n_frames, dilation): - """Stack n_frames last frames with dilation factor. - :param env: (Gym Environment) the environment - :param n_frames: (int) the number of frames to stack - :param dilation: (int) the dilation factor - """ - gym.Wrapper.__init__(self, env) - self.n_frames = n_frames - self.dilation = dilation - # Keeping all n_frames*dilation in memory, - # then extract the subset given by the dilation factor - self.frames = deque([], maxlen=n_frames * dilation) - shp = self.observation_space.shape - self.observation_space = spaces.Box(low=0, high=255, - shape=(shp[0], shp[1], shp[2] * n_frames), - dtype=self.observation_space.dtype) - - def reset(self, **kwargs): - obs = self.env.reset(**kwargs) - for _ in range(self.n_frames * self.dilation): - self.frames.append(obs) - return self.get_ob() - - def step(self, action): - obs, reward, done, info = self.env.step(action) - self.frames.append(obs) - - # Add last obs n_frames - 1 times in case of - # new round / stage / continueGame - if ((info["round_done"] or info["stage_done"] or info["game_done"]) and not done): - for _ in range(self.n_frames * self.dilation - 1): - self.frames.append(obs) - - return self.get_ob(), reward, done, info - - def get_ob(self): - frames_subset = list(self.frames)[self.dilation - 1::self.dilation] - assert len(frames_subset) == self.n_frames - return np.concatenate(frames_subset, axis=2) - - -class ScaledFloatObsNeg(gym.ObservationWrapper): - def __init__(self, env): - gym.ObservationWrapper.__init__(self, env) - self.observation_space = spaces.Box(low=-1.0, high=1.0, - shape=self.observation_space.shape, dtype=np.float32) - - def observation(self, observation): - # careful! This undoes the memory optimization, use - # with smaller replay buffers only. - return observation / 127.5 - 1.0 - - -class ScaledFloatObs(gym.ObservationWrapper): - def __init__(self, env): - gym.ObservationWrapper.__init__(self, env) - - self.observation_space = spaces.Box(low=0, high=1.0, - shape=self.observation_space.shape, - dtype=np.float32) - - def observation(self, observation): - - return observation / 255.0 diff --git a/diambra/arena/wrappers/observation.py b/diambra/arena/wrappers/observation.py new file mode 100644 index 00000000..68475fc9 --- /dev/null +++ b/diambra/arena/wrappers/observation.py @@ -0,0 +1,410 @@ +import gymnasium as gym +from copy import deepcopy +import numpy as np +from collections import deque +from collections.abc import Mapping +import cv2 # pytype:disable=import-error +cv2.ocl.setUseOpenCL(False) +from diambra.engine import Roles + +# Env Wrappers classes +class GrayscaleFrame(gym.ObservationWrapper): + def __init__(self, env): + """ + :param env: (Gym Environment) the environment + """ + gym.ObservationWrapper.__init__(self, env) + self.unwrapped.logger.warning("Warning: for speedup, avoid frame warping wrappers, use environment's " + + "native frame grey scaling through \"frame_shape\" setting (see documentation for details)") + + self.width = self.observation_space.spaces["frame"].shape[1] + self.height = self.observation_space.spaces["frame"].shape[0] + self.observation_space.spaces["frame"] = gym.spaces.Box(low=0, high=255, shape=(self.height, self.width, 1), + dtype=self.observation_space["frame"].dtype) + + def observation(self, obs): + """ + returns the current observation from a obs + :param obs: environment obs + :return: the observation + """ + obs["frame"] = cv2.cvtColor(obs["frame"], cv2.COLOR_RGB2GRAY) + return obs + +class WarpFrame(gym.ObservationWrapper): + def __init__(self, env, frame_shape=[84, 84]): + """ + Warp frames to frame_shape resolution, not altering channels + :param env: (Gym Environment) the environment + """ + gym.ObservationWrapper.__init__(self, env) + self.unwrapped.logger.warning("Warning: for speedup, avoid frame warping wrappers, use environment's " + + "native frame wrapping through \"frame_shape\" setting (see documentation for details)") + + self.width = frame_shape[1] + self.height = frame_shape[0] + channels = self.observation_space.spaces["frame"].shape[2] + self.observation_space.spaces["frame"] = gym.spaces.Box(low=0, high=255, shape=(self.height, self.width, channels), + dtype=self.observation_space["frame"].dtype) + + def observation(self, obs): + """ + returns the current observation from a obs + :param obs: environment obs + :return: the observation + """ + obs["frame"] = cv2.resize(obs["frame"], (self.width, self.height), interpolation=cv2.INTER_LINEAR)[:, :, None] + return obs + +class FrameStack(gym.Wrapper): + def __init__(self, env, n_frames, dilation): + """Stack n_frames last frames with dilation factor. + :param env: (Gym Environment) the environment + :param n_frames: (int) the number of frames to stack + :param dilation: (int) the dilation factor + """ + gym.Wrapper.__init__(self, env) + self.n_frames = n_frames + self.dilation = dilation + # Keeping all n_frames*dilation in memory, + # then extract the subset given by the dilation factor + self.frames = deque([], maxlen=n_frames * dilation) + shp = self.observation_space["frame"].shape + self.observation_space.spaces["frame"] = gym.spaces.Box(low=0, high=255, + shape=(shp[0], shp[1], shp[2] * n_frames), + dtype=self.observation_space["frame"].dtype) + + def reset(self, **kwargs): + obs, info = self.env.reset(**kwargs) + for _ in range(self.n_frames * self.dilation): + self.frames.append(obs["frame"]) + obs["frame"] = self.get_ob() + return obs, info + + def step(self, action): + obs, reward, terminated, truncated, info = self.env.step(action) + self.frames.append(obs["frame"]) + + # Add last obs n_frames - 1 times in case of + # new round / stage / continueGame + if ((info["round_done"] or info["stage_done"] or info["game_done"]) and not (terminated or truncated)): + for _ in range(self.n_frames * self.dilation - 1): + self.frames.append(obs["frame"]) + + obs["frame"] = self.get_ob() + return obs, reward, terminated, truncated, info + + def get_ob(self): + frames_subset = list(self.frames)[self.dilation - 1::self.dilation] + assert len(frames_subset) == self.n_frames + return np.concatenate(frames_subset, axis=2) + +class AddLastActionToObservation(gym.Wrapper): + def __init__(self, env): + """Add last performed action to observation space + :param env: (Gym Environment) the environment + """ + gym.Wrapper.__init__(self, env) + if self.unwrapped.env_settings.n_players == 1: + self.observation_space = gym.spaces.Dict({ + **self.observation_space.spaces, + "action": self.action_space, + }) + def _add_last_action_to_obs_1p(obs, last_action): + obs["action"] = last_action + return obs + self._add_last_action_to_obs = _add_last_action_to_obs_1p + else: + for idx in range(self.unwrapped.env_settings.n_players): + action_dictionary = {} + action_dictionary["action"] = self.action_space["agent_{}".format(idx)] + self.observation_space = gym.spaces.Dict({ + **self.observation_space.spaces, + "agent_{}".format(idx): gym.spaces.Dict(action_dictionary), + }) + def _add_last_action_to_obs_2p(obs, last_action): + for idx in range(self.unwrapped.env_settings.n_players): + action_dictionary = {} + action_dictionary["action"] = last_action["agent_{}".format(idx)] + obs["agent_{}".format(idx)] = action_dictionary + return obs + self._add_last_action_to_obs = _add_last_action_to_obs_2p + + def reset(self, **kwargs): + obs, info = self.env.reset(**kwargs) + return self._add_last_action_to_obs(obs, self.unwrapped.get_no_op_action()), info + + def step(self, action): + obs, reward, terminated, truncated, info = self.env.step(action) + return self._add_last_action_to_obs(obs, action), reward, terminated, truncated, info + +class ActionsStack(gym.Wrapper): + def __init__(self, env, n_actions_stack): + """Stack n_actions_stack last actions. + :param env: (Gym Environment) the environment + :param n_actions_stack: (int) the number of actions to stack + """ + gym.Wrapper.__init__(self, env) + + self.n_actions_stack = n_actions_stack + no_op_action = self.unwrapped.get_no_op_action() + + if self.unwrapped.env_settings.n_players == 1: + assert "action" in self.observation_space.spaces, "ActionsStack wrapper can be activated only "\ + "when \"action\" info is in the observation space" + + if isinstance(self.action_space, gym.spaces.MultiDiscrete): + self.action_stack = [deque(no_op_action * n_actions_stack, maxlen=n_actions_stack * 2)] + action_space_size = list(self.observation_space["action"].nvec) + else: + self.action_stack = [deque([no_op_action] * n_actions_stack, maxlen=n_actions_stack)] + action_space_size = [self.observation_space["action"].n] + self.observation_space["action"] = gym.spaces.MultiDiscrete(action_space_size * n_actions_stack) + + if isinstance(self.action_space, gym.spaces.MultiDiscrete): + def _add_action_to_stack_1p(action): + self.action_stack[0].append(action[0]) + self.action_stack[0].append(action[1]) + else: + def _add_action_to_stack_1p(action): + self.action_stack[0].append(action) + self._add_action_to_stack = _add_action_to_stack_1p + + def _process_obs_1p(obs): + obs["action"] = np.array(self.action_stack[0]) + return obs + self._process_obs = _process_obs_1p + else: + self.action_stack = [] + assert "action" in self.observation_space["agent_0"].keys(), "ActionsStack wrapper can be activated only "\ + "when \"action\" info is in the observation space" + for idx in range(self.unwrapped.env_settings.n_players): + if isinstance(self.action_space["agent_{}".format(idx)], gym.spaces.MultiDiscrete): + self.action_stack.append(deque(no_op_action["agent_{}".format(idx)] * n_actions_stack, maxlen=n_actions_stack * 2)) + action_space_size = list(self.observation_space["agent_{}".format(idx)]["action"].nvec) + else: + self.action_stack.append(deque([no_op_action["agent_{}".format(idx)]] * n_actions_stack, maxlen=n_actions_stack)) + action_space_size = [self.observation_space["agent_{}".format(idx)]["action"].n] + self.observation_space["agent_{}".format(idx)]["action"] = gym.spaces.MultiDiscrete(action_space_size * n_actions_stack) + + def _add_action_to_stack_2p(action): + for idx in range(self.unwrapped.env_settings.n_players): + if isinstance(self.action_space["agent_{}".format(idx)], gym.spaces.MultiDiscrete): + self.action_stack[idx].append(action["agent_{}".format(idx)][0]) + self.action_stack[idx].append(action["agent_{}".format(idx)][1]) + else: + self.action_stack[idx].append(action["agent_{}".format(idx)]) + self._add_action_to_stack = _add_action_to_stack_2p + + def _process_obs_2p(obs): + for idx in range(self.unwrapped.env_settings.n_players): + obs["agent_{}".format(idx)]["action"] = np.array(self.action_stack[idx]) + return obs + self._process_obs = _process_obs_2p + + def reset(self, **kwargs): + obs, info = self.env.reset(**kwargs) + self._fill_stack() + return self._process_obs(obs), info + + def step(self, action): + obs, reward, terminated, truncated, info = self.env.step(action) + self._add_action_to_stack(action) + + # Add noAction for n_actions_stack - 1 times + # in case of new round / stage / continueGame + if ((info["round_done"] or info["stage_done"] or info["game_done"]) and not (terminated or truncated)): + self._fill_stack() + + return self._process_obs(obs), reward, terminated, truncated, info + + def _fill_stack(self): + no_op_action = self.unwrapped.get_no_op_action() + for _ in range(self.n_actions_stack): + self._add_action_to_stack(no_op_action) + +class NormalizeObservation(gym.ObservationWrapper): + def __init__(self, env, exclude_image_scaling=False, process_discrete_binary=False): + gym.ObservationWrapper.__init__(self, env) + + self.exclude_image_scaling = exclude_image_scaling + self.process_discrete_binary = process_discrete_binary + + self.original_observation_space = deepcopy(self.observation_space) + self._obs_space_normalization_func(self.observation_space) + + def observation(self, observation): + return self._obs_normalization_func(observation, self.original_observation_space) + + # Recursive function to modify obs space dict + # FIXME: this can probably be dropped with gym >= 0.21 and only use the next one, here for SB2 compatibility + def _obs_space_normalization_func(self, obs_dict): + # Updating observation space dict + for k, v in obs_dict.items(): + if isinstance(v, gym.spaces.Dict): + self._obs_space_normalization_func(v) + else: + if isinstance(v, gym.spaces.MultiDiscrete): + # One hot encoding x nStack + obs_dict[k] = gym.spaces.MultiBinary(np.sum(v.nvec)) + elif isinstance(v, gym.spaces.Discrete) and (v.n > 2 or self.process_discrete_binary is True): + # One hot encoding + obs_dict[k] = gym.spaces.MultiBinary(v.n) + elif isinstance(v, gym.spaces.Box) and (self.exclude_image_scaling is False or len(v.shape) < 3): + obs_dict[k] = gym.spaces.Box(low=0.0, high=1.0, shape=v.shape, dtype=np.float32) + + # Recursive function to modify obs dict + def _obs_normalization_func(self, observation, observation_space): + # Process all observations + for k, v in observation.items(): + if isinstance(v, dict): + self._obs_normalization_func(v, observation_space.spaces[k]) + else: + v_space = observation_space[k] + if isinstance(v_space, gym.spaces.MultiDiscrete): + actions_vector = np.zeros((np.sum(v_space.nvec)), dtype=np.uint8) + column_index = 0 + for iact in range(v_space.nvec.shape[0]): + actions_vector[column_index + observation[k][iact]] = 1 + column_index += v_space.nvec[iact] + observation[k] = actions_vector + elif isinstance(v_space, gym.spaces.Discrete) and (v_space.n > 2 or self.process_discrete_binary is True): + var_vector = np.zeros((v_space.n), dtype=np.uint8) + var_vector[observation[k]] = 1 + observation[k] = var_vector + elif isinstance(v_space, gym.spaces.Box) and (self.exclude_image_scaling is False or len(v_space.shape) < 3): + high_val = np.max(v_space.high) + low_val = np.min(v_space.low) + observation[k] = np.array((observation[k] - low_val) / (high_val - low_val), dtype=np.float32) + + return observation + +class RoleRelativeObservation(gym.Wrapper): + def __init__(self, env): + gym.Wrapper.__init__(self, env) + + new_observation_space = {} + if self.unwrapped.env_settings.n_players == 1: + for k, v in self.observation_space.items(): + if not isinstance(v, gym.spaces.Dict): + new_observation_space[k] = v + new_observation_space["own"] = self.observation_space["P1"] + new_observation_space["opp"] = self.observation_space["P1"] + else: + for k, v in self.observation_space.items(): + if not isinstance(v, gym.spaces.Dict) or k.startswith("agent_"): + new_observation_space[k] = v + for idx in range(self.unwrapped.env_settings.n_players): + new_observation_space["agent_{}".format(idx)]["own"] = self.observation_space["P1"] + new_observation_space["agent_{}".format(idx)]["opp"] = self.observation_space["P1"] + + self.observation_space = gym.spaces.Dict(new_observation_space) + + def reset(self, **kwargs): + obs, info = self.env.reset(**kwargs) + if self.unwrapped.env_settings.n_players == 1: + def _process_obs_1p(observation): + new_observation = {} + role_name = Roles.Name(info["settings"].episode_settings.player_settings[0].role) + opponent_role_name = "P2" if role_name == "P1" else "P1" + for k, v in observation.items(): + if not isinstance(v, dict): + new_observation[k] = v + new_observation["own"] = observation[role_name] + new_observation["opp"] = observation[opponent_role_name] + return new_observation + self._process_obs = _process_obs_1p + else: + def _process_obs_2p(observation): + new_observation = {} + for k, v in observation.items(): + if not isinstance(v, dict) or k.startswith("agent_"): + new_observation[k] = v + for idx in range(self.unwrapped.env_settings.n_players): + role_name = Roles.Name(info["settings"].episode_settings.player_settings[idx].role) + opponent_role_name = "P2" if role_name == "P1" else "P1" + new_observation["agent_{}".format(idx)]["own"] = observation[role_name] + new_observation["agent_{}".format(idx)]["opp"] = observation[opponent_role_name] + return new_observation + self._process_obs = _process_obs_2p + return self._process_obs(obs), info + + def step(self, action): + obs, reward, terminated, truncated, info = self.env.step(action) + return self._process_obs(obs), reward, terminated, truncated, info + +class FlattenFilterDictObs(gym.ObservationWrapper): + def __init__(self, env, filter_keys): + gym.ObservationWrapper.__init__(self, env) + + self.filter_keys = filter_keys + if len(filter_keys) != 0: + self.filter_keys = list(set(filter_keys)) + if "frame" not in filter_keys: + self.filter_keys += ["frame"] + + original_obs_space_keys = (flatten_filter_obs_space_func(self.observation_space, [])).keys() + self.observation_space = gym.spaces.Dict(flatten_filter_obs_space_func(self.observation_space, self.filter_keys)) + + if len(filter_keys) != 0: + if (sorted(self.observation_space.spaces.keys()) != sorted(self.filter_keys)): + raise Exception("Specified observation key(s) not found:", + " Available key(s):", sorted(original_obs_space_keys), + " Specified key(s):", sorted(self.filter_keys), + " Key(s) not found:", sorted([key for key in self.filter_keys if key not in original_obs_space_keys]), + ) + + def observation(self, observation): + return flatten_filter_obs_func(observation, self.filter_keys) + +def flatten_filter_obs_space_func(input_dictionary, filter_keys): + _FLAG_FIRST = object() + flattened_dict = {} + + def dummy_check(new_key): + return True + + def check_filter(new_key): + return new_key in filter_keys + + def visit(subdict, flattened_dict, partial_key, check_method): + for k, v in subdict.spaces.items(): + new_key = k if partial_key == _FLAG_FIRST else partial_key + "_" + k + if isinstance(v, Mapping) or isinstance(v, gym.spaces.Dict): + visit(v, flattened_dict, new_key, check_method) + else: + if check_method(new_key): + flattened_dict[new_key] = v + + if len(filter_keys) != 0: + visit(input_dictionary, flattened_dict, _FLAG_FIRST, check_filter) + else: + visit(input_dictionary, flattened_dict, _FLAG_FIRST, dummy_check) + + return flattened_dict + +def flatten_filter_obs_func(input_dictionary, filter_keys): + _FLAG_FIRST = object() + flattened_dict = {} + + def dummy_check(new_key): + return True + + def check_filter(new_key): + return new_key in filter_keys + + def visit(subdict, flattened_dict, partial_key, check_method): + for k, v in subdict.items(): + new_key = k if partial_key == _FLAG_FIRST else partial_key + "_" + k + if isinstance(v, Mapping) or isinstance(v, gym.spaces.Dict): + visit(v, flattened_dict, new_key, check_method) + else: + if check_method(new_key): + flattened_dict[new_key] = v + + if len(filter_keys) != 0: + visit(input_dictionary, flattened_dict, _FLAG_FIRST, check_filter) + else: + visit(input_dictionary, flattened_dict, _FLAG_FIRST, dummy_check) + + return flattened_dict \ No newline at end of file diff --git a/diambra/arena/wrappers/traj_rec_wrapper.py b/diambra/arena/wrappers/traj_rec_wrapper.py deleted file mode 100644 index 3158217e..00000000 --- a/diambra/arena/wrappers/traj_rec_wrapper.py +++ /dev/null @@ -1,159 +0,0 @@ -import os -import numpy as np -import datetime - -import gym -from ..utils.gym_utils import gym_obs_dict_space_to_standard_dict,\ - ParallelPickleWriter -from diambra.arena.env_settings import RecordingSettings - -# Trajectory recorder wrapper - - -class TrajectoryRecorder(gym.Wrapper): - def __init__(self, env, recording_settings: RecordingSettings): - """ - Record trajectories to use them for imitation learning - :param env: (Gym Environment) the environment to wrap - :param file_path: (str) file path specifying where to - store the trajectory file - """ - gym.Wrapper.__init__(self, env) - self.file_path = recording_settings.file_path - self.username = recording_settings.username - self.ignore_p2 = recording_settings.ignore_p2 - self.frame_shp = self.env.observation_space["frame"].shape - - if (self.env.player_side == "P1P2"): - if ((self.env.attack_but_combination[0] != self.env.attack_but_combination[1]) - or (self.env.action_space["P1"] != self.env.action_space["P2"])): - raise Exception("Different attack buttons combinations and/or " - "different action spaces not supported for 2P " - "experience recordings" - "action space: {}, attack button combo: {}".format(self.env.action_space, self.env.attack_but_combination) - ) - - if ("P1" in self.observation_space.keys()) is False: - raise Exception("Trajectory recording for not hardcore mode does not work with Observation Dict flattening, please deactivate it.") - - env.logger.info("Recording trajectories in \"{}\"".format(self.file_path)) - os.makedirs(self.file_path, exist_ok=True) - - def reset(self, **kwargs): - """ - Reset the environment and add requested info to the observation - :return: observation - """ - - # Items to store - self.last_frame_hist = [] - self.ram_states_hist = [] - self.rewards_hist = [] - self.actions_hist = [] - self.flag_hist = [] - self.cumulative_rew = 0 - - obs = self.env.reset(**kwargs) - - for idx in range(self.frame_shp[2]): - self.last_frame_hist.append(obs["frame"][:, :, idx]) - - # Store the whole obs without the frame - tmp_obs = obs.copy() - tmp_obs.pop("frame") - self.ram_states_hist.append(tmp_obs) - - return obs - - def step(self, action): - """ - Step the environment with the given action - and add requested info to the observation - :param action: ([int] or [float]) the action - :return: new observation, reward, done, information - """ - - obs, reward, done, info = self.env.step(action) - - self.last_frame_hist.append(obs["frame"][:, :, self.frame_shp[2] - 1]) - - # Add last obs nFrames - 1 times in case of - # new round / stage / continue_game - if ((info["round_done"] or info["stage_done"] or info["game_done"]) and not done): - for _ in range(self.frame_shp[2] - 1): - self.last_frame_hist.append(obs["frame"][:, :, self.frame_shp[2] - 1]) - - # Store the whole obs without the frame - tmp_obs = obs.copy() - tmp_obs.pop("frame") - self.ram_states_hist.append(tmp_obs) - - self.rewards_hist.append(reward) - self.actions_hist.append(action) - self.flag_hist.append([info["round_done"], info["stage_done"], - info["game_done"], info["ep_done"]]) - self.cumulative_rew += reward - - if done: - to_save = {} - n_actions = self.env.n_actions if self.env.player_side != "P1P2" else self.env.n_actions[0] - to_save["username"] = self.username - to_save["player_side"] = self.env.player_side - if self.env.player_side != "P1P2": - to_save["difficulty"] = self.env.difficulty - if isinstance(self.env.action_space, gym.spaces.Discrete): - to_save["action_space"] = "discrete" - else: - to_save["action_space"] = "multi_discrete" - to_save["attack_but_comb"] = self.env.attack_but_combination - else: - if isinstance(self.env.action_space["P1"], gym.spaces.Discrete): - to_save["action_space"] = "discrete" - else: - to_save["action_space"] = "multi_discrete" - to_save["attack_but_comb"] = self.env.attack_but_combination[0] - to_save["n_actions"] = n_actions - to_save["frame_shp"] = self.frame_shp - to_save["ignore_p2"] = self.ignore_p2 - to_save["char_names"] = self.env.char_names - - # Handle flattened and unflattened obs_dicts - if "P1" in self.env.observation_space.keys(): - tmp_elem = self.env.observation_space["P1"]["actions"]["move"] - else: - tmp_elem = self.env.observation_space["P1_actions_move"] - - if isinstance(tmp_elem, gym.spaces.MultiDiscrete): - to_save["n_actions_stack"] = int(tmp_elem.nvec.shape[0] / n_actions[0]) - else: - to_save["n_actions_stack"] = int(tmp_elem.n / n_actions[0]) - - to_save["ep_len"] = len(self.rewards_hist) - to_save["cum_rew"] = self.cumulative_rew - to_save["frames"] = self.last_frame_hist - to_save["ram_states"] = self.ram_states_hist - to_save["rewards"] = self.rewards_hist - to_save["actions"] = self.actions_hist - to_save["done_flags"] = self.flag_hist - to_save["observation_space_dict"] = gym_obs_dict_space_to_standard_dict(self.env.observation_space) - - # Characters name - # If 2P mode - if self.env.player_side == "P1P2" and self.ignore_p2 is False: - save_path = "mod_" + str(self.ignore_p2) + "_" +\ - self.env.player_side + "_rew" +\ - str(np.round(self.cumulative_rew, 3)) +\ - "_" + datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') - # If 1P mode - else: - save_path = "mod_" + str(self.ignore_p2) + "_" +\ - self.env.player_side + "_diff" +\ - str(self.env.difficulty) + "_rew" +\ - str(np.round(self.cumulative_rew, 3)) + "_" +\ - datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') - - pickle_writer = ParallelPickleWriter( - os.path.join(self.file_path, save_path), to_save) - pickle_writer.start() - - return obs, reward, done, info diff --git a/diambra/arena/wrappers/traj_rec_wrapper_hardcore.py b/diambra/arena/wrappers/traj_rec_wrapper_hardcore.py deleted file mode 100644 index fd85281c..00000000 --- a/diambra/arena/wrappers/traj_rec_wrapper_hardcore.py +++ /dev/null @@ -1,129 +0,0 @@ -import os -import numpy as np -import datetime - -import gym -from ..utils.gym_utils import ParallelPickleWriter -from diambra.arena.env_settings import RecordingSettings - -# Trajectory recorder wrapper -class TrajectoryRecorder(gym.Wrapper): - def __init__(self, env, recording_settings: RecordingSettings): - """ - Record trajectories to use them for imitation learning - :param env: (Gym Environment) the environment to wrap - """ - gym.Wrapper.__init__(self, env) - self.file_path = recording_settings.file_path - self.username = recording_settings.username - self.ignore_p2 = recording_settings.ignore_p2 - self.frame_shp = self.env.observation_space.shape - - if (self.env.player_side == "P1P2"): - if ((self.env.attack_but_combination[0] != self.env.attack_but_combination[1]) - or (self.env.action_space["P1"] != self.env.action_space["P2"])): - raise Exception("Different attack buttons combinations and/or " - "different action spaces not supported for 2P " - "experience recordings" - "action space: {}, attack button combo: {}".format(self.env.action_space, self.env.attack_but_combination) - ) - - env.logger.info("Recording trajectories in \"{}\"".format(self.file_path)) - os.makedirs(self.file_path, exist_ok=True) - - def reset(self, **kwargs): - """ - Reset the environment and add requested info to the observation - :return: observation - """ - - # Items to store - self.last_frame_hist = [] - self.rewards_hist = [] - self.actions_hist = [] - self.flag_hist = [] - self.cumulative_rew = 0 - - obs = self.env.reset(**kwargs) - - for idx in range(self.frame_shp[2]): - self.last_frame_hist.append(obs[:, :, idx]) - - return obs - - def step(self, action): - """ - Step the environment with the given action - and add requested info to the observation - :param action: ([int] or [float]) the action - :return: new observation, reward, done, information - """ - - obs, reward, done, info = self.env.step(action) - - self.last_frame_hist.append(obs[:, :, self.frame_shp[2]-1]) - - # Add last obs nFrames - 1 times in case of - # new round / stage / continue_game - if ((info["round_done"] or info["stage_done"] or info["game_done"]) and not done): - for _ in range(self.frame_shp[2]-1): - self.last_frame_hist.append(obs[:, :, self.frame_shp[2]-1]) - - self.rewards_hist.append(reward) - self.actions_hist.append(action) - self.flag_hist.append([info["round_done"], info["stage_done"], - info["game_done"], info["ep_done"]]) - self.cumulative_rew += reward - - if done: - to_save = {} - n_actions = self.env.n_actions if self.env.player_side != "P1P2" else self.env.n_actions[0] - to_save["username"] = self.username - to_save["player_side"] = self.env.player_side - if self.env.player_side != "P1P2": - to_save["difficulty"] = self.env.difficulty - if isinstance(self.env.action_space, gym.spaces.Discrete): - to_save["action_space"] = "discrete" - else: - to_save["action_space"] = "multi_discrete" - to_save["attack_but_comb"] = self.env.attack_but_combination - else: - if isinstance(self.env.action_space["P1"], gym.spaces.Discrete): - to_save["action_space"] = "discrete" - else: - to_save["action_space"] = "multi_discrete" - to_save["attack_but_comb"] = self.env.attack_but_combination[0] - to_save["n_actions"] = n_actions - to_save["frame_shp"] = self.frame_shp - to_save["ignore_p2"] = self.ignore_p2 - to_save["char_names"] = self.env.char_names - to_save["n_actions_stack"] = 0 - to_save["ep_len"] = len(self.rewards_hist) - to_save["cum_rew"] = self.cumulative_rew - to_save["frames"] = self.last_frame_hist - to_save["rewards"] = self.rewards_hist - to_save["actions"] = self.actions_hist - to_save["done_flags"] = self.flag_hist - to_save["obs_space_bounds"] = [self.env.observation_space.low[0][0][0], - self.env.observation_space.high[0][0][0]] - - # Characters name - # If 2P mode - if self.env.player_side == "P1P2" and self.ignore_p2 == 0: - save_path = "HC_mod" + str(self.ignore_p2) + "_" +\ - self.env.player_side + "_rew" +\ - str(np.round(self.cumulative_rew, 3)) + "_" +\ - datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') - # If 1P mode - else: - save_path = "HC_mod" + str(self.ignore_p2) + "_" +\ - self.env.player_side + "_diff" +\ - str(self.env.difficulty) + "_rew" +\ - str(np.round(self.cumulative_rew, 3)) + "_" +\ - datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') - - pickle_writer = ParallelPickleWriter( - os.path.join(self.file_path, save_path), to_save) - pickle_writer.start() - - return obs, reward, done, info diff --git a/examples/diambra_arena_gist.ipynb b/examples/diambra_arena_gist.ipynb deleted file mode 100644 index 48109aee..00000000 --- a/examples/diambra_arena_gist.ipynb +++ /dev/null @@ -1,105 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## DIAMBRA Arena Jupyter Notebook Example" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### DIAMBRA Arena module import" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import diambra.arena" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Environment creation and reset" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "env = diambra.arena.make(\"doapp\")\n", - "observation = env.reset()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Agent-Environment interaction loop" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "while True:\n", - " env.render()\n", - "\n", - " actions = env.action_space.sample()\n", - "\n", - " observation, reward, done, info = env.step(actions)\n", - "\n", - " if done:\n", - " break" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Environment close" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "env.close()" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.13" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/examples/diambra_arena_gist.py b/examples/diambra_arena_gist.py index 63ea3fd3..5b497f9b 100755 --- a/examples/diambra_arena_gist.py +++ b/examples/diambra_arena_gist.py @@ -1,20 +1,34 @@ #!/usr/bin/env python3 import diambra.arena -if __name__ == '__main__': +def main(): + # Environment creation + env = diambra.arena.make("doapp", render_mode="human") - env = diambra.arena.make("doapp") - observation = env.reset() + # Environment reset + observation, info = env.reset(seed=42) + # Agent-Environment interaction loop while True: + # (Optional) Environment rendering env.render() + # Action random sampling actions = env.action_space.sample() - observation, reward, done, info = env.step(actions) + # Environment stepping + observation, reward, terminated, truncated, info = env.step(actions) - if done: - observation = env.reset() + # Episode end (Done condition) check + if terminated or truncated: + observation, info = env.reset() break + # Environment shutdown env.close() + + # Return success + return 0 + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/examples/episode_data_loader.py b/examples/episode_data_loader.py new file mode 100644 index 00000000..f4c20a40 --- /dev/null +++ b/examples/episode_data_loader.py @@ -0,0 +1,38 @@ +from diambra.arena.utils.diambra_data_loader import DiambraDataLoader +import argparse +import os + +def main(dataset_path_input): + if dataset_path_input is not None: + dataset_path = dataset_path_input + else: + base_path = os.path.dirname(os.path.abspath(__file__)) + dataset_path = os.path.join(base_path, "dataset") + + data_loader = DiambraDataLoader(dataset_path) + + n_loops = data_loader.reset() + + while n_loops == 0: + observation, action, reward, terminated, truncated, info = data_loader.step() + print("Observation: {}".format(observation)) + print("Action: {}".format(action)) + print("Reward: {}".format(reward)) + print("Terminated: {}".format(terminated)) + print("Truncated: {}".format(truncated)) + print("Info: {}".format(info)) + data_loader.render() + + if terminated or truncated: + n_loops = data_loader.reset() + + # Return success + return 0 + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--dataset_path', type=str, default=None, help='Path to dataset') + opt = parser.parse_args() + print(opt) + + main(opt.dataset_path) diff --git a/examples/episode_recording.py b/examples/episode_recording.py new file mode 100644 index 00000000..afacacf2 --- /dev/null +++ b/examples/episode_recording.py @@ -0,0 +1,56 @@ +import os +from os.path import expanduser +import diambra.arena +from diambra.arena import SpaceTypes, EnvironmentSettings, RecordingSettings +from diambra.arena.utils.controller import get_diambra_controller +import argparse + +def main(use_controller): + # Environment Settings + settings = EnvironmentSettings() + settings.step_ratio = 1 + settings.frame_shape = (256, 256, 1) + settings.action_space = SpaceTypes.MULTI_DISCRETE + + # Recording settings + home_dir = expanduser("~") + game_id = "doapp" + recording_settings = RecordingSettings() + recording_settings.dataset_path = os.path.join(home_dir, "DIAMBRA/episode_recording", game_id if use_controller else "mock") + recording_settings.username = "alexpalms" + + env = diambra.arena.make(game_id, settings, episode_recording_settings=recording_settings, render_mode="human") + + if use_controller is True: + # Controller initialization + controller = get_diambra_controller(env.get_actions_tuples()) + controller.start() + + observation, info = env.reset(seed=42) + + while True: + env.render() + if use_controller is True: + actions = controller.get_actions() + else: + actions = env.action_space.sample() + observation, reward, terminated, truncated, info = env.step(actions) + done = terminated or truncated + if done: + observation, info = env.reset() + break + + if use_controller is True: + controller.stop() + env.close() + + # Return success + return 0 + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--use_controller', type=int, default=1, help='Flag to activate use of controller') + opt = parser.parse_args() + print(opt) + + main(bool(opt.use_controller)) diff --git a/examples/human_trajectory_recording_options_single_player.py b/examples/human_trajectory_recording_options_single_player.py deleted file mode 100644 index ca637703..00000000 --- a/examples/human_trajectory_recording_options_single_player.py +++ /dev/null @@ -1,59 +0,0 @@ -import os -from os.path import expanduser -import diambra.arena -from diambra.arena.utils.controller import get_diambra_controller - -if __name__ == '__main__': - - # Environment Settings - settings = {} - settings["player"] = "Random" - settings["step_ratio"] = 1 - settings["frame_shape"] = (128, 128, 1) - settings["action_space"] = "multi_discrete" - settings["attack_but_combination"] = True - - # Gym wrappers settings - wrappers_settings = {} - wrappers_settings["reward_normalization"] = True - wrappers_settings["frame_stack"] = 4 - wrappers_settings["actions_stack"] = 12 - wrappers_settings["scale"] = True - - # Gym trajectory recording wrapper kwargs - traj_rec_settings = {} - home_dir = expanduser("~") - - # Username - traj_rec_settings["username"] = "Alex" - - # Path where to save recorderd trajectories - game_id = "doapp" - traj_rec_settings["file_path"] = os.path.join(home_dir, "diambraArena/trajRecordings", game_id) - - # If to ignore P2 trajectory (useful when collecting - # only human trajectories while playing as a human against a RL agent) - traj_rec_settings["ignore_p2"] = False - - env = diambra.arena.make(game_id, settings, wrappers_settings, traj_rec_settings) - - # Controller initialization - controller = get_diambra_controller(env.action_list) - controller.start() - - observation = env.reset() - - while True: - - env.render() - - actions = controller.get_actions() - - observation, reward, done, info = env.step(actions) - - if done: - observation = env.reset() - break - - controller.stop() - env.close() diff --git a/examples/imitation_learning.py b/examples/imitation_learning.py deleted file mode 100644 index 4acf8835..00000000 --- a/examples/imitation_learning.py +++ /dev/null @@ -1,54 +0,0 @@ -import diambra.arena -import os -import numpy as np - -if __name__ == '__main__': - - # Show files in folder - base_path = os.path.dirname(os.path.abspath(__file__)) - recorded_traj_folder = os.path.join(base_path, "recordedTrajectories") - recorded_traj_files = [os.path.join(recorded_traj_folder, f) - for f in os.listdir(recorded_traj_folder) - if os.path.isfile(os.path.join(recorded_traj_folder, f))] - print(recorded_traj_files) - - # Imitation learning settings - settings = {} - - # List of recorded trajectories files - settings["traj_files_list"] = recorded_traj_files - - # Number of parallel Imitation Learning environments that will be run - settings["total_cpus"] = 2 - - # Rank of the created environment - settings["rank"] = 0 - - env = diambra.arena.ImitationLearning(**settings) - - observation = env.reset() - env.render(mode="human") - env.show_obs(observation) - - # Show trajectory summary - env.traj_summary() - - while True: - - dummy_actions = 0 - observation, reward, done, info = env.step(dummy_actions) - env.render(mode="human") - env.show_obs(observation) - print("Reward: {}".format(reward)) - print("Done: {}".format(done)) - print("Info: {}".format(info)) - - if np.any(env.exhausted): - break - - if done: - observation = env.reset() - env.render(mode="human") - env.show_obs(observation) - - env.close() diff --git a/examples/multi_player_env.py b/examples/multi_player_env.py index 3d4aca5b..3e0b6553 100644 --- a/examples/multi_player_env.py +++ b/examples/multi_player_env.py @@ -1,48 +1,54 @@ #!/usr/bin/env python3 import diambra.arena -import numpy as np - -if __name__ == '__main__': +from diambra.arena import SpaceTypes, EnvironmentSettingsMultiAgent +def main(): # Environment Settings - settings = {} + settings = EnvironmentSettingsMultiAgent() # Multi Agents environment - # 2 Players game - settings["player"] = "P1P2" + # --- Environment settings --- - # Characters to be used, automatically extended with "Random" for games - # required to select more than one character (e.g. Tekken Tag Tournament) - settings["characters"] = ("Random", "Random") + # If to use discrete or multi_discrete action space + settings.action_space = (SpaceTypes.DISCRETE, SpaceTypes.DISCRETE) - # Characters outfit - settings["char_outfits"] = (2, 2) + # --- Episode settings --- - # If to use discrete or multi_discrete action space - settings["action_space"] = ("discrete", "discrete") + # Characters to be used, automatically extended with None for games + # requiring to select more than one character (e.g. Tekken Tag Tournament) + settings.characters = ("Ryu", "Ken") - # If to use attack buttons combinations actions - settings["attack_but_combination"] = (True, True) + # Characters outfit + settings.outfits = (2, 2) - env = diambra.arena.make("doapp", settings) + env = diambra.arena.make("sfiii3n", settings, render_mode="human") - observation = env.reset() + observation, info = env.reset(seed=42) env.show_obs(observation) while True: - actions = env.action_space.sample() - actions = np.append(actions["P1"], actions["P2"]) print("Actions: {}".format(actions)) - observation, reward, done, info = env.step(actions) + observation, reward, terminated, truncated, info = env.step(actions) + done = terminated or truncated env.show_obs(observation) print("Reward: {}".format(reward)) print("Done: {}".format(done)) print("Info: {}".format(info)) if done: - observation = env.reset() + # Optionally, change episode settings here + options = {} + options["characters"] = (None, None) + options["char_outfits"] = (5, 5) + observation, info = env.reset(options=options) env.show_obs(observation) break env.close() + + # Return success + return 0 + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/examples/recordedTrajectories/mod_False_P1P2_rew-2.0_2022-11-23_15-23-19 b/examples/recordedTrajectories/mod_False_P1P2_rew-2.0_2022-11-23_15-23-19 deleted file mode 100644 index fcd21eb5..00000000 Binary files a/examples/recordedTrajectories/mod_False_P1P2_rew-2.0_2022-11-23_15-23-19 and /dev/null differ diff --git a/examples/recordedTrajectories/mod_False_P1_diff4_rew-4.769_2022-11-23_15-25-19 b/examples/recordedTrajectories/mod_False_P1_diff4_rew-4.769_2022-11-23_15-25-19 deleted file mode 100644 index d6d841e3..00000000 Binary files a/examples/recordedTrajectories/mod_False_P1_diff4_rew-4.769_2022-11-23_15-25-19 and /dev/null differ diff --git a/examples/single_player_env.py b/examples/single_player_env.py index 8b8ccb5b..e0b3ec47 100644 --- a/examples/single_player_env.py +++ b/examples/single_player_env.py @@ -1,71 +1,83 @@ #!/usr/bin/env python3 import diambra.arena +from diambra.arena import SpaceTypes, Roles, EnvironmentSettings -if __name__ == '__main__': - +def main(): # Settings - settings = {} + settings = EnvironmentSettings() # Single agent environment - # Player side selection: P1 (left), P2 (right), Random (50% P1, 50% P2) - settings["player"] = "P2" + # --- Environment settings --- # Number of steps performed by the game # for every environment step, bounds: [1, 6] - settings["step_ratio"] = 6 + settings.step_ratio = 6 # Native frame resize operation - settings["frame_shape"] = (128, 128, 0) # RBG with 128x128 size - # settings["frame_shape"] = (0, 0, 1) # Grayscale with original size - # settings["frame_shape"] = (0, 0, 0) # Deactivated (Original size RBG) + settings.frame_shape = (128, 128, 0) # RBG with 128x128 size + # settings.frame_shape = (0, 0, 1) # Grayscale with original size + # settings.frame_shape = (0, 0, 0) # Deactivated (Original size RBG) + + # If to use discrete or multi_discrete action space + settings.action_space = SpaceTypes.MULTI_DISCRETE + + # --- Episode settings --- + + # Player role selection: P1 (left), P2 (right), None (50% P1, 50% P2) + settings.role = Roles.P1 # Game continue logic (0.0 by default): # - [0.0, 1.0]: probability of continuing game at game over # - int((-inf, -1.0]): number of continues at game over # before episode to be considered done - settings["continue_game"] = 0.0 + settings.continue_game = 0.0 # If to show game final when game is completed - settings["show_final"] = False - - # If to use hardcore mode in which observations are only made of game frame - settings["hardcore"] = False + settings.show_final = False # Game-specific options (see documentation for details) # Game difficulty level - settings["difficulty"] = 4 + settings.difficulty = 4 - # Character to be used, automatically extended with "Random" for games - # required to select more than one character (e.g. Tekken Tag Tournament) - settings["characters"] = "Random" + # Character to be used, automatically extended with None for games + # requiring to select more than one character (e.g. Tekken Tag Tournament) + settings.characters = "Kasumi" # Character outfit - settings["char_outfits"] = 2 + settings.outfits = 2 - # If to use discrete or multi_discrete action space - settings["action_space"] = "multi_discrete" - - # If to use attack buttons combinations actions - settings["attack_but_combination"] = True + env = diambra.arena.make("doapp", settings, render_mode="human") - env = diambra.arena.make("doapp", settings) - - observation = env.reset() + observation, info = env.reset(seed=42) env.show_obs(observation) while True: - actions = env.action_space.sample() print("Actions: {}".format(actions)) - observation, reward, done, info = env.step(actions) + observation, reward, terminated, truncated, info = env.step(actions) + done = terminated or truncated env.show_obs(observation) print("Reward: {}".format(reward)) print("Done: {}".format(done)) print("Info: {}".format(info)) if done: - observation = env.reset() + # Optionally, change episode settings here + options = {} + options["role"] = Roles.P2 + options["continue_game"] = 0.0 + options["difficulty"] = None + options["characters"] = None + options["outfits"] = 4 + + observation, info = env.reset(options=options) env.show_obs(observation) break env.close() + + # Return success + return 0 + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/examples/wrappers_options.py b/examples/wrappers_options.py index d305129c..c055d4bb 100644 --- a/examples/wrappers_options.py +++ b/examples/wrappers_options.py @@ -1,79 +1,109 @@ import diambra.arena +from diambra.arena import SpaceTypes, EnvironmentSettings, WrappersSettings -if __name__ == '__main__': +def main(): + # Environment settings + settings = EnvironmentSettings() + settings.action_space = SpaceTypes.MULTI_DISCRETE # Gym wrappers settings - wrappers_settings = {} + wrappers_settings = WrappersSettings() + + ### Generic wrappers # Number of no-Op actions to be executed # at the beginning of the episode (0 by default) - wrappers_settings["no_op_max"] = 0 + wrappers_settings.no_op_max = 0 # Number of steps for which the same action should be sent (1 by default) - wrappers_settings["sticky_actions"] = 1 + wrappers_settings.repeat_action = 1 - # Frame resize operation spec (deactivated by default) - # WARNING: for speedup, avoid frame warping wrappers, - # use environment's native frame wrapping through - # "frame_shape" setting (see documentation for details). - wrappers_settings["hwc_obs_resize"] = (128, 128, 1) + ### Reward wrappers # Wrapper option for reward normalization # When activated, the reward normalization factor can be set (default = 0.5) # The normalization is performed by dividing the reward value # by the product of the factor times the value of the full health bar # reward = reward / (C * fullHealthBarValue) - wrappers_settings["reward_normalization"] = True - wrappers_settings["reward_normalization_factor"] = 0.5 + wrappers_settings.normalize_reward = True + wrappers_settings.normalization_factor = 0.5 # If to clip rewards (False by default) - wrappers_settings["clip_rewards"] = False + wrappers_settings.clip_reward = False + + ### Action space wrapper(s) + + # Limit the action space to single attack buttons, removing attack buttons combinations (False by default) + wrappers_settings.no_attack_buttons_combinations = False + + ### Observation space wrapper(s) + + # Frame resize operation spec (deactivated by default) + # WARNING: for speedup, avoid frame warping wrappers, + # use environment's native frame wrapping through + # "frame_shape" setting (see documentation for details). + wrappers_settings.frame_shape = (128, 128, 1) # Number of frames to be stacked together (1 by default) - wrappers_settings["frame_stack"] = 4 + wrappers_settings.stack_frames = 4 # Frames interval when stacking (1 by default) - wrappers_settings["dilation"] = 1 - - # How many past actions to stack together (1 by default) - wrappers_settings["actions_stack"] = 12 + wrappers_settings.dilation = 1 - # If to scale observation numerical values (deactivated by default) - # optionally exclude images from normalization (deactivated by default) - # and optionally perform one-hot encoding also on discrete binary variables (deactivated by default) - wrappers_settings["scale"] = True - wrappers_settings["exclude_image_scaling"] = True - wrappers_settings["process_discrete_binary"] = True + # Add last action to observation (False by default) + wrappers_settings.add_last_action = True - # Scaling interval (0 = [0.0, 1.0], 1 = [-1.0, 1.0]) - wrappers_settings["scale_mod"] = 0 + # How many past actions to stack together (1 by default) + # NOTE: needs "add_last_action_to_observation" wrapper to be active + wrappers_settings.stack_actions = 6 + + # If to scale observation numerical values (False by default) + # optionally exclude images from normalization (False by default) + # and optionally perform one-hot encoding also on discrete binary variables (False by default) + wrappers_settings.scale = True + wrappers_settings.exclude_image_scaling = True + wrappers_settings.process_discrete_binary = False + + # If to make the observation relative to the agent as a function to its role (P1 or P2) (deactivate by default) + # i.e.: + # - In 1P environments, if the agent is P1 then the observation "P1" nesting level becomes "own" and "P2" becomes "opp" + # - In 2P environments, if "agent_0" role is P1 and "agent_1" role is P2, then the player specific nesting levels observation (P1 - P2) + # are grouped under "agent_0" and "agent_1", and: + # - Under "agent_0", "P1" nesting level becomes "own" and "P2" becomes "opp" + # - Under "agent_1", "P1" nesting level becomes "opp" and "P2" becomes "own" + wrappers_settings.role_relative = True # Flattening observation dictionary and filtering # a sub-set of the RAM states - wrappers_settings["flatten"] = True - wrappers_settings["filter_keys"] = ["stage", "P1_ownSide", "P1_oppSide", - "P1_ownHealth", "P1_oppChar", - "P1_actions_move", "P1_actions_attack"] + wrappers_settings.flatten = True + wrappers_settings.filter_keys = ["stage", "timer", "action", "own_side", "opp_side", + "own_health", "opp_health", "opp_character"] - env = diambra.arena.make("doapp", {}, wrappers_settings) + env = diambra.arena.make("doapp", settings, wrappers_settings, render_mode="human") - observation = env.reset() - env.show_obs(observation) + observation, info = env.reset(seed=42) + env.unwrapped.show_obs(observation) while True: - actions = env.action_space.sample() print("Actions: {}".format(actions)) - observation, reward, done, info = env.step(actions) - env.show_obs(observation) + observation, reward, terminated, truncated, info = env.step(actions) + done = terminated or truncated + env.unwrapped.show_obs(observation) print("Reward: {}".format(reward)) print("Done: {}".format(done)) print("Info: {}".format(info)) if done: - observation = env.reset() - env.show_obs(observation) + observation, info = env.reset() + env.unwrapped.show_obs(observation) break env.close() + + # Return success + return 0 + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/images/stable-baselines/Dockerfile b/images/stable-baselines/Dockerfile index 47aff8d4..67087026 100644 --- a/images/stable-baselines/Dockerfile +++ b/images/stable-baselines/Dockerfile @@ -10,8 +10,11 @@ ENV LANG=C.UTF-8 LC_ALL=C.UTF-8 # Copy arena to tmp since bind-mount is read-only and pip doesn't support # out-of-tree builds. +# Need to pin the specific version of tensorflow 1.15.0 due to +# numpy incompatibilities with gymnasium of different versions +# see https://github.com/diambra/arena/issues/89 RUN --mount=target=/usr/src/arena,type=bind,source=. \ cp -r /usr/src/arena /tmp/arena && \ pip install /tmp/arena[stable-baselines] && \ - pip install tensorflow==1.15.5 && \ + pip install tensorflow==1.15.0 && \ rm -rf /tmp/arena diff --git a/setup.py b/setup.py index 460df9a8..de633c83 100644 --- a/setup.py +++ b/setup.py @@ -12,9 +12,9 @@ extras= { 'core': [], 'tests': ['pytest', 'pytest-mock', 'testresources'], - 'stable-baselines': ['stable-baselines==2.10.2', "protobuf==3.20.1", "pyyaml"], - 'stable-baselines3': ['stable-baselines3[extra]==1.6.1', "pyyaml"], - 'ray-rllib': ['ray[rllib]==2.0.0', 'tensorflow<=2.10.0', 'torch<=1.12.1', "pyyaml"], + 'stable-baselines': ['stable-baselines~=2.10.2', 'gym<=0.21.0', "protobuf==3.20.1", "pyyaml"], + 'stable-baselines3': ['stable-baselines3[extra]~=2.1.0', "pyyaml"], + 'ray-rllib': ['ray[rllib]~=2.7.0', 'tensorflow', 'torch', "pyyaml"], } # NOTE Package data is inside MANIFEST.In @@ -32,18 +32,17 @@ install_requires=[ 'pip>=21', 'importlib-metadata<=4.12.0; python_version <= "3.7"', # problem with gym for importlib-metadata==5.0.0 and python <=3.7 - 'setuptools<=66.0.0', # Required until we can upgrade to gym >= 0.22.0 + 'setuptools', 'distro>=1', - 'gym<=0.21.0', + 'gymnasium>=0.26.3', 'inputs', 'screeninfo', 'tk', 'opencv-python>=4.4.0.42', 'grpcio', - 'diambra-engine>=2.1.0rc7,<2.2.0', + 'diambra-engine~=2.2.0', 'dacite'], - packages=[package for package in setuptools.find_packages( - ) if package.startswith("diambra")], + packages=[package for package in setuptools.find_packages() if package.startswith("diambra")], include_package_data=True, extras_require=extras, classifiers=[ diff --git a/tests/env_exec_interface.py b/tests/env_exec_interface.py index 0876d27a..e083a02d 100755 --- a/tests/env_exec_interface.py +++ b/tests/env_exec_interface.py @@ -1,166 +1,141 @@ #!/usr/bin/env python3 import diambra.arena +from diambra.arena import SpaceTypes from diambra.arena.utils.gym_utils import env_spaces_summary, discrete_to_multi_discrete_action -import time +import random import numpy as np import warnings default_args = { - "interactive_viz": False, + "interactive": False, "n_episodes": 1, - "no_action": False + "no_action_probability": 0.0, + "render": False, + "log_output": False, } -def env_exec(settings, wrappers_settings, traj_rec_settings, args=default_args): - +def env_exec(settings, options_list, wrappers_settings, episode_recording_settings, args=default_args): try: - time_dep_seed = int((time.time() - int(time.time() - 0.5)) * 1000) - wait_key = 1 - if args["interactive_viz"] is True: + if args["interactive"] is True: wait_key = 0 - no_action = args["no_action"] - n_rounds = 2 - if settings["game_id"] == "kof98umh": + if settings.game_id == "kof98umh": n_rounds = 3 - env = diambra.arena.make(settings["game_id"], settings, wrappers_settings, traj_rec_settings, seed=time_dep_seed) + env = diambra.arena.make(settings.game_id, settings, wrappers_settings, episode_recording_settings) # Print environment obs and action spaces summary - env_spaces_summary(env) + if args["log_output"] is True: + env_spaces_summary(env) - observation = env.reset() + for options in options_list: + observation, info = env.reset(options=options) + if args["log_output"] is True: + env.show_obs(observation, wait_key, args["render"]) - cumulative_ep_rew = 0.0 - cumulative_ep_rew_all = [] + cumulative_ep_rew = 0.0 + cumulative_ep_rew_all = [] - max_num_ep = args["n_episodes"] - curr_num_ep = 0 + max_num_ep = args["n_episodes"] + curr_num_ep = 0 - while curr_num_ep < max_num_ep: + no_action = random.choices([True, False], [args["no_action_probability"], 1.0 - args["no_action_probability"]])[0] - actions = [None, None] - if settings["player"] != "P1P2": + while curr_num_ep < max_num_ep: actions = env.action_space.sample() - if no_action is True: - if settings["action_space"] == "multi_discrete": - for iel, _ in enumerate(actions): - actions[iel] = 0 - else: - actions = 0 - - if settings["action_space"] == "discrete": - move_action, att_action = discrete_to_multi_discrete_action( - actions, env.n_actions[0]) - else: - move_action, att_action = actions[0], actions[1] - - print("(P1) {} {}".format(env.print_actions_dict[0][move_action], - env.print_actions_dict[1][att_action])) - - else: - for idx in range(2): - actions[idx] = env.action_space["P{}".format(idx + 1)].sample() - - if no_action is True and idx == 0: - if settings["action_space"][idx] == "multi_discrete": - for iel, _ in enumerate(actions[idx]): - actions[idx][iel] = 0 - else: - actions[idx] = 0 - - if settings["action_space"][idx] == "discrete": - move_action, att_action = discrete_to_multi_discrete_action( - actions[idx], env.n_actions[idx][0]) - else: - move_action, att_action = actions[idx][0], actions[idx][1] - - print("(P{}) {} {}".format(idx + 1, env.print_actions_dict[0][move_action], - env.print_actions_dict[1][att_action])) - - if (settings["player"] == "P1P2" or settings["action_space"] != "discrete"): - actions = np.append(actions[0], actions[1]) + if settings.n_players == 1: + if no_action is True: + actions = env.get_no_op_action() - observation, reward, done, info = env.step(actions) - - cumulative_ep_rew += reward - print("action =", actions) - print("reward =", reward) - print("done =", done) - for k, v in info.items(): - print("info[\"{}\"] = {}".format(k, v)) - env.show_obs(observation, wait_key) - print("--") - print("Current Cumulative Reward =", cumulative_ep_rew) - - print("----------") - - if done: - print("Resetting Env") - curr_num_ep += 1 - print("Ep. # = ", curr_num_ep) - print("Ep. Cumulative Rew # = ", cumulative_ep_rew) - cumulative_ep_rew_all.append(cumulative_ep_rew) - cumulative_ep_rew = 0.0 - - observation = env.reset() - env.show_obs(observation, wait_key) - - if np.any([info["round_done"], info["stage_done"], info["game_done"], info["ep_done"]]): - - if settings["hardcore"] is False: - # Side check - if "P1_ownSide" in observation.keys(): - ram_state_values = [observation["P1_ownSide"], observation["P1_oppSide"]] + if settings.action_space == SpaceTypes.DISCRETE: + move_action, att_action = discrete_to_multi_discrete_action(actions, env.n_actions[0]) else: - ram_state_values = [observation["P1"]["ownSide"], observation["P1"]["oppSide"]] + move_action, att_action = actions[0], actions[1] - if env.player_side == "P2": - if (ram_state_values[0] != 1.0 or ram_state_values[1] != 0.0): - raise RuntimeError("Wrong starting sides:", ram_state_values[0], ram_state_values[1]) - else: - if (ram_state_values[0] != 0.0 or ram_state_values[1] != 1.0): - raise RuntimeError("Wrong starting sides:", ram_state_values[0], ram_state_values[1]) + if args["log_output"] is True: + print("(agent_0) {} {}".format(env.print_actions_dict[0][move_action], env.print_actions_dict[1][att_action])) - frame = observation["frame"] else: - frame = observation - - # Frames equality check - if ("hwc_obs_resize" in wrappers_settings.keys() and wrappers_settings["hwc_obs_resize"][2] == 1): - for frame_idx in range(frame.shape[2] - 1): - if np.any(frame[:, :, frame_idx] != frame[:, :, frame_idx + 1]): - raise RuntimeError("Frames inside observation after " - "round/stage/game/episode done are " - "not equal. Dones =", info["round_done"], info["stage_done"], - info["game_done"], info["ep_done"]) + if no_action is True: + actions["agent_0"] = env.get_no_op_action()["agent_0"] - print("Cumulative reward = ", cumulative_ep_rew_all) - print("Mean cumulative reward = ", np.mean(cumulative_ep_rew_all)) - print("Std cumulative reward = ", np.std(cumulative_ep_rew_all)) - - env.close() - - if len(cumulative_ep_rew_all) != max_num_ep: - raise RuntimeError("Not run all episodes") - - if settings["continue_game"] <= 0.0 and settings["player"] != "P1P2": - max_continue = int(-settings["continue_game"]) - else: - max_continue = 0 + for idx in range(settings.n_players): + if settings.action_space[idx] == SpaceTypes.DISCRETE: + move_action, att_action = discrete_to_multi_discrete_action(actions["agent_{}".format(idx)], env.n_actions[0]) + else: + move_action, att_action = actions["agent_{}".format(idx)][0], actions["agent_{}".format(idx)][1] + + if args["log_output"] is True: + print("(agent_{}) {} {}".format(idx, env.print_actions_dict[0][move_action], env.print_actions_dict[1][att_action])) + + observation, reward, terminated, truncated, info = env.step(actions) + + cumulative_ep_rew += reward + if args["log_output"] is True: + print("action =", actions) + print("reward =", reward) + print("done =", terminated or truncated) + for k, v in info.items(): + print("info[\"{}\"] = {}".format(k, v)) + env.show_obs(observation, wait_key, args["render"]) + print("--") + print("Current Cumulative Reward =", cumulative_ep_rew) + + print("----------") + + if terminated or truncated: + observation, info = env.reset() + if args["log_output"] is True: + env.show_obs(observation, wait_key, args["render"]) + print("Ep. # = ", curr_num_ep) + print("Ep. Cumulative Rew # = ", cumulative_ep_rew) + curr_num_ep += 1 + no_action = random.choices([True, False], [args["no_action_probability"], 1.0 - args["no_action_probability"]])[0] + cumulative_ep_rew_all.append(cumulative_ep_rew) + cumulative_ep_rew = 0.0 + + if info["round_done"]: + # Side check when no wrappers active: + if (wrappers_settings.role_relative is False and wrappers_settings.flatten is False): + if (observation["P1"]["side"] != 0.0 or observation["P2"]["side"] != 1.0): + raise RuntimeError("Wrong starting sides:", observation["P1"]["side"], observation["P2"]["side"]) + + elif (wrappers_settings.frame_shape is not None and wrappers_settings.frame_shape[2] == 1): + # Frames equality check + frame = observation["frame"] + + for frame_idx in range(frame.shape[2] - 1): + if np.any(frame[:, :, frame_idx] != frame[:, :, frame_idx + 1]): + raise RuntimeError("Frames inside observation after round/stage/game/episode done are " + "not equal. Dones =", info["round_done"], info["stage_done"], + info["game_done"], info["episode_done"]) + + if args["log_output"] is True: + print("Cumulative reward = ", cumulative_ep_rew_all) + print("Mean cumulative reward = ", np.mean(cumulative_ep_rew_all)) + print("Std cumulative reward = ", np.std(cumulative_ep_rew_all)) + + if len(cumulative_ep_rew_all) != max_num_ep: + raise RuntimeError("Not run all episodes") + + if env.env_settings.continue_game <= 0.0 and env.env_settings.n_players == 1: + max_continue = int(-env.env_settings.continue_game) + else: + max_continue = 0 - if settings["game_id"] == "tektagt": - max_continue = (max_continue + 1) * 0.7 - 1 + if env.env_settings.game_id == "tektagt": + max_continue = (max_continue + 1) * 0.7 - 1 - round_max_reward = env.max_delta_health / env.reward_normalization_value - if (no_action is True and (np.mean(cumulative_ep_rew_all) > -(max_continue + 1) * round_max_reward * n_rounds + 0.001)): + round_max_reward = env.max_delta_health / env.reward_normalization_value + if (no_action is True and (np.mean(cumulative_ep_rew_all) > -(max_continue + 1) * round_max_reward * n_rounds + 0.001)): + message = "NoAction policy and average reward different than {} ({})".format( + -(max_continue + 1) * round_max_reward * n_rounds, np.mean(cumulative_ep_rew_all)) + raise RuntimeError(message) - message = "NoAction policy and average reward different than {} ({})".format( - -(max_continue + 1) * round_max_reward * n_rounds, np.mean(cumulative_ep_rew_all)) - warnings.warn(UserWarning(message)) + env.close() print("COMPLETED SUCCESSFULLY!") return 0 diff --git a/tests/man_test_random.py b/tests/man_test_random.py index 8aa2a918..297f1827 100644 --- a/tests/man_test_random.py +++ b/tests/man_test_random.py @@ -1,96 +1,121 @@ #!/usr/bin/env python3 import argparse from env_exec_interface import env_exec -import time import os from os.path import expanduser +import random +from diambra.arena import SpaceTypes, Roles, EnvironmentSettings, EnvironmentSettingsMultiAgent, WrappersSettings, RecordingSettings if __name__ == "__main__": - parser = argparse.ArgumentParser() parser.add_argument("--gameId", type=str, default="doapp", help="Game ID [(doapp), sfiii3n, tektagt, umk3]") - parser.add_argument("--player", type=str, default="Random", help="Player (Random)") - parser.add_argument("--character1", type=str, default="Random", help="Character P1 (Random)") - parser.add_argument("--character2", type=str, default="Random", help="Character P2 (Random)") - parser.add_argument("--character1_2", type=str, default="Random", help="Character P1_2 (Random)") - parser.add_argument("--character2_2", type=str, default="Random", help="Character P2_2 (Random)") - parser.add_argument("--character1_3", type=str, default="Random", help="Character P1_3 (Random)") - parser.add_argument("--character2_3", type=str, default="Random", help="Character P2_3 (Random)") - parser.add_argument("--difficulty", type=int, default=3, help="Game difficulty") + parser.add_argument("--nPlayers", type=int, default=1, help="Number of Agents (1)") + parser.add_argument("--role0", type=str, default="Random", help="agent_0 role (Random)") + parser.add_argument("--role1", type=str, default="Random", help="agent_1 role (Random)") + parser.add_argument("--character0", type=str, default="Random", help="Character agent_0 (Random)") + parser.add_argument("--character1", type=str, default="Random", help="Character agent (Random)") + parser.add_argument("--character0_2", type=str, default="Random", help="Character P1_2 (Random)") + parser.add_argument("--character1_2", type=str, default="Random", help="Character P2_2 (Random)") + parser.add_argument("--character0_3", type=str, default="Random", help="Character P1_3 (Random)") + parser.add_argument("--character1_3", type=str, default="Random", help="Character P2_3 (Random)") + parser.add_argument("--difficulty", type=int, default=0, help="Game difficulty (0)") parser.add_argument("--stepRatio", type=int, default=3, help="Frame ratio") parser.add_argument("--nEpisodes", type=int, default=1, help="Number of episodes") parser.add_argument("--continueGame", type=float, default=-1.0, help="ContinueGame flag (-inf,+1.0]") parser.add_argument("--actionSpace", type=str, default="discrete", help="discrete/multi_discrete") - parser.add_argument("--attButComb", type=bool, default=False, help="Use attack button combinations (0=F)/1=T") parser.add_argument("--noAction", type=int, default=0, help="If to use no action policy (0=False)") - parser.add_argument("--recordTraj", type=bool, default=False, help="If to record trajectories") - parser.add_argument("--hardcore", type=bool, default=False, help="Hard core mode") - parser.add_argument("--interactiveViz", type=int, default=0, help="Interactive Visualization (0=False)") + parser.add_argument("--recordEpisode", type=int, default=0, help="If to record episode") + parser.add_argument("--interactive", type=int, default=0, help="Interactive Visualization (False)") + parser.add_argument("--render", type=int, default=1, help="Render frame (False)") + parser.add_argument("--wrappers", type=int, default=0, help="If to use wrappers") parser.add_argument("--envAddress", type=str, default="", help="diambraEngine Address") - parser.add_argument("--wrappers", type=bool, default=False, help="If to use wrappers") opt = parser.parse_args() print(opt) - time_dep_seed = int((time.time() - int(time.time() - 0.5)) * 1000) - # Settings - settings = {} - settings["game_id"] = opt.gameId + if (opt.nPlayers == 1): + settings = EnvironmentSettings() + else: + settings = EnvironmentSettingsMultiAgent() + settings.game_id = opt.gameId + settings.frame_shape = random.choice([(128, 128, 1), (256, 256, 0)]) if opt.envAddress != "": - settings["env_address"] = opt.envAddress - settings["player"] = opt.player - settings["difficulty"] = opt.difficulty - settings["characters"] = ((opt.character1, opt.character1_2, opt.character1_3), - (opt.character2, opt.character2_2, opt.character2_3)) - settings["step_ratio"] = opt.stepRatio - settings["continue_game"] = opt.continueGame - settings["action_space"] = (opt.actionSpace, opt.actionSpace) - settings["attack_but_combination"] = (opt.attButComb, opt.attButComb) - if settings["player"] != "P1P2": - settings["characters"] = settings["characters"][0] - settings["action_space"] = settings["action_space"][0] - settings["attack_but_combination"] = settings["attack_but_combination"][0] - settings["hardcore"] = opt.hardcore + settings.env_address = opt.envAddress + settings.role = (Roles.P1 if opt.role0 == "P1" else Roles.P2 if opt.role0 == "P2" else None, + Roles.P1 if opt.role1 == "P1" else Roles.P2 if opt.role1 == "P2" else None) + settings.difficulty = opt.difficulty if opt.difficulty != 0 else None + settings.characters = ((opt.character0, opt.character0_2, opt.character0_3), + (opt.character1, opt.character1_2, opt.character1_3)) + settings.characters = tuple([None if "Random" in settings["characters"][idx] else settings["characters"] for idx in range(2)]) + settings.step_ratio = opt.stepRatio + settings.continue_game = opt.continueGame + settings.action_space = (SpaceTypes.DISCRETE, SpaceTypes.DISCRETE) if opt.actionSpace == "discrete" else \ + (SpaceTypes.MULTI_DISCRETE, SpaceTypes.MULTI_DISCRETE) + if settings.n_players == 1: + settings.role = settings.role[0] + settings.characters = settings.characters[0] + settings.action_space = settings.action_space[0] # Env wrappers settings - wrappers_settings = {} + wrappers_settings = WrappersSettings() wrappers_settings["no_op_max"] = 0 wrappers_settings["sticky_actions"] = 1 - wrappers_settings["hwc_obs_resize"] = (128, 128, 1) + wrappers_settings["frame_shape"] = random.choice([(128, 128, 1), (256, 256, 0)]) wrappers_settings["reward_normalization"] = True wrappers_settings["clip_rewards"] = False wrappers_settings["frame_stack"] = 4 wrappers_settings["dilation"] = 1 wrappers_settings["actions_stack"] = 12 wrappers_settings["scale"] = True - wrappers_settings["scale_mod"] = 0 wrappers_settings["flatten"] = True + suffix = "" + if opt.nPlayers == 2: + suffix = "agent_0_" if opt.gameId != "tektagt": - wrappers_settings["filter_keys"] = ["stage", "P1_ownSide", "P1_oppSide", "P1_oppSide", - "P1_ownHealth", "P1_oppHealth", "P1_oppChar", - "P1_actions_move", "P1_actions_attack"] + wrappers_settings["filter_keys"] = ["stage", "timer", suffix+"own_side", suffix+"opp_side", + suffix+"own_health", suffix+"opp_health", + suffix+"action_move", suffix+"action_attack"] else: - wrappers_settings["filter_keys"] = ["stage", "P1_ownSide", "P1_oppSide", "P1_oppSide", - "P1_ownHealth1", "P1_oppHealth1", "P1_oppChar", - "P1_ownHealth2", "P1_oppHealth2", - "P1_actions_move", "P1_actions_attack"] - if opt.wrappers is False: - wrappers_settings = {} + wrappers_settings["filter_keys"] = ["stage", "timer", suffix+"own_side", suffix+"opp_side", + suffix+"own_health_1", suffix+"opp_health_1", + suffix+"own_health_2", suffix+"opp_health_2", + suffix+"action_move", suffix+"action_attack"] + + + # Env wrappers settings + wrappers_settings = WrappersSettings() + if bool(opt.wrappers) is True: + wrappers_settings.no_op_max = 0 + wrappers_settings.sticky_actions = 1 + wrappers_settings.frame_shape = random.choice([(128, 128, 1), (256, 256, 0)]) + wrappers_settings.reward_normalization = True + wrappers_settings.clip_rewards = False + wrappers_settings.frame_stack = 4 + wrappers_settings.dilation = 1 + wrappers_settings.add_last_action_to_observation = True + wrappers_settings.actions_stack = 12 + wrappers_settings.scale = True + wrappers_settings.role_relative_observation = True + wrappers_settings.flatten = True + suffix = "" + if settings.n_players == 2: + suffix = "agent_0_" + wrappers_settings.filter_keys = ["stage", "timer", suffix + "own_side", suffix + "opp_side", + suffix + "opp_character", suffix + "action"] # Recording settings - traj_rec_settings = {} - traj_rec_settings["user_name"] = "Alex" - traj_rec_settings["file_path"] = os.path.join(expanduser("~"), "DIAMBRA/trajRecordings", opt.gameId) - traj_rec_settings["ignore_p2"] = False - if opt.recordTraj is False: - traj_rec_settings = {} - else: - wrappers_settings["flatten"] = False + episode_recording_settings = RecordingSettings() + if bool(opt.recordEpisode) is True: + home_dir = expanduser("~") + episode_recording_settings["username"] = "alexpalms" + episode_recording_settings["dataset_path"] = os.path.join(home_dir, "DIAMBRA/episode_recording", opt.gameId) # Args args = {} - args["interactive_viz"] = bool(opt.interactiveViz) - args["no_action"] = True if opt.noAction == 1 else False + args["interactive"] = bool(opt.interactive) + args["no_action_probability"] = 1.0 if opt.noAction == 1 else 0.0 args["n_episodes"] = opt.nEpisodes + args["render"] = bool(opt.render) + args["log_output"] = True - env_exec(settings, wrappers_settings, traj_rec_settings, args) + env_exec(settings, [{}], wrappers_settings, episode_recording_settings, args) diff --git a/tests/pytest_utils.py b/tests/pytest_utils.py index 753f4b60..8067c282 100644 --- a/tests/pytest_utils.py +++ b/tests/pytest_utils.py @@ -1,24 +1,43 @@ #!/usr/bin/env python3 -def generate_pytest_decorator_input(var_order, test_parameters, outcome): - +def generate_pytest_decorator_input(var_order, ok_test_parameters, ko_test_parameters): test_vars = "" values_list = [] - number_of_tests = 0 - for k, v in test_parameters.items(): - number_of_tests = max(number_of_tests, len(v)) + # OK tests + number_of_ok_tests = 0 + for k, v in ok_test_parameters.items(): + number_of_ok_tests = max(number_of_ok_tests, len(v)) for var in var_order: test_vars += var + "," test_vars += "expected" - for idx in range(number_of_tests): + for idx in range(number_of_ok_tests): + test_value_tuple = tuple() + + for var in var_order: + test_value_tuple += (ok_test_parameters[var][idx % len(ok_test_parameters[var])],) + test_value_tuple += (0,) + + values_list.append(test_value_tuple) + + # KO tests + test_parameters_ko_list = [] + for k, v in ko_test_parameters.items(): + for value in v: + test_parameters_ko_list.append([k, value]) + number_of_ko_tests = len(test_parameters_ko_list) + for idx in range(number_of_ko_tests): test_value_tuple = tuple() for var in var_order: - test_value_tuple += (test_parameters[var][idx % len(test_parameters[var])],) - test_value_tuple += (outcome,) + if var == test_parameters_ko_list[idx][0]: + test_value_tuple += (test_parameters_ko_list[idx][1],) + else: + test_value_tuple += (ok_test_parameters[var][idx % len(ok_test_parameters[var])],) + + test_value_tuple += (1,) values_list.append(test_value_tuple) diff --git a/tests/run_engine_mock.py b/tests/run_engine_mock.py new file mode 100644 index 00000000..6a860521 --- /dev/null +++ b/tests/run_engine_mock.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python3 +import argparse +from diambra.arena import Roles +from diambra.arena.env_settings import EnvironmentSettings, EnvironmentSettingsMultiAgent +from diambra.arena.utils.engine_mock import DiambraEngineMock +import random + +def print_response(response): + print("---") + print("Obs =", {key: response.observation.ram_states[key].val for key in sorted(response.observation.ram_states.keys())}) + print("Reward =", response.reward) + print("Info =", {key: response.info.game_states[key] for key in sorted(response.info.game_states.keys())}) + print("---") + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--gameId", type=str, default="doapp", help="Game ID [(doapp), sfiii3n, tektagt, umk3]") + parser.add_argument("--nPlayers", type=int, default=1, help="Number of Agents (1)") + parser.add_argument("--role0", type=str, default="Random", help="agent_0 role (Random)") + parser.add_argument("--role1", type=str, default="Random", help="agent_1 role (Random)") + parser.add_argument("--character0", type=str, default="Random", help="Character agent_0 (Random)") + parser.add_argument("--character1", type=str, default="Random", help="Character agent (Random)") + parser.add_argument("--character0_2", type=str, default="Random", help="Character P1_2 (Random)") + parser.add_argument("--character1_2", type=str, default="Random", help="Character P2_2 (Random)") + parser.add_argument("--character0_3", type=str, default="Random", help="Character P1_3 (Random)") + parser.add_argument("--character1_3", type=str, default="Random", help="Character P2_3 (Random)") + parser.add_argument("--difficulty", type=int, default=0, help="Game difficulty (0)") + parser.add_argument("--stepRatio", type=int, default=3, help="Frame ratio") + parser.add_argument("--continueGame", type=float, default=-1.0, help="ContinueGame flag (-inf,+1.0]") + parser.add_argument("--noAction", type=int, default=0, help="If to use no action policy (0=False)") + parser.add_argument("--interactive", type=int, default=0, help="Interactive Visualization (False)") + parser.add_argument("--render", type=int, default=1, help="Render frame (False)") + opt = parser.parse_args() + print(opt) + + # Settings + if opt.nPlayers == 1: + settings = EnvironmentSettings() + else: + settings = EnvironmentSettingsMultiAgent() + + settings = {} + settings.game_id = opt.gameId + settings.role = (Roles.P1 if opt.role0 == "P1" else Roles.P2 if opt.role0 == "P2" else None, + Roles.P1 if opt.role1 == "P1" else Roles.P2 if opt.role1 == "P2" else None) + settings.difficulty = opt.difficulty + settings.characters = ((opt.character0, opt.character0_2, opt.character0_3), + (opt.character1, opt.character1_2, opt.character1_3)) + settings.step_ratio = opt.stepRatio + settings.continue_game = opt.continueGame + if settings.n_players == 1: + settings.role = settings.role[0] + settings.characters = settings.characters[0] + + settings.sanity_check() + + engine_mock = DiambraEngineMock() + env_info = engine_mock.mock_env_init(settings.get_pb_request()) + print("Env info =", env_info) + + reset_response = engine_mock.mock_reset() + print_response(reset_response) + + action = [[0, 0],[0, 0]] + cumulative_reward = 0.0 + while True: + for idx in range(settings.n_players): + action[idx] = [random.randint(0, 8), random.randint(0, 4)] + if bool(opt.noAction) is True: + action[idx] = [0, 0] + print("Action =", action) + + step_response = engine_mock.mock_step(action) + cumulative_reward += step_response.reward + print_response(step_response) + print("Cumulative reward =", cumulative_reward) + + done = step_response.info.game_states["episode_done"] if settings.n_players == 1 else step_response.info.game_states["game_done"] + if done: + print("Total cumulative reward =", cumulative_reward) + reset_response = engine_mock.mock_reset() + print_response(reset_response) + break diff --git a/tests/test_episode_data_loader.py b/tests/test_episode_data_loader.py new file mode 100644 index 00000000..f5cd7b86 --- /dev/null +++ b/tests/test_episode_data_loader.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python3 +from diambra.arena.utils.diambra_data_loader import DiambraDataLoader +import os +from os.path import expanduser + +def func(): + try: + home_dir = expanduser("~") + dataset_path = os.path.join(home_dir, "DIAMBRA/episode_recording/mock") + + data_loader = DiambraDataLoader(dataset_path) + + n_loops = data_loader.reset() + + while n_loops == 0: + observation, action, reward, terminated, truncated, info = data_loader.step() + print("Observation: {}".format(observation)) + print("Action: {}".format(action)) + print("Reward: {}".format(reward)) + print("Terminated: {}".format(terminated)) + print("Truncated: {}".format(truncated)) + print("Info: {}".format(info)) + data_loader.render() + + if terminated or truncated: + n_loops = data_loader.reset() + + return 0 + except Exception as e: + print(e) + return 1 + +def test_episode_data_loader(): + assert func() == 0 + diff --git a/tests/test_examples.py b/tests/test_examples.py new file mode 100644 index 00000000..fdc11127 --- /dev/null +++ b/tests/test_examples.py @@ -0,0 +1,33 @@ +import pytest +import sys +from os.path import expanduser +import os +from diambra.arena.utils.engine_mock import load_mocker + +# Add the scripts directory to sys.path +root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "examples")) +sys.path.append(root_dir) + +import diambra_arena_gist, episode_data_loader, episode_recording, multi_player_env, single_player_env, wrappers_options + +def func(script, mocker, *args): + + load_mocker(mocker) + + try: + return script.main(*args) + except Exception as e: + print(e) + return 1 + +home_dir = expanduser("~") +dataset_path = os.path.join(home_dir, "DIAMBRA/episode_recording/mock") +use_controller = False +#[episode_data_loader, (dataset_path,)] # Removing episode data loader from tests because of unavailability of trajectories +scripts = [[diambra_arena_gist, ()], [single_player_env, ()], [multi_player_env, ()], [wrappers_options, ()], + [episode_recording, (use_controller,)]] + +@pytest.mark.parametrize("script", scripts) +def test_example_scripts(script, mocker): + + assert func(script[0], mocker, *script[1]) == 0 \ No newline at end of file diff --git a/tests/test_gym_settings.py b/tests/test_gym_settings.py index 66c3258e..ccea9ab6 100644 --- a/tests/test_gym_settings.py +++ b/tests/test_gym_settings.py @@ -2,9 +2,10 @@ import pytest import random import diambra.arena -from diambra.arena.utils.engine_mock import DiambraEngineMock +from diambra.arena import SpaceTypes, Roles, EnvironmentSettings, EnvironmentSettingsMultiAgent from diambra.arena.utils.gym_utils import available_games from pytest_utils import generate_pytest_decorator_input +from diambra.arena.utils.engine_mock import load_mocker # Example Usage: # pytest @@ -13,12 +14,12 @@ # -s (show output) # -k "expression" (filter tests using case-insensitive with parts of the test name and/or parameters values combined with boolean operators, e.g. "wrappers and doapp") -def env_exec(settings, wrappers_settings, traj_rec_settings): +def func(settings, mocker): + load_mocker(mocker) try: - env = diambra.arena.make(settings["game_id"], settings, wrappers_settings, traj_rec_settings) - + env = diambra.arena.make(settings.game_id, settings) + env.reset() env.close() - print("COMPLETED SUCCESSFULLY!") return 0 except Exception as e: @@ -26,100 +27,78 @@ def env_exec(settings, wrappers_settings, traj_rec_settings): print("ERROR, ABORTED.") return 1 -def func(settings, wrappers_settings, traj_rec_settings, mocker): - - diambra_engine_mock = DiambraEngineMock() - - mocker.patch("diambra.arena.engine.interface.DiambraEngine.__init__", diambra_engine_mock._mock__init__) - mocker.patch("diambra.arena.engine.interface.DiambraEngine._env_init", diambra_engine_mock._mock_env_init) - mocker.patch("diambra.arena.engine.interface.DiambraEngine._reset", diambra_engine_mock._mock_reset) - mocker.patch("diambra.arena.engine.interface.DiambraEngine._step_1p", diambra_engine_mock._mock_step_1p) - mocker.patch("diambra.arena.engine.interface.DiambraEngine._step_2p", diambra_engine_mock._mock_step_2p) - mocker.patch("diambra.arena.engine.interface.DiambraEngine.close", diambra_engine_mock._mock_close) - - try: - return env_exec(settings, wrappers_settings, traj_rec_settings) - except Exception as e: - print(e) - return 1 - games_dict = available_games(False) -gym_settings_var_order = ["player", "step_ratio", "frame_shape", "tower", "super_art", - "fighting_style", "ultimate_style", "continue_game", "action_space", - "attack_buttons_combination"] +gym_settings_var_order = ["frame_shape", "step_ratio", "action_space", "difficulty", "continue_game", + "tower", "role", "characters", "super_art", "fighting_style", "ultimate_style"] ok_test_parameters = { - "continue_game": [-1.0, 0.0, 0.3], - "action_space": ["discrete", "multi_discrete"], - "attack_buttons_combination": [False, True], - "player": ["P1", "P2", "Random", "P1P2"], - "step_ratio": [1, 3, 6], "frame_shape": [(0, 0, 0), (0, 0, 1), (82, 82, 0), (82, 82, 1)], + "step_ratio": [1, 3, 6], + "action_space": [SpaceTypes.DISCRETE, SpaceTypes.MULTI_DISCRETE], + "difficulty": [None, 1, 3], + "continue_game": [-1.0, 0.0, 0.3], "tower": [1, 3, 4], - "super_art": [0, 1, 3], - "fighting_style": [0, 1, 3], - "ultimate_style": [(0, 0, 0), (1, 2, 0), (2, 2, 2)], + "role": [[Roles.P1, Roles.P2], [Roles.P2, Roles.P1], [None, None]], + "characters": [[None, None], [None, "TBD"], ["TBD", "TBD"]], + "super_art": [None, 1, 3], + "fighting_style": [None, 1, 3], + "ultimate_style": [None, (2, 2, 2)], } ko_test_parameters = { - "continue_game": [1.3, "string"], - "action_space": ["random", 12], - "attack_buttons_combination": [1], - "player": [4, "P2P1"], - "step_ratio": [8], "frame_shape": [(0, 82, 0), (0, 0, 4), (-100, -100, 3)], + "step_ratio": [8], + "difficulty": [True, 0, "Random"], + "action_space": ["Random", 12, "discrete", SpaceTypes.BOX], + "continue_game": [1.3, "string"], "tower": [5], - "super_art": ["value", 4], - "fighting_style": [False, 6], - "ultimate_style": [(10, 0, 0), "string"], + "role": [["P1", "P2"], [5, 4], ["P1P2", "Random"], ["Random", "Random"]], + "characters": [["Random", "TBD"], ["NoName", None]], + "super_art": ["Random", 4, 0], + "fighting_style": [False, 6, 0, "Random"], + "ultimate_style": [(10, 0, 0), "string", (None, None, None), ("Random", "Random", "Random")], } def pytest_generate_tests(metafunc): - test_vars, values_list_ok = generate_pytest_decorator_input(gym_settings_var_order, ok_test_parameters, 0) - test_vars, values_list_ko = generate_pytest_decorator_input(gym_settings_var_order, ko_test_parameters, 1) - values_list = values_list_ok + values_list_ko + test_vars, values_list = generate_pytest_decorator_input(gym_settings_var_order, ok_test_parameters, ko_test_parameters) metafunc.parametrize(test_vars, values_list) # Gym @pytest.mark.parametrize("game_id", list(games_dict.keys())) -def test_settings_gym(game_id, player, step_ratio, frame_shape, tower, super_art, - fighting_style, ultimate_style, continue_game, action_space, - attack_buttons_combination, expected, mocker): +@pytest.mark.parametrize("n_players", [1, 2]) +def test_gym_settings(game_id, n_players, frame_shape, step_ratio, action_space, difficulty, continue_game, + tower, role, characters, super_art, fighting_style, ultimate_style, expected, mocker): game_data = games_dict[game_id] - difficulty_range = range(game_data["difficulty"][0], game_data["difficulty"][1] + 1) - characters_list = ["Random"] + game_data["char_list"] + outfits_range = range(game_data["outfits"][0], game_data["outfits"][1] + 1) - difficulty = random.choice(difficulty_range) - characters = random.choice(characters_list) - char_outfits = random.choice(outfits_range) + characters = [random.choice(game_data["char_list"]) if characters[idx] == "TBD" else characters[idx] for idx in range(2)] + outfits = random.choice(outfits_range) # Env settings - settings = {} - settings["game_id"] = game_id - settings["player"] = player - settings["step_ratio"] = step_ratio - settings["continue_game"] = continue_game - settings["difficulty"] = difficulty - settings["frame_shape"] = frame_shape - - settings["tower"] = tower - - settings["characters"] = (characters, characters) - settings["char_outfits"] = (char_outfits, char_outfits) - settings["action_space"] = (action_space, action_space) - settings["attack_but_combination"] = (attack_buttons_combination, attack_buttons_combination) - - settings["super_art"] = (super_art, super_art) - settings["fighting_style"] = (fighting_style, fighting_style) - settings["ultimate_style"] = (ultimate_style, ultimate_style) - - if settings["player"] != "P1P2": - for key in ["characters" , "char_outfits", "action_space", "attack_but_combination", + if (n_players == 1): + settings = EnvironmentSettings() + else: + settings = EnvironmentSettingsMultiAgent() + settings.game_id = game_id + settings.frame_shape = frame_shape + settings.step_ratio = step_ratio + settings.action_space = (action_space, action_space) + + settings.difficulty = difficulty + settings.continue_game = continue_game + settings.tower = tower + + settings.role = (role[0], role[1]) + settings.characters = (characters[0], characters[1]) + settings.outfits = (outfits, outfits) + settings.super_art = (super_art, super_art) + settings.fighting_style = (fighting_style, fighting_style) + settings.ultimate_style = (ultimate_style, ultimate_style) + + if n_players != 2: + for key in ["action_space", "role", "characters" , "outfits", "super_art", "fighting_style", "ultimate_style"]: - settings[key] = settings[key][0] - - wrappers_settings = {} - traj_rec_settings = {} + setattr(settings, key, getattr(settings, key)[0]) - assert func(settings, wrappers_settings, traj_rec_settings, mocker) == expected + assert func(settings, mocker) == expected diff --git a/tests/test_imitation_learning.py b/tests/test_imitation_learning.py deleted file mode 100644 index d0728dac..00000000 --- a/tests/test_imitation_learning.py +++ /dev/null @@ -1,127 +0,0 @@ -#!/usr/bin/env python3 -import pytest -import diambra.arena -import os -from os import listdir -import numpy as np - -def func(path, hardcore): - try: - nProc = 1 - - # Show files in folder - traj_rec_folder = path - trajectories_files = [os.path.join(traj_rec_folder, f) for f in listdir( - traj_rec_folder) if os.path.isfile(os.path.join(traj_rec_folder, f))] - print(trajectories_files) - - diambra_il_settings = {} - diambra_il_settings["traj_files_list"] = trajectories_files - diambra_il_settings["total_cpus"] = nProc - - if hardcore is False: - env = diambra.arena.ImitationLearning(**diambra_il_settings) - else: - env = diambra.arena.ImitationLearningHardcore(**diambra_il_settings) - - observation = env.reset() - env.render(mode="human") - - env.traj_summary() - - env.show_obs(observation) - - cumulative_ep_rew = 0.0 - cumulative_ep_rew_all = [] - - max_num_ep = 10 - curr_num_ep = 0 - - while curr_num_ep < max_num_ep: - - dummy_actions = 0 - observation, reward, done, info = env.step(dummy_actions) - env.render(mode="human") - - action = info["action"] - - print("Action:", action) - print("reward:", reward) - print("done = ", done) - for k, v in info.items(): - print("info[\"{}\"] = {}".format(k, v)) - env.show_obs(observation) - - print("----------") - - # if done: - # observation = info[procIdx]["terminal_observation"] - - cumulative_ep_rew += reward - - if (np.any([info["round_done"], info["stage_done"], info["game_done"]]) and not done): - # Frames equality check - if hardcore is False: - for frame_idx in range(observation["frame"].shape[2] - 1): - if np.any(observation["frame"][:, :, frame_idx] != observation["frame"][:, :, frame_idx + 1]): - raise RuntimeError("Frames inside observation after " - "round/stage/game/episode done are " - "not equal. Dones =", - info["round_done"], - info["stage_done"], - info["game_done"], - info["ep_done"]) - else: - for frame_idx in range(observation.shape[2] - 1): - if np.any(observation[:, :, frame_idx] != observation[:, :, frame_idx + 1]): - raise RuntimeError("Frames inside observation after " - "round/stage/game/episode done are " - "not equal. Dones =", - info["round_done"], - info["stage_done"], - info["game_done"], - info["ep_done"]) - - if np.any(env.exhausted): - break - - if done: - curr_num_ep += 1 - print("Ep. # = ", curr_num_ep) - print("Ep. Cumulative Rew # = ", cumulative_ep_rew) - - cumulative_ep_rew_all.append(cumulative_ep_rew) - cumulative_ep_rew = 0.0 - - observation = env.reset() - env.render(mode="human") - env.show_obs(observation) - - if diambra_il_settings["total_cpus"] == 1: - print("All ep. rewards =", cumulative_ep_rew_all) - print("Mean cumulative reward =", np.mean(cumulative_ep_rew_all)) - print("Std cumulative reward =", np.std(cumulative_ep_rew_all)) - - env.close() - - return 0 - except Exception as e: - print(e) - return 1 - - -base_path = os.path.dirname(__file__) -normal_discrete = os.path.join(base_path, "data/Discrete/Normal") -hardcore_discrete = os.path.join(base_path, "data/Discrete/HC") - -normal_multi_discrete = os.path.join(base_path, "data/MultiDiscrete/Normal") -hardcore_multi_discrete = os.path.join(base_path, "data/MultiDiscrete/HC") - -@pytest.mark.parametrize("path", [normal_discrete, normal_multi_discrete]) -def test_imitation_normal_mode(path): - assert func(path, False) == 0 - -@pytest.mark.parametrize("path", [hardcore_discrete, hardcore_multi_discrete]) -def test_imitation_hardcore_mode(path): - assert func(path, True) == 0 - diff --git a/tests/test_integration.py b/tests/test_integration.py deleted file mode 100644 index b42bcf5a..00000000 --- a/tests/test_integration.py +++ /dev/null @@ -1,106 +0,0 @@ -#!/usr/bin/env python3 -import pytest -from env_exec_interface import env_exec -import random -from os.path import expanduser -import os - -# Example Usage: -# pytest -# (optional) -# module.py (Run specific module) -# -s (show output) -# -k "expression" (filter tests using case-insensitive with parts of the test name and/or parameters values combined with boolean operators, e.g. "wrappers and doapp") - -def func(game_id, player, continue_game, action_space, attack_buttons_combination, - wrappers_settings, traj_rec_settings, hardcore_prob, no_action_prob): - - # Args - args = {} - args["interactive_viz"] = False - args["n_episodes"] = 1 - - args["no_action"] = random.choices([True, False], [no_action_probability, 1.0 - no_action_probability])[0] - - try: - # Settings - settings = {} - settings["game_id"] = game_id - settings["player"] = player - settings["continue_game"] = continue_game - settings["action_space"] = (action_space, action_space) - settings["attack_but_combination"] = (attack_buttons_combination, attack_buttons_combination) - if settings["player"] != "P1P2": - settings["action_space"] = settings["action_space"][0] - settings["attack_but_combination"] = settings["attack_but_combination"][0] - settings["hardcore"] = random.choices([True, False], [hardcore_probability, 1.0 - hardcore_probability])[0] - - return env_exec(settings, wrappers_settings, traj_rec_settings, args) - except Exception as e: - print(e) - return 1 - -game_ids = ["doapp", "sfiii3n", "tektagt", "umk3", "samsh5sp", "kof98umh"] -players = ["Random", "P1P2"] -action_spaces = ["multi_discrete"] -attack_buttons_combinations = [True] -continue_games = [-1.0, 0.3] -hardcore_probability = 0.4 -no_action_probability = 0.5 -rec_traj_probability = 0.5 - -@pytest.mark.parametrize("game_id", game_ids) -@pytest.mark.parametrize("player", players) -@pytest.mark.parametrize("continue_game", continue_games) -@pytest.mark.parametrize("action_space", action_spaces) -@pytest.mark.parametrize("attack_buttons_combination", attack_buttons_combinations) -def test_integration_gym(game_id, player, continue_game, action_space, attack_buttons_combination): - wrappers_settings = {} - traj_rec_settings = {} - assert func(game_id, player, continue_game, action_space, attack_buttons_combination, - wrappers_settings, traj_rec_settings, hardcore_probability, no_action_probability) == 0 - -@pytest.mark.parametrize("game_id", game_ids) -@pytest.mark.parametrize("player", players) -@pytest.mark.parametrize("continue_game", continue_games) -@pytest.mark.parametrize("action_space", action_spaces) -@pytest.mark.parametrize("attack_buttons_combination", attack_buttons_combinations) -def test_integration_wrappers(game_id, player, continue_game, action_space, attack_buttons_combination): - - # Env wrappers settings - wrappers_settings = {} - wrappers_settings["no_op_max"] = 0 - wrappers_settings["sticky_actions"] = 1 - wrappers_settings["hwc_obs_resize"] = (128, 128, 1) - wrappers_settings["reward_normalization"] = True - wrappers_settings["clip_rewards"] = False - wrappers_settings["frame_stack"] = 4 - wrappers_settings["dilation"] = 1 - wrappers_settings["actions_stack"] = 12 - wrappers_settings["scale"] = True - wrappers_settings["scale_mod"] = 0 - wrappers_settings["flatten"] = True - if game_id != "tektagt": - wrappers_settings["filter_keys"] = ["stage", "P1_ownSide", "P1_oppSide", "P1_oppSide", - "P1_ownHealth", "P1_oppHealth", "P1_oppChar", - "P1_actions_move", "P1_actions_attack"] - else: - wrappers_settings["filter_keys"] = ["stage", "P1_ownSide", "P1_oppSide", "P1_oppSide", - "P1_ownHealth1", "P1_oppHealth1", "P1_oppChar", - "P1_ownHealth2", "P1_oppHealth2", - "P1_actions_move", "P1_actions_attack"] - - # Recording settings - home_dir = expanduser("~") - traj_rec_settings = {} - traj_rec_settings["username"] = "Alex" - traj_rec_settings["file_path"] = os.path.join(expanduser("~"), "DIAMBRA/trajRecordings", game_id) - traj_rec_settings["ignore_p2"] = False - - if (random.choices([True, False], [rec_traj_probability, 1.0 - rec_traj_probability])[0] is False): - traj_rec_settings = {} - else: - wrappers_settings["flatten"] = False - - assert func(game_id, player, continue_game, action_space, attack_buttons_combination, - wrappers_settings, traj_rec_settings, hardcore_probability, no_action_probability) == 0 diff --git a/tests/test_random.py b/tests/test_random.py index 4d31eb26..d27af33d 100755 --- a/tests/test_random.py +++ b/tests/test_random.py @@ -2,10 +2,8 @@ import pytest from env_exec_interface import env_exec import random -from os.path import expanduser -import os -import diambra.arena -from diambra.arena.utils.engine_mock import DiambraEngineMock +from diambra.arena.utils.engine_mock import load_mocker +from diambra.arena import SpaceTypes, Roles, EnvironmentSettings, EnvironmentSettingsMultiAgent, WrappersSettings, RecordingSettings # Example Usage: # pytest @@ -14,97 +12,128 @@ # -s (show output) # -k "expression" (filter tests using case-insensitive with parts of the test name and/or parameters values combined with boolean operators, e.g. "wrappers and doapp") -def func(player, continue_game, action_space, attack_buttons_combination, frame_shape, - wrappers_settings, traj_rec_settings, hardcore_prob, no_action_prob, mocker): +def func(game_id, n_players, action_space, frame_shape, wrappers_settings, + no_action_probability, continue_games, use_mock_env, mocker): # Args args = {} - args["interactive_viz"] = False + args["interactive"] = False args["n_episodes"] = 1 - args["no_action"] = random.choices([True, False], [no_action_prob, 1.0 - no_action_prob])[0] + args["no_action_probability"] = no_action_probability + args["render"] = False + args["log_output"] = False - diambra_engine_mock = DiambraEngineMock() - - mocker.patch("diambra.arena.engine.interface.DiambraEngine.__init__", diambra_engine_mock._mock__init__) - mocker.patch("diambra.arena.engine.interface.DiambraEngine._env_init", diambra_engine_mock._mock_env_init) - mocker.patch("diambra.arena.engine.interface.DiambraEngine._reset", diambra_engine_mock._mock_reset) - mocker.patch("diambra.arena.engine.interface.DiambraEngine._step_1p", diambra_engine_mock._mock_step_1p) - mocker.patch("diambra.arena.engine.interface.DiambraEngine._step_2p", diambra_engine_mock._mock_step_2p) - mocker.patch("diambra.arena.engine.interface.DiambraEngine.close", diambra_engine_mock._mock_close) + if use_mock_env is True: + override_perfect_probability = None + if no_action_probability == 1.0: + override_perfect_probability = 0.0 + load_mocker(mocker, override_perfect_probability=override_perfect_probability) try: # Settings - settings = {} - settings["game_id"] = random.choice(list(diambra.arena.available_games(print_out=False).keys())) - settings["player"] = player - settings["frame_shape"] = frame_shape - settings["continue_game"] = continue_game - settings["action_space"] = (action_space, action_space) - settings["attack_but_combination"] = (attack_buttons_combination, attack_buttons_combination) - if settings["player"] != "P1P2": - settings["action_space"] = settings["action_space"][0] - settings["attack_but_combination"] = settings["attack_but_combination"][0] - settings["hardcore"] = random.choices([True, False], [hardcore_prob, 1.0 - hardcore_prob])[0] + if (n_players == 1): + settings = EnvironmentSettings() + else: + settings = EnvironmentSettingsMultiAgent() + settings.game_id = game_id + settings.frame_shape = frame_shape + settings.action_space = (action_space, action_space) + if settings.n_players == 1: + settings.action_space = settings.action_space[0] + + # Options (settings to change at reset) + options_list = [] + roles = [[Roles.P1, Roles.P2], [Roles.P2, Roles.P1]] + for role in roles: + role_value = (role[0], role[1]) + if settings.n_players == 1: + role_value = role_value[0] + for continue_val in continue_games: + options_list.append({"role": role_value, "continue_game": continue_val}) + else: + options_list.append({"role": role_value}) - return env_exec(settings, wrappers_settings, traj_rec_settings, args) + return env_exec(settings, options_list, wrappers_settings, RecordingSettings(), args) except Exception as e: print(e) return 1 -players = ["Random", "P1P2"] -continue_games = [-1.0, 0.0, 0.3] -action_spaces = ["discrete", "multi_discrete"] -attack_buttons_combinations = [False, True] -hardcore_probability = 0.4 -no_action_probability = 0.5 -rec_traj_probability = 0.5 +game_ids = ["doapp", "sfiii3n", "tektagt", "umk3", "samsh5sp", "kof98umh"] +n_players = [1, 2] +action_spaces = [SpaceTypes.DISCRETE, SpaceTypes.MULTI_DISCRETE] +no_action_probabilities = [0.0, 1.0] -@pytest.mark.parametrize("player", players) -@pytest.mark.parametrize("continue_game", continue_games) +@pytest.mark.parametrize("game_id", game_ids) +@pytest.mark.parametrize("n_players", n_players) @pytest.mark.parametrize("action_space", action_spaces) -@pytest.mark.parametrize("attack_buttons_combination", attack_buttons_combinations) -def test_random_gym(player, continue_game, action_space, attack_buttons_combination, mocker): +@pytest.mark.parametrize("no_action_probability", no_action_probabilities) +def test_random_gym_mock(game_id, n_players, action_space, no_action_probability, mocker): frame_shape = random.choice([(128, 128, 1), (256, 256, 0)]) - wrappers_settings = {} - traj_rec_settings = {} - assert func(player, continue_game, action_space, attack_buttons_combination, frame_shape, - wrappers_settings, traj_rec_settings, hardcore_probability, no_action_probability, mocker) == 0 + continue_games = [0.0] + use_mock_env = True + assert func(game_id, n_players, action_space, frame_shape, WrappersSettings(), + no_action_probability, continue_games, use_mock_env, mocker) == 0 -@pytest.mark.parametrize("player", players) -@pytest.mark.parametrize("continue_game", continue_games) +@pytest.mark.parametrize("game_id", game_ids) +@pytest.mark.parametrize("n_players", n_players) @pytest.mark.parametrize("action_space", action_spaces) -@pytest.mark.parametrize("attack_buttons_combination", attack_buttons_combinations) -def test_random_wrappers(player, continue_game, action_space, attack_buttons_combination, mocker): - - frame_shape = (256, 256, 0) +@pytest.mark.parametrize("no_action_probability", no_action_probabilities) +def test_random_wrappers_mock(game_id, n_players, action_space, no_action_probability, mocker): + frame_shape = random.choice([(128, 128, 1), (256, 256, 0)]) + continue_games = [0.0] + use_mock_env = True # Env wrappers settings - wrappers_settings = {} - wrappers_settings["no_op_max"] = 0 - wrappers_settings["sticky_actions"] = 1 - wrappers_settings["hwc_obs_resize"] = (128, 128, 1) - wrappers_settings["reward_normalization"] = True - wrappers_settings["clip_rewards"] = False - wrappers_settings["frame_stack"] = 4 - wrappers_settings["dilation"] = 1 - wrappers_settings["actions_stack"] = 12 - wrappers_settings["scale"] = True - wrappers_settings["scale_mod"] = 0 - wrappers_settings["flatten"] = True - wrappers_settings["filter_keys"] = ["stage", "P1_ownSide", "P1_oppSide", "P1_oppSide", - "P1_oppChar", "P1_actions_move", "P1_actions_attack"] + wrappers_settings = WrappersSettings() + wrappers_settings.no_op_max = 0 + wrappers_settings.repeat_action = 1 + wrappers_settings.frame_shape = random.choice([(128, 128, 1), (256, 256, 0)]) + wrappers_settings.normalize_reward = True + wrappers_settings.clip_reward = False + wrappers_settings.stack_frames = 4 + wrappers_settings.dilation = 1 + wrappers_settings.add_last_action = True + wrappers_settings.stack_actions = 12 + wrappers_settings.scale = True + wrappers_settings.role_relative = True + wrappers_settings.flatten = True + suffix = "" + if n_players == 2: + suffix = "agent_0_" + wrappers_settings.filter_keys = ["stage", "timer", suffix + "own_side", suffix + "opp_side", + suffix + "opp_character", suffix + "action"] - # Recording settings - home_dir = expanduser("~") - traj_rec_settings = {} - traj_rec_settings["user_name"] = "Alex" - traj_rec_settings["file_path"] = os.path.join(home_dir, "DIAMBRA/trajRecordings/mock") - traj_rec_settings["ignore_p2"] = False + assert func(game_id, n_players, action_space, frame_shape, wrappers_settings, + no_action_probability, continue_games, use_mock_env, mocker) == 0 - if (random.choices([True, False], [rec_traj_probability, 1.0 - rec_traj_probability])[0] is False): - traj_rec_settings = {} - else: - wrappers_settings["flatten"] = False +@pytest.mark.parametrize("game_id", game_ids) +@pytest.mark.parametrize("n_players", n_players) +@pytest.mark.parametrize("action_space", action_spaces) +@pytest.mark.parametrize("no_action_probability", [0.0]) +def test_random_integration(game_id, n_players, action_space, no_action_probability, mocker): + frame_shape = random.choice([(128, 128, 1), (256, 256, 0)]) + continue_games = [-1.0, 0.0, 0.3] + use_mock_env = False + + # Env wrappers settings + wrappers_settings = WrappersSettings() + wrappers_settings.no_op_max = 0 + wrappers_settings.repeat_action = 1 + wrappers_settings.frame_shape = (128, 128, 1) + wrappers_settings.normalize_reward = True + wrappers_settings.clip_reward = False + wrappers_settings.stack_frames = 4 + wrappers_settings.dilation = 1 + wrappers_settings.add_last_action = True + wrappers_settings.stack_actions = 12 + wrappers_settings.scale = True + wrappers_settings.role_relative = True + wrappers_settings.flatten = True + suffix = "" + if n_players == 2: + suffix = "agent_0_" + wrappers_settings.filter_keys = ["stage", "timer", suffix + "own_side", suffix + "opp_side", + suffix + "opp_character", suffix + "action"] - assert func(player, continue_game, action_space, attack_buttons_combination, frame_shape, - wrappers_settings, traj_rec_settings, hardcore_probability, no_action_probability, mocker) == 0 + assert func(game_id, n_players, action_space, frame_shape, wrappers_settings, + no_action_probability, use_mock_env, mocker) == 0 \ No newline at end of file diff --git a/tests/test_recording_settings.py b/tests/test_recording_settings.py index ffa6532c..285d0aba 100644 --- a/tests/test_recording_settings.py +++ b/tests/test_recording_settings.py @@ -2,9 +2,10 @@ import pytest from os.path import expanduser import os -from diambra.arena.utils.engine_mock import DiambraEngineMock import diambra.arena +from diambra.arena import SpaceTypes, EnvironmentSettings, EnvironmentSettingsMultiAgent, WrappersSettings, RecordingSettings from pytest_utils import generate_pytest_decorator_input +from diambra.arena.utils.engine_mock import load_mocker from diambra.arena.utils.gym_utils import available_games # Example Usage: @@ -14,9 +15,10 @@ # -s (show output) # -k "expression" (filter tests using case-insensitive with parts of the test name and/or parameters values combined with boolean operators, e.g. "wrappers and doapp") -def env_exec(settings, wrappers_settings, traj_rec_settings): +def func(settings, wrappers_settings, episode_recording_settings, mocker): + load_mocker(mocker) try: - env = diambra.arena.make(settings["game_id"], settings, wrappers_settings, traj_rec_settings) + env = diambra.arena.make(settings.game_id, settings, wrappers_settings, episode_recording_settings) env.close() print("COMPLETED SUCCESSFULLY!") @@ -26,78 +28,52 @@ def env_exec(settings, wrappers_settings, traj_rec_settings): print("ERROR, ABORTED.") return 1 -def func(settings, wrappers_settings, traj_rec_settings, mocker): - - diambra_engine_mock = DiambraEngineMock() - - mocker.patch("diambra.arena.engine.interface.DiambraEngine.__init__", diambra_engine_mock._mock__init__) - mocker.patch("diambra.arena.engine.interface.DiambraEngine._env_init", diambra_engine_mock._mock_env_init) - mocker.patch("diambra.arena.engine.interface.DiambraEngine._reset", diambra_engine_mock._mock_reset) - mocker.patch("diambra.arena.engine.interface.DiambraEngine._step_1p", diambra_engine_mock._mock_step_1p) - mocker.patch("diambra.arena.engine.interface.DiambraEngine._step_2p", diambra_engine_mock._mock_step_2p) - mocker.patch("diambra.arena.engine.interface.DiambraEngine.close", diambra_engine_mock._mock_close) - - try: - return env_exec(settings, wrappers_settings, traj_rec_settings) - except Exception as e: - print(e) - return 1 - -wrappers_settings_var_order = ["username", "file_path", "ignore_p2"] +episode_recording_settings_var_order = ["username", "dataset_path"] games_dict = available_games(False) home_dir = expanduser("~") ok_test_parameters = { "username": ["alexpalms", "test"], - "file_path": [os.path.join(home_dir, "DIAMBRA")], - "ignore_p2": [False, True], + "dataset_path": [os.path.join(home_dir, "DIAMBRA")], } ko_test_parameters = { "username": [123], - "file_path": [True], - "ignore_p2": [1], + "dataset_path": [True], } def pytest_generate_tests(metafunc): - test_vars, values_list_ok = generate_pytest_decorator_input(wrappers_settings_var_order, ok_test_parameters, 0) - test_vars, values_list_ko = generate_pytest_decorator_input(wrappers_settings_var_order, ko_test_parameters, 1) - values_list = values_list_ok + values_list_ko + test_vars, values_list = generate_pytest_decorator_input(episode_recording_settings_var_order, ok_test_parameters, ko_test_parameters) metafunc.parametrize(test_vars, values_list) # Recording @pytest.mark.parametrize("game_id", list(games_dict.keys())) -@pytest.mark.parametrize("player", ["Random", "P1P2"]) -@pytest.mark.parametrize("hardcore", [False, True]) -@pytest.mark.parametrize("action_space", ["discrete", "multi_discrete"]) -@pytest.mark.parametrize("attack_buttons_combination", [False, True]) -def test_settings_recording(game_id ,username, file_path, ignore_p2, - player, action_space, attack_buttons_combination, hardcore, expected, mocker): - +@pytest.mark.parametrize("n_players", [1, 2]) +@pytest.mark.parametrize("action_space", [SpaceTypes.DISCRETE, SpaceTypes.MULTI_DISCRETE]) +def test_settings_recording(game_id ,username, dataset_path, n_players, action_space, expected, mocker): # Env settings - settings = {} - settings["game_id"] = game_id - settings["player"] = player - settings["hardcore"] = hardcore - settings["action_space"] = action_space - settings["attack_buttons_combination"] = attack_buttons_combination - if player == "P1P2": - settings["action_space"] = (action_space, action_space) - settings["attack_buttons_combination"] = (attack_buttons_combination, attack_buttons_combination) + if (n_players == 1): + settings = EnvironmentSettings() + else: + settings = EnvironmentSettingsMultiAgent() + settings.game_id = game_id + settings.action_space = action_space + if n_players == 2: + settings.action_space = (action_space, action_space) # Env wrappers settings - wrappers_settings = {} - wrappers_settings["hwc_obs_resize"] = (128, 128, 1) - wrappers_settings["reward_normalization"] = True - wrappers_settings["frame_stack"] = 4 - wrappers_settings["actions_stack"] = 12 - wrappers_settings["scale"] = True + wrappers_settings = WrappersSettings() + wrappers_settings.frame_shape = (128, 128, 1) + wrappers_settings.normalize_reward = True + wrappers_settings.stack_frames = 4 + wrappers_settings.add_last_action = True + wrappers_settings.stack_actions = 12 + wrappers_settings.scale = True # Recording settings - traj_rec_settings = {} - traj_rec_settings["username"] = username - traj_rec_settings["file_path"] = file_path - traj_rec_settings["ignore_p2"] = ignore_p2 + episode_recording_settings = RecordingSettings() + episode_recording_settings.username = username + episode_recording_settings.dataset_path = dataset_path - assert func(settings, wrappers_settings, traj_rec_settings, mocker) == expected + assert func(settings, wrappers_settings, episode_recording_settings, mocker) == expected diff --git a/tests/test_speed.py b/tests/test_speed.py index f86d837a..0ac53db2 100644 --- a/tests/test_speed.py +++ b/tests/test_speed.py @@ -1,10 +1,9 @@ #!/usr/bin/env python3 import pytest -from env_exec_interface import env_exec import time -from os.path import expanduser import diambra.arena -from diambra.arena.utils.engine_mock import DiambraEngineMock +from diambra.arena import EnvironmentSettings, EnvironmentSettingsMultiAgent, WrappersSettings +from diambra.arena.utils.engine_mock import load_mocker import numpy as np import warnings @@ -14,55 +13,32 @@ def reject_outliers(data): filtered = [e for e in data if (u - 2 * s < e < u + 2 * s)] return filtered -def func(player, wrappers_settings, target_speed, mocker): - - diambra_engine_mock = DiambraEngineMock(fps=500) - - mocker.patch("diambra.arena.engine.interface.DiambraEngine.__init__", diambra_engine_mock._mock__init__) - mocker.patch("diambra.arena.engine.interface.DiambraEngine._env_init", diambra_engine_mock._mock_env_init) - mocker.patch("diambra.arena.engine.interface.DiambraEngine._reset", diambra_engine_mock._mock_reset) - mocker.patch("diambra.arena.engine.interface.DiambraEngine._step_1p", diambra_engine_mock._mock_step_1p) - mocker.patch("diambra.arena.engine.interface.DiambraEngine._step_2p", diambra_engine_mock._mock_step_2p) - mocker.patch("diambra.arena.engine.interface.DiambraEngine.close", diambra_engine_mock._mock_close) - +def func(n_players, wrappers_settings, target_speed, mocker): + load_mocker(mocker) try: # Settings - settings = {} - settings["player"] = player - settings["action_space"] = "discrete" - settings["attack_but_combination"] = False - if player == "P1P2": - settings["action_space"] = ("discrete", "discrete") - settings["attack_but_combination"] = (False, False) + if (n_players == 1): + settings = EnvironmentSettings() + else: + settings = EnvironmentSettingsMultiAgent() env = diambra.arena.make("doapp", settings, wrappers_settings) + observation, info = env.reset() - observation = env.reset() n_step = 0 - fps_val = [] - while n_step < 1000: - n_step += 1 - actions = [None, None] - if settings["player"] != "P1P2": - actions = env.action_space.sample() - else: - for idx in range(2): - actions[idx] = env.action_space["P{}".format(idx + 1)].sample() - - if (settings["player"] == "P1P2" or settings["action_space"] != "discrete"): - actions = np.append(actions[0], actions[1]) + actions = env.action_space.sample() tic = time.time() - observation, reward, done, info = env.step(actions) + observation, reward, terminated, truncated, info = env.step(actions) toc = time.time() fps = 1 / (toc - tic) fps_val.append(fps) - if done: - observation = env.reset() + if terminated or truncated: + observation, info = env.reset() break env.close() @@ -84,32 +60,34 @@ def func(player, wrappers_settings, target_speed, mocker): print(e) return 1 -players = ["Random", "P1P2"] - +n_players = [1, 2] target_speeds = [400, 300] -@pytest.mark.parametrize("player", players) -def test_speed_gym(player, mocker): - wrappers_settings = {} - assert func(player, wrappers_settings, target_speeds[0], mocker) == 0 +@pytest.mark.parametrize("n_players", n_players) +def test_speed_gym(n_players, mocker): + assert func(n_players, WrappersSettings(), target_speeds[0], mocker) == 0 -@pytest.mark.parametrize("player", players) -def test_speed_wrappers(player, mocker): +@pytest.mark.parametrize("n_players", n_players) +def test_speed_wrappers(n_players, mocker): # Env wrappers settings - wrappers_settings = {} - wrappers_settings["no_op_max"] = 0 - wrappers_settings["sticky_actions"] = 1 - wrappers_settings["reward_normalization"] = True - wrappers_settings["clip_rewards"] = False - wrappers_settings["frame_stack"] = 4 - wrappers_settings["dilation"] = 1 - wrappers_settings["actions_stack"] = 12 - wrappers_settings["scale"] = True - wrappers_settings["scale_mod"] = 0 - wrappers_settings["flatten"] = True - wrappers_settings["filter_keys"] = ["stage", "P1_ownSide", "P1_oppSide", "P1_oppSide", - "P1_ownHealth", "P1_oppHealth", "P1_oppChar", - "P1_actions_move", "P1_actions_attack"] - - assert func(player, wrappers_settings, target_speeds[1], mocker) == 0 + wrappers_settings = WrappersSettings() + wrappers_settings.no_op_max = 0 + wrappers_settings.repeat_action = 1 + wrappers_settings.normalize_reward = True + wrappers_settings.clip_reward = False + wrappers_settings.stack_frames = 4 + wrappers_settings.dilation = 1 + wrappers_settings.add_last_action = True + wrappers_settings.stack_actions = 12 + wrappers_settings.scale = True + wrappers_settings.role_relative = True + wrappers_settings.flatten = True + + suffix = "" + if n_players == 2: + suffix = "agent_0_" + wrappers_settings.filter_keys = ["stage", "timer", suffix + "own_side", suffix + "opp_side", + suffix + "opp_character", suffix + "action"] + + assert func(n_players, wrappers_settings, target_speeds[1], mocker) == 0 diff --git a/tests/test_wrappers_settings.py b/tests/test_wrappers_settings.py index 11164147..6f8b913d 100644 --- a/tests/test_wrappers_settings.py +++ b/tests/test_wrappers_settings.py @@ -1,9 +1,9 @@ #!/usr/bin/env python import pytest -from os.path import expanduser import diambra.arena -from diambra.arena.utils.engine_mock import DiambraEngineMock from pytest_utils import generate_pytest_decorator_input +from diambra.arena import SpaceTypes, EnvironmentSettings, EnvironmentSettingsMultiAgent, WrappersSettings +from diambra.arena.utils.engine_mock import load_mocker from diambra.arena.utils.gym_utils import available_games # Example Usage: @@ -13,9 +13,10 @@ # -s (show output) # -k "expression" (filter tests using case-insensitive with parts of the test name and/or parameters values combined with boolean operators, e.g. "wrappers and doapp") -def env_exec(settings, wrappers_settings, traj_rec_settings): +def func(settings, wrappers_settings, mocker): + load_mocker(mocker) try: - env = diambra.arena.make(settings["game_id"], settings, wrappers_settings, traj_rec_settings) + env = diambra.arena.make(settings.game_id, settings, wrappers_settings) env.close() print("COMPLETED SUCCESSFULLY!") @@ -25,107 +26,94 @@ def env_exec(settings, wrappers_settings, traj_rec_settings): print("ERROR, ABORTED.") return 1 -def func(settings, wrappers_settings, traj_rec_settings, mocker): - - diambra_engine_mock = DiambraEngineMock() - - mocker.patch("diambra.arena.engine.interface.DiambraEngine.__init__", diambra_engine_mock._mock__init__) - mocker.patch("diambra.arena.engine.interface.DiambraEngine._env_init", diambra_engine_mock._mock_env_init) - mocker.patch("diambra.arena.engine.interface.DiambraEngine._reset", diambra_engine_mock._mock_reset) - mocker.patch("diambra.arena.engine.interface.DiambraEngine._step_1p", diambra_engine_mock._mock_step_1p) - mocker.patch("diambra.arena.engine.interface.DiambraEngine._step_2p", diambra_engine_mock._mock_step_2p) - mocker.patch("diambra.arena.engine.interface.DiambraEngine.close", diambra_engine_mock._mock_close) - - try: - return env_exec(settings, wrappers_settings, traj_rec_settings) - except Exception as e: - print(e) - return 1 - -wrappers_settings_var_order = ["no_op_max", "sticky_actions", "hwc_obs_resize", "reward_normalization", - "reward_normalization_factor", "clip_rewards", "frame_stack", "dilation", - "actions_stack", "scale", "scale_mod", "flatten", "filter_keys"] +wrappers_settings_var_order = ["no_op_max", "repeat_action", "normalize_reward", "normalization_factor", + "clip_reward", "no_attack_buttons_combinations", "frame_shape", "stack_frames", "dilation", + "add_last_action", "stack_actions", "scale", "role_relative", + "flatten", "filter_keys", "wrappers"] games_dict = available_games(False) ok_test_parameters = { "no_op_max": [0, 2], - "sticky_actions": [1, 4], - "hwc_obs_resize": [(0, 0, 0), (84, 84, 1), (84, 84, 3), (84, 84, 0)], - "reward_normalization": [True, False], - "reward_normalization_factor": [0.2, 0.5], - "clip_rewards": [True, False], - "frame_stack": [1, 5], + "repeat_action": [1, 4], + "normalize_reward": [True, False], + "normalization_factor": [0.2, 0.5], + "clip_reward": [True, False], + "no_attack_buttons_combinations": [True, False], + "frame_shape": [(0, 0, 0), (84, 84, 1), (84, 84, 3), (84, 84, 0)], + "stack_frames": [1, 5], "dilation": [1, 3], - "actions_stack": [1, 6], + "add_last_action": [True, False], + "stack_actions": [1, 6], "scale": [True, False], - "scale_mod": [0], + "role_relative": [True, False], "flatten": [True, False], - "filter_keys": [[], ["stage", "P1_ownSide"]] + "filter_keys": [[], ["stage", "own_side"]], + "wrappers": [[]], } ko_test_parameters = { "no_op_max": [-1], - "sticky_actions": [True], - "hwc_obs_resize": [(0, 84, 3), (0, 0, 1)], - "reward_normalization": ["True"], - "reward_normalization_factor": [-10], - "clip_rewards": [0.5], - "frame_stack": [0], + "repeat_action": [True], + "normalize_reward": ["True"], + "normalization_factor": [-10], + "clip_reward": [0.5], + "no_attack_buttons_combinations": [-1], + "frame_shape": [(0, 84, 3), (128, 0, 1)], + "stack_frames": [0], "dilation": [0], - "actions_stack": [-2], + "add_last_action": [10], + "stack_actions": [-2], "scale": [10], - "scale_mod": [2], + "role_relative": [24], "flatten": [None], - "filter_keys": [12] + "filter_keys": [12], + "wrappers": ["test"], } def pytest_generate_tests(metafunc): - test_vars, values_list_ok = generate_pytest_decorator_input(wrappers_settings_var_order, ok_test_parameters, 0) - test_vars, values_list_ko = generate_pytest_decorator_input(wrappers_settings_var_order, ko_test_parameters, 1) - values_list = values_list_ok + values_list_ko + test_vars, values_list = generate_pytest_decorator_input(wrappers_settings_var_order, ok_test_parameters, ko_test_parameters) metafunc.parametrize(test_vars, values_list) # Wrappers @pytest.mark.parametrize("game_id", list(games_dict.keys())) @pytest.mark.parametrize("step_ratio", [1]) -@pytest.mark.parametrize("player", ["Random", "P1P2"]) -@pytest.mark.parametrize("hardcore", [False, True]) -@pytest.mark.parametrize("action_space", ["discrete", "multi_discrete"]) -@pytest.mark.parametrize("attack_buttons_combination", [False, True]) -def test_settings_wrappers(game_id, step_ratio, player, action_space, attack_buttons_combination, hardcore, - no_op_max, sticky_actions, hwc_obs_resize, reward_normalization, - reward_normalization_factor, clip_rewards, frame_stack, dilation, - actions_stack, scale, scale_mod, flatten, filter_keys, expected, mocker): +@pytest.mark.parametrize("n_players", [1, 2]) +@pytest.mark.parametrize("action_space", [SpaceTypes.DISCRETE, SpaceTypes.MULTI_DISCRETE]) +def test_wrappers_settings(game_id, step_ratio, n_players, action_space, no_op_max, repeat_action, + normalize_reward, normalization_factor, clip_reward, + no_attack_buttons_combinations, frame_shape, stack_frames, dilation, + add_last_action, stack_actions, scale, role_relative, + flatten, filter_keys, wrappers, expected, mocker): # Env settings - settings = {} - settings["game_id"] = game_id - settings["step_ratio"] = step_ratio - settings["player"] = player - settings["hardcore"] = hardcore - settings["action_space"] = action_space - settings["attack_buttons_combination"] = attack_buttons_combination - if player == "P1P2": - settings["action_space"] = (action_space, action_space) - settings["attack_buttons_combination"] = (attack_buttons_combination, attack_buttons_combination) + if (n_players == 1): + settings = EnvironmentSettings() + else: + settings = EnvironmentSettingsMultiAgent() + settings.game_id = game_id + settings.step_ratio = step_ratio + settings.action_space = action_space + if n_players == 2: + settings.action_space = (action_space, action_space) # Env wrappers settings - wrappers_settings = {} - wrappers_settings["no_op_max"] = no_op_max - wrappers_settings["sticky_actions"] = sticky_actions - wrappers_settings["hwc_obs_resize"] = hwc_obs_resize - wrappers_settings["reward_normalization"] = reward_normalization - wrappers_settings["clip_rewards"] = clip_rewards - wrappers_settings["frame_stack"] = frame_stack - wrappers_settings["dilation"] = dilation - wrappers_settings["actions_stack"] = actions_stack - wrappers_settings["scale"] = scale - wrappers_settings["scale_mod"] = scale_mod - wrappers_settings["flatten"] = flatten - wrappers_settings["filter_keys"] = filter_keys - - # Recording settings - traj_rec_settings = {} - - assert func(settings, wrappers_settings, traj_rec_settings, mocker) == expected + wrappers_settings = WrappersSettings() + wrappers_settings.no_op_max = no_op_max + wrappers_settings.repeat_action = repeat_action + wrappers_settings.normalize_reward = normalize_reward + wrappers_settings.normalization_factor = normalization_factor + wrappers_settings.clip_reward = clip_reward + wrappers_settings.no_attack_buttons_combinations = no_attack_buttons_combinations + wrappers_settings.frame_shape = frame_shape + wrappers_settings.stack_frames = stack_frames + wrappers_settings.dilation = dilation + wrappers_settings.add_last_action = add_last_action + wrappers_settings.stack_actions = 1 if add_last_action is False and expected == 0 else stack_actions + wrappers_settings.scale = scale + wrappers_settings.role_relative = role_relative + wrappers_settings.flatten = flatten + wrappers_settings.filter_keys = filter_keys + wrappers_settings.wrappers = wrappers + + assert func(settings, wrappers_settings, mocker) == expected