Skip to content

Commit

Permalink
Update Ray Examples
Browse files Browse the repository at this point in the history
  • Loading branch information
alexpalms committed Sep 29, 2023
1 parent d2becf8 commit 69c86ad
Show file tree
Hide file tree
Showing 10 changed files with 40 additions and 42 deletions.
23 changes: 12 additions & 11 deletions ray_rllib/agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import diambra.arena
from diambra.arena import SpaceTypes, EnvironmentSettings, WrappersSettings
from diambra.arena.ray_rllib.make_ray_env import DiambraArena, preprocess_ray_config
from ray.rllib.algorithms.ppo import PPO

Expand All @@ -13,17 +14,19 @@

def main(trained_model, env_spaces, test=False):
# Settings
env_settings = {}
env_settings["frame_shape"] = (84, 84, 1)
env_settings["characters"] = ("Kasumi")
env_settings["action_space"] = "discrete"
env_settings = EnvironmentSettings()
env_settings.frame_shape = (84, 84, 1)
env_settings.characters = ("Kasumi")
env_settings.action_space = SpaceTypes.DISCRETE

# Wrappers Settings
wrappers_settings = {}
wrappers_settings["reward_normalization"] = True
wrappers_settings["actions_stack"] = 12
wrappers_settings["frame_stack"] = 5
wrappers_settings["scale"] = True
wrappers_settings = WrappersSettings()
wrappers_settings.reward_normalization = True
wrappers_settings.add_last_action_to_observation = True
wrappers_settings.actions_stack = 12
wrappers_settings.frame_stack = 5
wrappers_settings.scale = True
wrappers_settings.role_relative_observation = True

config = {
# Define and configure the environment
Expand All @@ -50,14 +53,12 @@ def main(trained_model, env_spaces, test=False):
print("Policy architecture =\n{}".format(agent.get_policy().model))

env = diambra.arena.make("doapp", env_settings, wrappers_settings, render_mode="human")

obs, info = env.reset()

while True:
env.render()

action = agent.compute_single_action(observation=obs, explore=True, policy_id="default_policy")

obs, reward, terminated, truncated, info = env.step(action)

if terminated or truncated:
Expand Down
8 changes: 4 additions & 4 deletions ray_rllib/basic.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import diambra.arena
from diambra.arena import SpaceTypes, EnvironmentSettings
import gymnasium as gym
from diambra.arena.ray_rllib.make_ray_env import DiambraArena, preprocess_ray_config
from ray.rllib.algorithms.ppo import PPO, PPOConfig
from ray.tune.logger import pretty_print

def main():
# Environment Settings
env_settings = {}
env_settings["frame_shape"] = (84, 84, 1)
env_settings["action_space"] = "discrete"
env_settings = EnvironmentSettings()
env_settings.frame_shape = (84, 84, 1)
env_settings.action_space = SpaceTypes.DISCRETE

# env_config
env_config = {
Expand Down Expand Up @@ -48,7 +49,6 @@ def main():
env.render()

action = agent.compute_single_action(observation)

observation, reward, terminated, truncated, info = env.step(action)

if terminated or truncated:
Expand Down
21 changes: 12 additions & 9 deletions ray_rllib/dict_obs_space.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
from diambra.arena import SpaceTypes, EnvironmentSettings, WrappersSettings
from diambra.arena.ray_rllib.make_ray_env import DiambraArena, preprocess_ray_config
from ray.rllib.algorithms.ppo import PPO
from ray.tune.logger import pretty_print

def main():
# Settings
env_settings = {}
env_settings["frame_shape"] = (84, 84, 1)
env_settings["characters"] = ("Kasumi")
env_settings["action_space"] = "discrete"
env_settings = EnvironmentSettings()
env_settings.frame_shape = (84, 84, 1)
env_settings.characters = ("Kasumi")
env_settings.action_space = SpaceTypes.DISCRETE

# Wrappers Settings
wrappers_settings = {}
wrappers_settings["reward_normalization"] = True
wrappers_settings["actions_stack"] = 12
wrappers_settings["frame_stack"] = 5
wrappers_settings["scale"] = True
wrappers_settings = WrappersSettings()
wrappers_settings.reward_normalization = True
wrappers_settings.add_last_action_to_observation = True
wrappers_settings.actions_stack = 12
wrappers_settings.frame_stack = 5
wrappers_settings.scale = True
wrappers_settings.role_relative_observation = True

config = {
# Define and configure the environment
Expand Down
7 changes: 4 additions & 3 deletions ray_rllib/parallel_envs.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from diambra.arena import SpaceTypes, EnvironmentSettings
from diambra.arena.ray_rllib.make_ray_env import DiambraArena, preprocess_ray_config
from ray.rllib.algorithms.ppo import PPO
from ray.tune.logger import pretty_print

def main():
# Settings
env_settings = {}
env_settings["frame_shape"] = (84, 84, 1)
env_settings["action_space"] = "discrete"
env_settings = EnvironmentSettings()
env_settings.frame_shape = (84, 84, 1)
env_settings.action_space = SpaceTypes.DISCRETE

config = {
# Define and configure the environment
Expand Down
Empty file.
3 changes: 0 additions & 3 deletions ray_rllib/results/doapp_sr6_84x5_das_c/.tune_metadata

This file was deleted.

3 changes: 0 additions & 3 deletions ray_rllib/results/doapp_sr6_84x5_das_c/checkpoint-1

This file was deleted.

3 changes: 0 additions & 3 deletions ray_rllib/results/doapp_sr6_84x5_das_c/diambra_ray_env_spaces

This file was deleted.

9 changes: 5 additions & 4 deletions ray_rllib/saving_loading_evaluating.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from diambra.arena import SpaceTypes, EnvironmentSettings
from diambra.arena.ray_rllib.make_ray_env import DiambraArena, preprocess_ray_config
from ray.rllib.algorithms.ppo import PPO
from ray.tune.logger import pretty_print

def main():
# Settings
env_settings = {}
env_settings["frame_shape"] = (84, 84, 1)
env_settings["action_space"] = "discrete"
env_settings = EnvironmentSettings()
env_settings.frame_shape = (84, 84, 1)
env_settings.action_space = SpaceTypes.DISCRETE

config = {
# Define and configure the environment
Expand Down Expand Up @@ -36,7 +37,7 @@ def main():
print("Training results:\n{}".format(pretty_print(results)))

# Save the agent
checkpoint = agent.save()
checkpoint = agent.save().checkpoint.path
print("Checkpoint saved at {}".format(checkpoint))
del agent # delete trained model to demonstrate loading

Expand Down
5 changes: 3 additions & 2 deletions tests/test_ray_rllib.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@ def func(script, mocker, *args):
print(e)
return 1

trained_model_folder = os.path.join(root_dir, "results/doapp_sr6_84x5_das_c/")
trained_model_folder = os.path.join(root_dir, "results/test/")
env_spaces_descriptor_path = os.path.join(trained_model_folder, "diambra_ray_env_spaces")
#[parallel_envs, ()] # Not possible to test parallel_envs script as it requires multiple envs and the mocker does not work with child processes / threads
scripts = [[basic, ()], [saving_loading_evaluating, ()], [dict_obs_space, ()], [agent, (trained_model_folder, env_spaces_descriptor_path, True)]]
#[agent, (trained_model_folder, env_spaces_descriptor_path, True)] # Removed because of too big save file
scripts = [[basic, ()], [saving_loading_evaluating, ()], [dict_obs_space, ()]]

@pytest.mark.parametrize("script", scripts)
def test_ray_rllib_scripts(script, mocker):
Expand Down

0 comments on commit 69c86ad

Please sign in to comment.