From 7a59036cc737ea5f88b1d40651d2177f5e9b1a4f Mon Sep 17 00:00:00 2001 From: thibaultdvx <154365476+thibaultdvx@users.noreply.github.com> Date: Mon, 3 Jun 2024 11:12:36 +0200 Subject: [PATCH] correction in predict test --- tests/test_predict.py | 69 +++++++++++++++++++++++++++++-------------- 1 file changed, 47 insertions(+), 22 deletions(-) diff --git a/tests/test_predict.py b/tests/test_predict.py index 34427eeeb..a25b92f43 100644 --- a/tests/test_predict.py +++ b/tests/test_predict.py @@ -1,6 +1,5 @@ # coding: utf8 import json -import os import shutil from os.path import exists from pathlib import Path @@ -8,7 +7,8 @@ import pytest from clinicadl import MapsManager -from tests.testing_tools import clean_folder, compare_folders + +from .testing_tools import compare_folders, modify_maps @pytest.fixture( @@ -33,46 +33,71 @@ def test_predict(cmdopt, tmp_path, test_name): tmp_out_dir.mkdir(parents=True) if test_name == "predict_image_classification": - model_folder = input_dir / "maps_image_cnn" + maps_name = "maps_image_cnn" modes = ["image"] use_labels = True elif test_name == "predict_slice_classification": - model_folder = input_dir / "maps_slice_cnn" + maps_name = "maps_slice_cnn" modes = ["image", "slice"] use_labels = True elif test_name == "predict_patch_regression": - model_folder = input_dir / "maps_patch_cnn" + maps_name = "maps_patch_cnn" modes = ["image", "patch"] use_labels = False elif test_name == "predict_roi_regression": - model_folder = input_dir / "maps_roi_cnn" + maps_name = "maps_roi_cnn" modes = ["image", "roi"] use_labels = False elif test_name == "predict_patch_multi_classification": - model_folder = input_dir / "maps_patch_multi_cnn" + maps_name = "maps_patch_multi_cnn" modes = ["image", "patch"] use_labels = False elif test_name == "predict_roi_reconstruction": - model_folder = input_dir / "maps_roi_ae" + maps_name = "maps_roi_ae" modes = ["roi"] use_labels = False else: raise NotImplementedError(f"Test {test_name} is not implemented.") - out_dir = str(model_folder / "split-0/best-loss/test-RANDOM") + shutil.copytree(input_dir / maps_name, tmp_out_dir / maps_name) + model_folder = tmp_out_dir / maps_name + + if cmdopt["adapt-base-dir"]: + with open(model_folder / "maps.json", "r") as f: + config = json.load(f) + config = modify_maps( + maps=config, + base_dir=base_dir, + no_gpu=cmdopt["no-gpu"], + adapt_base_dir=cmdopt["adapt-base-dir"], + ) + with open(model_folder / "maps.json", "w") as f: + json.dump(config, f, skipkeys=True, indent=4) + + with open(model_folder / "groups/test-RANDOM/maps.json", "r") as f: + config = json.load(f) + config = modify_maps( + maps=config, + base_dir=base_dir, + no_gpu=False, + adapt_base_dir=cmdopt["adapt-base-dir"], + ) + with open(model_folder / "groups/test-RANDOM/maps.json", "w") as f: + json.dump(config, f, skipkeys=True, indent=4) - if exists(out_dir): - shutil.rmtree(out_dir) + tmp_out_subdir = str(model_folder / "split-0/best-loss/test-RANDOM") + if exists(tmp_out_subdir): + shutil.rmtree(tmp_out_subdir) - # Correction of JSON file for ROI - if "roi" in modes: - json_path = model_folder / "maps.json" - with open(json_path, "r") as f: - parameters = json.load(f) - parameters["roi_list"] = ["leftHippocampusBox", "rightHippocampusBox"] - json_data = json.dumps(parameters, skipkeys=True, indent=4) - with open(json_path, "w") as f: - f.write(json_data) + # # Correction of JSON file for ROI + # if "roi" in modes: + # json_path = model_folder / "maps.json" + # with open(json_path, "r") as f: + # parameters = json.load(f) + # parameters["roi_list"] = ["leftHippocampusBox", "rightHippocampusBox"] + # json_data = json.dumps(parameters, skipkeys=True, indent=4) + # with open(json_path, "w") as f: + # f.write(json_data) maps_manager = MapsManager(model_folder, verbose="debug") maps_manager.predict( @@ -91,7 +116,7 @@ def test_predict(cmdopt, tmp_path, test_name): maps_manager.get_metrics(data_group="test-RANDOM", mode=mode) assert compare_folders( - tmp_out_dir / test_name, - ref_dir / test_name, + tmp_out_dir / maps_name, + ref_dir / maps_name, tmp_out_dir, )