Skip to content

Commit

Permalink
Merge branch 'main' of github.com:diambra/agents into main
Browse files Browse the repository at this point in the history
  • Loading branch information
alexpalms committed Apr 26, 2024
2 parents c62b8d1 + 099920a commit 81fbd70
Show file tree
Hide file tree
Showing 8 changed files with 816 additions and 12 deletions.
123 changes: 123 additions & 0 deletions sheeprl/agent-dreamer_v3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import argparse
import json

import gymnasium as gym
import torch
from lightning import Fabric
from omegaconf import OmegaConf
from sheeprl.algos.dreamer_v3.agent import build_agent
from sheeprl.algos.dreamer_v3.utils import prepare_obs
from sheeprl.utils.env import make_env
from sheeprl.utils.utils import dotdict

"""This is an example agent based on SheepRL.
Usage:
cd sheeprl
diambra run python agent-dreamer_v3.py --cfg_path "/absolute/path/to/example-logs/runs/dreamer_v3/doapp/experiment/version_0/config.yaml" --checkpoint_path "/absolute/path/to/example-logs/runs/dreamer_v3/doapp/experiment/version_0/checkpoint/ckpt_1024_0.ckpt"
"""


def main(cfg_path: str, checkpoint_path: str, test=False):
# Read the cfg file
cfg = dotdict(OmegaConf.to_container(OmegaConf.load(cfg_path), resolve=True))
print("Config parameters = ", json.dumps(cfg, sort_keys=True, indent=4))

# Override configs for evaluation
# You do not need to capture the video since you are submitting the agent and the video is recorded by DIAMBRA
cfg.env.capture_video = False
# Only one environment is used for evaluation
cfg.env.num_envs = 1

# Instantiate Fabric
# You must use the same precision and plugins used for training.
precision = getattr(cfg.fabric, "precision", None)
plugins = getattr(cfg.fabric, "plugins", None)
fabric = Fabric(
accelerator="auto",
devices=1,
num_nodes=1,
precision=precision,
plugins=plugins,
strategy="auto",
)

# Create Environment
env = make_env(cfg, 0, 0)()
observation_space = env.observation_space
is_multidiscrete = isinstance(env.action_space, gym.spaces.MultiDiscrete)
actions_dim = tuple(
env.action_space.nvec.tolist() if is_multidiscrete else [env.action_space.n]
)
cnn_keys = cfg.algo.cnn_keys.encoder

# Load the trained agent
state = fabric.load(checkpoint_path)
# You need to retrieve only the player
# Check for each algorithm what models the `build_agent()` function returns
# (placed in the `agent.py` file of the algorithm), and which arguments it needs.
# Check also which are the keys of the checkpoint: if the `build_agent()` parameter
# is called `model_state`, then you retrieve the model state with `state["model"]`.
agent = build_agent(
fabric=fabric,
actions_dim=actions_dim,
is_continuous=False,
cfg=cfg,
obs_space=observation_space,
world_model_state=state["world_model"],
actor_state=state["actor"],
critic_state=state["critic"],
target_critic_state=state["target_critic"],
)[-1]
agent.eval()

# Print policy network architecture
print("Policy architecture:")
print(agent)

obs, info = env.reset()
# Every time you reset the environment, you must reset the initial states of the model
agent.init_states()

while True:
# Convert numpy observations into torch observations and normalize image observations
# Every algorithm has its own way to do it, you must import the correct method
torch_obs = prepare_obs(fabric, obs, cnn_keys=cnn_keys)

# Select actions, the agent returns a one-hot categorical or
# more one-hot categorical distributions for muli-discrete actions space
actions = agent.get_actions(torch_obs, greedy=False)
# Convert actions from one-hot categorical to categorial
actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1)

obs, _, terminated, truncated, info = env.step(
actions.cpu().numpy().reshape(env.action_space.shape)
)

if terminated or truncated:
obs, info = env.reset()
# Every time you reset the environment, you must reset the initial states of the model
agent.init_states()
if info["env_done"] or test is True:
break

# Close the environment
env.close()

# Return success
return 0


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--cfg_path", type=str, required=True, help="Configuration file"
)
parser.add_argument(
"--checkpoint_path", type=str, default="model", help="Model checkpoint"
)
parser.add_argument("--test", action="store_true", help="Test mode")
opt = parser.parse_args()
print(opt)

main(opt.cfg_path, opt.checkpoint_path, opt.test)
116 changes: 116 additions & 0 deletions sheeprl/agent-ppo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import argparse
import json

