From e8ddf96630c530d5d31cbd60c92686e336ff7344 Mon Sep 17 00:00:00 2001 From: thibaultdvx <154365476+thibaultdvx@users.noreply.github.com> Date: Tue, 4 Jun 2024 09:40:55 +0200 Subject: [PATCH] fix issues in predict and train --- tests/test_predict.py | 2 +- tests/test_train_ae.py | 13 ++++++------- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/tests/test_predict.py b/tests/test_predict.py index a25b92f43..c6b6a39fa 100644 --- a/tests/test_predict.py +++ b/tests/test_predict.py @@ -117,6 +117,6 @@ def test_predict(cmdopt, tmp_path, test_name): assert compare_folders( tmp_out_dir / maps_name, - ref_dir / maps_name, + input_dir / maps_name, tmp_out_dir, ) diff --git a/tests/test_train_ae.py b/tests/test_train_ae.py index b20749258..c7fbcb276 100644 --- a/tests/test_train_ae.py +++ b/tests/test_train_ae.py @@ -33,8 +33,10 @@ def test_train_ae(cmdopt, tmp_path, test_name): labels_path = str(input_dir / "labels_list" / "2_fold") config_path = str(input_dir / "train_config.toml") + split = 0 + if test_name == "image_ae": - split = [0, 0] + split = 1 test_input = [ "train", "reconstruction", @@ -45,10 +47,9 @@ def test_train_ae(cmdopt, tmp_path, test_name): "-c", config_path, "--split", - "1", + str(split), ] elif test_name == "patch_multi_ae": - split = [0, 0] test_input = [ "train", "reconstruction", @@ -61,7 +62,6 @@ def test_train_ae(cmdopt, tmp_path, test_name): "--multi_network", ] elif test_name == "roi_ae": - split = [0, 0] test_input = [ "train", "reconstruction", @@ -73,7 +73,6 @@ def test_train_ae(cmdopt, tmp_path, test_name): config_path, ] elif test_name == "slice_ae": - split = [0, 0] test_input = [ "train", "reconstruction", @@ -116,7 +115,7 @@ def test_train_ae(cmdopt, tmp_path, test_name): tmp_path, ) assert compare_folders( - tmp_out_dir / f"split-{split[0]}" / "best-loss", - ref_dir / ("maps_" + test_name) / f"split-{split[1]}" / "best-loss", + tmp_out_dir / f"split-{split}" / "best-loss", + ref_dir / ("maps_" + test_name) / f"split-{split}" / "best-loss", tmp_path, )