diff --git a/diambra/arena/__init__.py b/diambra/arena/__init__.py index 6110e39f..882e1151 100644 --- a/diambra/arena/__init__.py +++ b/diambra/arena/__init__.py @@ -1,4 +1,5 @@ from diambra.engine import SpaceTypes, Roles from diambra.engine import model +from .env_settings import EnvironmentSettings1P, EnvironmentSettings2P, WrappersSettings, RecordingSettings, load_settings_flat_dict from .make_env import make from .utils.gym_utils import available_games, game_sha_256, check_game_sha_256, get_num_envs \ No newline at end of file diff --git a/diambra/arena/env_settings.py b/diambra/arena/env_settings.py index a3182500..71598928 100644 --- a/diambra/arena/env_settings.py +++ b/diambra/arena/env_settings.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Union, List, Tuple, Any, Dict from diambra.arena.utils.gym_utils import available_games from diambra.arena import SpaceTypes, Roles @@ -6,6 +6,10 @@ import random from diambra.engine import model import time +from dacite import from_dict, Config + +def load_settings_flat_dict(target_class, settings: dict): + return from_dict(target_class, settings, config=Config(strict=True)) MAX_VAL = float("inf") MIN_VAL = float("-inf") @@ -22,6 +26,13 @@ def check_val_in_list(key, value, valid_list): assert (value in valid_list), error_message assert (type(value)==type(valid_list[valid_list.index(value)])), error_message +def check_type(key, value, expected_type, admit_none=True): + error_message = "ERROR: \"{}\" ({}) is of type {}, admissible types are {}".format(key, value, type(value), expected_type) + if value is not None: + assert isinstance(value, expected_type), error_message + else: + assert admit_none==True, "ERROR: \"{}\" ({}) cannot be NoneType".format(key, value) + def check_space_type(key, value, valid_list): error_message = "ERROR: \"{}\" ({}) admissible values are {}".format(key, SpaceTypes.Name(value), [SpaceTypes.Name(elem) for elem in valid_list]) assert (value in valid_list), error_message @@ -37,15 +48,14 @@ class EnvironmentSettings: games_dict = None # Environment settings - game_id: str + game_id: str = "doapp" frame_shape: Tuple[int, int, int] = (0, 0, 0) step_ratio: int = 6 - n_players: int = 1 disable_keyboard: bool = True disable_joystick: bool = True render_mode: Union[None, str] = None rank: int = 0 - env_address: str = "localhost:50051" + env_address: str = None grpc_timeout: int = 600 # Episode settings @@ -55,7 +65,7 @@ class EnvironmentSettings: show_final: bool = False tower: Union[None, int] = 3 # UMK3 Specific - # Bookeeping variables + # Bookkeeping variables _last_seed: Union[None, int] = None pb_model: model = None @@ -161,21 +171,22 @@ def _sample_characters(self, n_characters=3): def _sanity_check(self): if self.env_info is None or self.games_dict is None: - raise("EnvironmentSettings class not correctly initialized") + raise Exception("EnvironmentSettings class not correctly initialized") # TODO: consider using typing.Literal to specify lists of admissible values: NOTE: It requires Python 3.8+ check_val_in_list("game_id", self.game_id, list(self.games_dict.keys())) - if self.render_mode is not None: - check_val_in_list("render_mode", self.render_mode, ["human", "rgb_array"]) + check_num_in_range("step_ratio", self.step_ratio, [1, 6]) check_num_in_range("frame_shape[0]", self.frame_shape[0], [0, MAX_FRAME_RES]) check_num_in_range("frame_shape[1]", self.frame_shape[1], [0, MAX_FRAME_RES]) if (min(self.frame_shape[0], self.frame_shape[1]) == 0 and max(self.frame_shape[0], self.frame_shape[1]) != 0): raise Exception("\"frame_shape[0] and frame_shape[1]\" must be either both different from or equal to 0") check_val_in_list("frame_shape[2]", self.frame_shape[2], [0, 1]) - check_num_in_range("step_ratio", self.step_ratio, [1, 6]) - check_num_in_range("grpc_timeout", self.grpc_timeout, [0, 3600]) + if self.render_mode is not None: + check_val_in_list("render_mode", self.render_mode, ["human", "rgb_array"]) check_num_in_range("rank", self.rank, [0, MAX_VAL]) + check_type("env_address", self.env_address, str) + check_num_in_range("grpc_timeout", self.grpc_timeout, [0, 3600]) if self.seed is not None: check_num_in_range("seed", self.seed, [-1, MAX_VAL]) @@ -183,6 +194,7 @@ def _sanity_check(self): difficulty_admissible_values.append(None) check_val_in_list("difficulty", self.difficulty, difficulty_admissible_values) check_num_in_range("continue_game", self.continue_game, [MIN_VAL, 1.0]) + check_type("show_final", self.show_final, bool) check_val_in_list("tower", self.tower, [None, 1, 2, 3, 4]) def _process_random_values(self): @@ -194,8 +206,8 @@ def _process_random_values(self): @dataclass class EnvironmentSettings1P(EnvironmentSettings): """Single Agent Environment Settings Class""" - # Env settings + n_players: int = 1 action_space: int = SpaceTypes.MULTI_DISCRETE # Episode settings @@ -270,8 +282,9 @@ def _get_player_specific_values(self): @dataclass class EnvironmentSettings2P(EnvironmentSettings): - """Single Agent Environment Settings Class""" + """Multi Agent Environment Settings Class""" # Env Settings + n_players: int = 2 action_space: Tuple[int, int] = (SpaceTypes.MULTI_DISCRETE, SpaceTypes.MULTI_DISCRETE) # Episode Settings @@ -382,28 +395,42 @@ class WrappersSettings: process_discrete_binary: bool = False role_relative_observation: bool = False flatten: bool = False - filter_keys: List[str] = None - wrappers: List[List[Any]] = None + filter_keys: List[str] = field(default_factory=list) + wrappers: List[List[Any]] = field(default_factory=list) def sanity_check(self): check_num_in_range("no_op_max", self.no_op_max, [0, 12]) check_num_in_range("sticky_actions", self.sticky_actions, [1, 12]) - check_num_in_range("frame_stack", self.frame_stack, [1, MAX_STACK_VALUE]) - check_num_in_range("dilation", self.dilation, [1, MAX_STACK_VALUE]) - actions_stack_bounds = [1, 1] - if self.add_last_action_to_observation is True: - actions_stack_bounds = [1, MAX_STACK_VALUE] - check_num_in_range("actions_stack", self.actions_stack, actions_stack_bounds) + check_type("reward_normalization", self.reward_normalization, bool, admit_none=False) check_num_in_range("reward_normalization_factor", self.reward_normalization_factor, [0.0, 1000000]) - + check_type("clip_rewards", self.clip_rewards, bool, admit_none=False) + check_type("no_attack_buttons_combinations", self.no_attack_buttons_combinations, bool, admit_none=False) check_val_in_list("frame_shape[2]", self.frame_shape[2], [0, 1, 3]) check_num_in_range("frame_shape[0]", self.frame_shape[0], [0, MAX_FRAME_RES]) check_num_in_range("frame_shape[1]", self.frame_shape[1], [0, MAX_FRAME_RES]) if (min(self.frame_shape[0], self.frame_shape[1]) == 0 and max(self.frame_shape[0], self.frame_shape[1]) != 0): raise Exception("\"frame_shape[0] and frame_shape[1]\" must be both different from 0") + check_num_in_range("frame_stack", self.frame_stack, [1, MAX_STACK_VALUE]) + check_num_in_range("dilation", self.dilation, [1, MAX_STACK_VALUE]) + check_type("add_last_action_to_observation", self.add_last_action_to_observation, bool, admit_none=False) + actions_stack_bounds = [1, 1] + if self.add_last_action_to_observation is True: + actions_stack_bounds = [1, MAX_STACK_VALUE] + check_num_in_range("actions_stack", self.actions_stack, actions_stack_bounds) + check_type("scale", self.scale, bool, admit_none=False) + check_type("exclude_image_scaling", self.exclude_image_scaling, bool, admit_none=False) + check_type("process_discrete_binary", self.process_discrete_binary, bool, admit_none=False) + check_type("role_relative_observation", self.role_relative_observation, bool, admit_none=False) + check_type("flatten", self.flatten, bool, admit_none=False) + check_type("filter_keys", self.filter_keys, list, admit_none=False) + check_type("wrappers", self.wrappers, list, admit_none=False) @dataclass class RecordingSettings: - dataset_path: str = "./" - username: str = "username" + dataset_path: Union[None, str] = None + username: Union[None, str] = None + + def sanity_check(self): + check_type("dataset_path", self.dataset_path, str) + check_type("username", self.username, str) diff --git a/diambra/arena/make_env.py b/diambra/arena/make_env.py index 5212b2d0..de598365 100644 --- a/diambra/arena/make_env.py +++ b/diambra/arena/make_env.py @@ -1,13 +1,14 @@ import os import logging -from dacite import from_dict, Config from diambra.arena.arena_gym import DiambraGym1P, DiambraGym2P from diambra.arena.wrappers.arena_wrappers import env_wrapping -from diambra.arena.env_settings import EnvironmentSettings1P, EnvironmentSettings2P, WrappersSettings, RecordingSettings +from diambra.arena import EnvironmentSettings1P, EnvironmentSettings2P, WrappersSettings, RecordingSettings from diambra.arena.wrappers.episode_recording import EpisodeRecorder +from typing import Union -def make(game_id, env_settings:dict={}, wrappers_settings:dict={}, episode_recording_settings:dict={}, - render_mode=None, rank=0, log_level=logging.INFO): +def make(game_id, env_settings: Union[EnvironmentSettings1P, EnvironmentSettings2P]=EnvironmentSettings1P(), + wrappers_settings: WrappersSettings=WrappersSettings(), episode_recording_settings: RecordingSettings=RecordingSettings(), + render_mode: str=None, rank: int=0, log_level=logging.INFO): """ Create a wrapped environment. :param seed: (int) the initial seed for RNG @@ -19,8 +20,8 @@ def make(game_id, env_settings:dict={}, wrappers_settings:dict={}, episode_recor logger = logging.getLogger(__name__) # Include game_id and render_mode in env_settings - env_settings["game_id"] = game_id - env_settings["render_mode"] = render_mode + env_settings.game_id = game_id + env_settings.render_mode = render_mode # Check if DIAMBRA_ENVS var present env_addresses = os.getenv("DIAMBRA_ENVS", "").split() @@ -31,20 +32,13 @@ def make(game_id, env_settings:dict={}, wrappers_settings:dict={}, episode_recor "# of env servers: {}".format(len(env_addresses)), "# rank of client: {} (0-based index)".format(rank)) else: # If not present, set default value - if "env_address" not in env_settings: + if env_settings.env_address is None: env_addresses = ["localhost:50051"] else: - env_addresses = [env_settings["env_address"]] + env_addresses = [env_settings.env_address] - env_settings["env_address"] = env_addresses[rank] - env_settings["rank"] = rank - - # Checking settings and setting up default ones - if "n_players" in env_settings.keys() and env_settings["n_players"] == 2: - env_settings = from_dict(EnvironmentSettings2P, env_settings, config=Config(strict=True)) - else: - env_settings["n_players"] = 1 - env_settings = from_dict(EnvironmentSettings1P, env_settings, config=Config(strict=True)) + env_settings.env_address = env_addresses[rank] + env_settings.rank = rank # Make environment if env_settings.n_players == 1: # 1P Mode @@ -53,12 +47,11 @@ def make(game_id, env_settings:dict={}, wrappers_settings:dict={}, episode_recor env = DiambraGym2P(env_settings) # Apply episode recorder wrapper - if len(episode_recording_settings) != 0: - episode_recording_settings = from_dict(RecordingSettings, episode_recording_settings, config=Config(strict=True)) + if episode_recording_settings.dataset_path is not None: + episode_recording_settings.sanity_check() env = EpisodeRecorder(env, episode_recording_settings) # Apply environment wrappers - wrappers_settings = from_dict(WrappersSettings, wrappers_settings, config=Config(strict=True)) wrappers_settings.sanity_check() env = env_wrapping(env, wrappers_settings) diff --git a/diambra/arena/wrappers/arena_wrappers.py b/diambra/arena/wrappers/arena_wrappers.py index 9d4f2678..7e3f4eb7 100644 --- a/diambra/arena/wrappers/arena_wrappers.py +++ b/diambra/arena/wrappers/arena_wrappers.py @@ -190,9 +190,8 @@ def env_wrapping(env, wrappers_settings: WrappersSettings): if wrappers_settings.flatten is True: env = FlattenFilterDictObs(env, wrappers_settings.filter_keys) - # Apply all additional wrappers in sequence - if wrappers_settings.wrappers is not None: - for wrapper in wrappers_settings.wrappers: - env = wrapper[0](env, **wrapper[1]) + # Apply all additional wrappers in sequence: + for wrapper in wrappers_settings.wrappers: + env = wrapper[0](env, **wrapper[1]) return env diff --git a/diambra/arena/wrappers/observation.py b/diambra/arena/wrappers/observation.py index 056cee74..c3e85d24 100644 --- a/diambra/arena/wrappers/observation.py +++ b/diambra/arena/wrappers/observation.py @@ -333,15 +333,15 @@ def __init__(self, env, filter_keys): gym.ObservationWrapper.__init__(self, env) self.filter_keys = filter_keys - if (filter_keys is not None): + if len(filter_keys) != 0: self.filter_keys = list(set(filter_keys)) if "frame" not in filter_keys: self.filter_keys += ["frame"] - original_obs_space_keys = (flatten_filter_obs_space_func(self.observation_space, None)).keys() + original_obs_space_keys = (flatten_filter_obs_space_func(self.observation_space, [])).keys() self.observation_space = gym.spaces.Dict(flatten_filter_obs_space_func(self.observation_space, self.filter_keys)) - if filter_keys is not None: + if len(filter_keys) != 0: if (sorted(self.observation_space.spaces.keys()) != sorted(self.filter_keys)): raise Exception("Specified observation key(s) not found:", " Available key(s):", sorted(original_obs_space_keys), @@ -371,7 +371,7 @@ def visit(subdict, flattened_dict, partial_key, check_method): if check_method(new_key): flattened_dict[new_key] = v - if filter_keys is not None: + if len(filter_keys) != 0: visit(input_dictionary, flattened_dict, _FLAG_FIRST, check_filter) else: visit(input_dictionary, flattened_dict, _FLAG_FIRST, dummy_check) @@ -397,7 +397,7 @@ def visit(subdict, flattened_dict, partial_key, check_method): if check_method(new_key): flattened_dict[new_key] = v - if filter_keys is not None: + if len(filter_keys) != 0: visit(input_dictionary, flattened_dict, _FLAG_FIRST, check_filter) else: visit(input_dictionary, flattened_dict, _FLAG_FIRST, dummy_check) diff --git a/examples/episode_recording.py b/examples/episode_recording.py index b66feb23..abee4105 100644 --- a/examples/episode_recording.py +++ b/examples/episode_recording.py @@ -1,25 +1,23 @@ import os from os.path import expanduser import diambra.arena -from diambra.arena import SpaceTypes +from diambra.arena import SpaceTypes, EnvironmentSettings1P, RecordingSettings from diambra.arena.utils.controller import get_diambra_controller import argparse def main(use_controller): # Environment Settings - settings = {} - settings["n_players"] = 1 - settings["role"] = None - settings["step_ratio"] = 1 - settings["frame_shape"] = (256, 256, 1) - settings["action_space"] = SpaceTypes.MULTI_DISCRETE + settings = EnvironmentSettings1P() + settings.step_ratio = 1 + settings.frame_shape = (256, 256, 1) + settings.action_space = SpaceTypes.MULTI_DISCRETE # Recording settings home_dir = expanduser("~") game_id = "doapp" - recording_settings = {} - recording_settings["dataset_path"] = os.path.join(home_dir, "DIAMBRA/episode_recording", game_id if use_controller else "mock") - recording_settings["username"] = "alexpalms" + recording_settings = RecordingSettings() + recording_settings.dataset_path = os.path.join(home_dir, "DIAMBRA/episode_recording", game_id if use_controller else "mock") + recording_settings.username = "alexpalms" env = diambra.arena.make(game_id, settings, episode_recording_settings=recording_settings, render_mode="human") diff --git a/examples/multi_player_env.py b/examples/multi_player_env.py index faac138c..d8c23e0e 100644 --- a/examples/multi_player_env.py +++ b/examples/multi_player_env.py @@ -1,27 +1,24 @@ #!/usr/bin/env python3 import diambra.arena -from diambra.arena import SpaceTypes +from diambra.arena import SpaceTypes, EnvironmentSettings2P def main(): # Environment Settings - settings = {} + settings = EnvironmentSettings2P() # Multi Agents environment # --- Environment settings --- - # 2 Players game - settings["n_players"] = 2 - # If to use discrete or multi_discrete action space - settings["action_space"] = (SpaceTypes.DISCRETE, SpaceTypes.DISCRETE) + settings.action_space = (SpaceTypes.DISCRETE, SpaceTypes.DISCRETE) # --- Episode settings --- # Characters to be used, automatically extended with None for games # requiring to select more than one character (e.g. Tekken Tag Tournament) - settings["characters"] = ("Ryu", "Ken") + settings.characters = ("Ryu", "Ken") # Characters outfit - settings["outfits"] = (2, 2) + settings.outfits = (2, 2) env = diambra.arena.make("sfiii3n", settings, render_mode="human") diff --git a/examples/single_player_env.py b/examples/single_player_env.py index 0e288faf..f027111c 100644 --- a/examples/single_player_env.py +++ b/examples/single_player_env.py @@ -1,52 +1,49 @@ #!/usr/bin/env python3 import diambra.arena -from diambra.arena import SpaceTypes, Roles +from diambra.arena import SpaceTypes, Roles, EnvironmentSettings1P def main(): # Settings - settings = {} + settings = EnvironmentSettings1P() # Single agent environment # --- Environment settings --- - # Number players to use - settings["n_players"] = 1 # Single player env, "Standard RL" - # Number of steps performed by the game # for every environment step, bounds: [1, 6] - settings["step_ratio"] = 6 + settings.step_ratio = 6 # Native frame resize operation - settings["frame_shape"] = (128, 128, 0) # RBG with 128x128 size - # settings["frame_shape"] = (0, 0, 1) # Grayscale with original size - # settings["frame_shape"] = (0, 0, 0) # Deactivated (Original size RBG) + settings.frame_shape = (128, 128, 0) # RBG with 128x128 size + # settings.frame_shape = (0, 0, 1) # Grayscale with original size + # settings.frame_shape = (0, 0, 0) # Deactivated (Original size RBG) # If to use discrete or multi_discrete action space - settings["action_space"] = SpaceTypes.MULTI_DISCRETE + settings.action_space = SpaceTypes.MULTI_DISCRETE # --- Episode settings --- # Player role selection: P1 (left), P2 (right), None (50% P1, 50% P2) - settings["role"] = Roles.P1 + settings.role = Roles.P1 # Game continue logic (0.0 by default): # - [0.0, 1.0]: probability of continuing game at game over # - int((-inf, -1.0]): number of continues at game over # before episode to be considered done - settings["continue_game"] = 0.0 + settings.continue_game = 0.0 # If to show game final when game is completed - settings["show_final"] = False + settings.show_final = False # Game-specific options (see documentation for details) # Game difficulty level - settings["difficulty"] = 4 + settings.difficulty = 4 # Character to be used, automatically extended with None for games # requiring to select more than one character (e.g. Tekken Tag Tournament) - settings["characters"] = "Kasumi" + settings.characters = "Kasumi" # Character outfit - settings["outfits"] = 2 + settings.outfits = 2 env = diambra.arena.make("doapp", settings, render_mode="human") diff --git a/examples/wrappers_options.py b/examples/wrappers_options.py index ea8f46fb..037c6eeb 100644 --- a/examples/wrappers_options.py +++ b/examples/wrappers_options.py @@ -1,21 +1,22 @@ import diambra.arena -from diambra.arena import SpaceTypes +from diambra.arena import SpaceTypes, EnvironmentSettings1P, WrappersSettings def main(): # Environment settings - settings = {"n_players": 1, "action_space": SpaceTypes.MULTI_DISCRETE} + settings = EnvironmentSettings1P() + settings.action_space = SpaceTypes.MULTI_DISCRETE # Gym wrappers settings - wrappers_settings = {} + wrappers_settings = WrappersSettings() ### Generic wrappers # Number of no-Op actions to be executed # at the beginning of the episode (0 by default) - wrappers_settings["no_op_max"] = 0 + wrappers_settings.no_op_max = 0 # Number of steps for which the same action should be sent (1 by default) - wrappers_settings["sticky_actions"] = 1 + wrappers_settings.sticky_actions = 1 ### Reward wrappers @@ -24,16 +25,16 @@ def main(): # The normalization is performed by dividing the reward value # by the product of the factor times the value of the full health bar # reward = reward / (C * fullHealthBarValue) - wrappers_settings["reward_normalization"] = True - wrappers_settings["reward_normalization_factor"] = 0.5 + wrappers_settings.reward_normalization = True + wrappers_settings.reward_normalization_factor = 0.5 # If to clip rewards (False by default) - wrappers_settings["clip_rewards"] = False + wrappers_settings.clip_rewards = False ### Action space wrapper(s) # Limit the action space to single attack buttons, removing attack buttons combinations (False by default) - wrappers_settings["no_attack_buttons_combinations"] = False + wrappers_settings.no_attack_buttons_combinations = False ### Observation space wrapper(s) @@ -41,27 +42,27 @@ def main(): # WARNING: for speedup, avoid frame warping wrappers, # use environment's native frame wrapping through # "frame_shape" setting (see documentation for details). - wrappers_settings["frame_shape"] = (128, 128, 1) + wrappers_settings.frame_shape = (128, 128, 1) # Number of frames to be stacked together (1 by default) - wrappers_settings["frame_stack"] = 4 + wrappers_settings.frame_stack = 4 # Frames interval when stacking (1 by default) - wrappers_settings["dilation"] = 1 + wrappers_settings.dilation = 1 # Add last action to observation (False by default) - wrappers_settings["add_last_action_to_observation"] = True + wrappers_settings.add_last_action_to_observation = True # How many past actions to stack together (1 by default) # NOTE: needs "add_last_action_to_observation" wrapper to be active - wrappers_settings["actions_stack"] = 6 + wrappers_settings.actions_stack = 6 # If to scale observation numerical values (False by default) # optionally exclude images from normalization (False by default) # and optionally perform one-hot encoding also on discrete binary variables (False by default) - wrappers_settings["scale"] = True - wrappers_settings["exclude_image_scaling"] = True - wrappers_settings["process_discrete_binary"] = False + wrappers_settings.scale = True + wrappers_settings.exclude_image_scaling = True + wrappers_settings.process_discrete_binary = False # If to make the observation relative to the agent as a function to its role (P1 or P2) (deactivate by default) # i.e.: @@ -70,13 +71,13 @@ def main(): # are grouped under "agent_0" and "agent_1", and: # - Under "agent_0", "P1" nesting level becomes "own" and "P2" becomes "opp" # - Under "agent_1", "P1" nesting level becomes "opp" and "P2" becomes "own" - wrappers_settings["role_relative_observation"] = True + wrappers_settings.role_relative_observation = True # Flattening observation dictionary and filtering # a sub-set of the RAM states - wrappers_settings["flatten"] = True - wrappers_settings["filter_keys"] = ["stage", "timer", "action", "own_side", "opp_side", - "own_health", "opp_health", "opp_character"] + wrappers_settings.flatten = True + wrappers_settings.filter_keys = ["stage", "timer", "action", "own_side", "opp_side", + "own_health", "opp_health", "opp_character"] env = diambra.arena.make("doapp", settings, wrappers_settings, render_mode="human") diff --git a/tests/env_exec_interface.py b/tests/env_exec_interface.py index 0156315a..69bfcbd3 100755 --- a/tests/env_exec_interface.py +++ b/tests/env_exec_interface.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 import diambra.arena -from diambra.arena import SpaceTypes, Roles +from diambra.arena import SpaceTypes from diambra.arena.utils.gym_utils import env_spaces_summary, discrete_to_multi_discrete_action import random import numpy as np @@ -21,10 +21,10 @@ def env_exec(settings, options_list, wrappers_settings, episode_recording_settin wait_key = 0 n_rounds = 2 - if settings["game_id"] == "kof98umh": + if settings.game_id == "kof98umh": n_rounds = 3 - env = diambra.arena.make(settings["game_id"], settings, wrappers_settings, episode_recording_settings) + env = diambra.arena.make(settings.game_id, settings, wrappers_settings, episode_recording_settings) # Print environment obs and action spaces summary if args["log_output"] is True: @@ -46,11 +46,11 @@ def env_exec(settings, options_list, wrappers_settings, episode_recording_settin while curr_num_ep < max_num_ep: actions = env.action_space.sample() - if env.env_settings.n_players == 1: + if settings.n_players == 1: if no_action is True: actions = env.get_no_op_action() - if env.env_settings.action_space == SpaceTypes.DISCRETE: + if settings.action_space == SpaceTypes.DISCRETE: move_action, att_action = discrete_to_multi_discrete_action(actions, env.n_actions[0]) else: move_action, att_action = actions[0], actions[1] @@ -62,8 +62,8 @@ def env_exec(settings, options_list, wrappers_settings, episode_recording_settin if no_action is True: actions["agent_0"] = env.get_no_op_action()["agent_0"] - for idx in range(env.env_settings.n_players): - if env.env_settings.action_space[idx] == SpaceTypes.DISCRETE: + for idx in range(settings.n_players): + if settings.action_space[idx] == SpaceTypes.DISCRETE: move_action, att_action = discrete_to_multi_discrete_action(actions["agent_{}".format(idx)], env.n_actions[0]) else: move_action, att_action = actions["agent_{}".format(idx)][0], actions["agent_{}".format(idx)][1] @@ -99,11 +99,11 @@ def env_exec(settings, options_list, wrappers_settings, episode_recording_settin if info["round_done"]: # Side check when no wrappers active: - if len(wrappers_settings) == 0: + if (wrappers_settings.role_relative_observation is False and wrappers_settings.flatten is False): if (observation["P1"]["side"] != 0.0 or observation["P2"]["side"] != 1.0): raise RuntimeError("Wrong starting sides:", observation["P1"]["side"], observation["P2"]["side"]) - elif ("frame_shape" in wrappers_settings.keys() and wrappers_settings["frame_shape"][2] == 1): + elif (wrappers_settings.frame_shape is not None and wrappers_settings.frame_shape[2] == 1): # Frames equality check frame = observation["frame"] diff --git a/tests/test_gym_settings.py b/tests/test_gym_settings.py index dc173fba..8e6d6a57 100644 --- a/tests/test_gym_settings.py +++ b/tests/test_gym_settings.py @@ -2,7 +2,7 @@ import pytest import random import diambra.arena -from diambra.arena import SpaceTypes, Roles +from diambra.arena import SpaceTypes, Roles, EnvironmentSettings1P, EnvironmentSettings2P from diambra.arena.utils.gym_utils import available_games from pytest_utils import generate_pytest_decorator_input from diambra.arena.utils.engine_mock import load_mocker @@ -14,10 +14,10 @@ # -s (show output) # -k "expression" (filter tests using case-insensitive with parts of the test name and/or parameters values combined with boolean operators, e.g. "wrappers and doapp") -def func(settings, wrappers_settings, traj_rec_settings, mocker): +def func(settings, mocker): load_mocker(mocker) try: - env = diambra.arena.make(settings["game_id"], settings, wrappers_settings, traj_rec_settings) + env = diambra.arena.make(settings.game_id, settings) env.reset() env.close() print("COMPLETED SUCCESSFULLY!") @@ -76,30 +76,29 @@ def test_gym_settings(game_id, n_players, frame_shape, step_ratio, action_space, outfits = random.choice(outfits_range) # Env settings - settings = {} - settings["game_id"] = game_id - settings["frame_shape"] = frame_shape - settings["step_ratio"] = step_ratio - settings["n_players"] = n_players - settings["action_space"] = (action_space, action_space) + if (n_players == 1): + settings = EnvironmentSettings1P() + else: + settings = EnvironmentSettings2P() + settings.game_id = game_id + settings.frame_shape = frame_shape + settings.step_ratio = step_ratio + settings.action_space = (action_space, action_space) - settings["difficulty"] = difficulty - settings["continue_game"] = continue_game - settings["tower"] = tower + settings.difficulty = difficulty + settings.continue_game = continue_game + settings.tower = tower - settings["role"] = (role[0], role[1]) - settings["characters"] = (characters[0], characters[1]) - settings["outfits"] = (outfits, outfits) - settings["super_art"] = (super_art, super_art) - settings["fighting_style"] = (fighting_style, fighting_style) - settings["ultimate_style"] = (ultimate_style, ultimate_style) + settings.role = (role[0], role[1]) + settings.characters = (characters[0], characters[1]) + settings.outfits = (outfits, outfits) + settings.super_art = (super_art, super_art) + settings.fighting_style = (fighting_style, fighting_style) + settings.ultimate_style = (ultimate_style, ultimate_style) - if settings["n_players"] != 2: + if n_players != 2: for key in ["action_space", "role", "characters" , "outfits", "super_art", "fighting_style", "ultimate_style"]: - settings[key] = settings[key][0] + setattr(settings, key, getattr(settings, key)[0]) - wrappers_settings = {} - traj_rec_settings = {} - - assert func(settings, wrappers_settings, traj_rec_settings, mocker) == expected + assert func(settings, mocker) == expected diff --git a/tests/test_random.py b/tests/test_random.py index f8dbbc42..5e7b6550 100755 --- a/tests/test_random.py +++ b/tests/test_random.py @@ -3,8 +3,7 @@ from env_exec_interface import env_exec import random from diambra.arena.utils.engine_mock import load_mocker -import diambra.arena -from diambra.arena import SpaceTypes, Roles +from diambra.arena import SpaceTypes, Roles, EnvironmentSettings1P, EnvironmentSettings2P, WrappersSettings, RecordingSettings # Example Usage: # pytest @@ -29,13 +28,15 @@ def func(game_id, n_players, action_space, frame_shape, wrappers_settings, try: # Settings - settings = {} - settings["game_id"] = game_id - settings["frame_shape"] = frame_shape - settings["n_players"] = n_players - settings["action_space"] = (action_space, action_space) - if settings["n_players"] == 1: - settings["action_space"] = settings["action_space"][0] + if (n_players == 1): + settings = EnvironmentSettings1P() + else: + settings = EnvironmentSettings2P() + settings.game_id = game_id + settings.frame_shape = frame_shape + settings.action_space = (action_space, action_space) + if settings.n_players == 1: + settings.action_space = settings.action_space[0] # Options (settings to change at reset) options_list = [] @@ -43,14 +44,14 @@ def func(game_id, n_players, action_space, frame_shape, wrappers_settings, continue_games = [-1.0, 0.0, 0.3] for role in roles: role_value = (role[0], role[1]) - if settings["n_players"] == 1: + if settings.n_players == 1: role_value = role_value[0] for continue_val in continue_games: options_list.append({"role": role_value, "continue_game": continue_val}) else: options_list.append({"role": role_value}) - return env_exec(settings, options_list, wrappers_settings, {}, args) + return env_exec(settings, options_list, wrappers_settings, RecordingSettings(), args) except Exception as e: print(e) return 1 @@ -66,8 +67,7 @@ def func(game_id, n_players, action_space, frame_shape, wrappers_settings, def test_random_gym_mock(game_id, n_players, action_space, mocker): frame_shape = random.choice([(128, 128, 1), (256, 256, 0)]) use_mock_env = True - wrappers_settings = {} - assert func(game_id, n_players, action_space, frame_shape, wrappers_settings, + assert func(game_id, n_players, action_space, frame_shape, WrappersSettings(), no_action_probability, use_mock_env, mocker) == 0 @pytest.mark.parametrize("game_id", game_ids) @@ -78,24 +78,24 @@ def test_random_wrappers_mock(game_id, n_players, action_space, mocker): use_mock_env = True # Env wrappers settings - wrappers_settings = {} - wrappers_settings["no_op_max"] = 0 - wrappers_settings["sticky_actions"] = 1 - wrappers_settings["frame_shape"] = random.choice([(128, 128, 1), (256, 256, 0)]) - wrappers_settings["reward_normalization"] = True - wrappers_settings["clip_rewards"] = False - wrappers_settings["frame_stack"] = 4 - wrappers_settings["dilation"] = 1 - wrappers_settings["add_last_action_to_observation"] = True - wrappers_settings["actions_stack"] = 12 - wrappers_settings["scale"] = True - wrappers_settings["role_relative_observation"] = True - wrappers_settings["flatten"] = True + wrappers_settings = WrappersSettings() + wrappers_settings.no_op_max = 0 + wrappers_settings.sticky_actions = 1 + wrappers_settings.frame_shape = random.choice([(128, 128, 1), (256, 256, 0)]) + wrappers_settings.reward_normalization = True + wrappers_settings.clip_rewards = False + wrappers_settings.frame_stack = 4 + wrappers_settings.dilation = 1 + wrappers_settings.add_last_action_to_observation = True + wrappers_settings.actions_stack = 12 + wrappers_settings.scale = True + wrappers_settings.role_relative_observation = True + wrappers_settings.flatten = True suffix = "" if n_players == 2: suffix = "agent_0_" - wrappers_settings["filter_keys"] = ["stage", "timer", suffix + "own_side", suffix + "opp_side", - suffix + "opp_character", suffix + "action"] + wrappers_settings.filter_keys = ["stage", "timer", suffix + "own_side", suffix + "opp_side", + suffix + "opp_character", suffix + "action"] assert func(game_id, n_players, action_space, frame_shape, wrappers_settings, no_action_probability, use_mock_env, mocker) == 0 @@ -108,24 +108,24 @@ def test_random_integration(game_id, n_players, action_space, mocker): use_mock_env = False # Env wrappers settings - wrappers_settings = {} - wrappers_settings["no_op_max"] = 0 - wrappers_settings["sticky_actions"] = 1 - wrappers_settings["frame_shape"] = (128, 128, 1) - wrappers_settings["reward_normalization"] = True - wrappers_settings["clip_rewards"] = False - wrappers_settings["frame_stack"] = 4 - wrappers_settings["dilation"] = 1 - wrappers_settings["add_last_action_to_observation"] = True - wrappers_settings["actions_stack"] = 12 - wrappers_settings["scale"] = True - wrappers_settings["role_relative_observation"] = True - wrappers_settings["flatten"] = True + wrappers_settings = WrappersSettings() + wrappers_settings.no_op_max = 0 + wrappers_settings.sticky_actions = 1 + wrappers_settings.frame_shape = (128, 128, 1) + wrappers_settings.reward_normalization = True + wrappers_settings.clip_rewards = False + wrappers_settings.frame_stack = 4 + wrappers_settings.dilation = 1 + wrappers_settings.add_last_action_to_observation = True + wrappers_settings.actions_stack = 12 + wrappers_settings.scale = True + wrappers_settings.role_relative_observation = True + wrappers_settings.flatten = True suffix = "" if n_players == 2: suffix = "agent_0_" - wrappers_settings["filter_keys"] = ["stage", "timer", suffix + "own_side", suffix + "opp_side", - suffix + "opp_character", suffix + "action"] + wrappers_settings.filter_keys = ["stage", "timer", suffix + "own_side", suffix + "opp_side", + suffix + "opp_character", suffix + "action"] assert func(game_id, n_players, action_space, frame_shape, wrappers_settings, no_action_probability, use_mock_env, mocker) == 0 \ No newline at end of file diff --git a/tests/test_recording_settings.py b/tests/test_recording_settings.py index 83171942..dadc39c1 100644 --- a/tests/test_recording_settings.py +++ b/tests/test_recording_settings.py @@ -3,7 +3,7 @@ from os.path import expanduser import os import diambra.arena -from diambra.arena import SpaceTypes +from diambra.arena import SpaceTypes, EnvironmentSettings1P, EnvironmentSettings2P, WrappersSettings, RecordingSettings from pytest_utils import generate_pytest_decorator_input from diambra.arena.utils.engine_mock import load_mocker from diambra.arena.utils.gym_utils import available_games @@ -18,7 +18,7 @@ def func(settings, wrappers_settings, episode_recording_settings, mocker): load_mocker(mocker) try: - env = diambra.arena.make(settings["game_id"], settings, wrappers_settings, episode_recording_settings) + env = diambra.arena.make(settings.game_id, settings, wrappers_settings, episode_recording_settings) env.close() print("COMPLETED SUCCESSFULLY!") @@ -52,27 +52,28 @@ def pytest_generate_tests(metafunc): @pytest.mark.parametrize("n_players", [1, 2]) @pytest.mark.parametrize("action_space", [SpaceTypes.DISCRETE, SpaceTypes.MULTI_DISCRETE]) def test_settings_recording(game_id ,username, dataset_path, n_players, action_space, expected, mocker): - # Env settings - settings = {} - settings["game_id"] = game_id - settings["n_players"] = n_players - settings["action_space"] = action_space + if (n_players == 1): + settings = EnvironmentSettings1P() + else: + settings = EnvironmentSettings2P() + settings.game_id = game_id + settings.action_space = action_space if n_players == 2: - settings["action_space"] = (action_space, action_space) + settings.action_space = (action_space, action_space) # Env wrappers settings - wrappers_settings = {} - wrappers_settings["frame_shape"] = (128, 128, 1) - wrappers_settings["reward_normalization"] = True - wrappers_settings["frame_stack"] = 4 - wrappers_settings["add_last_action_to_observation"] = True - wrappers_settings["actions_stack"] = 12 - wrappers_settings["scale"] = True + wrappers_settings = WrappersSettings() + wrappers_settings.frame_shape = (128, 128, 1) + wrappers_settings.reward_normalization = True + wrappers_settings.frame_stack = 4 + wrappers_settings.add_last_action_to_observation = True + wrappers_settings.actions_stack = 12 + wrappers_settings.scale = True # Recording settings - episode_recording_settings = {} - episode_recording_settings["username"] = username - episode_recording_settings["dataset_path"] = dataset_path + episode_recording_settings = RecordingSettings() + episode_recording_settings.username = username + episode_recording_settings.dataset_path = dataset_path assert func(settings, wrappers_settings, episode_recording_settings, mocker) == expected diff --git a/tests/test_speed.py b/tests/test_speed.py index 41fb0a92..5494e3f9 100644 --- a/tests/test_speed.py +++ b/tests/test_speed.py @@ -2,6 +2,7 @@ import pytest import time import diambra.arena +from diambra.arena import EnvironmentSettings1P, EnvironmentSettings2P, WrappersSettings from diambra.arena.utils.engine_mock import load_mocker import numpy as np import warnings @@ -16,8 +17,10 @@ def func(n_players, wrappers_settings, target_speed, mocker): load_mocker(mocker) try: # Settings - settings = {} - settings["n_players"] = n_players + if (n_players == 1): + settings = EnvironmentSettings1P() + else: + settings = EnvironmentSettings2P() env = diambra.arena.make("doapp", settings, wrappers_settings) observation, info = env.reset() @@ -62,30 +65,29 @@ def func(n_players, wrappers_settings, target_speed, mocker): @pytest.mark.parametrize("n_players", n_players) def test_speed_gym(n_players, mocker): - wrappers_settings = {} - assert func(n_players, wrappers_settings, target_speeds[0], mocker) == 0 + assert func(n_players, WrappersSettings(), target_speeds[0], mocker) == 0 @pytest.mark.parametrize("n_players", n_players) def test_speed_wrappers(n_players, mocker): # Env wrappers settings - wrappers_settings = {} - wrappers_settings["no_op_max"] = 0 - wrappers_settings["sticky_actions"] = 1 - wrappers_settings["reward_normalization"] = True - wrappers_settings["clip_rewards"] = False - wrappers_settings["frame_stack"] = 4 - wrappers_settings["dilation"] = 1 - wrappers_settings["add_last_action_to_observation"] = True - wrappers_settings["actions_stack"] = 12 - wrappers_settings["scale"] = True - wrappers_settings["role_relative_observation"] = True - wrappers_settings["flatten"] = True + wrappers_settings = WrappersSettings() + wrappers_settings.no_op_max = 0 + wrappers_settings.sticky_actions = 1 + wrappers_settings.reward_normalization = True + wrappers_settings.clip_rewards = False + wrappers_settings.frame_stack = 4 + wrappers_settings.dilation = 1 + wrappers_settings.add_last_action_to_observation = True + wrappers_settings.actions_stack = 12 + wrappers_settings.scale = True + wrappers_settings.role_relative_observation = True + wrappers_settings.flatten = True suffix = "" if n_players == 2: suffix = "agent_0_" - wrappers_settings["filter_keys"] = ["stage", "timer", suffix + "own_side", suffix + "opp_side", - suffix + "opp_character", suffix + "action"] + wrappers_settings.filter_keys = ["stage", "timer", suffix + "own_side", suffix + "opp_side", + suffix + "opp_character", suffix + "action"] assert func(n_players, wrappers_settings, target_speeds[1], mocker) == 0 diff --git a/tests/test_wrappers_settings.py b/tests/test_wrappers_settings.py index 9a19ad3f..93f73b08 100644 --- a/tests/test_wrappers_settings.py +++ b/tests/test_wrappers_settings.py @@ -2,7 +2,7 @@ import pytest import diambra.arena from pytest_utils import generate_pytest_decorator_input -from diambra.arena import SpaceTypes +from diambra.arena import SpaceTypes, EnvironmentSettings1P, EnvironmentSettings2P, WrappersSettings from diambra.arena.utils.engine_mock import load_mocker from diambra.arena.utils.gym_utils import available_games @@ -13,10 +13,10 @@ # -s (show output) # -k "expression" (filter tests using case-insensitive with parts of the test name and/or parameters values combined with boolean operators, e.g. "wrappers and doapp") -def func(settings, wrappers_settings, episode_recording_settings, mocker): +def func(settings, wrappers_settings, mocker): load_mocker(mocker) try: - env = diambra.arena.make(settings["game_id"], settings, wrappers_settings, episode_recording_settings) + env = diambra.arena.make(settings.game_id, settings, wrappers_settings) env.close() print("COMPLETED SUCCESSFULLY!") @@ -87,34 +87,33 @@ def test_wrappers_settings(game_id, step_ratio, n_players, action_space, no_op_m flatten, filter_keys, wrappers, expected, mocker): # Env settings - settings = {} - settings["game_id"] = game_id - settings["step_ratio"] = step_ratio - settings["n_players"] = n_players - settings["action_space"] = action_space + if (n_players == 1): + settings = EnvironmentSettings1P() + else: + settings = EnvironmentSettings2P() + settings.game_id = game_id + settings.step_ratio = step_ratio + settings.action_space = action_space if n_players == 2: - settings["action_space"] = (action_space, action_space) + settings.action_space = (action_space, action_space) # Env wrappers settings - wrappers_settings = {} - wrappers_settings["no_op_max"] = no_op_max - wrappers_settings["sticky_actions"] = sticky_actions - wrappers_settings["reward_normalization"] = reward_normalization - wrappers_settings["reward_normalization_factor"] = reward_normalization_factor - wrappers_settings["clip_rewards"] = clip_rewards - wrappers_settings["no_attack_buttons_combinations"] = no_attack_buttons_combinations - wrappers_settings["frame_shape"] = frame_shape - wrappers_settings["frame_stack"] = frame_stack - wrappers_settings["dilation"] = dilation - wrappers_settings["add_last_action_to_observation"] = add_last_action_to_observation - wrappers_settings["actions_stack"] = 1 if add_last_action_to_observation is False and expected == 0 else actions_stack - wrappers_settings["scale"] = scale - wrappers_settings["role_relative_observation"] = role_relative_observation - wrappers_settings["flatten"] = flatten - wrappers_settings["filter_keys"] = filter_keys - wrappers_settings["wrappers"] = wrappers + wrappers_settings = WrappersSettings() + wrappers_settings.no_op_max = no_op_max + wrappers_settings.sticky_actions = sticky_actions + wrappers_settings.reward_normalization = reward_normalization + wrappers_settings.reward_normalization_factor = reward_normalization_factor + wrappers_settings.clip_rewards = clip_rewards + wrappers_settings.no_attack_buttons_combinations = no_attack_buttons_combinations + wrappers_settings.frame_shape = frame_shape + wrappers_settings.frame_stack = frame_stack + wrappers_settings.dilation = dilation + wrappers_settings.add_last_action_to_observation = add_last_action_to_observation + wrappers_settings.actions_stack = 1 if add_last_action_to_observation is False and expected == 0 else actions_stack + wrappers_settings.scale = scale + wrappers_settings.role_relative_observation = role_relative_observation + wrappers_settings.flatten = flatten + wrappers_settings.filter_keys = filter_keys + wrappers_settings.wrappers = wrappers - # Recording settings - episode_recording_settings = {} - - assert func(settings, wrappers_settings, episode_recording_settings, mocker) == expected + assert func(settings, wrappers_settings, mocker) == expected