Skip to content

Commit

Permalink
Rework settings exposing classes directly instead of dicts
Browse files Browse the repository at this point in the history
  • Loading branch information
alexpalms committed Sep 22, 2023
1 parent 4a3b27a commit 32120db
Show file tree
Hide file tree
Showing 15 changed files with 263 additions and 249 deletions.
1 change: 1 addition & 0 deletions diambra/arena/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from diambra.engine import SpaceTypes, Roles
from diambra.engine import model
from .env_settings import EnvironmentSettings1P, EnvironmentSettings2P, WrappersSettings, RecordingSettings, load_settings_flat_dict
from .make_env import make
from .utils.gym_utils import available_games, game_sha_256, check_game_sha_256, get_num_envs
73 changes: 50 additions & 23 deletions diambra/arena/env_settings.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Union, List, Tuple, Any, Dict
from diambra.arena.utils.gym_utils import available_games
from diambra.arena import SpaceTypes, Roles
import numpy as np
import random
from diambra.engine import model
import time
from dacite import from_dict, Config

def load_settings_flat_dict(target_class, settings: dict):
return from_dict(target_class, settings, config=Config(strict=True))

MAX_VAL = float("inf")
MIN_VAL = float("-inf")
Expand All @@ -22,6 +26,13 @@ def check_val_in_list(key, value, valid_list):
assert (value in valid_list), error_message
assert (type(value)==type(valid_list[valid_list.index(value)])), error_message

def check_type(key, value, expected_type, admit_none=True):
error_message = "ERROR: \"{}\" ({}) is of type {}, admissible types are {}".format(key, value, type(value), expected_type)
if value is not None:
assert isinstance(value, expected_type), error_message
else:
assert admit_none==True, "ERROR: \"{}\" ({}) cannot be NoneType".format(key, value)

def check_space_type(key, value, valid_list):
error_message = "ERROR: \"{}\" ({}) admissible values are {}".format(key, SpaceTypes.Name(value), [SpaceTypes.Name(elem) for elem in valid_list])
assert (value in valid_list), error_message
Expand All @@ -37,15 +48,14 @@ class EnvironmentSettings:
games_dict = None

# Environment settings
game_id: str
game_id: str = "doapp"
frame_shape: Tuple[int, int, int] = (0, 0, 0)
step_ratio: int = 6
n_players: int = 1
disable_keyboard: bool = True
disable_joystick: bool = True
render_mode: Union[None, str] = None
rank: int = 0
env_address: str = "localhost:50051"
env_address: str = None
grpc_timeout: int = 600

# Episode settings
Expand All @@ -55,7 +65,7 @@ class EnvironmentSettings:
show_final: bool = False
tower: Union[None, int] = 3 # UMK3 Specific

# Bookeeping variables
# Bookkeeping variables
_last_seed: Union[None, int] = None
pb_model: model = None

Expand Down Expand Up @@ -161,28 +171,30 @@ def _sample_characters(self, n_characters=3):

def _sanity_check(self):
if self.env_info is None or self.games_dict is None:
raise("EnvironmentSettings class not correctly initialized")
raise Exception("EnvironmentSettings class not correctly initialized")

# TODO: consider using typing.Literal to specify lists of admissible values: NOTE: It requires Python 3.8+
check_val_in_list("game_id", self.game_id, list(self.games_dict.keys()))
if self.render_mode is not None:
check_val_in_list("render_mode", self.render_mode, ["human", "rgb_array"])
check_num_in_range("step_ratio", self.step_ratio, [1, 6])
check_num_in_range("frame_shape[0]", self.frame_shape[0], [0, MAX_FRAME_RES])
check_num_in_range("frame_shape[1]", self.frame_shape[1], [0, MAX_FRAME_RES])
if (min(self.frame_shape[0], self.frame_shape[1]) == 0 and
max(self.frame_shape[0], self.frame_shape[1]) != 0):
raise Exception("\"frame_shape[0] and frame_shape[1]\" must be either both different from or equal to 0")
check_val_in_list("frame_shape[2]", self.frame_shape[2], [0, 1])
check_num_in_range("step_ratio", self.step_ratio, [1, 6])
check_num_in_range("grpc_timeout", self.grpc_timeout, [0, 3600])
if self.render_mode is not None:
check_val_in_list("render_mode", self.render_mode, ["human", "rgb_array"])
check_num_in_range("rank", self.rank, [0, MAX_VAL])
check_type("env_address", self.env_address, str)
check_num_in_range("grpc_timeout", self.grpc_timeout, [0, 3600])

