From 2585777e2cd8de4bd542e43ab6f6fe1647c19fe7 Mon Sep 17 00:00:00 2001 From: Alessandro Palmas Date: Sat, 23 Sep 2023 16:30:55 -0400 Subject: [PATCH] Update lib for ray examples and dependencies versions --- diambra/arena/ray_rllib/make_ray_env.py | 6 +++--- diambra/arena/wrappers/observation.py | 10 ++++++++-- setup.py | 8 ++++---- 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/diambra/arena/ray_rllib/make_ray_env.py b/diambra/arena/ray_rllib/make_ray_env.py index 54e0fc12..f77cfd3e 100644 --- a/diambra/arena/ray_rllib/make_ray_env.py +++ b/diambra/arena/ray_rllib/make_ray_env.py @@ -1,5 +1,6 @@ import os import diambra.arena +from diambra.arena import EnvironmentSettings, WrappersSettings import logging import gymnasium as gym from ray.rllib.env.env_context import EnvContext @@ -35,8 +36,8 @@ def __init__(self, config: EnvContext): raise Exception(message) self.game_id = config["game_id"] - self.settings = config["settings"] if "settings" in config.keys() else {} - self.wrappers_settings = config["wrappers_settings"] if "wrappers_settings" in config.keys() else {} + self.settings = config["settings"] if "settings" in config.keys() else EnvironmentSettings() + self.wrappers_settings = config["wrappers_settings"] if "wrappers_settings" in config.keys() else WrappersSettings() self.render_mode = config["render_mode"] if "render_mode" in config.keys() else None num_rollout_workers = config["num_workers"] @@ -77,7 +78,6 @@ def __init__(self, config: EnvContext): env_spaces_file = open(self.env_spaces_file_name, "wb") pickle.dump(env_spaces_dict, env_spaces_file) env_spaces_file.close() - else: print("Loading environment spaces from: {}".format(self.env_spaces_file_name)) self.logger.info("Loading environment spaces from: {}".format(self.env_spaces_file_name)) diff --git a/diambra/arena/wrappers/observation.py b/diambra/arena/wrappers/observation.py index db724142..68475fc9 100644 --- a/diambra/arena/wrappers/observation.py +++ b/diambra/arena/wrappers/observation.py @@ -106,7 +106,10 @@ def __init__(self, env): """ gym.Wrapper.__init__(self, env) if self.unwrapped.env_settings.n_players == 1: - self.observation_space["action"] = self.action_space + self.observation_space = gym.spaces.Dict({ + **self.observation_space.spaces, + "action": self.action_space, + }) def _add_last_action_to_obs_1p(obs, last_action): obs["action"] = last_action return obs @@ -115,7 +118,10 @@ def _add_last_action_to_obs_1p(obs, last_action): for idx in range(self.unwrapped.env_settings.n_players): action_dictionary = {} action_dictionary["action"] = self.action_space["agent_{}".format(idx)] - self.observation_space["agent_{}".format(idx)] = gym.spaces.Dict(action_dictionary) + self.observation_space = gym.spaces.Dict({ + **self.observation_space.spaces, + "agent_{}".format(idx): gym.spaces.Dict(action_dictionary), + }) def _add_last_action_to_obs_2p(obs, last_action): for idx in range(self.unwrapped.env_settings.n_players): action_dictionary = {} diff --git a/setup.py b/setup.py index 46287f9a..de633c83 100644 --- a/setup.py +++ b/setup.py @@ -12,9 +12,9 @@ extras= { 'core': [], 'tests': ['pytest', 'pytest-mock', 'testresources'], - 'stable-baselines': ['stable-baselines==2.10.2', 'gym<=0.21.0', "protobuf==3.20.1", "pyyaml"], - 'stable-baselines3': ['stable-baselines3[extra]==2.1.*', "pyyaml"], - 'ray-rllib': ['ray[rllib]==2.6.*', 'tensorflow', 'torch', "pyyaml"], + 'stable-baselines': ['stable-baselines~=2.10.2', 'gym<=0.21.0', "protobuf==3.20.1", "pyyaml"], + 'stable-baselines3': ['stable-baselines3[extra]~=2.1.0', "pyyaml"], + 'ray-rllib': ['ray[rllib]~=2.7.0', 'tensorflow', 'torch', "pyyaml"], } # NOTE Package data is inside MANIFEST.In @@ -40,7 +40,7 @@ 'tk', 'opencv-python>=4.4.0.42', 'grpcio', - 'diambra-engine>=2.2.0', + 'diambra-engine~=2.2.0', 'dacite'], packages=[package for package in setuptools.find_packages() if package.startswith("diambra")], include_package_data=True,