From 099920a0f0aa8eeeb56da3599c19c9c55610d14d Mon Sep 17 00:00:00 2001 From: Michele Milesi Date: Mon, 22 Apr 2024 19:15:37 +0200 Subject: [PATCH] feat: update tests --- tests/test_sheeprl.py | 32 +++++++++++++++++++------------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/tests/test_sheeprl.py b/tests/test_sheeprl.py index eb4c07b..e6bdab4 100644 --- a/tests/test_sheeprl.py +++ b/tests/test_sheeprl.py @@ -30,7 +30,7 @@ def _test_agent(mocker, agent, kwargs): agent = importlib.import_module(f"agent-{agent}") os.environ["DIAMBRA_ENVS"] = "127.0.0.1:32781" - agent.main(**kwargs) + return agent.main(**kwargs) def _test_train_eval( @@ -154,30 +154,36 @@ def test_sheeprl_evaluation(mocker): def test_sheeprl_ppo_agent(mocker): cfg_path = os.path.join( - ROOT_DIR, "/fake-logs/runs/ppo/doapp/fake-experiment/version_0/config.yaml" + ROOT_DIR, "example-logs/runs/ppo/doapp/experiment/version_0/config.yaml" ) checkpoint_path = os.path.join( ROOT_DIR, - "/fake-logs/runs/ppo/doapp/fake-experiment/version_0/checkpoint/ckpt_1024_0.ckpt", + "example-logs/runs/ppo/doapp/experiment/version_0/checkpoint/ckpt_1024_0.ckpt", ) - assert _test_agent( - mocker, - "ppo", - {"cfg_path": cfg_path, "checkpoint_path": checkpoint_path, "test": True}, + assert ( + _test_agent( + mocker, + "ppo", + {"cfg_path": cfg_path, "checkpoint_path": checkpoint_path, "test": True}, + ) + == 0 ) def test_sheeprl_dreamer_v3_agent(mocker): cfg_path = os.path.join( ROOT_DIR, - "/fake-logs/runs/dreamer_v3/doapp/fake-experiment/version_0/config.yaml", + "example-logs/runs/dreamer_v3/doapp/experiment/version_0/config.yaml", ) checkpoint_path = os.path.join( ROOT_DIR, - "/fake-logs/runs/dreamer_v3/doapp/fake-experiment/version_0/checkpoint/ckpt_1024_0.ckpt", + "example-logs/runs/dreamer_v3/doapp/experiment/version_0/checkpoint/ckpt_1024_0.ckpt", ) - assert _test_agent( - mocker, - "dreamer_v3", - {"cfg_path": cfg_path, "checkpoint_path": checkpoint_path, "test": True}, + assert ( + _test_agent( + mocker, + "dreamer_v3", + {"cfg_path": cfg_path, "checkpoint_path": checkpoint_path, "test": True}, + ) + == 0 )