Skip to content

Commit

Permalink
Update lib for ray examples and dependencies versions
Browse files Browse the repository at this point in the history
  • Loading branch information
alexpalms committed Sep 23, 2023
1 parent 21d0c31 commit 2585777
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 9 deletions.
6 changes: 3 additions & 3 deletions diambra/arena/ray_rllib/make_ray_env.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import diambra.arena
from diambra.arena import EnvironmentSettings, WrappersSettings
import logging
import gymnasium as gym
from ray.rllib.env.env_context import EnvContext
Expand Down Expand Up @@ -35,8 +36,8 @@ def __init__(self, config: EnvContext):
raise Exception(message)

self.game_id = config["game_id"]
self.settings = config["settings"] if "settings" in config.keys() else {}
self.wrappers_settings = config["wrappers_settings"] if "wrappers_settings" in config.keys() else {}
self.settings = config["settings"] if "settings" in config.keys() else EnvironmentSettings()
self.wrappers_settings = config["wrappers_settings"] if "wrappers_settings" in config.keys() else WrappersSettings()
self.render_mode = config["render_mode"] if "render_mode" in config.keys() else None

num_rollout_workers = config["num_workers"]
Expand Down Expand Up @@ -77,7 +78,6 @@ def __init__(self, config: EnvContext):
env_spaces_file = open(self.env_spaces_file_name, "wb")
pickle.dump(env_spaces_dict, env_spaces_file)
env_spaces_file.close()

else:
print("Loading environment spaces from: {}".format(self.env_spaces_file_name))
self.logger.info("Loading environment spaces from: {}".format(self.env_spaces_file_name))
Expand Down
10 changes: 8 additions & 2 deletions diambra/arena/wrappers/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,10 @@ def __init__(self, env):
"""
gym.Wrapper.__init__(self, env)
if self.unwrapped.env_settings.n_players == 1:
self.observation_space["action"] = self.action_space
self.observation_space = gym.spaces.Dict({
**self.observation_space.spaces,
"action": self.action_space,
})
def _add_last_action_to_obs_1p(obs, last_action):
obs["action"] = last_action
return obs
Expand All @@ -115,7 +118,10 @@ def _add_last_action_to_obs_1p(obs, last_action):
for idx in range(self.unwrapped.env_settings.n_players):
action_dictionary = {}
action_dictionary["action"] = self.action_space["agent_{}".format(idx)]
self.observation_space["agent_{}".format(idx)] = gym.spaces.Dict(action_dictionary)
self.observation_space = gym.spaces.Dict({
**self.observation_space.spaces,
"agent_{}".format(idx): gym.spaces.Dict(action_dictionary),
})
def _add_last_action_to_obs_2p(obs, last_action):
for idx in range(self.unwrapped.env_settings.n_players):
action_dictionary = {}
Expand Down
8 changes: 4 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
extras= {
'core': [],
'tests': ['pytest', 'pytest-mock', 'testresources'],
'stable-baselines': ['stable-baselines==2.10.2', 'gym<=0.21.0', "protobuf==3.20.1", "pyyaml"],
'stable-baselines3': ['stable-baselines3[extra]==2.1.*', "pyyaml"],
'ray-rllib': ['ray[rllib]==2.6.*', 'tensorflow', 'torch', "pyyaml"],
'stable-baselines': ['stable-baselines~=2.10.2', 'gym<=0.21.0', "protobuf==3.20.1", "pyyaml"],
'stable-baselines3': ['stable-baselines3[extra]~=2.1.0', "pyyaml"],
'ray-rllib': ['ray[rllib]~=2.7.0', 'tensorflow', 'torch', "pyyaml"],
}

# NOTE Package data is inside MANIFEST.In
Expand All @@ -40,7 +40,7 @@
'tk',
'opencv-python>=4.4.0.42',
'grpcio',
'diambra-engine>=2.2.0',
'diambra-engine~=2.2.0',
'dacite'],
packages=[package for package in setuptools.find_packages() if package.startswith("diambra")],
include_package_data=True,
Expand Down

0 comments on commit 2585777

Please sign in to comment.