diff --git a/sheeprl/agent-dreamer_v3.py b/sheeprl/agent-dreamer_v3.py new file mode 100644 index 0000000..b38e35c --- /dev/null +++ b/sheeprl/agent-dreamer_v3.py @@ -0,0 +1,123 @@ +import argparse +import json + +import gymnasium as gym +import torch +from lightning import Fabric +from omegaconf import OmegaConf +from sheeprl.algos.dreamer_v3.agent import build_agent +from sheeprl.algos.dreamer_v3.utils import prepare_obs +from sheeprl.utils.env import make_env +from sheeprl.utils.utils import dotdict + +"""This is an example agent based on SheepRL. + +Usage: +cd sheeprl +diambra run python agent-dreamer_v3.py --cfg_path "/absolute/path/to/example-logs/runs/dreamer_v3/doapp/experiment/version_0/config.yaml" --checkpoint_path "/absolute/path/to/example-logs/runs/dreamer_v3/doapp/experiment/version_0/checkpoint/ckpt_1024_0.ckpt" +""" + + +def main(cfg_path: str, checkpoint_path: str, test=False): + # Read the cfg file + cfg = dotdict(OmegaConf.to_container(OmegaConf.load(cfg_path), resolve=True)) + print("Config parameters = ", json.dumps(cfg, sort_keys=True, indent=4)) + + # Override configs for evaluation + # You do not need to capture the video since you are submitting the agent and the video is recorded by DIAMBRA + cfg.env.capture_video = False + # Only one environment is used for evaluation + cfg.env.num_envs = 1 + + # Instantiate Fabric + # You must use the same precision and plugins used for training. + precision = getattr(cfg.fabric, "precision", None) + plugins = getattr(cfg.fabric, "plugins", None) + fabric = Fabric( + accelerator="auto", + devices=1, + num_nodes=1, + precision=precision, + plugins=plugins, + strategy="auto", + ) + + # Create Environment + env = make_env(cfg, 0, 0)() + observation_space = env.observation_space + is_multidiscrete = isinstance(env.action_space, gym.spaces.MultiDiscrete) + actions_dim = tuple( + env.action_space.nvec.tolist() if is_multidiscrete else [env.action_space.n] + ) + cnn_keys = cfg.algo.cnn_keys.encoder + + # Load the trained agent + state = fabric.load(checkpoint_path) + # You need to retrieve only the player + # Check for each algorithm what models the `build_agent()` function returns + # (placed in the `agent.py` file of the algorithm), and which arguments it needs. + # Check also which are the keys of the checkpoint: if the `build_agent()` parameter + # is called `model_state`, then you retrieve the model state with `state["model"]`. + agent = build_agent( + fabric=fabric, + actions_dim=actions_dim, + is_continuous=False, + cfg=cfg, + obs_space=observation_space, + world_model_state=state["world_model"], + actor_state=state["actor"], + critic_state=state["critic"], + target_critic_state=state["target_critic"], + )[-1] + agent.eval() + + # Print policy network architecture + print("Policy architecture:") + print(agent) + + obs, info = env.reset() + # Every time you reset the environment, you must reset the initial states of the model + agent.init_states() + + while True: + # Convert numpy observations into torch observations and normalize image observations + # Every algorithm has its own way to do it, you must import the correct method + torch_obs = prepare_obs(fabric, obs, cnn_keys=cnn_keys) + + # Select actions, the agent returns a one-hot categorical or + # more one-hot categorical distributions for muli-discrete actions space + actions = agent.get_actions(torch_obs, greedy=False) + # Convert actions from one-hot categorical to categorial + actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1) + + obs, _, terminated, truncated, info = env.step( + actions.cpu().numpy().reshape(env.action_space.shape) + ) + + if terminated or truncated: + obs, info = env.reset() + # Every time you reset the environment, you must reset the initial states of the model + agent.init_states() + if info["env_done"] or test is True: + break + + # Close the environment + env.close() + + # Return success + return 0 + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--cfg_path", type=str, required=True, help="Configuration file" + ) + parser.add_argument( + "--checkpoint_path", type=str, default="model", help="Model checkpoint" + ) + parser.add_argument("--test", action="store_true", help="Test mode") + opt = parser.parse_args() + print(opt) + + main(opt.cfg_path, opt.checkpoint_path, opt.test) diff --git a/sheeprl/agent-ppo.py b/sheeprl/agent-ppo.py new file mode 100644 index 0000000..2e5088e --- /dev/null +++ b/sheeprl/agent-ppo.py @@ -0,0 +1,116 @@ +import argparse +import json + +import gymnasium as gym +import torch +from lightning import Fabric +from omegaconf import OmegaConf +from sheeprl.algos.ppo.agent import build_agent +from sheeprl.algos.ppo.utils import prepare_obs +from sheeprl.utils.env import make_env +from sheeprl.utils.utils import dotdict + +"""This is an example agent based on SheepRL. + +Usage: +cd sheeprl +diambra run python agent-ppo.py --cfg_path "/absolute/path/to/example-logs/runs/ppo/doapp/experiment/version_0/config.yaml" --checkpoint_path "/absolute/path/to/example-logs/runs/ppo/doapp/experiment/version_0/checkpoint/ckpt_1024_0.ckpt" +""" + + +def main(cfg_path: str, checkpoint_path: str, test=False): + # Read the cfg file + cfg = dotdict(OmegaConf.to_container(OmegaConf.load(cfg_path), resolve=True)) + print("Config parameters = ", json.dumps(cfg, sort_keys=True, indent=4)) + + # Override configs for evaluation + # You do not need to capture the video since you are submitting the agent and the video is recorded by DIAMBRA + cfg.env.capture_video = False + # Only one environment is used for evaluation + cfg.env.num_envs = 1 + + # Instantiate Fabric + # You must use the same precision and plugins used for training. + precision = getattr(cfg.fabric, "precision", None) + plugins = getattr(cfg.fabric, "plugins", None) + fabric = Fabric( + accelerator="auto", + devices=1, + num_nodes=1, + precision=precision, + plugins=plugins, + strategy="auto", + ) + + # Create Environment + env = make_env(cfg, 0, 0)() + observation_space = env.observation_space + is_multidiscrete = isinstance(env.action_space, gym.spaces.MultiDiscrete) + actions_dim = tuple( + env.action_space.nvec.tolist() if is_multidiscrete else [env.action_space.n] + ) + cnn_keys = cfg.algo.cnn_keys.encoder + + # Load the trained agent + state = fabric.load(checkpoint_path) + # You need to retrieve only the player + # Check for each algorithm what models the `build_agent()` function returns + # (placed in the `agent.py` file of the algorithm), and which arguments it needs. + # Check also which are the keys of the checkpoint: if the `build_agent()` parameter + # is called `model_state`, then you retrieve the model state with `state["model"]`. + agent = build_agent( + fabric=fabric, + actions_dim=actions_dim, + is_continuous=False, + cfg=cfg, + obs_space=observation_space, + agent_state=state["agent"], + )[-1] + agent.eval() + + # Print policy network architecture + print("Policy architecture:") + print(agent) + + obs, info = env.reset() + + while True: + # Convert numpy observations into torch observations and normalize image observations + # Every algorithm has its own way to do it, you must import the correct method + torch_obs = prepare_obs(fabric, obs, cnn_keys=cnn_keys) + + # Select actions, the agent returns a one-hot categorical or + # more one-hot categorical distributions for muli-discrete actions space + actions = agent.get_actions(torch_obs, greedy=True) + # Convert actions from one-hot categorical to categorial + actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1) + + obs, _, terminated, truncated, info = env.step( + actions.cpu().numpy().reshape(env.action_space.shape) + ) + + if terminated or truncated: + obs, info = env.reset() + if info["env_done"] or test is True: + break + + # Close the environment + env.close() + + # Return success + return 0 + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--cfg_path", type=str, required=True, help="Configuration file" + ) + parser.add_argument( + "--checkpoint_path", type=str, default="model", help="Model checkpoint" + ) + parser.add_argument("--test", action="store_true", help="Test mode") + opt = parser.parse_args() + print(opt) + + main(opt.cfg_path, opt.checkpoint_path, opt.test) diff --git a/sheeprl/configs/env/custom_env.yaml b/sheeprl/configs/env/custom_env.yaml index d5b2fde..1ac205e 100644 --- a/sheeprl/configs/env/custom_env.yaml +++ b/sheeprl/configs/env/custom_env.yaml @@ -23,8 +23,7 @@ wrapper: # class to be instantiated _target_: sheeprl.envs.diambra.DiambraWrapper id: ${env.id} - action_space: diambra.arena.SpaceTypes.DISCRETE - # or `diambra.arena.SpaceTypes.MULTI_DISCRETE` + action_space: DISCRETE # or "MULTI_DISCRETE" screen_size: ${env.screen_size} grayscale: ${env.grayscale} repeat_action: ${env.action_repeat} @@ -32,7 +31,7 @@ wrapper: log_level: 0 increase_performance: True diambra_settings: - role: diambra.arena.Roles.P1 # or `diambra.arena.Roles.P1` or `null` + role: P1 # or "P2" or null step_ratio: 6 difficulty: 4 continue_game: 0.0 diff --git a/sheeprl/example-logs/runs/dreamer_v3/doapp/experiment/version_0/checkpoint/ckpt_1024_0.ckpt b/sheeprl/example-logs/runs/dreamer_v3/doapp/experiment/version_0/checkpoint/ckpt_1024_0.ckpt new file mode 100644 index 0000000..452d227 Binary files /dev/null and b/sheeprl/example-logs/runs/dreamer_v3/doapp/experiment/version_0/checkpoint/ckpt_1024_0.ckpt differ diff --git a/sheeprl/example-logs/runs/dreamer_v3/doapp/experiment/version_0/config.yaml b/sheeprl/example-logs/runs/dreamer_v3/doapp/experiment/version_0/config.yaml new file mode 100644 index 0000000..9c8874c --- /dev/null +++ b/sheeprl/example-logs/runs/dreamer_v3/doapp/experiment/version_0/config.yaml @@ -0,0 +1,354 @@ +num_threads: 1 +float32_matmul_precision: high +dry_run: false +seed: 42 +torch_use_deterministic_algorithms: false +torch_backends_cudnn_benchmark: true +torch_backends_cudnn_deterministic: false +cublas_workspace_config: null +exp_name: dreamer_v3_doapp +run_name: 2024-04-16_17-34-17_dreamer_v3_doapp_42 +root_dir: dreamer_v3/doapp +algo: + name: dreamer_v3 + total_steps: 1024 + per_rank_batch_size: 2 + run_test: false + cnn_keys: + encoder: + - frame + decoder: + - frame + mlp_keys: + encoder: + - own_character + - own_health + - own_side + - own_wins + - opp_character + - opp_health + - opp_side + - opp_wins + - stage + - timer + - action + decoder: + - own_character + - own_health + - own_side + - own_wins + - opp_character + - opp_health + - opp_side + - opp_wins + - stage + - timer + - action + world_model: + optimizer: + _target_: torch.optim.Adam + lr: 0.0001 + eps: 1.0e-08 + weight_decay: 0 + betas: + - 0.9 + - 0.999 + discrete_size: 4 + stochastic_size: 4 + kl_dynamic: 0.5 + kl_representation: 0.1 + kl_free_nats: 1.0 + kl_regularizer: 1.0 + continue_scale_factor: 1.0 + clip_gradients: 1000.0 + decoupled_rssm: false + learnable_initial_recurrent_state: true + encoder: + cnn_channels_multiplier: 2 + cnn_act: torch.nn.SiLU + dense_act: torch.nn.SiLU + mlp_layers: 1 + cnn_layer_norm: + cls: sheeprl.models.models.LayerNormChannelLast + kw: + eps: 0.001 + mlp_layer_norm: + cls: sheeprl.models.models.LayerNorm + kw: + eps: 0.001 + dense_units: 8 + recurrent_model: + recurrent_state_size: 8 + layer_norm: + cls: sheeprl.models.models.LayerNorm + kw: + eps: 0.001 + dense_units: 8 + transition_model: + hidden_size: 8 + dense_act: torch.nn.SiLU + layer_norm: + cls: sheeprl.models.models.LayerNorm + kw: + eps: 0.001 + representation_model: + hidden_size: 8 + dense_act: torch.nn.SiLU + layer_norm: + cls: sheeprl.models.models.LayerNorm + kw: + eps: 0.001 + observation_model: + cnn_channels_multiplier: 2 + cnn_act: torch.nn.SiLU + dense_act: torch.nn.SiLU + mlp_layers: 1 + cnn_layer_norm: + cls: sheeprl.models.models.LayerNormChannelLast + kw: + eps: 0.001 + mlp_layer_norm: + cls: sheeprl.models.models.LayerNorm + kw: + eps: 0.001 + dense_units: 8 + reward_model: + dense_act: torch.nn.SiLU + mlp_layers: 1 + layer_norm: + cls: sheeprl.models.models.LayerNorm + kw: + eps: 0.001 + dense_units: 8 + bins: 255 + discount_model: + learnable: true + dense_act: torch.nn.SiLU + mlp_layers: 1 + layer_norm: + cls: sheeprl.models.models.LayerNorm + kw: + eps: 0.001 + dense_units: 8 + actor: + optimizer: + _target_: torch.optim.Adam + lr: 8.0e-05 + eps: 1.0e-05 + weight_decay: 0 + betas: + - 0.9 + - 0.999 + cls: sheeprl.algos.dreamer_v3.agent.Actor + ent_coef: 0.0003 + min_std: 0.1 + max_std: 1.0 + init_std: 2.0 + dense_act: torch.nn.SiLU + mlp_layers: 1 + layer_norm: + cls: sheeprl.models.models.LayerNorm + kw: + eps: 0.001 + dense_units: 8 + clip_gradients: 100.0 + unimix: 0.01 + action_clip: 1.0 + moments: + decay: 0.99 + max: 1.0 + percentile: + low: 0.05 + high: 0.95 + critic: + optimizer: + _target_: torch.optim.Adam + lr: 8.0e-05 + eps: 1.0e-05 + weight_decay: 0 + betas: + - 0.9 + - 0.999 + dense_act: torch.nn.SiLU + mlp_layers: 1 + layer_norm: + cls: sheeprl.models.models.LayerNorm + kw: + eps: 0.001 + dense_units: 8 + per_rank_target_network_update_freq: 1 + tau: 0.02 + bins: 255 + clip_gradients: 100.0 + gamma: 0.996996996996997 + lmbda: 0.95 + horizon: 15 + replay_ratio: 0.0625 + learning_starts: 1024 + per_rank_pretrain_steps: 0 + per_rank_sequence_length: 64 + cnn_layer_norm: + cls: sheeprl.models.models.LayerNormChannelLast + kw: + eps: 0.001 + mlp_layer_norm: + cls: sheeprl.models.models.LayerNorm + kw: + eps: 0.001 + dense_units: 8 + mlp_layers: 1 + dense_act: torch.nn.SiLU + cnn_act: torch.nn.SiLU + unimix: 0.01 + hafner_initialization: true + player: + discrete_size: 4 +buffer: + size: 1024 + memmap: true + validate_args: false + from_numpy: false + checkpoint: true +checkpoint: + every: 10000 + resume_from: null + save_last: true + keep_last: 5 +distribution: + validate_args: false + type: auto +env: + id: doapp + num_envs: 1 + frame_stack: -1 + sync_env: true + screen_size: 64 + action_repeat: 1 + grayscale: false + clip_rewards: false + capture_video: true + frame_stack_dilation: 1 + max_episode_steps: null + reward_as_observation: false + wrapper: + _target_: sheeprl.envs.diambra.DiambraWrapper + id: doapp + action_space: DISCRETE + screen_size: 64 + grayscale: false + repeat_action: 1 + rank: null + log_level: 0 + increase_performance: true + diambra_settings: + role: P1 + step_ratio: 6 + difficulty: 4 + continue_game: 0.0 + show_final: false + outfits: 2 + splash_screen: false + diambra_wrappers: + stack_actions: 1 + no_op_max: 0 + no_attack_buttons_combinations: false + add_last_action: true + scale: false + exclude_image_scaling: false + process_discrete_binary: false + role_relative: true +fabric: + _target_: lightning.fabric.Fabric + devices: 1 + num_nodes: 1 + strategy: auto + accelerator: cpu + precision: 32-true + callbacks: + - _target_: sheeprl.utils.callback.CheckpointCallback + keep_last: 5 +metric: + log_every: 5000 + disable_timer: false + log_level: 1 + sync_on_compute: false + aggregator: + _target_: sheeprl.utils.metric.MetricAggregator + raise_on_missing: false + metrics: + Rewards/rew_avg: + _target_: torchmetrics.MeanMetric + sync_on_compute: false + Game/ep_len_avg: + _target_: torchmetrics.MeanMetric + sync_on_compute: false + Loss/world_model_loss: + _target_: torchmetrics.MeanMetric + sync_on_compute: false + Loss/value_loss: + _target_: torchmetrics.MeanMetric + sync_on_compute: false + Loss/policy_loss: + _target_: torchmetrics.MeanMetric + sync_on_compute: false + Loss/observation_loss: + _target_: torchmetrics.MeanMetric + sync_on_compute: false + Loss/reward_loss: + _target_: torchmetrics.MeanMetric + sync_on_compute: false + Loss/state_loss: + _target_: torchmetrics.MeanMetric + sync_on_compute: false + Loss/continue_loss: + _target_: torchmetrics.MeanMetric + sync_on_compute: false + State/kl: + _target_: torchmetrics.MeanMetric + sync_on_compute: false + State/post_entropy: + _target_: torchmetrics.MeanMetric + sync_on_compute: false + State/prior_entropy: + _target_: torchmetrics.MeanMetric + sync_on_compute: false + Grads/world_model: + _target_: torchmetrics.MeanMetric + sync_on_compute: false + Grads/actor: + _target_: torchmetrics.MeanMetric + sync_on_compute: false + Grads/critic: + _target_: torchmetrics.MeanMetric + sync_on_compute: false + logger: + _target_: lightning.fabric.loggers.TensorBoardLogger + name: 2024-04-16_17-34-17_dreamer_v3_doapp_42 + root_dir: logs/runs/dreamer_v3/doapp + version: null + default_hp_metric: true + prefix: '' + sub_dir: null +model_manager: + disabled: true + models: + world_model: + model_name: dreamer_v3_doapp_world_model + description: DreamerV3 World Model used in doapp Environment + tags: {} + actor: + model_name: dreamer_v3_doapp_actor + description: DreamerV3 Actor used in doapp Environment + tags: {} + critic: + model_name: dreamer_v3_doapp_critic + description: DreamerV3 Critic used in doapp Environment + tags: {} + target_critic: + model_name: dreamer_v3_doapp_target_critic + description: DreamerV3 Target Critic used in doapp Environment + tags: {} + moments: + model_name: dreamer_v3_doapp_moments + description: DreamerV3 Moments used in doapp Environment + tags: {} diff --git a/sheeprl/example-logs/runs/ppo/doapp/experiment/version_0/checkpoint/ckpt_1024_0.ckpt b/sheeprl/example-logs/runs/ppo/doapp/experiment/version_0/checkpoint/ckpt_1024_0.ckpt new file mode 100644 index 0000000..ad9032f Binary files /dev/null and b/sheeprl/example-logs/runs/ppo/doapp/experiment/version_0/checkpoint/ckpt_1024_0.ckpt differ diff --git a/sheeprl/example-logs/runs/ppo/doapp/experiment/version_0/config.yaml b/sheeprl/example-logs/runs/ppo/doapp/experiment/version_0/config.yaml new file mode 100644 index 0000000..fc8784f --- /dev/null +++ b/sheeprl/example-logs/runs/ppo/doapp/experiment/version_0/config.yaml @@ -0,0 +1,164 @@ +num_threads: 1 +float32_matmul_precision: high +dry_run: false +seed: 42 +torch_use_deterministic_algorithms: false +torch_backends_cudnn_benchmark: true +torch_backends_cudnn_deterministic: false +cublas_workspace_config: null +exp_name: ppo_doapp +run_name: 2024-04-15_15-25-55_ppo_doapp_42 +root_dir: ppo/doapp +algo: + name: ppo + total_steps: 1024 + per_rank_batch_size: 16 + run_test: true + cnn_keys: + encoder: + - frame + mlp_keys: + encoder: + - own_character + - own_health + - own_side + - own_wins + - opp_character + - opp_health + - opp_side + - opp_wins + - stage + - timer + - action + optimizer: + _target_: torch.optim.Adam + lr: 0.005 + eps: 1.0e-06 + weight_decay: 0 + betas: + - 0.9 + - 0.999 + anneal_lr: false + gamma: 0.99 + gae_lambda: 0.95 + update_epochs: 1 + loss_reduction: mean + normalize_advantages: true + clip_coef: 0.2 + anneal_clip_coef: false + clip_vloss: false + ent_coef: 0.0 + anneal_ent_coef: false + vf_coef: 1.0 + rollout_steps: 32 + dense_units: 16 + mlp_layers: 1 + dense_act: torch.nn.Tanh + layer_norm: false + max_grad_norm: 1.0 + encoder: + cnn_features_dim: 128 + mlp_features_dim: 32 + dense_units: 16 + mlp_layers: 1 + dense_act: torch.nn.Tanh + layer_norm: false + actor: + dense_units: 16 + mlp_layers: 1 + dense_act: torch.nn.Tanh + layer_norm: false + critic: + dense_units: 16 + mlp_layers: 1 + dense_act: torch.nn.Tanh + layer_norm: false +buffer: + size: 32 + memmap: true + validate_args: false + from_numpy: false + share_data: false +checkpoint: + every: 100 + resume_from: null + save_last: true + keep_last: 5 +distribution: + validate_args: false +env: + id: doapp + num_envs: 1 + frame_stack: 1 + sync_env: true + screen_size: 64 + action_repeat: 1 + grayscale: false + clip_rewards: false + capture_video: true + frame_stack_dilation: 1 + max_episode_steps: null + reward_as_observation: false + wrapper: + _target_: sheeprl.envs.diambra.DiambraWrapper + id: doapp + action_space: DISCRETE + screen_size: 64 + grayscale: false + repeat_action: 1 + rank: null + log_level: 0 + increase_performance: true + diambra_settings: + role: P1 + step_ratio: 6 + difficulty: 4 + continue_game: 0.0 + show_final: false + outfits: 2 + splash_screen: false + diambra_wrappers: + stack_actions: 1 + no_op_max: 0 + no_attack_buttons_combinations: false + add_last_action: true + scale: false + exclude_image_scaling: false + process_discrete_binary: false + role_relative: true +fabric: + _target_: lightning.fabric.Fabric + devices: 1 + num_nodes: 1 + strategy: auto + accelerator: cpu + precision: 32-true + callbacks: + - _target_: sheeprl.utils.callback.CheckpointCallback + keep_last: 5 +metric: + log_every: 5000 + disable_timer: false + log_level: 1 + sync_on_compute: false + aggregator: + _target_: sheeprl.utils.metric.MetricAggregator + raise_on_missing: false + metrics: + Rewards/rew_avg: + _target_: torchmetrics.MeanMetric + sync_on_compute: false + Game/ep_len_avg: + _target_: torchmetrics.MeanMetric + sync_on_compute: false + logger: + _target_: lightning.fabric.loggers.TensorBoardLogger + name: 2024-04-15_15-25-55_ppo_doapp_42 + root_dir: logs/runs/ppo/doapp + version: null + default_hp_metric: true + prefix: '' + sub_dir: null +model_manager: + disabled: true + models: {} diff --git a/tests/test_sheeprl.py b/tests/test_sheeprl.py index ace819c..e6bdab4 100644 --- a/tests/test_sheeprl.py +++ b/tests/test_sheeprl.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +import importlib import os import shutil import sys @@ -24,7 +25,17 @@ ] -def func(mocker, n_envs, args, evaluation=False, root_dir=None, run_name=None): +def _test_agent(mocker, agent, kwargs): + load_mocker(mocker) + + agent = importlib.import_module(f"agent-{agent}") + os.environ["DIAMBRA_ENVS"] = "127.0.0.1:32781" + return agent.main(**kwargs) + + +def _test_train_eval( + mocker, n_envs, args, evaluation=False, root_dir=None, run_name=None +): load_mocker(mocker) try: @@ -36,9 +47,9 @@ def func(mocker, n_envs, args, evaluation=False, root_dir=None, run_name=None): os.environ["DIAMBRA_ENVS"] = envs # SheepRL config folder setup - os.environ[ - "SHEEPRL_SEARCH_PATH" - ] = "file://sheeprl/configs;pkg://sheeprl.configs" + os.environ["SHEEPRL_SEARCH_PATH"] = ( + "file://sheeprl/configs;pkg://sheeprl.configs" + ) # Execution of the train script with mock.patch.object(sys, "argv", STANDARD_ARGS + args): @@ -70,7 +81,7 @@ def func(mocker, n_envs, args, evaluation=False, root_dir=None, run_name=None): def test_sheeprl_train_base(mocker): assert ( - func( + _test_train_eval( mocker, 2, ["exp=custom_exp", "checkpoint.save_last=False"], @@ -81,7 +92,7 @@ def test_sheeprl_train_base(mocker): def test_sheeprl_train_parallel_envs(mocker): assert ( - func( + _test_train_eval( mocker, 6, ["exp=custom_parallel_env_exp", "checkpoint.save_last=False"], @@ -92,7 +103,7 @@ def test_sheeprl_train_parallel_envs(mocker): def test_sheeprl_train_fabric(mocker): assert ( - func( + _test_train_eval( mocker, 2, [ @@ -108,7 +119,7 @@ def test_sheeprl_train_fabric(mocker): def test_sheeprl_train_metrics(mocker): assert ( - func( + _test_train_eval( mocker, 2, [ @@ -124,7 +135,7 @@ def test_sheeprl_train_metrics(mocker): def test_sheeprl_evaluation(mocker): assert ( - func( + _test_train_eval( mocker, 3, [ @@ -139,3 +150,40 @@ def test_sheeprl_evaluation(mocker): ) == 0 ) + + +def test_sheeprl_ppo_agent(mocker): + cfg_path = os.path.join( + ROOT_DIR, "example-logs/runs/ppo/doapp/experiment/version_0/config.yaml" + ) + checkpoint_path = os.path.join( + ROOT_DIR, + "example-logs/runs/ppo/doapp/experiment/version_0/checkpoint/ckpt_1024_0.ckpt", + ) + assert ( + _test_agent( + mocker, + "ppo", + {"cfg_path": cfg_path, "checkpoint_path": checkpoint_path, "test": True}, + ) + == 0 + ) + + +def test_sheeprl_dreamer_v3_agent(mocker): + cfg_path = os.path.join( + ROOT_DIR, + "example-logs/runs/dreamer_v3/doapp/experiment/version_0/config.yaml", + ) + checkpoint_path = os.path.join( + ROOT_DIR, + "example-logs/runs/dreamer_v3/doapp/experiment/version_0/checkpoint/ckpt_1024_0.ckpt", + ) + assert ( + _test_agent( + mocker, + "dreamer_v3", + {"cfg_path": cfg_path, "checkpoint_path": checkpoint_path, "test": True}, + ) + == 0 + )