if self.seed is not None:
check_num_in_range("seed", self.seed, [-1, MAX_VAL])
difficulty_admissible_values = list(range(self.env_info.difficulty_bounds.min, self.env_info.difficulty_bounds.max + 1))
difficulty_admissible_values.append(None)
check_val_in_list("difficulty", self.difficulty, difficulty_admissible_values)
check_num_in_range("continue_game", self.continue_game, [MIN_VAL, 1.0])
check_type("show_final", self.show_final, bool)
check_val_in_list("tower", self.tower, [None, 1, 2, 3, 4])

def _process_random_values(self):
Expand All @@ -194,8 +206,8 @@ def _process_random_values(self):
@dataclass
class EnvironmentSettings1P(EnvironmentSettings):
"""Single Agent Environment Settings Class"""

# Env settings
n_players: int = 1
action_space: int = SpaceTypes.MULTI_DISCRETE

# Episode settings
Expand Down Expand Up @@ -270,8 +282,9 @@ def _get_player_specific_values(self):

@dataclass
class EnvironmentSettings2P(EnvironmentSettings):
"""Single Agent Environment Settings Class"""
"""Multi Agent Environment Settings Class"""
# Env Settings
n_players: int = 2
action_space: Tuple[int, int] = (SpaceTypes.MULTI_DISCRETE, SpaceTypes.MULTI_DISCRETE)

# Episode Settings
Expand Down Expand Up @@ -382,28 +395,42 @@ class WrappersSettings:
process_discrete_binary: bool = False
role_relative_observation: bool = False
flatten: bool = False
filter_keys: List[str] = None
wrappers: List[List[Any]] = None
filter_keys: List[str] = field(default_factory=list)
wrappers: List[List[Any]] = field(default_factory=list)

def sanity_check(self):
check_num_in_range("no_op_max", self.no_op_max, [0, 12])
check_num_in_range("sticky_actions", self.sticky_actions, [1, 12])
check_num_in_range("frame_stack", self.frame_stack, [1, MAX_STACK_VALUE])
check_num_in_range("dilation", self.dilation, [1, MAX_STACK_VALUE])
actions_stack_bounds = [1, 1]
if self.add_last_action_to_observation is True:
actions_stack_bounds = [1, MAX_STACK_VALUE]
check_num_in_range("actions_stack", self.actions_stack, actions_stack_bounds)
check_type("reward_normalization", self.reward_normalization, bool, admit_none=False)
check_num_in_range("reward_normalization_factor", self.reward_normalization_factor, [0.0, 1000000])

check_type("clip_rewards", self.clip_rewards, bool, admit_none=False)
check_type("no_attack_buttons_combinations", self.no_attack_buttons_combinations, bool, admit_none=False)
check_val_in_list("frame_shape[2]", self.frame_shape[2], [0, 1, 3])
check_num_in_range("frame_shape[0]", self.frame_shape[0], [0, MAX_FRAME_RES])
check_num_in_range("frame_shape[1]", self.frame_shape[1], [0, MAX_FRAME_RES])
if (min(self.frame_shape[0], self.frame_shape[1]) == 0 and
max(self.frame_shape[0], self.frame_shape[1]) != 0):
raise Exception("\"frame_shape[0] and frame_shape[1]\" must be both different from 0")
check_num_in_range("frame_stack", self.frame_stack, [1, MAX_STACK_VALUE])
check_num_in_range("dilation", self.dilation, [1, MAX_STACK_VALUE])
check_type("add_last_action_to_observation", self.add_last_action_to_observation, bool, admit_none=False)
actions_stack_bounds = [1, 1]
if self.add_last_action_to_observation is True:
actions_stack_bounds = [1, MAX_STACK_VALUE]
check_num_in_range("actions_stack", self.actions_stack, actions_stack_bounds)
check_type("scale", self.scale, bool, admit_none=False)
check_type("exclude_image_scaling", self.exclude_image_scaling, bool, admit_none=False)
check_type("process_discrete_binary", self.process_discrete_binary, bool, admit_none=False)
check_type("role_relative_observation", self.role_relative_observation, bool, admit_none=False)
check_type("flatten", self.flatten, bool, admit_none=False)
check_type("filter_keys", self.filter_keys, list, admit_none=False)
check_type("wrappers", self.wrappers, list, admit_none=False)

