Skip to content

Commit

Permalink
correction in predict test
Browse files Browse the repository at this point in the history
  • Loading branch information
thibaultdvx committed Jun 3, 2024
1 parent c38875c commit 7a59036
Showing 1 changed file with 47 additions and 22 deletions.
69 changes: 47 additions & 22 deletions tests/test_predict.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
# coding: utf8
import json
import os
import shutil
from os.path import exists
from pathlib import Path

import pytest

from clinicadl import MapsManager
from tests.testing_tools import clean_folder, compare_folders

from .testing_tools import compare_folders, modify_maps


@pytest.fixture(
Expand All @@ -33,46 +33,71 @@ def test_predict(cmdopt, tmp_path, test_name):
tmp_out_dir.mkdir(parents=True)

if test_name == "predict_image_classification":
model_folder = input_dir / "maps_image_cnn"
maps_name = "maps_image_cnn"
modes = ["image"]
use_labels = True
elif test_name == "predict_slice_classification":
model_folder = input_dir / "maps_slice_cnn"
maps_name = "maps_slice_cnn"
modes = ["image", "slice"]
use_labels = True
elif test_name == "predict_patch_regression":
model_folder = input_dir / "maps_patch_cnn"
maps_name = "maps_patch_cnn"
modes = ["image", "patch"]
use_labels = False
elif test_name == "predict_roi_regression":
model_folder = input_dir / "maps_roi_cnn"
maps_name = "maps_roi_cnn"
modes = ["image", "roi"]
use_labels = False
elif test_name == "predict_patch_multi_classification":
model_folder = input_dir / "maps_patch_multi_cnn"
maps_name = "maps_patch_multi_cnn"
modes = ["image", "patch"]
use_labels = False
elif test_name == "predict_roi_reconstruction":
model_folder = input_dir / "maps_roi_ae"
maps_name = "maps_roi_ae"
modes = ["roi"]
use_labels = False
else:
raise NotImplementedError(f"Test {test_name} is not implemented.")

out_dir = str(model_folder / "split-0/best-loss/test-RANDOM")
shutil.copytree(input_dir / maps_name, tmp_out_dir / maps_name)
model_folder = tmp_out_dir / maps_name

if cmdopt["adapt-base-dir"]:
with open(model_folder / "maps.json", "r") as f:
config = json.load(f)
config = modify_maps(
maps=config,
base_dir=base_dir,
no_gpu=cmdopt["no-gpu"],
adapt_base_dir=cmdopt["adapt-base-dir"],
)
with open(model_folder / "maps.json", "w") as f:
json.dump(config, f, skipkeys=True, indent=4)

with open(model_folder / "groups/test-RANDOM/maps.json", "r") as f:
config = json.load(f)
config = modify_maps(
maps=config,
base_dir=base_dir,
no_gpu=False,
adapt_base_dir=cmdopt["adapt-base-dir"],
)
with open(model_folder / "groups/test-RANDOM/maps.json", "w") as f:
json.dump(config, f, skipkeys=True, indent=4)

if exists(out_dir):
shutil.rmtree(out_dir)
tmp_out_subdir = str(model_folder / "split-0/best-loss/test-RANDOM")
if exists(tmp_out_subdir):
shutil.rmtree(tmp_out_subdir)

# Correction of JSON file for ROI
if "roi" in modes:
json_path = model_folder / "maps.json"
with open(json_path, "r") as f:
parameters = json.load(f)
parameters["roi_list"] = ["leftHippocampusBox", "rightHippocampusBox"]
json_data = json.dumps(parameters, skipkeys=True, indent=4)
with open(json_path, "w") as f:
f.write(json_data)
# # Correction of JSON file for ROI
# if "roi" in modes:
# json_path = model_folder / "maps.json"
# with open(json_path, "r") as f:
# parameters = json.load(f)
# parameters["roi_list"] = ["leftHippocampusBox", "rightHippocampusBox"]
# json_data = json.dumps(parameters, skipkeys=True, indent=4)
# with open(json_path, "w") as f:
# f.write(json_data)

maps_manager = MapsManager(model_folder, verbose="debug")
maps_manager.predict(
Expand All @@ -91,7 +116,7 @@ def test_predict(cmdopt, tmp_path, test_name):
maps_manager.get_metrics(data_group="test-RANDOM", mode=mode)

assert compare_folders(
tmp_out_dir / test_name,
ref_dir / test_name,
tmp_out_dir / maps_name,
ref_dir / maps_name,
tmp_out_dir,
)

0 comments on commit 7a59036

Please sign in to comment.