From 21d0c316c63ebc11a29e39b3ec5d3878058cb0d6 Mon Sep 17 00:00:00 2001 From: Alessandro Palmas Date: Sat, 23 Sep 2023 12:55:21 -0400 Subject: [PATCH] Update SB interface with new settings --- diambra/arena/stable_baselines/make_sb_env.py | 24 ++++++++--------- diambra/arena/wrappers/arena_wrappers.py | 26 ++++++++++++++----- diambra/arena/wrappers/observation.py | 3 +-- 3 files changed, 32 insertions(+), 21 deletions(-) diff --git a/diambra/arena/stable_baselines/make_sb_env.py b/diambra/arena/stable_baselines/make_sb_env.py index 06b5e5d1..9b26b847 100644 --- a/diambra/arena/stable_baselines/make_sb_env.py +++ b/diambra/arena/stable_baselines/make_sb_env.py @@ -1,6 +1,7 @@ import os import time import diambra.arena +from diambra.arena import SpaceTypes, EnvironmentSettings, WrappersSettings, RecordingSettings import gym from stable_baselines import logger @@ -9,9 +10,11 @@ from stable_baselines.common import set_global_seeds # Make Stable Baselines Env function -def make_sb_env(game_id: str, env_settings: dict={}, wrappers_settings: dict={}, - episode_recording_settings: dict={}, render_mode: str="rgb_array", seed: int=None, - start_index: int=0, allow_early_resets: bool=True, start_method: str=None, +def make_sb_env(game_id: str, env_settings: EnvironmentSettings=EnvironmentSettings(), + wrappers_settings: WrappersSettings=WrappersSettings(), + episode_recording_settings: RecordingSettings=RecordingSettings(), + render_mode: str="rgb_array", seed: int=None, start_index: int=0, + allow_early_resets: bool=True, start_method: str=None, no_vec: bool=False, use_subprocess: bool=False): """ Create a wrapped, monitored VecEnv. @@ -36,15 +39,11 @@ def make_sb_env(game_id: str, env_settings: dict={}, wrappers_settings: dict={}, # Seed management if seed is None: seed = int(time.time()) - env_settings["seed"] = seed + env_settings.seed = seed # Add the conversion from gymnasium to gym old_gym_wrapper = [OldGymWrapper, {}] - if 'wrappers' in wrappers_settings: - wrappers_settings['wrappers'].insert(0, old_gym_wrapper) - else: - # If it's not present, add the key with a new list containing your custom element - wrappers_settings['wrappers'] = [old_gym_wrapper] + wrappers_settings.wrappers.insert(0, old_gym_wrapper) def _make_sb_env(rank): def _init(): @@ -78,11 +77,12 @@ def __init__(self, env): :param env: (Gym<=0.21 Environment) the resulting environment """ gym.Wrapper.__init__(self, env) - if self.env_settings.action_space == diambra.arena.SpaceType.MULTI_DISCRETE: + if self.env_settings.action_space == SpaceTypes.MULTI_DISCRETE: self.action_space = gym.spaces.MultiDiscrete(self.n_actions) - elif self.env_settings.action_space == diambra.arena.SpaceType.DISCRETE: + elif self.env_settings.action_space == SpaceTypes.DISCRETE: self.action_space = gym.spaces.Discrete(self.n_actions[0] + self.n_actions[1] - 1) - self.logger.debug("Using {} action space".format(diambra.arena.SpaceType.Name(self.env_settings.action_space))) + self.logger.debug("Using {} action space".format(SpaceTypes.Name(self.env_settings.action_space))) + def reset(self, **kwargs): obs, _ = self.env.reset(**kwargs) diff --git a/diambra/arena/wrappers/arena_wrappers.py b/diambra/arena/wrappers/arena_wrappers.py index 7e3f4eb7..50ae7d31 100644 --- a/diambra/arena/wrappers/arena_wrappers.py +++ b/diambra/arena/wrappers/arena_wrappers.py @@ -2,7 +2,7 @@ import numpy as np import gymnasium as gym import logging -from diambra.arena.env_settings import WrappersSettings +from diambra.arena import SpaceTypes, WrappersSettings from diambra.arena.wrappers.observation import WarpFrame, GrayscaleFrame, FrameStack, ActionsStack, \ NormalizeObservation, FlattenFilterDictObs, \ AddLastActionToObservation, RoleRelativeObservation @@ -17,12 +17,24 @@ def __init__(self, env): gym.Wrapper.__init__(self, env) # N actions self.n_actions = [self.unwrapped.env_info.available_actions.n_moves, self.unwrapped.env_info.available_actions.n_attacks_no_comb] - if self.unwrapped.env_settings.action_space == "multi_discrete": - self.action_space = gym.spaces.MultiDiscrete(self.n_actions) - self.unwrapped.logger.debug("Using MultiDiscrete action space without attack buttons combinations") - elif self.unwrapped.env_settings.action_space == "discrete": - self.action_space = gym.spaces.Discrete(self.n_actions[0] + self.n_actions[1] - 1) - self.unwrapped.logger.debug("Using Discrete action space without attack buttons combinations") + if self.unwrapped.env_settings.n_players == 1: + if self.unwrapped.env_settings.action_space == SpaceTypes.MULTI_DISCRETE: + self.action_space = gym.spaces.MultiDiscrete(self.n_actions) + elif self.unwrapped.env_settings.action_space == SpaceTypes.DISCRETE: + self.action_space = gym.spaces.Discrete(self.n_actions[0] + self.n_actions[1] - 1) + else: + raise Exception("Action space not recognized in \"NoAttackButtonsCombinations\" wrapper") + self.unwrapped.logger.debug("Using {} action space without attack buttons combinations".format(SpaceTypes.Name(self.unwrapped.env_settings.action_space))) + else: + self.unwrapped.logger.warning("Warning: \"NoAttackButtonsCombinations\" is by default applied on all agents actions space") + for idx in range(self.unwrapped.env_settings.n_players): + if self.unwrapped.env_settings.action_space[idx] == SpaceTypes.MULTI_DISCRETE: + self.action_space["agent_{}".format(idx)] = gym.spaces.MultiDiscrete(self.n_actions) + elif self.unwrapped.env_settings.action_space[idx] == SpaceTypes.DISCRETE: + self.action_space["agent_{}".format(idx)] = gym.spaces.Discrete(self.n_actions[0] + self.n_actions[1] - 1) + else: + raise Exception("Action space not recognized in \"NoAttackButtonsCombinations\" wrapper") + self.unwrapped.logger.debug("Using {} action space for agent_{} without attack buttons combinations".format(SpaceTypes.Name(self.unwrapped.env_settings.action_space[idx]), idx)) def reset(self, **kwargs): return self.env.reset(**kwargs) diff --git a/diambra/arena/wrappers/observation.py b/diambra/arena/wrappers/observation.py index c3e85d24..db724142 100644 --- a/diambra/arena/wrappers/observation.py +++ b/diambra/arena/wrappers/observation.py @@ -132,7 +132,6 @@ def step(self, action): obs, reward, terminated, truncated, info = self.env.step(action) return self._add_last_action_to_obs(obs, action), reward, terminated, truncated, info - class ActionsStack(gym.Wrapper): def __init__(self, env, n_actions_stack): """Stack n_actions_stack last actions. @@ -145,7 +144,7 @@ def __init__(self, env, n_actions_stack): no_op_action = self.unwrapped.get_no_op_action() if self.unwrapped.env_settings.n_players == 1: - assert "action" in self.observation_space.keys(), "ActionsStack wrapper can be activated only "\ + assert "action" in self.observation_space.spaces, "ActionsStack wrapper can be activated only "\ "when \"action\" info is in the observation space" if isinstance(self.action_space, gym.spaces.MultiDiscrete):