Skip to content

Commit

Permalink
WIP - Optimized actions related wrappers, working on role-relative obs
Browse files Browse the repository at this point in the history
  • Loading branch information
alexpalms committed Sep 21, 2023
1 parent f09ae9f commit 63573f6
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 119 deletions.
3 changes: 2 additions & 1 deletion diambra/arena/env_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,14 +372,15 @@ class WrappersSettings:
reward_normalization_factor: float = 0.5
clip_rewards: bool = False
no_attack_buttons_combinations: bool = False
frame_shape: Tuple[int, int, int] = (0, 0, 0)
frame_stack: int = 1
dilation: int = 1
add_last_action_to_observation: bool = False
actions_stack: int = 1
scale: bool = False
exclude_image_scaling: bool = False
process_discrete_binary: bool = False
frame_shape: Tuple[int, int, int] = (0, 0, 0)
role_relative_observation: bool = False
flatten: bool = False
filter_keys: List[str] = None
wrappers: List[List[Any]] = None
Expand Down
21 changes: 13 additions & 8 deletions diambra/arena/wrappers/arena_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import logging
from diambra.arena.env_settings import WrappersSettings
from diambra.arena.wrappers.observation import WarpFrame, GrayscaleFrame, FrameStack, ActionsStack, \
ScaledFloatObs, FlattenFilterDictObs, AddLastActionToObservation
NormalizeObservation, FlattenFilterDictObs, \
AddLastActionToObservation, RoleRelativeObservation

# Remove attack buttons combinations
class NoAttackButtonsCombinations(gym.Wrapper):
Expand All @@ -29,7 +30,7 @@ def reset(self, **kwargs):
def step(self, action):
return self.env.step(action)

class NoopResetEnv(gym.Wrapper):
class NoopReset(gym.Wrapper):
def __init__(self, env, no_op_max=6):
"""
Sample initial states by taking random number of no-ops on reset.
Expand Down Expand Up @@ -81,7 +82,7 @@ def step(self, action):

return obs, rew, done, info

class ClipRewardEnv(gym.RewardWrapper):
class ClipReward(gym.RewardWrapper):
def __init__(self, env):
"""
clips the reward to {+1, 0, -1} by its sign.
Expand All @@ -96,7 +97,7 @@ def reward(self, reward):
"""
return np.sign(reward)

class NormalizeRewardEnv(gym.RewardWrapper):
class NormalizeReward(gym.RewardWrapper):
def __init__(self, env, reward_normalization_factor):
"""
Normalize the reward dividing it by the product of
Expand Down Expand Up @@ -127,17 +128,17 @@ def env_wrapping(env, wrappers_settings: WrappersSettings):

### Generic wrappers(s)
if wrappers_settings.no_op_max > 0:
env = NoopResetEnv(env, no_op_max=wrappers_settings.no_op_max)
env = NoopReset(env, no_op_max=wrappers_settings.no_op_max)

if wrappers_settings.sticky_actions > 1:
env = StickyActions(env, sticky_actions=wrappers_settings.sticky_actions)

### Reward wrappers(s)
if wrappers_settings.reward_normalization is True:
env = NormalizeRewardEnv(env, wrappers_settings.reward_normalization_factor)
env = NormalizeReward(env, wrappers_settings.reward_normalization_factor)

if wrappers_settings.clip_rewards is True:
env = ClipRewardEnv(env)
env = ClipReward(env)

### Action space wrapper(s)
if wrappers_settings.no_attack_buttons_combinations is True:
Expand Down Expand Up @@ -180,7 +181,11 @@ def env_wrapping(env, wrappers_settings: WrappersSettings):

# Scales observations normalizing them between 0.0 and 1.0
if wrappers_settings.scale is True:
env = ScaledFloatObs(env, wrappers_settings.exclude_image_scaling, wrappers_settings.process_discrete_binary)
env = NormalizeObservation(env, wrappers_settings.exclude_image_scaling, wrappers_settings.process_discrete_binary)

# Convert base observation to role-relative observation
if wrappers_settings.role_relative_observation is True:
env = RoleRelativeObservation(env)

