Skip to content

Commit

Permalink
WIP - Acquire engine RAM states refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
alexpalms committed Sep 17, 2023
1 parent f46a6b2 commit ef1f465
Show file tree
Hide file tree
Showing 9 changed files with 218 additions and 328 deletions.
203 changes: 74 additions & 129 deletions diambra/arena/arena_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ class DiambraGymBase(gym.Env):
"""Diambra Environment gymnasium base interface"""
metadata = {"render_modes": ["human", "rgb_array"]}
_frame = None
_last_action = None
reward_normalization_value = 1.0
render_gui_started = False

Expand Down Expand Up @@ -55,11 +54,44 @@ def __init__(self, env_settings: Union[EnvironmentSettings1P, EnvironmentSetting
self.print_actions_dict = [move_dict, attack_dict]

# Maximum difference in players health
for k in sorted(self.env_info.ram_states.keys()):
if "health" in k:
self.max_delta_health = self.env_info.ram_states[k].max - self.env_info.ram_states[k].min
category_key_enum = model.RamStatesCategories.Value("P1")
for k in sorted(self.env_info.ram_states_categories[category_key_enum].ram_states.keys()):
key_enum_name = model.RamStates.Name(k)
if "health" in key_enum_name:
self.max_delta_health = self.env_info.ram_states_categories[category_key_enum].ram_states[k].max - \
self.env_info.ram_states_categories[category_key_enum].ram_states[k].min
break

# Observation space
# Dictionary
observation_space_dict = {}
observation_space_dict['frame'] = gym.spaces.Box(low=0, high=255, shape=(self.env_info.frame_shape.h,
self.env_info.frame_shape.w,
self.env_info.frame_shape.c),
dtype=np.uint8)

# Adding RAM states observations
for k, v in self.env_info.ram_states_categories.items():
print("Processing {}, {}".format(model.RamStatesCategories.Name(k), v))
if k == model.RamStatesCategories.common:
target_dict = observation_space_dict
else:
observation_space_dict[model.RamStatesCategories.Name(k)] = {}
target_dict = observation_space_dict[model.RamStatesCategories.Name(k)]

for k2, v2 in v.ram_states.items():
if v2.type == SpaceType.BINARY or v2.type == SpaceType.DISCRETE:
target_dict[model.RamStates.Name(k2)] = gym.spaces.Discrete(v2.max + 1)
elif v2.type == SpaceType.BOX:
target_dict[model.RamStates.Name(k2)] = gym.spaces.Box(low=v2.min, high=v2.max, shape=(1,), dtype=np.int16)
else:
raise RuntimeError("Only Discrete (Binary/Categorical) | Box Spaces allowed")

for space_key in [model.RamStatesCategories.P1, model.RamStatesCategories.P2]:
observation_space_dict[model.RamStatesCategories.Name(space_key)] = gym.spaces.Dict(observation_space_dict[model.RamStatesCategories.Name(space_key)])

self.observation_space = gym.spaces.Dict(observation_space_dict)

# Return env action list
def get_actions_tuples(self):
return self.actions_tuples
Expand All @@ -81,7 +113,6 @@ def get_cumulative_reward_bounds(self):

# Reset the environment
def reset(self, seed: int = None, options: Dict[str, Any] = None):
self._last_action = [[0, 0], [0, 0]]
if options is None:
options = {}
options["seed"] = seed
Expand Down Expand Up @@ -109,21 +140,13 @@ def render(self, wait_key=1):
return self._frame

# Print observation details to the console
def show_obs(self, observation, wait_key=1, viz=True, string="observation", key=None):
def show_obs(self, observation, wait_key=1, viz=True, string="observation", key=None, outermost=True):
if type(observation) == dict:
for k, v in sorted(observation.items()):
self.show_obs(v, wait_key=wait_key, viz=viz, string=string + "[\"{}\"]".format(k), key=k)
self.show_obs(v, wait_key=wait_key, viz=viz, string=string + "[\"{}\"]".format(k), key=k, outermost=False)
else:
if key != "frame":
if "action" in key:
out_value = observation
additional_string = ": "
if isinstance(observation, (int, np.integer)) is False:
n_actions_stack = int(observation.size / (self.n_actions[0] if "move" in key else self.n_actions[1]))
out_value = np.reshape(observation, [n_actions_stack, -1])
additional_string = " (reshaped for visualization):\n"
print(string + "{}{}".format(additional_string, out_value))
elif "own_char" in key or "opp_char" in key:
if key.startswith("character"):
char_idx = observation if type(observation) == int else np.where(observation == 1)[0][0]
print(string + ": {} / {}".format(observation, self.env_info.characters_info.char_list[char_idx]))
else:
Expand All @@ -136,44 +159,21 @@ def show_obs(self, observation, wait_key=1, viz=True, string="observation", key=
norm_factor = 255 if np.amax(observation) > 1.0 else 1.0
for idx in range(observation.shape[2]):
cv2.imshow("[{}] Frame channel {}".format(os.getpid(), idx), observation[:, :, idx] / norm_factor)

cv2.waitKey(wait_key)
except:
pass

if outermost is True and viz is True and (sys.platform.startswith('linux') is False or 'DISPLAY' in os.environ):
try:
cv2.waitKey(wait_key)
except:
pass

# Closing the environment
def close(self):
# Close DIAMBRA Arena
cv2.destroyAllWindows()
self.arena_engine.close()

def _get_ram_states_obs_dict(self):
player_spec_dict = {}
generic_dict = {}
# Adding env additional observations (side-specific)
for k, v in self.env_info.ram_states.items():
if k.endswith("P1"):
target_dict = player_spec_dict
knew = "own_" + k[:-3]
elif k.endswith("P2"):
target_dict = player_spec_dict
knew = "opp_" + k[:-3]
else:
target_dict = generic_dict
knew = k

if v.type == SpaceType.BINARY or v.type == SpaceType.DISCRETE:
target_dict[knew] = gym.spaces.Discrete(v.max + 1)
elif v.type == SpaceType.BOX:
target_dict[knew] = gym.spaces.Box(low=v.min, high=v.max, shape=(1,), dtype=np.int32)
else:
raise RuntimeError("Only Discrete (Binary/Categorical) | Box Spaces allowed")

player_spec_dict["action_move"] = gym.spaces.Discrete(self.n_actions[0])
player_spec_dict["action_attack"] = gym.spaces.Discrete(self.n_actions[1])

return generic_dict, player_spec_dict

# Get frame
def _get_frame(self, response):
self._frame = np.frombuffer(response.observation.frame, dtype='uint8').reshape(self.env_info.frame_shape.h, \
Expand All @@ -183,56 +183,38 @@ def _get_frame(self, response):

# Get info
def _get_info(self, response):
info = dict(response.info.game_states)
info = {model.GameStates.Name(k): v for k, v in response.info.game_states.items()}
info["settings"] = self.env_settings.pb_model
return info

# Integrate player specific RAM states into observation
def _player_specific_ram_states_integration(self, response, idx):
player_spec_dict = {}
generic_dict = {}

# Adding env additional observations (side-specific)
player_role = self.env_settings.pb_model.episode_settings.player_settings[idx].role
for k, v in self.env_info.ram_states.items():
if (k.endswith("P1") or k.endswith("P2")):
target_dict = player_spec_dict
if k[-2:] == player_role:
knew = "own_" + k[:-3]
else:
knew = "opp_" + k[:-3]
def _get_obs(self, response):
observation = {}
observation["frame"] = self._get_frame(response)

# Adding RAM states observations
for k, v in self.env_info.ram_states_categories.items():
if k == model.RamStatesCategories.common:
target_dict = observation
else:
target_dict = generic_dict
knew = k
observation[model.RamStatesCategories.Name(k)] = {}
target_dict = observation[model.RamStatesCategories.Name(k)]

# Box spaces
if v.type == SpaceType.BOX:
target_dict[knew] = np.array([response.observation.ram_states[k]], dtype=np.int32)
else: # Discrete spaces (binary / categorical)
target_dict[knew] = response.observation.ram_states[k]
category_ram_states = response.observation.ram_states_categories[k]

player_spec_dict["action_move"] = self._last_action[idx][0]
player_spec_dict["action_attack"] = self._last_action[idx][1]
for k2, v2 in v.ram_states.items():
# Box spaces
if v2.type == SpaceType.BOX:
target_dict[model.RamStates.Name(k2)] = np.array([category_ram_states.ram_states[k2]])
else: # Discrete spaces (binary / categorical)
target_dict[model.RamStates.Name(k2)] = category_ram_states.ram_states[k2]

return generic_dict, player_spec_dict
return observation

class DiambraGym1P(DiambraGymBase):
"""Diambra Environment gymnasium single agent interface"""
def __init__(self, env_settings):
super().__init__(env_settings)

# Observation space
# Dictionary
observation_space_dict = {}
observation_space_dict['frame'] = gym.spaces.Box(low=0, high=255, shape=(self.env_info.frame_shape.h,
self.env_info.frame_shape.w,
self.env_info.frame_shape.c),
dtype=np.uint8)
generic_obs_dict, player_obs_dict = self._get_ram_states_obs_dict()
observation_space_dict.update(generic_obs_dict)
observation_space_dict.update(player_obs_dict)
self.observation_space = gym.spaces.Dict(observation_space_dict)

# Action space
# MultiDiscrete actions:
# - Arrows -> One discrete set
Expand All @@ -256,44 +238,17 @@ def get_no_op_action(self):
# Step the environment
def step(self, action: Union[int, List[int]]):
# Defining move and attack actions P1/P2 as a function of action_space
if isinstance(self.action_space, gym.spaces.MultiDiscrete):
self._last_action[0] = action
else:
self._last_action[0] = list(discrete_to_multi_discrete_action(action, self.n_actions[0]))
response = self.arena_engine.step(self._last_action)
observation = self._get_obs(response)
if isinstance(self.action_space, gym.spaces.Discrete):
action = list(discrete_to_multi_discrete_action(action, self.n_actions[0]))
response = self.arena_engine.step([action])

return observation, response.reward, response.info.game_states["episode_done"], False, self._get_info(response)

def _get_obs(self, response):
observation = {}
observation["frame"] = self._get_frame(response)
generic_obs_dict, player_obs_dict = self._player_specific_ram_states_integration(response, 0)
observation.update(generic_obs_dict)
observation.update(player_obs_dict)

return observation
return self._get_obs(response), response.reward, response.info.game_states[model.GameStates.episode_done], False, self._get_info(response)

class DiambraGym2P(DiambraGymBase):
"""Diambra Environment gymnasium multi-agent interface"""
def __init__(self, env_settings):
super().__init__(env_settings)

# Dictionary observation space
observation_space_dict = {}
observation_space_dict['frame'] = gym.spaces.Box(low=0, high=255,
shape=(self.env_info.frame_shape.h,
self.env_info.frame_shape.w,
self.env_info.frame_shape.c),
dtype=np.uint8)

generic_obs_dict, player_obs_dict = self._get_ram_states_obs_dict()
observation_space_dict.update(generic_obs_dict)
observation_space_dict["agent_0"] = gym.spaces.Dict(player_obs_dict)
observation_space_dict["agent_1"] = gym.spaces.Dict(player_obs_dict)

self.observation_space = gym.spaces.Dict(observation_space_dict)

# Action space
# Dictionary
action_spaces_values = {SpaceType.MULTI_DISCRETE: gym.spaces.MultiDiscrete(self.n_actions),
Expand All @@ -312,30 +267,20 @@ def get_no_op_action(self):
def step(self, actions: Dict[str, Union[int, List[int]]]):
# NOTE: the assumption in current interface is that we have actions sorted as agent's order
actions = sorted(actions.items())
action_list = [[],[]]
for idx, action in enumerate(actions):
# Defining move and attack actions P1/P2 as a function of action_space
if isinstance(self.action_space[action[0]], gym.spaces.MultiDiscrete):
self._last_action[idx] = action[1]
action_list[idx] = action[1]
else:
self._last_action[idx] = list(discrete_to_multi_discrete_action(action[1], self.n_actions[0]))
response = self.arena_engine.step(self._last_action)
observation = self._get_obs(response)
action_list[idx] = list(discrete_to_multi_discrete_action(action[1], self.n_actions[0]))
response = self.arena_engine.step(action_list)

return observation, response.reward, response.info.game_states["game_done"], False, self._get_info(response)
return self._get_obs(response), response.reward, response.info.game_states[model.GameStates.game_done], False, self._get_info(response)

def _map_action_spaces_to_agents(self, values_dict):
out_dict = {}
for idx, action_space in enumerate(self.env_settings.action_space):
out_dict["agent_{}".format(idx)] = values_dict[action_space]

return out_dict

def _get_obs(self, response):
observation = {}
observation["frame"] = self._get_frame(response)
for idx in range(self.env_settings.n_players):
generic_obs_dict, player_obs_dict = self._player_specific_ram_states_integration(response, idx)
observation["agent_{}".format(idx)] = player_obs_dict
observation.update(generic_obs_dict)

return observation
return out_dict
2 changes: 0 additions & 2 deletions diambra/arena/env_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,6 @@ class WrappersSettings:
scale: bool = False
exclude_image_scaling: bool = False
process_discrete_binary: bool = False
scale_mod: int = 0
frame_shape: Tuple[int, int, int] = (0, 0, 0)
flatten: bool = False
filter_keys: List[str] = None
Expand All @@ -384,7 +383,6 @@ def sanity_check(self):
check_num_in_range("frame_stack", self.frame_stack, [1, MAX_STACK_VALUE])
check_num_in_range("dilation", self.dilation, [1, MAX_STACK_VALUE])
check_num_in_range("actions_stack", self.actions_stack, [1, MAX_STACK_VALUE])
check_num_in_range("scale_mod", self.scale_mod, [0, 0])
check_num_in_range("reward_normalization_factor", self.reward_normalization_factor, [0.0, 1000000])

check_val_in_list("frame_shape[2]", self.frame_shape[2], [0, 1, 3])
Expand Down
Loading

0 comments on commit ef1f465

Please sign in to comment.