diff --git a/diambra/arena/env_settings.py b/diambra/arena/env_settings.py index 5d304fa5..a3182500 100644 --- a/diambra/arena/env_settings.py +++ b/diambra/arena/env_settings.py @@ -372,6 +372,7 @@ class WrappersSettings: reward_normalization_factor: float = 0.5 clip_rewards: bool = False no_attack_buttons_combinations: bool = False + frame_shape: Tuple[int, int, int] = (0, 0, 0) frame_stack: int = 1 dilation: int = 1 add_last_action_to_observation: bool = False @@ -379,7 +380,7 @@ class WrappersSettings: scale: bool = False exclude_image_scaling: bool = False process_discrete_binary: bool = False - frame_shape: Tuple[int, int, int] = (0, 0, 0) + role_relative_observation: bool = False flatten: bool = False filter_keys: List[str] = None wrappers: List[List[Any]] = None diff --git a/diambra/arena/wrappers/arena_wrappers.py b/diambra/arena/wrappers/arena_wrappers.py index ed160e68..9d4f2678 100644 --- a/diambra/arena/wrappers/arena_wrappers.py +++ b/diambra/arena/wrappers/arena_wrappers.py @@ -4,7 +4,8 @@ import logging from diambra.arena.env_settings import WrappersSettings from diambra.arena.wrappers.observation import WarpFrame, GrayscaleFrame, FrameStack, ActionsStack, \ - ScaledFloatObs, FlattenFilterDictObs, AddLastActionToObservation + NormalizeObservation, FlattenFilterDictObs, \ + AddLastActionToObservation, RoleRelativeObservation # Remove attack buttons combinations class NoAttackButtonsCombinations(gym.Wrapper): @@ -29,7 +30,7 @@ def reset(self, **kwargs): def step(self, action): return self.env.step(action) -class NoopResetEnv(gym.Wrapper): +class NoopReset(gym.Wrapper): def __init__(self, env, no_op_max=6): """ Sample initial states by taking random number of no-ops on reset. @@ -81,7 +82,7 @@ def step(self, action): return obs, rew, done, info -class ClipRewardEnv(gym.RewardWrapper): +class ClipReward(gym.RewardWrapper): def __init__(self, env): """ clips the reward to {+1, 0, -1} by its sign. @@ -96,7 +97,7 @@ def reward(self, reward): """ return np.sign(reward) -class NormalizeRewardEnv(gym.RewardWrapper): +class NormalizeReward(gym.RewardWrapper): def __init__(self, env, reward_normalization_factor): """ Normalize the reward dividing it by the product of @@ -127,17 +128,17 @@ def env_wrapping(env, wrappers_settings: WrappersSettings): ### Generic wrappers(s) if wrappers_settings.no_op_max > 0: - env = NoopResetEnv(env, no_op_max=wrappers_settings.no_op_max) + env = NoopReset(env, no_op_max=wrappers_settings.no_op_max) if wrappers_settings.sticky_actions > 1: env = StickyActions(env, sticky_actions=wrappers_settings.sticky_actions) ### Reward wrappers(s) if wrappers_settings.reward_normalization is True: - env = NormalizeRewardEnv(env, wrappers_settings.reward_normalization_factor) + env = NormalizeReward(env, wrappers_settings.reward_normalization_factor) if wrappers_settings.clip_rewards is True: - env = ClipRewardEnv(env) + env = ClipReward(env) ### Action space wrapper(s) if wrappers_settings.no_attack_buttons_combinations is True: @@ -180,7 +181,11 @@ def env_wrapping(env, wrappers_settings: WrappersSettings): # Scales observations normalizing them between 0.0 and 1.0 if wrappers_settings.scale is True: - env = ScaledFloatObs(env, wrappers_settings.exclude_image_scaling, wrappers_settings.process_discrete_binary) + env = NormalizeObservation(env, wrappers_settings.exclude_image_scaling, wrappers_settings.process_discrete_binary) + + # Convert base observation to role-relative observation + if wrappers_settings.role_relative_observation is True: + env = RoleRelativeObservation(env) if wrappers_settings.flatten is True: env = FlattenFilterDictObs(env, wrappers_settings.filter_keys) diff --git a/diambra/arena/wrappers/observation.py b/diambra/arena/wrappers/observation.py index 361f54db..f2719b89 100644 --- a/diambra/arena/wrappers/observation.py +++ b/diambra/arena/wrappers/observation.py @@ -107,21 +107,22 @@ def __init__(self, env): gym.Wrapper.__init__(self, env) if self.unwrapped.env_settings.n_players == 1: self.observation_space["action"] = self.action_space + def _add_last_action_to_obs_1p(obs, last_action): + obs["action"] = last_action + return obs + self._add_last_action_to_obs = _add_last_action_to_obs_1p else: for idx in range(self.unwrapped.env_settings.n_players): action_dictionary = {} action_dictionary["action"] = self.action_space["agent_{}".format(idx)] self.observation_space["agent_{}".format(idx)] = gym.spaces.Dict(action_dictionary) - - def _add_last_action_to_obs(self, obs, last_action): - if self.unwrapped.env_settings.n_players == 1: - obs["action"] = last_action - else: - for idx in range(self.unwrapped.env_settings.n_players): - action_dictionary = {} - action_dictionary["action"] = last_action["agent_{}".format(idx)] - obs["agent_{}".format(idx)] = action_dictionary - return obs + def _add_last_action_to_obs_2p(obs, last_action): + for idx in range(self.unwrapped.env_settings.n_players): + action_dictionary = {} + action_dictionary["action"] = last_action["agent_{}".format(idx)] + obs["agent_{}".format(idx)] = action_dictionary + return obs + self._add_last_action_to_obs = _add_last_action_to_obs_2p def reset(self, **kwargs): obs, info = self.env.reset(**kwargs) @@ -154,6 +155,20 @@ def __init__(self, env, n_actions_stack): self.action_stack = [deque([no_op_action] * n_actions_stack, maxlen=n_actions_stack)] action_space_size = [self.observation_space["action"].n] self.observation_space["action"] = gym.spaces.MultiDiscrete(action_space_size * n_actions_stack) + + if isinstance(self.action_space, gym.spaces.MultiDiscrete): + def _add_action_to_stack_1p(action): + self.action_stack[0].append(action[0]) + self.action_stack[0].append(action[1]) + else: + def _add_action_to_stack_1p(action): + self.action_stack[0].append(action) + self._add_action_to_stack = _add_action_to_stack_1p + + def _process_obs_1p(obs): + obs["action"] = np.array(self.action_stack[0]) + return obs + self._process_obs = _process_obs_1p else: self.action_stack = [] assert "action" in self.observation_space["agent_0"].keys(), "ActionsStack wrapper can be activated only "\ @@ -167,64 +182,43 @@ def __init__(self, env, n_actions_stack): action_space_size = [self.observation_space["agent_{}".format(idx)]["action"].n] self.observation_space["agent_{}".format(idx)]["action"] = gym.spaces.MultiDiscrete(action_space_size * n_actions_stack) - def fill_stack(self): - # Fill the actions stack with no action after reset - no_op_action = self.unwrapped.get_no_op_action() - if self.unwrapped.env_settings.n_players == 1: - if isinstance(self.action_space, gym.spaces.MultiDiscrete): - for _ in range(self.n_actions_stack): - self.action_stack[0].append(no_op_action[0]) - self.action_stack[0].append(no_op_action[1]) - else: - for _ in range(self.n_actions_stack): - self.action_stack[0].append(no_op_action) - else: - for idx in range(self.unwrapped.env_settings.n_players): - if isinstance(self.action_space["agent_{}".format(idx)], gym.spaces.MultiDiscrete): - for _ in range(self.n_actions_stack): - self.action_stack[idx].append(no_op_action["agent_{}".format(idx)][0]) - self.action_stack[idx].append(no_op_action["agent_{}".format(idx)][1]) - else: - for _ in range(self.n_actions_stack): - self.action_stack[idx].append(no_op_action["agent_{}".format(idx)]) - - def _process_obs(self, obs): - if self.unwrapped.env_settings.n_players == 1: - obs["action"] = np.array(self.action_stack[0]) - else: - for idx in range(self.unwrapped.env_settings.n_players): - obs["agent_{}".format(idx)]["action"] = np.array(self.action_stack[idx]) - return obs + def _add_action_to_stack_2p(action): + for idx in range(self.unwrapped.env_settings.n_players): + if isinstance(self.action_space["agent_{}".format(idx)], gym.spaces.MultiDiscrete): + self.action_stack[idx].append(action["agent_{}".format(idx)][0]) + self.action_stack[idx].append(action["agent_{}".format(idx)][1]) + else: + self.action_stack[idx].append(action["agent_{}".format(idx)]) + self._add_action_to_stack = _add_action_to_stack_2p + + def _process_obs_2p(obs): + for idx in range(self.unwrapped.env_settings.n_players): + obs["agent_{}".format(idx)]["action"] = np.array(self.action_stack[idx]) + return obs + self._process_obs = _process_obs_2p def reset(self, **kwargs): obs, info = self.env.reset(**kwargs) - self.fill_stack() + self._fill_stack() return self._process_obs(obs), info def step(self, action): obs, reward, terminated, truncated, info = self.env.step(action) - if self.unwrapped.env_settings.n_players == 1: - if isinstance(self.action_space, gym.spaces.MultiDiscrete): - self.action_stack[0].append(action[0]) - self.action_stack[0].append(action[1]) - else: - self.action_stack[0].append(action) - else: - for idx in range(self.unwrapped.env_settings.n_players): - if isinstance(self.action_space["agent_{}".format(idx)], gym.spaces.MultiDiscrete): - self.action_stack[idx].append(action["agent_{}".format(idx)][0]) - self.action_stack[idx].append(action["agent_{}".format(idx)][1]) - else: - self.action_stack[idx].append(action["agent_{}".format(idx)]) + self._add_action_to_stack(action) # Add noAction for n_actions_stack - 1 times # in case of new round / stage / continueGame if ((info["round_done"] or info["stage_done"] or info["game_done"]) and not (terminated or truncated)): - self.fill_stack() + self._fill_stack() return self._process_obs(obs), reward, terminated, truncated, info -class ScaledFloatObs(gym.ObservationWrapper): + def _fill_stack(self): + no_op_action = self.unwrapped.get_no_op_action() + for _ in range(self.n_actions_stack): + self._add_action_to_stack(no_op_action) + +class NormalizeObservation(gym.ObservationWrapper): def __init__(self, env, exclude_image_scaling=False, process_discrete_binary=False): gym.ObservationWrapper.__init__(self, env) @@ -232,47 +226,45 @@ def __init__(self, env, exclude_image_scaling=False, process_discrete_binary=Fal self.process_discrete_binary = process_discrete_binary self.original_observation_space = deepcopy(self.observation_space) - self.scaled_float_obs_space_func(self.observation_space) + self._obs_space_normalization_func(self.observation_space) + + def observation(self, observation): + return self._obs_normalization_func(observation, self.original_observation_space) # Recursive function to modify obs space dict # FIXME: this can probably be dropped with gym >= 0.21 and only use the next one, here for SB2 compatibility - def scaled_float_obs_space_func(self, obs_dict): + def _obs_space_normalization_func(self, obs_dict): # Updating observation space dict - for k, v in obs_dict.spaces.items(): - + for k, v in obs_dict.items(): if isinstance(v, gym.spaces.Dict): - self.scaled_float_obs_space_func(v) + self._obs_space_normalization_func(v) else: if isinstance(v, gym.spaces.MultiDiscrete): # One hot encoding x nStack - n_val = v.nvec.shape[0] - max_val = v.nvec[0] - obs_dict.spaces[k] = gym.spaces.MultiBinary(n_val * max_val) + obs_dict[k] = gym.spaces.MultiBinary(np.sum(v.nvec)) elif isinstance(v, gym.spaces.Discrete) and (v.n > 2 or self.process_discrete_binary is True): # One hot encoding - obs_dict.spaces[k] = gym.spaces.MultiBinary(v.n) + obs_dict[k] = gym.spaces.MultiBinary(v.n) elif isinstance(v, gym.spaces.Box) and (self.exclude_image_scaling is False or len(v.shape) < 3): - obs_dict.spaces[k] = gym.spaces.Box(low=0.0, high=1.0, shape=v.shape, dtype=np.float32) + obs_dict[k] = gym.spaces.Box(low=0.0, high=1.0, shape=v.shape, dtype=np.float32) # Recursive function to modify obs dict - def scaled_float_obs_func(self, observation, observation_space): - + def _obs_normalization_func(self, observation, observation_space): # Process all observations for k, v in observation.items(): - if isinstance(v, dict): - self.scaled_float_obs_func(v, observation_space.spaces[k]) + self._obs_normalization_func(v, observation_space.spaces[k]) else: - v_space = observation_space.spaces[k] + v_space = observation_space[k] if isinstance(v_space, gym.spaces.MultiDiscrete): - n_act = observation_space.spaces[k].nvec[0] - buf_len = observation_space.spaces[k].nvec.shape[0] - actions_vector = np.zeros((buf_len * n_act), dtype=np.uint8) - for iact in range(buf_len): - actions_vector[iact * n_act + observation[k][iact]] = 1 + actions_vector = np.zeros((np.sum(v_space.nvec)), dtype=np.uint8) + column_index = 0 + for iact in range(v_space.nvec.shape[0]): + actions_vector[column_index + observation[k][iact]] = 1 + column_index += v_space.nvec[iact] observation[k] = actions_vector elif isinstance(v_space, gym.spaces.Discrete) and (v_space.n > 2 or self.process_discrete_binary is True): - var_vector = np.zeros((observation_space.spaces[k].n), dtype=np.uint8) + var_vector = np.zeros((v_space.n), dtype=np.uint8) var_vector[observation[k]] = 1 observation[k] = var_vector elif isinstance(v_space, gym.spaces.Box) and (self.exclude_image_scaling is False or len(v_space.shape) < 3): @@ -282,9 +274,54 @@ def scaled_float_obs_func(self, observation, observation_space): return observation +class RoleRelativeObservation(gym.ObservationWrapper): + def __init__(self, env): + gym.ObservationWrapper.__init__(self, env) + + self.observation_space["own"] = self.observation_space["P1"] + self.observation_space["opp"] = self.observation_space["P1"] + del self.observation_space["P1"] + del self.observation_space["P2"] + def observation(self, observation): + return None + + """ + def rename_key_recursive(dictionary, old_key, new_key): + if isinstance(dictionary, dict): + new_dict = {} + for key, value in dictionary.items(): + if key == old_key: + key = new_key + new_dict[key] = rename_key_recursive(value, old_key, new_key) + return new_dict + else: + return dictionary + """ + +class FlattenFilterDictObs(gym.ObservationWrapper): + def __init__(self, env, filter_keys): + gym.ObservationWrapper.__init__(self, env) - return self.scaled_float_obs_func(observation, self.original_observation_space) + self.filter_keys = filter_keys + if (filter_keys is not None): + self.filter_keys = list(set(filter_keys)) + if "frame" not in filter_keys: + self.filter_keys += ["frame"] + + original_obs_space_keys = (flatten_filter_obs_space_func(self.observation_space, None)).keys() + self.observation_space = gym.spaces.Dict(flatten_filter_obs_space_func(self.observation_space, self.filter_keys)) + + if filter_keys is not None: + if (sorted(self.observation_space.spaces.keys()) != sorted(self.filter_keys)): + raise Exception("Specified observation key(s) not found:", + " Available key(s):", sorted(original_obs_space_keys), + " Specified key(s):", sorted(self.filter_keys), + " Key(s) not found:", sorted([key for key in self.filter_keys if key not in original_obs_space_keys]), + ) + + def observation(self, observation): + return flatten_filter_obs_func(observation, self.filter_keys) def flatten_filter_obs_space_func(input_dictionary, filter_keys): _FLAG_FIRST = object() @@ -336,29 +373,4 @@ def visit(subdict, flattened_dict, partial_key, check_method): else: visit(input_dictionary, flattened_dict, _FLAG_FIRST, dummy_check) - return flattened_dict - -class FlattenFilterDictObs(gym.ObservationWrapper): - def __init__(self, env, filter_keys): - gym.ObservationWrapper.__init__(self, env) - - self.filter_keys = filter_keys - if (filter_keys is not None): - self.filter_keys = list(set(filter_keys)) - if "frame" not in filter_keys: - self.filter_keys += ["frame"] - - original_obs_space_keys = (flatten_filter_obs_space_func(self.observation_space, None)).keys() - self.observation_space = gym.spaces.Dict(flatten_filter_obs_space_func(self.observation_space, self.filter_keys)) - - if filter_keys is not None: - if (sorted(self.observation_space.spaces.keys()) != sorted(self.filter_keys)): - raise Exception("Specified observation key(s) not found:", - " Available key(s):", sorted(original_obs_space_keys), - " Specified key(s):", sorted(self.filter_keys), - " Key(s) not found:", sorted([key for key in self.filter_keys if key not in original_obs_space_keys]), - ) - - def observation(self, observation): - - return flatten_filter_obs_func(observation, self.filter_keys) + return flattened_dict \ No newline at end of file diff --git a/examples/wrappers_options.py b/examples/wrappers_options.py index 62aa0f0d..e884ea44 100644 --- a/examples/wrappers_options.py +++ b/examples/wrappers_options.py @@ -3,8 +3,8 @@ def main(): # Environment settings - settings = {"n_players": 1, "action_space": SpaceTypes.MULTI_DISCRETE} - #settings = {"n_players": 2, "action_space": (SpaceTypes.MULTI_DISCRETE, SpaceTypes.DISCRETE)} + #settings = {"n_players": 1, "action_space": SpaceTypes.MULTI_DISCRETE} + settings = {"n_players": 2, "action_space": (SpaceTypes.MULTI_DISCRETE, SpaceTypes.DISCRETE)} # Gym wrappers settings wrappers_settings = {} @@ -57,13 +57,23 @@ def main(): # NOTE: needs "add_last_action_to_observation" wrapper to be active wrappers_settings["actions_stack"] = 6 - """ - # If to scale observation numerical values (deactivated by default) - # optionally exclude images from normalization (deactivated by default) - # and optionally perform one-hot encoding also on discrete binary variables (deactivated by default) + # If to scale observation numerical values (False by default) + # optionally exclude images from normalization (False by default) + # and optionally perform one-hot encoding also on discrete binary variables (False by default) wrappers_settings["scale"] = True - wrappers_settings["exclude_image_scaling"] = True - wrappers_settings["process_discrete_binary"] = True + #wrappers_settings["exclude_image_scaling"] = True + #wrappers_settings["process_discrete_binary"] = True + + """ + # If to make the observation relative to the agent as a function to its role (P1 or P2) (deactivate by default) + # i.e.: + # - In 1P environments, if the agent is P1 then the observation "P1" nesting level becomes "own" and "P2" becomes "opp" + # - In 2P environments, if "agent_0" role is P1 and "agent_1" role is P2, then the player specific nesting levels observation (P1 - P2) + # are grouped under "agent_0" and "agent_1", and: + # - Under "agent_0", "P1" nesting level becomes "own" and "P2" becomes "opp" + # - Under "agent_1", "P1" nesting level becomes "opp" and "P2" becomes "own" + wrappers_settings["role_relative_observation"] = True + # Flattening observation dictionary and filtering # a sub-set of the RAM states