Skip to content

Commit

Permalink
Update SB interface with new settings
Browse files Browse the repository at this point in the history
  • Loading branch information
alexpalms committed Sep 23, 2023
1 parent 292e5b0 commit 21d0c31
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 21 deletions.
24 changes: 12 additions & 12 deletions diambra/arena/stable_baselines/make_sb_env.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import time
import diambra.arena
from diambra.arena import SpaceTypes, EnvironmentSettings, WrappersSettings, RecordingSettings
import gym

from stable_baselines import logger
Expand All @@ -9,9 +10,11 @@
from stable_baselines.common import set_global_seeds

# Make Stable Baselines Env function
def make_sb_env(game_id: str, env_settings: dict={}, wrappers_settings: dict={},
episode_recording_settings: dict={}, render_mode: str="rgb_array", seed: int=None,
start_index: int=0, allow_early_resets: bool=True, start_method: str=None,
def make_sb_env(game_id: str, env_settings: EnvironmentSettings=EnvironmentSettings(),
wrappers_settings: WrappersSettings=WrappersSettings(),
episode_recording_settings: RecordingSettings=RecordingSettings(),
render_mode: str="rgb_array", seed: int=None, start_index: int=0,
allow_early_resets: bool=True, start_method: str=None,
no_vec: bool=False, use_subprocess: bool=False):
"""
Create a wrapped, monitored VecEnv.
Expand All @@ -36,15 +39,11 @@ def make_sb_env(game_id: str, env_settings: dict={}, wrappers_settings: dict={},
# Seed management
if seed is None:
seed = int(time.time())
env_settings["seed"] = seed
env_settings.seed = seed

# Add the conversion from gymnasium to gym
old_gym_wrapper = [OldGymWrapper, {}]
if 'wrappers' in wrappers_settings:
wrappers_settings['wrappers'].insert(0, old_gym_wrapper)
else:
# If it's not present, add the key with a new list containing your custom element
wrappers_settings['wrappers'] = [old_gym_wrapper]
wrappers_settings.wrappers.insert(0, old_gym_wrapper)

def _make_sb_env(rank):
def _init():
Expand Down Expand Up @@ -78,11 +77,12 @@ def __init__(self, env):
:param env: (Gym<=0.21 Environment) the resulting environment
"""
gym.Wrapper.__init__(self, env)
if self.env_settings.action_space == diambra.arena.SpaceType.MULTI_DISCRETE:
if self.env_settings.action_space == SpaceTypes.MULTI_DISCRETE:
self.action_space = gym.spaces.MultiDiscrete(self.n_actions)
elif self.env_settings.action_space == diambra.arena.SpaceType.DISCRETE:
elif self.env_settings.action_space == SpaceTypes.DISCRETE:
self.action_space = gym.spaces.Discrete(self.n_actions[0] + self.n_actions[1] - 1)
self.logger.debug("Using {} action space".format(diambra.arena.SpaceType.Name(self.env_settings.action_space)))
self.logger.debug("Using {} action space".format(SpaceTypes.Name(self.env_settings.action_space)))


def reset(self, **kwargs):
obs, _ = self.env.reset(**kwargs)
Expand Down
26 changes: 19 additions & 7 deletions diambra/arena/wrappers/arena_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np
import gymnasium as gym
import logging
from diambra.arena.env_settings import WrappersSettings
from diambra.arena import SpaceTypes, WrappersSettings
from diambra.arena.wrappers.observation import WarpFrame, GrayscaleFrame, FrameStack, ActionsStack, \
NormalizeObservation, FlattenFilterDictObs, \
AddLastActionToObservation, RoleRelativeObservation
Expand All @@ -17,12 +17,24 @@ def __init__(self, env):
gym.Wrapper.__init__(self, env)
# N actions
self.n_actions = [self.unwrapped.env_info.available_actions.n_moves, self.unwrapped.env_info.available_actions.n_attacks_no_comb]
if self.unwrapped.env_settings.action_space == "multi_discrete":
self.action_space = gym.spaces.MultiDiscrete(self.n_actions)
self.unwrapped.logger.debug("Using MultiDiscrete action space without attack buttons combinations")
elif self.unwrapped.env_settings.action_space == "discrete":
self.action_space = gym.spaces.Discrete(self.n_actions[0] + self.n_actions[1] - 1)
self.unwrapped.logger.debug("Using Discrete action space without attack buttons combinations")
if self.unwrapped.env_settings.n_players == 1:
if self.unwrapped.env_settings.action_space == SpaceTypes.MULTI_DISCRETE:
self.action_space = gym.spaces.MultiDiscrete(self.n_actions)
elif self.unwrapped.env_settings.action_space == SpaceTypes.DISCRETE:
self.action_space = gym.spaces.Discrete(self.n_actions[0] + self.n_actions[1] - 1)
else:
raise Exception("Action space not recognized in \"NoAttackButtonsCombinations\" wrapper")
self.unwrapped.logger.debug("Using {} action space without attack buttons combinations".format(SpaceTypes.Name(self.unwrapped.env_settings.action_space)))
else:
self.unwrapped.logger.warning("Warning: \"NoAttackButtonsCombinations\" is by default applied on all agents actions space")
for idx in range(self.unwrapped.env_settings.n_players):
if self.unwrapped.env_settings.action_space[idx] == SpaceTypes.MULTI_DISCRETE:
self.action_space["agent_{}".format(idx)] = gym.spaces.MultiDiscrete(self.n_actions)
elif self.unwrapped.env_settings.action_space[idx] == SpaceTypes.DISCRETE:
self.action_space["agent_{}".format(idx)] = gym.spaces.Discrete(self.n_actions[0] + self.n_actions[1] - 1)
else:
raise Exception("Action space not recognized in \"NoAttackButtonsCombinations\" wrapper")
self.unwrapped.logger.debug("Using {} action space for agent_{} without attack buttons combinations".format(SpaceTypes.Name(self.unwrapped.env_settings.action_space[idx]), idx))

def reset(self, **kwargs):
return self.env.reset(**kwargs)
Expand Down
3 changes: 1 addition & 2 deletions diambra/arena/wrappers/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,6 @@ def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action)
return self._add_last_action_to_obs(obs, action), reward, terminated, truncated, info


class ActionsStack(gym.Wrapper):
def __init__(self, env, n_actions_stack):
"""Stack n_actions_stack last actions.
Expand All @@ -145,7 +144,7 @@ def __init__(self, env, n_actions_stack):
no_op_action = self.unwrapped.get_no_op_action()

if self.unwrapped.env_settings.n_players == 1:
assert "action" in self.observation_space.keys(), "ActionsStack wrapper can be activated only "\
assert "action" in self.observation_space.spaces, "ActionsStack wrapper can be activated only "\
"when \"action\" info is in the observation space"

if isinstance(self.action_space, gym.spaces.MultiDiscrete):
Expand Down

0 comments on commit 21d0c31

Please sign in to comment.