From 69c86ad3e9f1076fb3d4d64c882b78c776c056d5 Mon Sep 17 00:00:00 2001 From: Alessandro Palmas Date: Sat, 23 Sep 2023 16:28:57 -0400 Subject: [PATCH] Update Ray Examples --- ray_rllib/agent.py | 23 ++++++++++--------- ray_rllib/basic.py | 8 +++---- ray_rllib/dict_obs_space.py | 21 +++++++++-------- ray_rllib/parallel_envs.py | 7 +++--- .../doapp_sr6_84x5_das_c/.is_checkpoint | 0 .../doapp_sr6_84x5_das_c/.tune_metadata | 3 --- .../results/doapp_sr6_84x5_das_c/checkpoint-1 | 3 --- .../diambra_ray_env_spaces | 3 --- ray_rllib/saving_loading_evaluating.py | 9 ++++---- tests/test_ray_rllib.py | 5 ++-- 10 files changed, 40 insertions(+), 42 deletions(-) delete mode 100644 ray_rllib/results/doapp_sr6_84x5_das_c/.is_checkpoint delete mode 100644 ray_rllib/results/doapp_sr6_84x5_das_c/.tune_metadata delete mode 100644 ray_rllib/results/doapp_sr6_84x5_das_c/checkpoint-1 delete mode 100644 ray_rllib/results/doapp_sr6_84x5_das_c/diambra_ray_env_spaces diff --git a/ray_rllib/agent.py b/ray_rllib/agent.py index 69073b4..9816362 100644 --- a/ray_rllib/agent.py +++ b/ray_rllib/agent.py @@ -1,5 +1,6 @@ import argparse import diambra.arena +from diambra.arena import SpaceTypes, EnvironmentSettings, WrappersSettings from diambra.arena.ray_rllib.make_ray_env import DiambraArena, preprocess_ray_config from ray.rllib.algorithms.ppo import PPO @@ -13,17 +14,19 @@ def main(trained_model, env_spaces, test=False): # Settings - env_settings = {} - env_settings["frame_shape"] = (84, 84, 1) - env_settings["characters"] = ("Kasumi") - env_settings["action_space"] = "discrete" + env_settings = EnvironmentSettings() + env_settings.frame_shape = (84, 84, 1) + env_settings.characters = ("Kasumi") + env_settings.action_space = SpaceTypes.DISCRETE # Wrappers Settings - wrappers_settings = {} - wrappers_settings["reward_normalization"] = True - wrappers_settings["actions_stack"] = 12 - wrappers_settings["frame_stack"] = 5 - wrappers_settings["scale"] = True + wrappers_settings = WrappersSettings() + wrappers_settings.reward_normalization = True + wrappers_settings.add_last_action_to_observation = True + wrappers_settings.actions_stack = 12 + wrappers_settings.frame_stack = 5 + wrappers_settings.scale = True + wrappers_settings.role_relative_observation = True config = { # Define and configure the environment @@ -50,14 +53,12 @@ def main(trained_model, env_spaces, test=False): print("Policy architecture =\n{}".format(agent.get_policy().model)) env = diambra.arena.make("doapp", env_settings, wrappers_settings, render_mode="human") - obs, info = env.reset() while True: env.render() action = agent.compute_single_action(observation=obs, explore=True, policy_id="default_policy") - obs, reward, terminated, truncated, info = env.step(action) if terminated or truncated: diff --git a/ray_rllib/basic.py b/ray_rllib/basic.py index 26870fb..8799513 100644 --- a/ray_rllib/basic.py +++ b/ray_rllib/basic.py @@ -1,4 +1,5 @@ import diambra.arena +from diambra.arena import SpaceTypes, EnvironmentSettings import gymnasium as gym from diambra.arena.ray_rllib.make_ray_env import DiambraArena, preprocess_ray_config from ray.rllib.algorithms.ppo import PPO, PPOConfig @@ -6,9 +7,9 @@ def main(): # Environment Settings - env_settings = {} - env_settings["frame_shape"] = (84, 84, 1) - env_settings["action_space"] = "discrete" + env_settings = EnvironmentSettings() + env_settings.frame_shape = (84, 84, 1) + env_settings.action_space = SpaceTypes.DISCRETE # env_config env_config = { @@ -48,7 +49,6 @@ def main(): env.render() action = agent.compute_single_action(observation) - observation, reward, terminated, truncated, info = env.step(action) if terminated or truncated: diff --git a/ray_rllib/dict_obs_space.py b/ray_rllib/dict_obs_space.py index f01fce0..bca0425 100644 --- a/ray_rllib/dict_obs_space.py +++ b/ray_rllib/dict_obs_space.py @@ -1,20 +1,23 @@ +from diambra.arena import SpaceTypes, EnvironmentSettings, WrappersSettings from diambra.arena.ray_rllib.make_ray_env import DiambraArena, preprocess_ray_config from ray.rllib.algorithms.ppo import PPO from ray.tune.logger import pretty_print def main(): # Settings - env_settings = {} - env_settings["frame_shape"] = (84, 84, 1) - env_settings["characters"] = ("Kasumi") - env_settings["action_space"] = "discrete" + env_settings = EnvironmentSettings() + env_settings.frame_shape = (84, 84, 1) + env_settings.characters = ("Kasumi") + env_settings.action_space = SpaceTypes.DISCRETE # Wrappers Settings - wrappers_settings = {} - wrappers_settings["reward_normalization"] = True - wrappers_settings["actions_stack"] = 12 - wrappers_settings["frame_stack"] = 5 - wrappers_settings["scale"] = True + wrappers_settings = WrappersSettings() + wrappers_settings.reward_normalization = True + wrappers_settings.add_last_action_to_observation = True + wrappers_settings.actions_stack = 12 + wrappers_settings.frame_stack = 5 + wrappers_settings.scale = True + wrappers_settings.role_relative_observation = True config = { # Define and configure the environment diff --git a/ray_rllib/parallel_envs.py b/ray_rllib/parallel_envs.py index bc2e9f3..1ee2c72 100644 --- a/ray_rllib/parallel_envs.py +++ b/ray_rllib/parallel_envs.py @@ -1,12 +1,13 @@ +from diambra.arena import SpaceTypes, EnvironmentSettings from diambra.arena.ray_rllib.make_ray_env import DiambraArena, preprocess_ray_config from ray.rllib.algorithms.ppo import PPO from ray.tune.logger import pretty_print def main(): # Settings - env_settings = {} - env_settings["frame_shape"] = (84, 84, 1) - env_settings["action_space"] = "discrete" + env_settings = EnvironmentSettings() + env_settings.frame_shape = (84, 84, 1) + env_settings.action_space = SpaceTypes.DISCRETE config = { # Define and configure the environment diff --git a/ray_rllib/results/doapp_sr6_84x5_das_c/.is_checkpoint b/ray_rllib/results/doapp_sr6_84x5_das_c/.is_checkpoint deleted file mode 100644 index e69de29..0000000 diff --git a/ray_rllib/results/doapp_sr6_84x5_das_c/.tune_metadata b/ray_rllib/results/doapp_sr6_84x5_das_c/.tune_metadata deleted file mode 100644 index 7171c65..0000000 --- a/ray_rllib/results/doapp_sr6_84x5_das_c/.tune_metadata +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:88c810103f98be99d5966b4bdc59bbdce1c23ce395717457ac53f1a316505f77 -size 791953 diff --git a/ray_rllib/results/doapp_sr6_84x5_das_c/checkpoint-1 b/ray_rllib/results/doapp_sr6_84x5_das_c/checkpoint-1 deleted file mode 100644 index 47b6e9a..0000000 --- a/ray_rllib/results/doapp_sr6_84x5_das_c/checkpoint-1 +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:18efb15fa42fc64e66952a431595ea38a02c3e6eaef1ee602f4b2553a59b7bfd -size 17565845 diff --git a/ray_rllib/results/doapp_sr6_84x5_das_c/diambra_ray_env_spaces b/ray_rllib/results/doapp_sr6_84x5_das_c/diambra_ray_env_spaces deleted file mode 100644 index d0e88bf..0000000 --- a/ray_rllib/results/doapp_sr6_84x5_das_c/diambra_ray_env_spaces +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:9b6a09fdb4ada5a26807ac60197db185a1d27c18a21ffd19c9d2efd0349e07ff -size 354934 diff --git a/ray_rllib/saving_loading_evaluating.py b/ray_rllib/saving_loading_evaluating.py index 63a711e..5f7056a 100644 --- a/ray_rllib/saving_loading_evaluating.py +++ b/ray_rllib/saving_loading_evaluating.py @@ -1,12 +1,13 @@ +from diambra.arena import SpaceTypes, EnvironmentSettings from diambra.arena.ray_rllib.make_ray_env import DiambraArena, preprocess_ray_config from ray.rllib.algorithms.ppo import PPO from ray.tune.logger import pretty_print def main(): # Settings - env_settings = {} - env_settings["frame_shape"] = (84, 84, 1) - env_settings["action_space"] = "discrete" + env_settings = EnvironmentSettings() + env_settings.frame_shape = (84, 84, 1) + env_settings.action_space = SpaceTypes.DISCRETE config = { # Define and configure the environment @@ -36,7 +37,7 @@ def main(): print("Training results:\n{}".format(pretty_print(results))) # Save the agent - checkpoint = agent.save() + checkpoint = agent.save().checkpoint.path print("Checkpoint saved at {}".format(checkpoint)) del agent # delete trained model to demonstrate loading diff --git a/tests/test_ray_rllib.py b/tests/test_ray_rllib.py index f702ae6..23d4ecb 100755 --- a/tests/test_ray_rllib.py +++ b/tests/test_ray_rllib.py @@ -27,10 +27,11 @@ def func(script, mocker, *args): print(e) return 1 -trained_model_folder = os.path.join(root_dir, "results/doapp_sr6_84x5_das_c/") +trained_model_folder = os.path.join(root_dir, "results/test/") env_spaces_descriptor_path = os.path.join(trained_model_folder, "diambra_ray_env_spaces") #[parallel_envs, ()] # Not possible to test parallel_envs script as it requires multiple envs and the mocker does not work with child processes / threads -scripts = [[basic, ()], [saving_loading_evaluating, ()], [dict_obs_space, ()], [agent, (trained_model_folder, env_spaces_descriptor_path, True)]] +#[agent, (trained_model_folder, env_spaces_descriptor_path, True)] # Removed because of too big save file +scripts = [[basic, ()], [saving_loading_evaluating, ()], [dict_obs_space, ()]] @pytest.mark.parametrize("script", scripts) def test_ray_rllib_scripts(script, mocker):