if wrappers_settings.flatten is True:
env = FlattenFilterDictObs(env, wrappers_settings.filter_keys)
Expand Down
216 changes: 114 additions & 102 deletions diambra/arena/wrappers/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,21 +107,22 @@ def __init__(self, env):
gym.Wrapper.__init__(self, env)
if self.unwrapped.env_settings.n_players == 1:
self.observation_space["action"] = self.action_space
def _add_last_action_to_obs_1p(obs, last_action):
obs["action"] = last_action
return obs
self._add_last_action_to_obs = _add_last_action_to_obs_1p
else:
for idx in range(self.unwrapped.env_settings.n_players):
action_dictionary = {}
action_dictionary["action"] = self.action_space["agent_{}".format(idx)]
self.observation_space["agent_{}".format(idx)] = gym.spaces.Dict(action_dictionary)

def _add_last_action_to_obs(self, obs, last_action):
if self.unwrapped.env_settings.n_players == 1:
obs["action"] = last_action
else:
for idx in range(self.unwrapped.env_settings.n_players):
action_dictionary = {}
action_dictionary["action"] = last_action["agent_{}".format(idx)]
obs["agent_{}".format(idx)] = action_dictionary
return obs
def _add_last_action_to_obs_2p(obs, last_action):
for idx in range(self.unwrapped.env_settings.n_players):
action_dictionary = {}
action_dictionary["action"] = last_action["agent_{}".format(idx)]
obs["agent_{}".format(idx)] = action_dictionary
return obs
self._add_last_action_to_obs = _add_last_action_to_obs_2p

def reset(self, **kwargs):
obs, info = self.env.reset(**kwargs)
Expand Down Expand Up @@ -154,6 +155,20 @@ def __init__(self, env, n_actions_stack):
self.action_stack = [deque([no_op_action] * n_actions_stack, maxlen=n_actions_stack)]
action_space_size = [self.observation_space["action"].n]
self.observation_space["action"] = gym.spaces.MultiDiscrete(action_space_size * n_actions_stack)

if isinstance(self.action_space, gym.spaces.MultiDiscrete):
def _add_action_to_stack_1p(action):
self.action_stack[0].append(action[0])
self.action_stack[0].append(action[1])
else:
def _add_action_to_stack_1p(action):
self.action_stack[0].append(action)
self._add_action_to_stack = _add_action_to_stack_1p

def _process_obs_1p(obs):
obs["action"] = np.array(self.action_stack[0])
return obs
self._process_obs = _process_obs_1p
else:
self.action_stack = []
assert "action" in self.observation_space["agent_0"].keys(), "ActionsStack wrapper can be activated only "\
Expand All @@ -167,112 +182,89 @@ def __init__(self, env, n_actions_stack):
action_space_size = [self.observation_space["agent_{}".format(idx)]["action"].n]
self.observation_space["agent_{}".format(idx)]["action"] = gym.spaces.MultiDiscrete(action_space_size * n_actions_stack)

def fill_stack(self):
# Fill the actions stack with no action after reset
no_op_action = self.unwrapped.get_no_op_action()
if self.unwrapped.env_settings.n_players == 1:
if isinstance(self.action_space, gym.spaces.MultiDiscrete):
for _ in range(self.n_actions_stack):
self.action_stack[0].append(no_op_action[0])
self.action_stack[0].append(no_op_action[1])
else:
for _ in range(self.n_actions_stack):
self.action_stack[0].append(no_op_action)
else:
for idx in range(self.unwrapped.env_settings.n_players):
if isinstance(self.action_space["agent_{}".format(idx)], gym.spaces.MultiDiscrete):
for _ in range(self.n_actions_stack):
self.action_stack[idx].append(no_op_action["agent_{}".format(idx)][0])
self.action_stack[idx].append(no_op_action["agent_{}".format(idx)][1])
else:
for _ in range(self.n_actions_stack):
self.action_stack[idx].append(no_op_action["agent_{}".format(idx)])

def _process_obs(self, obs):
if self.unwrapped.env_settings.n_players == 1:
obs["action"] = np.array(self.action_stack[0])
else:
for idx in range(self.unwrapped.env_settings.n_players):
obs["agent_{}".format(idx)]["action"] = np.array(self.action_stack[idx])
return obs
def _add_action_to_stack_2p(action):
for idx in range(self.unwrapped.env_settings.n_players):
if isinstance(self.action_space["agent_{}".format(idx)], gym.spaces.MultiDiscrete):
self.action_stack[idx].append(action["agent_{}".format(idx)][0])
self.action_stack[idx].append(action["agent_{}".format(idx)][1])
else:
self.action_stack[idx].append(action["agent_{}".format(idx)])
self._add_action_to_stack = _add_action_to_stack_2p

