Skip to content

Commit

Permalink
fix issues in predict and train
Browse files Browse the repository at this point in the history
  • Loading branch information
thibaultdvx committed Jun 4, 2024
1 parent 4efbcac commit e8ddf96
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 8 deletions.
2 changes: 1 addition & 1 deletion tests/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,6 @@ def test_predict(cmdopt, tmp_path, test_name):

assert compare_folders(
tmp_out_dir / maps_name,
ref_dir / maps_name,
input_dir / maps_name,
tmp_out_dir,
)
13 changes: 6 additions & 7 deletions tests/test_train_ae.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@ def test_train_ae(cmdopt, tmp_path, test_name):

labels_path = str(input_dir / "labels_list" / "2_fold")
config_path = str(input_dir / "train_config.toml")
split = 0

if test_name == "image_ae":
split = [0, 0]
split = 1
test_input = [
"train",
"reconstruction",
Expand All @@ -45,10 +47,9 @@ def test_train_ae(cmdopt, tmp_path, test_name):
"-c",
config_path,
"--split",
"1",
str(split),
]
elif test_name == "patch_multi_ae":
split = [0, 0]
test_input = [
"train",
"reconstruction",
Expand All @@ -61,7 +62,6 @@ def test_train_ae(cmdopt, tmp_path, test_name):
"--multi_network",
]
elif test_name == "roi_ae":
split = [0, 0]
test_input = [
"train",
"reconstruction",
Expand All @@ -73,7 +73,6 @@ def test_train_ae(cmdopt, tmp_path, test_name):
config_path,
]
elif test_name == "slice_ae":
split = [0, 0]
test_input = [
"train",
"reconstruction",
Expand Down Expand Up @@ -116,7 +115,7 @@ def test_train_ae(cmdopt, tmp_path, test_name):
tmp_path,
)
assert compare_folders(
tmp_out_dir / f"split-{split[0]}" / "best-loss",
ref_dir / ("maps_" + test_name) / f"split-{split[1]}" / "best-loss",
tmp_out_dir / f"split-{split}" / "best-loss",
ref_dir / ("maps_" + test_name) / f"split-{split}" / "best-loss",
tmp_path,
)

0 comments on commit e8ddf96

Please sign in to comment.