diff --git a/.github/workflows/run_tests_nnunet.yml b/.github/workflows/run_tests_nnunet.yml index 37511e5c5..7926d9a80 100644 --- a/.github/workflows/run_tests_nnunet.yml +++ b/.github/workflows/run_tests_nnunet.yml @@ -29,8 +29,6 @@ jobs: unzip -o $HOME/github_actions_nnunet/results/tmp_download_file.zip -d $HOME/github_actions_nnunet/results rm $HOME/github_actions_nnunet/results/tmp_download_file.zip - python -c "import urllib.request; urllib.request.urlretrieve('https://github.com/wasserth/TotalSegmentator/raw/refs/tags/v2.4.0/tests/reference_files/example_ct_sm.nii.gz', '$HOME/github_actions_nnunet/example_ct_sm.nii.gz')" - - name: Install dependencies on Ubuntu if: runner.os == 'Linux' run: | diff --git a/.gitignore b/.gitignore index ee76df517..acceb2804 100644 --- a/.gitignore +++ b/.gitignore @@ -99,7 +99,7 @@ ENV/ *.txt .idea/* *.png -*.nii.gz +# *.nii.gz # nifti files needed for example_data for github actions tests *.nii *.tif *.bmp diff --git a/nnunetv2/tests/example_data/example_ct_sm.nii.gz b/nnunetv2/tests/example_data/example_ct_sm.nii.gz new file mode 100644 index 000000000..ec5668e1e Binary files /dev/null and b/nnunetv2/tests/example_data/example_ct_sm.nii.gz differ diff --git a/nnunetv2/tests/example_data/example_ct_sm_T300_output.nii.gz b/nnunetv2/tests/example_data/example_ct_sm_T300_output.nii.gz new file mode 100644 index 000000000..22973a2b0 Binary files /dev/null and b/nnunetv2/tests/example_data/example_ct_sm_T300_output.nii.gz differ diff --git a/nnunetv2/tests/integration_tests/run_nnunet_inference.py b/nnunetv2/tests/integration_tests/run_nnunet_inference.py index 0a1de0882..116d0490d 100755 --- a/nnunetv2/tests/integration_tests/run_nnunet_inference.py +++ b/nnunetv2/tests/integration_tests/run_nnunet_inference.py @@ -3,29 +3,42 @@ import subprocess from pathlib import Path -""" -To run these tests do -python tests/tests_nnunet.py -""" -def run_tests_and_exit_on_failure(): +import nibabel as nib +import numpy as np + + +def dice_score(y_true, y_pred): + intersect = np.sum(y_true * y_pred) + denominator = np.sum(y_true) + np.sum(y_pred) + f1 = (2 * intersect) / (denominator + 1e-6) + return f1 + +def run_tests_and_exit_on_failure(): + """ + Runs inference of a simple nnU-Net for CT body segmentation on a small example CT image + and checks if the output is correct. + """ # Set nnUNet_results env var weights_dir = Path.home() / "github_actions_nnunet" / "results" os.environ["nnUNet_results"] = str(weights_dir) - print(f"Using weights directory: {weights_dir}") # Copy example file - os.makedirs("tests/nnunet_input_files", exist_ok=True) - shutil.copy(Path.home() / "github_actions_nnunet" / "example_ct_sm.nii.gz", "tests/nnunet_input_files/example_ct_sm_0000.nii.gz") + os.makedirs("nnunetv2/tests/github_actions_output", exist_ok=True) + shutil.copy("nnunetv2/tests/example_data/example_ct_sm.nii.gz", "nnunetv2/tests/github_actions_output/example_ct_sm_0000.nii.gz") # Run nnunet - subprocess.call(f"nnUNetv2_predict -i tests/nnunet_input_files -o tests/nnunet_input_files -d 300 -tr nnUNetTrainer -c 3d_fullres -f 0 -device cpu", shell=True) + subprocess.call(f"nnUNetv2_predict -i nnunetv2/tests/github_actions_output -o nnunetv2/tests/github_actions_output -d 300 -tr nnUNetTrainer -c 3d_fullres -f 0 -device cpu", shell=True) - # Check if output file exists - assert os.path.exists("tests/nnunet_input_files/example_ct_sm.nii.gz"), "A nnunet output file was not generated." + # Check if the nnunet segmentation is correct + img_gt = nib.load(f"nnunetv2/tests/example_data/example_ct_sm_T300_output.nii.gz").get_fdata() + img_pred = nib.load(f"nnunetv2/tests/github_actions_output/example_ct_sm.nii.gz").get_fdata() + dice = dice_score(img_gt, img_pred) + images_equal = dice > 0.99 + assert images_equal, f"The nnunet segmentation is not correct (dice: {dice:.5f})." # Clean up - shutil.rmtree("tests/nnunet_input_files") + shutil.rmtree("nnunetv2/tests/github_actions_output") shutil.rmtree(Path.home() / "github_actions_nnunet")