Skip to content

Commit

Permalink
Relax requirements, add logging in controller, refactor enumerating
Browse files Browse the repository at this point in the history
  • Loading branch information
alexpalms committed Sep 13, 2023
1 parent b182ccc commit fba8bed
Show file tree
Hide file tree
Showing 4 changed files with 223 additions and 230 deletions.
217 changes: 107 additions & 110 deletions diambra/arena/arena_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@
from diambra.arena.env_settings import EnvironmentSettings1P, EnvironmentSettings2P
from typing import Union, Any, Dict, List

# DIAMBRA Env Gym
class DiambraGymBase(gym.Env):
"""Diambra Environment gym interface"""
"""Diambra Environment gymnasium base interface"""
metadata = {"render_modes": ["human", "rgb_array"]}
_frame = None
_last_action = None
Expand Down Expand Up @@ -60,76 +59,6 @@ def __init__(self, env_settings: Union[EnvironmentSettings1P, EnvironmentSetting
self.max_delta_health = self.env_info.ram_states[k].max - self.env_info.ram_states[k].min
break

def _get_ram_states_obs_dict(self):
player_spec_dict = {}
generic_dict = {}
# Adding env additional observations (side-specific)
for k, v in self.env_info.ram_states.items():
if k[-2:] == "P1":
target_dict = player_spec_dict
knew = "own_" + k[:-2]
elif k[-2:] == "P2":
target_dict = player_spec_dict
knew = "opp_" + k[:-2]
else:
target_dict = generic_dict
knew = k

# Discrete spaces (binary / categorical)
if v.type == 0 or v.type == 2:
target_dict[knew] = gym.spaces.Discrete(v.max + 1)
elif v.type == 1: # Box spaces
target_dict[knew] = gym.spaces.Box(low=v.min, high=v.max, shape=(1,), dtype=np.int32)
else:
raise RuntimeError("Only Discrete (Binary/Categorical) | Box Spaces allowed")

player_spec_dict["action_move"] = gym.spaces.Discrete(self.n_actions[0])
player_spec_dict["action_attack"] = gym.spaces.Discrete(self.n_actions[1])

return generic_dict, player_spec_dict

# Get frame
def _get_frame(self, response):
self._frame = np.frombuffer(response.observation.frame, dtype='uint8').reshape(self.env_info.frame_shape.h, \
self.env_info.frame_shape.w, \
self.env_info.frame_shape.c)
return self._frame

# Get info
def _get_info(self, response):
info = dict(response.info.game_states)
info["settings"] = self.env_settings.pb_model
return info

# Integrate player specific RAM states into observation
def _player_specific_ram_states_integration(self, response, idx):
player_spec_dict = {}
generic_dict = {}

# Adding env additional observations (side-specific)
player_role = self.env_settings.pb_model.variable_env_settings.player_env_settings[idx].role
for k, v in self.env_info.ram_states.items():
if ("P1" in k or "P2" in k):
target_dict = player_spec_dict
if k[-2:] == player_role:
knew = "own_" + k[:-2]
else:
knew = "opp_" + k[:-2]
else:
target_dict = generic_dict
knew = k

# Box spaces
if v.type == 1:
target_dict[knew] = np.array([response.observation.ram_states[k]], dtype=np.int32)
else: # Discrete spaces (binary / categorical)
target_dict[knew] = response.observation.ram_states[k]

player_spec_dict["action_move"] = self._last_action[idx][0]
player_spec_dict["action_attack"] = self._last_action[idx][1]

return generic_dict, player_spec_dict

# Return env action list
def get_actions_tuples(self):
return self.actions_tuples
Expand Down Expand Up @@ -217,8 +146,78 @@ def close(self):
cv2.destroyAllWindows()
self.arena_engine.close()

# DIAMBRA Gym 1P class
def _get_ram_states_obs_dict(self):
player_spec_dict = {}
generic_dict = {}
# Adding env additional observations (side-specific)
for k, v in self.env_info.ram_states.items():
if k.endswith("P1"):
target_dict = player_spec_dict
knew = "own_" + k[:-2]
elif k.endswith("P2"):
target_dict = player_spec_dict
knew = "opp_" + k[:-2]
else:
target_dict = generic_dict
knew = k

# Discrete spaces (binary / categorical)
if v.type == 0 or v.type == 2:
target_dict[knew] = gym.spaces.Discrete(v.max + 1)
elif v.type == 1: # Box spaces
target_dict[knew] = gym.spaces.Box(low=v.min, high=v.max, shape=(1,), dtype=np.int32)
else:
raise RuntimeError("Only Discrete (Binary/Categorical) | Box Spaces allowed")

player_spec_dict["action_move"] = gym.spaces.Discrete(self.n_actions[0])
player_spec_dict["action_attack"] = gym.spaces.Discrete(self.n_actions[1])

return generic_dict, player_spec_dict

# Get frame
def _get_frame(self, response):
self._frame = np.frombuffer(response.observation.frame, dtype='uint8').reshape(self.env_info.frame_shape.h, \
self.env_info.frame_shape.w, \
self.env_info.frame_shape.c)
return self._frame

# Get info
def _get_info(self, response):
info = dict(response.info.game_states)
info["settings"] = self.env_settings.pb_model
return info

# Integrate player specific RAM states into observation
def _player_specific_ram_states_integration(self, response, idx):
player_spec_dict = {}
generic_dict = {}

# Adding env additional observations (side-specific)
player_role = self.env_settings.pb_model.variable_env_settings.player_env_settings[idx].role
for k, v in self.env_info.ram_states.items():
if (k.endswith("P1") or k.endswith("P2")):
target_dict = player_spec_dict
if k[-2:] == player_role:
knew = "own_" + k[:-2]
else:
knew = "opp_" + k[:-2]
else:
target_dict = generic_dict
knew = k

# Box spaces
if v.type == 1:
target_dict[knew] = np.array([response.observation.ram_states[k]], dtype=np.int32)
else: # Discrete spaces (binary / categorical)
target_dict[knew] = response.observation.ram_states[k]

player_spec_dict["action_move"] = self._last_action[idx][0]
player_spec_dict["action_attack"] = self._last_action[idx][1]

return generic_dict, player_spec_dict

class DiambraGym1P(DiambraGymBase):
"""Diambra Environment gymnasium single agent interface"""
def __init__(self, env_settings):
super().__init__(env_settings)

Expand Down Expand Up @@ -248,15 +247,6 @@ def __init__(self, env_settings):
self.action_space = gym.spaces.Discrete(self.n_actions[0] + self.n_actions[1] - 1)
self.logger.debug("Using Discrete action space")

def _get_obs(self, response):
observation = {}
observation["frame"] = self._get_frame(response)
generic_obs_dict, player_obs_dict = self._player_specific_ram_states_integration(response, 0)
observation.update(generic_obs_dict)
observation.update(player_obs_dict)

return observation

# Return the no-op action
def get_no_op_action(self):
if isinstance(self.action_space, gym.spaces.MultiDiscrete):
Expand All @@ -276,8 +266,17 @@ def step(self, action: Union[int, List[int]]):

return observation, response.reward, response.info.game_states["episode_done"], False, self._get_info(response)

# DIAMBRA Gym 2P Class
def _get_obs(self, response):
observation = {}
observation["frame"] = self._get_frame(response)
generic_obs_dict, player_obs_dict = self._player_specific_ram_states_integration(response, 0)
observation.update(generic_obs_dict)
observation.update(player_obs_dict)

return observation

class DiambraGym2P(DiambraGymBase):
"""Diambra Environment gymnasium multi-agent interface"""
def __init__(self, env_settings):
super().__init__(env_settings)

Expand All @@ -298,36 +297,17 @@ def __init__(self, env_settings):

# Action space
# Dictionary
action_space_dict = {}
for idx in range(2):
if env_settings.action_space[idx] == "multi_discrete":
action_space_dict["agent_{}".format(idx)] = gym.spaces.MultiDiscrete(self.n_actions)
elif env_settings.action_space[idx] == "discrete":
action_space_dict["agent_{}".format(idx)] = gym.spaces.Discrete(self.n_actions[0] + self.n_actions[1] - 1)
self.logger.debug("Using {} action space for agent_{}".format(env_settings.action_space[idx], idx))

action_spaces_values = {"multi_discrete": gym.spaces.MultiDiscrete(self.n_actions),
"discrete": gym.spaces.Discrete(self.n_actions[0] + self.n_actions[1] - 1)}
action_space_dict = self._update_dict(action_spaces_values)
self.logger.debug("Using the following action spaces: {}".format(action_space_dict))
self.action_space = gym.spaces.Dict(action_space_dict)

def _get_obs(self, response):
observation = {}
observation["frame"] = self._get_frame(response)
for idx in range(2):
generic_obs_dict, player_obs_dict = self._player_specific_ram_states_integration(response, idx)
observation["agent_{}".format(idx)] = player_obs_dict
observation.update(generic_obs_dict)

return observation

# Return the no-op action
def get_no_op_action(self):
no_op_action = {}
for idx in range(2):
if self.env_settings.action_space[idx] == "multi_discrete":
no_op_action["agent_{}".format(idx)] = [0, 0]
elif self.env_settings.action_space[idx] == "discrete":
no_op_action["agent_{}".format(idx)] = 0

return no_op_action
no_op_values = {"multi_discrete": [0, 0],
"discrete": 0}
return self._update_dict(no_op_values)

# Step the environment
def step(self, actions: Dict[str, Union[int, List[int]]]):
Expand All @@ -342,4 +322,21 @@ def step(self, actions: Dict[str, Union[int, List[int]]]):
response = self.arena_engine.step(self._last_action)
observation = self._get_obs(response)

return observation, response.reward, response.info.game_states["game_done"], False, self._get_info(response)
return observation, response.reward, response.info.game_states["game_done"], False, self._get_info(response)

def _update_dict(self, values_dict):
out_dict = {}
for idx, action_space in enumerate(self.env_settings.action_space):
out_dict["agent_{}".format(idx)] = values_dict[action_space]

return out_dict

def _get_obs(self, response):
observation = {}
observation["frame"] = self._get_frame(response)
for idx in range(self.env_settings.n_players):
generic_obs_dict, player_obs_dict = self._player_specific_ram_states_integration(response, idx)
observation["agent_{}".format(idx)] = player_obs_dict
observation.update(generic_obs_dict)

return observation
17 changes: 9 additions & 8 deletions diambra/arena/env_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def check_val_in_list(key, value, valid_list):

@dataclass
class EnvironmentSettings:
"""Generic Environment Settings Class"""
env_info = None
games_dict = None

Expand Down Expand Up @@ -183,7 +184,7 @@ def _process_random_values(self):

@dataclass
class EnvironmentSettings1P(EnvironmentSettings):
# Player level
"""Single Agent Environment Settings Class"""
role: str = "Random"
characters: Union[str, Tuple[str], Tuple[str, str], Tuple[str, str, str]] = ("Random", "Random", "Random")
outfits: int = 1
Expand Down Expand Up @@ -255,7 +256,7 @@ def _get_player_specific_values(self):

@dataclass
class EnvironmentSettings2P(EnvironmentSettings):
# Player level
"""Single Agent Environment Settings Class"""
role: Tuple[str, str] = ("Random", "Random")
characters: Union[Tuple[str, str], Tuple[Tuple[str], Tuple[str]],
Tuple[Tuple[str, str], Tuple[str, str]],
Expand Down Expand Up @@ -299,14 +300,14 @@ def _process_random_values(self):
super()._process_random_values()

characters_tmp = [[],[]]
for idx in range(2):
sampled_characters = self._sample_characters()

for idx, characters in enumerate(self.characters):
sampled_characters = self._sample_characters()
for jdx in range(3):
if self.characters[idx][jdx] == "Random":
if characters[jdx] == "Random":
characters_tmp[idx].append(sampled_characters[jdx])
else:
characters_tmp[idx].append(self.characters[idx][jdx])
characters_tmp[idx].append(characters[jdx])

self.characters = (tuple(characters_tmp[0]), tuple(characters_tmp[1]))

Expand All @@ -324,8 +325,8 @@ def _process_random_values(self):
self.ultimate_style = tuple([[random.choice(list(range(1, 3))) if self.ultimate_style[idx][jdx] == "Random" else self.ultimate_style[idx][jdx] for jdx in range(3)] for idx in range(2)])

def _get_action_spaces(self):
action_spaces = [model.EnvSettings.ActionSpace.ACTION_SPACE_DISCRETE if self.action_space[idx] == "discrete" else \
model.EnvSettings.ActionSpace.ACTION_SPACE_MULTI_DISCRETE for idx in range(2)]
action_spaces = [model.EnvSettings.ActionSpace.ACTION_SPACE_DISCRETE if action_space == "discrete" else \
model.EnvSettings.ActionSpace.ACTION_SPACE_MULTI_DISCRETE for action_space in self.action_space]

return action_spaces

Expand Down
Loading

0 comments on commit fba8bed

Please sign in to comment.