@dataclass
class RecordingSettings:
dataset_path: str = "./"
username: str = "username"
dataset_path: Union[None, str] = None
username: Union[None, str] = None

def sanity_check(self):
check_type("dataset_path", self.dataset_path, str)
check_type("username", self.username, str)
33 changes: 13 additions & 20 deletions diambra/arena/make_env.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import os
import logging
from dacite import from_dict, Config
from diambra.arena.arena_gym import DiambraGym1P, DiambraGym2P
from diambra.arena.wrappers.arena_wrappers import env_wrapping
from diambra.arena.env_settings import EnvironmentSettings1P, EnvironmentSettings2P, WrappersSettings, RecordingSettings
from diambra.arena import EnvironmentSettings1P, EnvironmentSettings2P, WrappersSettings, RecordingSettings
from diambra.arena.wrappers.episode_recording import EpisodeRecorder
from typing import Union

def make(game_id, env_settings:dict={}, wrappers_settings:dict={}, episode_recording_settings:dict={},
render_mode=None, rank=0, log_level=logging.INFO):
def make(game_id, env_settings: Union[EnvironmentSettings1P, EnvironmentSettings2P]=EnvironmentSettings1P(),
wrappers_settings: WrappersSettings=WrappersSettings(), episode_recording_settings: RecordingSettings=RecordingSettings(),
render_mode: str=None, rank: int=0, log_level=logging.INFO):
"""
Create a wrapped environment.
:param seed: (int) the initial seed for RNG
Expand All @@ -19,8 +20,8 @@ def make(game_id, env_settings:dict={}, wrappers_settings:dict={}, episode_recor
logger = logging.getLogger(__name__)

# Include game_id and render_mode in env_settings
env_settings["game_id"] = game_id
env_settings["render_mode"] = render_mode
env_settings.game_id = game_id
env_settings.render_mode = render_mode

# Check if DIAMBRA_ENVS var present
env_addresses = os.getenv("DIAMBRA_ENVS", "").split()
Expand All @@ -31,20 +32,13 @@ def make(game_id, env_settings:dict={}, wrappers_settings:dict={}, episode_recor
"# of env servers: {}".format(len(env_addresses)),
"# rank of client: {} (0-based index)".format(rank))
else: # If not present, set default value
if "env_address" not in env_settings:
if env_settings.env_address is None:
env_addresses = ["localhost:50051"]
else:
env_addresses = [env_settings["env_address"]]
env_addresses = [env_settings.env_address]

env_settings["env_address"] = env_addresses[rank]
env_settings["rank"] = rank

# Checking settings and setting up default ones
if "n_players" in env_settings.keys() and env_settings["n_players"] == 2:
env_settings = from_dict(EnvironmentSettings2P, env_settings, config=Config(strict=True))
else:
env_settings["n_players"] = 1
env_settings = from_dict(EnvironmentSettings1P, env_settings, config=Config(strict=True))
env_settings.env_address = env_addresses[rank]
env_settings.rank = rank

# Make environment
if env_settings.n_players == 1: # 1P Mode
Expand All @@ -53,12 +47,11 @@ def make(game_id, env_settings:dict={}, wrappers_settings:dict={}, episode_recor
env = DiambraGym2P(env_settings)

# Apply episode recorder wrapper
if len(episode_recording_settings) != 0:
episode_recording_settings = from_dict(RecordingSettings, episode_recording_settings, config=Config(strict=True))
if episode_recording_settings.dataset_path is not None:
episode_recording_settings.sanity_check()
env = EpisodeRecorder(env, episode_recording_settings)

# Apply environment wrappers
wrappers_settings = from_dict(WrappersSettings, wrappers_settings, config=Config(strict=True))
wrappers_settings.sanity_check()
env = env_wrapping(env, wrappers_settings)

Expand Down
7 changes: 3 additions & 4 deletions diambra/arena/wrappers/arena_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,8 @@ def env_wrapping(env, wrappers_settings: WrappersSettings):
if wrappers_settings.flatten is True:
env = FlattenFilterDictObs(env, wrappers_settings.filter_keys)

# Apply all additional wrappers in sequence
if wrappers_settings.wrappers is not None:
for wrapper in wrappers_settings.wrappers:
env = wrapper[0](env, **wrapper[1])
# Apply all additional wrappers in sequence:
for wrapper in wrappers_settings.wrappers:
env = wrapper[0](env, **wrapper[1])

return env
10 changes: 5 additions & 5 deletions diambra/arena/wrappers/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,15 +333,15 @@ def __init__(self, env, filter_keys):
gym.ObservationWrapper.__init__(self, env)

self.filter_keys = filter_keys
if (filter_keys is not None):
if len(filter_keys) != 0:
self.filter_keys = list(set(filter_keys))
if "frame" not in filter_keys:
self.filter_keys += ["frame"]

original_obs_space_keys = (flatten_filter_obs_space_func(self.observation_space, None)).keys()
original_obs_space_keys = (flatten_filter_obs_space_func(self.observation_space, [])).keys()
self.observation_space = gym.spaces.Dict(flatten_filter_obs_space_func(self.observation_space, self.filter_keys))

if filter_keys is not None:
if len(filter_keys) != 0:
if (sorted(self.observation_space.spaces.keys()) != sorted(self.filter_keys)):
raise Exception("Specified observation key(s) not found:",
" Available key(s):", sorted(original_obs_space_keys),
Expand Down Expand Up @@ -371,7 +371,7 @@ def visit(subdict, flattened_dict, partial_key, check_method):
if check_method(new_key):
flattened_dict[new_key] = v

if filter_keys is not None:
if len(filter_keys) != 0:
visit(input_dictionary, flattened_dict, _FLAG_FIRST, check_filter)
else:
visit(input_dictionary, flattened_dict, _FLAG_FIRST, dummy_check)
Expand All @@ -397,7 +397,7 @@ def visit(subdict, flattened_dict, partial_key, check_method):
if check_method(new_key):
flattened_dict[new_key] = v

if filter_keys is not None:
if len(filter_keys) != 0:
visit(input_dictionary, flattened_dict, _FLAG_FIRST, check_filter)
else:
visit(input_dictionary, flattened_dict, _FLAG_FIRST, dummy_check)
Expand Down
18 changes: 8 additions & 10 deletions examples/episode_recording.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,23 @@
import os
from os.path import expanduser
import diambra.arena
from diambra.arena import SpaceTypes
from diambra.arena import SpaceTypes, EnvironmentSettings1P, RecordingSettings
from diambra.arena.utils.controller import get_diambra_controller
import argparse

def main(use_controller):
# Environment Settings
settings = {}
settings["n_players"] = 1
settings["role"] = None
settings["step_ratio"] = 1
settings["frame_shape"] = (256, 256, 1)
settings["action_space"] = SpaceTypes.MULTI_DISCRETE
settings = EnvironmentSettings1P()
settings.step_ratio = 1
settings.frame_shape = (256, 256, 1)
settings.action_space = SpaceTypes.MULTI_DISCRETE

# Recording settings
home_dir = expanduser("~")
game_id = "doapp"
recording_settings = {}
recording_settings["dataset_path"] = os.path.join(home_dir, "DIAMBRA/episode_recording", game_id if use_controller else "mock")
recording_settings["username"] = "alexpalms"
recording_settings = RecordingSettings()
recording_settings.dataset_path = os.path.join(home_dir, "DIAMBRA/episode_recording", game_id if use_controller else "mock")
recording_settings.username = "alexpalms"

env = diambra.arena.make(game_id, settings, episode_recording_settings=recording_settings, render_mode="human")

Expand Down
13 changes: 5 additions & 8 deletions examples/multi_player_env.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,24 @@
#!/usr/bin/env python3
import diambra.arena
from diambra.arena import SpaceTypes
from diambra.arena import SpaceTypes, EnvironmentSettings2P

def main():
# Environment Settings
settings = {}
settings = EnvironmentSettings2P() # Multi Agents environment

# --- Environment settings ---

# 2 Players game
settings["n_players"] = 2

# If to use discrete or multi_discrete action space
settings["action_space"] = (SpaceTypes.DISCRETE, SpaceTypes.DISCRETE)
settings.action_space = (SpaceTypes.DISCRETE, SpaceTypes.DISCRETE)

# --- Episode settings ---

# Characters to be used, automatically extended with None for games
# requiring to select more than one character (e.g. Tekken Tag Tournament)
settings["characters"] = ("Ryu", "Ken")
settings.characters = ("Ryu", "Ken")

# Characters outfit
settings["outfits"] = (2, 2)
settings.outfits = (2, 2)

env = diambra.arena.make("sfiii3n", settings, render_mode="human")

Expand Down
Loading

0 comments on commit 32120db

Please sign in to comment.