def _process_obs_2p(obs):
for idx in range(self.unwrapped.env_settings.n_players):
obs["agent_{}".format(idx)]["action"] = np.array(self.action_stack[idx])
return obs
self._process_obs = _process_obs_2p

def reset(self, **kwargs):
obs, info = self.env.reset(**kwargs)
self.fill_stack()
self._fill_stack()
return self._process_obs(obs), info

def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action)
if self.unwrapped.env_settings.n_players == 1:
if isinstance(self.action_space, gym.spaces.MultiDiscrete):
self.action_stack[0].append(action[0])
self.action_stack[0].append(action[1])
else:
self.action_stack[0].append(action)
else:
for idx in range(self.unwrapped.env_settings.n_players):
if isinstance(self.action_space["agent_{}".format(idx)], gym.spaces.MultiDiscrete):
self.action_stack[idx].append(action["agent_{}".format(idx)][0])
self.action_stack[idx].append(action["agent_{}".format(idx)][1])
else:
self.action_stack[idx].append(action["agent_{}".format(idx)])
self._add_action_to_stack(action)

# Add noAction for n_actions_stack - 1 times
# in case of new round / stage / continueGame
if ((info["round_done"] or info["stage_done"] or info["game_done"]) and not (terminated or truncated)):
self.fill_stack()
self._fill_stack()

return self._process_obs(obs), reward, terminated, truncated, info

class ScaledFloatObs(gym.ObservationWrapper):
def _fill_stack(self):
no_op_action = self.unwrapped.get_no_op_action()
for _ in range(self.n_actions_stack):
self._add_action_to_stack(no_op_action)

class NormalizeObservation(gym.ObservationWrapper):
def __init__(self, env, exclude_image_scaling=False, process_discrete_binary=False):
gym.ObservationWrapper.__init__(self, env)

self.exclude_image_scaling = exclude_image_scaling
self.process_discrete_binary = process_discrete_binary

self.original_observation_space = deepcopy(self.observation_space)
self.scaled_float_obs_space_func(self.observation_space)
self._obs_space_normalization_func(self.observation_space)

def observation(self, observation):
return self._obs_normalization_func(observation, self.original_observation_space)

# Recursive function to modify obs space dict
# FIXME: this can probably be dropped with gym >= 0.21 and only use the next one, here for SB2 compatibility
def scaled_float_obs_space_func(self, obs_dict):
def _obs_space_normalization_func(self, obs_dict):
# Updating observation space dict
for k, v in obs_dict.spaces.items():

for k, v in obs_dict.items():
if isinstance(v, gym.spaces.Dict):
self.scaled_float_obs_space_func(v)
self._obs_space_normalization_func(v)
else:
if isinstance(v, gym.spaces.MultiDiscrete):
# One hot encoding x nStack
n_val = v.nvec.shape[0]
max_val = v.nvec[0]
obs_dict.spaces[k] = gym.spaces.MultiBinary(n_val * max_val)
obs_dict[k] = gym.spaces.MultiBinary(np.sum(v.nvec))
elif isinstance(v, gym.spaces.Discrete) and (v.n > 2 or self.process_discrete_binary is True):
# One hot encoding
obs_dict.spaces[k] = gym.spaces.MultiBinary(v.n)
obs_dict[k] = gym.spaces.MultiBinary(v.n)
elif isinstance(v, gym.spaces.Box) and (self.exclude_image_scaling is False or len(v.shape) < 3):
obs_dict.spaces[k] = gym.spaces.Box(low=0.0, high=1.0, shape=v.shape, dtype=np.float32)
obs_dict[k] = gym.spaces.Box(low=0.0, high=1.0, shape=v.shape, dtype=np.float32)

# Recursive function to modify obs dict
def scaled_float_obs_func(self, observation, observation_space):

def _obs_normalization_func(self, observation, observation_space):
# Process all observations
for k, v in observation.items():