import gymnasium as gym
import torch
from lightning import Fabric
from omegaconf import OmegaConf
from sheeprl.algos.ppo.agent import build_agent
from sheeprl.algos.ppo.utils import prepare_obs
from sheeprl.utils.env import make_env
from sheeprl.utils.utils import dotdict

"""This is an example agent based on SheepRL.
Usage:
cd sheeprl
diambra run python agent-ppo.py --cfg_path "/absolute/path/to/example-logs/runs/ppo/doapp/experiment/version_0/config.yaml" --checkpoint_path "/absolute/path/to/example-logs/runs/ppo/doapp/experiment/version_0/checkpoint/ckpt_1024_0.ckpt"
"""


def main(cfg_path: str, checkpoint_path: str, test=False):
# Read the cfg file
cfg = dotdict(OmegaConf.to_container(OmegaConf.load(cfg_path), resolve=True))
print("Config parameters = ", json.dumps(cfg, sort_keys=True, indent=4))

# Override configs for evaluation
# You do not need to capture the video since you are submitting the agent and the video is recorded by DIAMBRA
cfg.env.capture_video = False
# Only one environment is used for evaluation
cfg.env.num_envs = 1

# Instantiate Fabric
# You must use the same precision and plugins used for training.
precision = getattr(cfg.fabric, "precision", None)
plugins = getattr(cfg.fabric, "plugins", None)
fabric = Fabric(
accelerator="auto",
devices=1,
num_nodes=1,
precision=precision,
plugins=plugins,
strategy="auto",
)

# Create Environment
env = make_env(cfg, 0, 0)()
observation_space = env.observation_space
is_multidiscrete = isinstance(env.action_space, gym.spaces.MultiDiscrete)
actions_dim = tuple(
env.action_space.nvec.tolist() if is_multidiscrete else [env.action_space.n]
)
cnn_keys = cfg.algo.cnn_keys.encoder

# Load the trained agent
state = fabric.load(checkpoint_path)
# You need to retrieve only the player
# Check for each algorithm what models the `build_agent()` function returns
# (placed in the `agent.py` file of the algorithm), and which arguments it needs.
# Check also which are the keys of the checkpoint: if the `build_agent()` parameter
# is called `model_state`, then you retrieve the model state with `state["model"]`.
agent = build_agent(
fabric=fabric,
actions_dim=actions_dim,
is_continuous=False,
cfg=cfg,
obs_space=observation_space,
agent_state=state["agent"],
)[-1]
agent.eval()

# Print policy network architecture
print("Policy architecture:")
print(agent)

obs, info = env.reset()

while True:
# Convert numpy observations into torch observations and normalize image observations
# Every algorithm has its own way to do it, you must import the correct method
torch_obs = prepare_obs(fabric, obs, cnn_keys=cnn_keys)

# Select actions, the agent returns a one-hot categorical or
# more one-hot categorical distributions for muli-discrete actions space
actions = agent.get_actions(torch_obs, greedy=True)
# Convert actions from one-hot categorical to categorial
actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1)

obs, _, terminated, truncated, info = env.step(
actions.cpu().numpy().reshape(env.action_space.shape)
)

if terminated or truncated:
obs, info = env.reset()
if info["env_done"] or test is True:
break

# Close the environment
env.close()

# Return success
return 0


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--cfg_path", type=str, required=True, help="Configuration file"
)
parser.add_argument(
"--checkpoint_path", type=str, default="model", help="Model checkpoint"
)
parser.add_argument("--test", action="store_true", help="Test mode")
opt = parser.parse_args()
print(opt)

main(opt.cfg_path, opt.checkpoint_path, opt.test)
5 changes: 2 additions & 3 deletions sheeprl/configs/env/custom_env.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,15 @@ wrapper:
# class to be instantiated
_target_: sheeprl.envs.diambra.DiambraWrapper
id: ${env.id}
action_space: diambra.arena.SpaceTypes.DISCRETE
# or `diambra.arena.SpaceTypes.MULTI_DISCRETE`
action_space: DISCRETE # or "MULTI_DISCRETE"
screen_size: ${env.screen_size}
grayscale: ${env.grayscale}
repeat_action: ${env.action_repeat}
rank: null
log_level: 0
increase_performance: True
diambra_settings:
role: diambra.arena.Roles.P1 # or `diambra.arena.Roles.P1` or `null`
role: P1 # or "P2" or null
step_ratio: 6
difficulty: 4
continue_game: 0.0
Expand Down
Binary file not shown.
Loading

0 comments on commit 81fbd70

Please sign in to comment.