if isinstance(v, dict):
self.scaled_float_obs_func(v, observation_space.spaces[k])
self._obs_normalization_func(v, observation_space.spaces[k])
else:
v_space = observation_space.spaces[k]
v_space = observation_space[k]
if isinstance(v_space, gym.spaces.MultiDiscrete):
n_act = observation_space.spaces[k].nvec[0]
buf_len = observation_space.spaces[k].nvec.shape[0]
actions_vector = np.zeros((buf_len * n_act), dtype=np.uint8)
for iact in range(buf_len):
actions_vector[iact * n_act + observation[k][iact]] = 1
actions_vector = np.zeros((np.sum(v_space.nvec)), dtype=np.uint8)
column_index = 0
for iact in range(v_space.nvec.shape[0]):
actions_vector[column_index + observation[k][iact]] = 1
column_index += v_space.nvec[iact]
observation[k] = actions_vector
elif isinstance(v_space, gym.spaces.Discrete) and (v_space.n > 2 or self.process_discrete_binary is True):
var_vector = np.zeros((observation_space.spaces[k].n), dtype=np.uint8)
var_vector = np.zeros((v_space.n), dtype=np.uint8)
var_vector[observation[k]] = 1
observation[k] = var_vector
elif isinstance(v_space, gym.spaces.Box) and (self.exclude_image_scaling is False or len(v_space.shape) < 3):
Expand All @@ -282,9 +274,54 @@ def scaled_float_obs_func(self, observation, observation_space):

return observation

class RoleRelativeObservation(gym.ObservationWrapper):
def __init__(self, env):
gym.ObservationWrapper.__init__(self, env)

self.observation_space["own"] = self.observation_space["P1"]
self.observation_space["opp"] = self.observation_space["P1"]
del self.observation_space["P1"]
del self.observation_space["P2"]

def observation(self, observation):
return None

"""
def rename_key_recursive(dictionary, old_key, new_key):
if isinstance(dictionary, dict):
new_dict = {}
for key, value in dictionary.items():
if key == old_key:
key = new_key
new_dict[key] = rename_key_recursive(value, old_key, new_key)
return new_dict
else:
return dictionary
"""

class FlattenFilterDictObs(gym.ObservationWrapper):
def __init__(self, env, filter_keys):
gym.ObservationWrapper.__init__(self, env)

return self.scaled_float_obs_func(observation, self.original_observation_space)
self.filter_keys = filter_keys
if (filter_keys is not None):
self.filter_keys = list(set(filter_keys))
if "frame" not in filter_keys:
self.filter_keys += ["frame"]

original_obs_space_keys = (flatten_filter_obs_space_func(self.observation_space, None)).keys()
self.observation_space = gym.spaces.Dict(flatten_filter_obs_space_func(self.observation_space, self.filter_keys))

if filter_keys is not None:
if (sorted(self.observation_space.spaces.keys()) != sorted(self.filter_keys)):
raise Exception("Specified observation key(s) not found:",
" Available key(s):", sorted(original_obs_space_keys),
" Specified key(s):", sorted(self.filter_keys),
" Key(s) not found:", sorted([key for key in self.filter_keys if key not in original_obs_space_keys]),
)

def observation(self, observation):
return flatten_filter_obs_func(observation, self.filter_keys)

def flatten_filter_obs_space_func(input_dictionary, filter_keys):
_FLAG_FIRST = object()
Expand Down Expand Up @@ -336,29 +373,4 @@ def visit(subdict, flattened_dict, partial_key, check_method):
else:
visit(input_dictionary, flattened_dict, _FLAG_FIRST, dummy_check)

return flattened_dict

class FlattenFilterDictObs(gym.ObservationWrapper):
def __init__(self, env, filter_keys):
gym.ObservationWrapper.__init__(self, env)

self.filter_keys = filter_keys
if (filter_keys is not None):
self.filter_keys = list(set(filter_keys))
if "frame" not in filter_keys:
self.filter_keys += ["frame"]

original_obs_space_keys = (flatten_filter_obs_space_func(self.observation_space, None)).keys()
self.observation_space = gym.spaces.Dict(flatten_filter_obs_space_func(self.observation_space, self.filter_keys))

if filter_keys is not None:
if (sorted(self.observation_space.spaces.keys()) != sorted(self.filter_keys)):
raise Exception("Specified observation key(s) not found:",
" Available key(s):", sorted(original_obs_space_keys),
" Specified key(s):", sorted(self.filter_keys),
" Key(s) not found:", sorted([key for key in self.filter_keys if key not in original_obs_space_keys]),
)

def observation(self, observation):

return flatten_filter_obs_func(observation, self.filter_keys)
return flattened_dict
Loading

0 comments on commit 63573f6

Please sign in to comment.