Skip to content

Commit

Permalink
add inference test - add reference groundtruth and calc dice score
Browse files Browse the repository at this point in the history
  • Loading branch information
wasserth committed Dec 17, 2024
1 parent 6fd717f commit 507abd2
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 15 deletions.
2 changes: 0 additions & 2 deletions .github/workflows/run_tests_nnunet.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ jobs:
unzip -o $HOME/github_actions_nnunet/results/tmp_download_file.zip -d $HOME/github_actions_nnunet/results
rm $HOME/github_actions_nnunet/results/tmp_download_file.zip
python -c "import urllib.request; urllib.request.urlretrieve('https://github.com/wasserth/TotalSegmentator/raw/refs/tags/v2.4.0/tests/reference_files/example_ct_sm.nii.gz', '$HOME/github_actions_nnunet/example_ct_sm.nii.gz')"
- name: Install dependencies on Ubuntu
if: runner.os == 'Linux'
run: |
Expand Down
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ ENV/
*.txt
.idea/*
*.png
*.nii.gz
# *.nii.gz # nifti files needed for example_data for github actions tests
*.nii
*.tif
*.bmp
Expand Down
Binary file added nnunetv2/tests/example_data/example_ct_sm.nii.gz
Binary file not shown.
Binary file not shown.
37 changes: 25 additions & 12 deletions nnunetv2/tests/integration_tests/run_nnunet_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,42 @@
import subprocess
from pathlib import Path

"""
To run these tests do
python tests/tests_nnunet.py
"""
def run_tests_and_exit_on_failure():
import nibabel as nib
import numpy as np


def dice_score(y_true, y_pred):
intersect = np.sum(y_true * y_pred)
denominator = np.sum(y_true) + np.sum(y_pred)
f1 = (2 * intersect) / (denominator + 1e-6)
return f1


def run_tests_and_exit_on_failure():
"""
Runs inference of a simple nnU-Net for CT body segmentation on a small example CT image
and checks if the output is correct.
"""
# Set nnUNet_results env var
weights_dir = Path.home() / "github_actions_nnunet" / "results"
os.environ["nnUNet_results"] = str(weights_dir)
print(f"Using weights directory: {weights_dir}")

# Copy example file
os.makedirs("tests/nnunet_input_files", exist_ok=True)
shutil.copy(Path.home() / "github_actions_nnunet" / "example_ct_sm.nii.gz", "tests/nnunet_input_files/example_ct_sm_0000.nii.gz")
os.makedirs("nnunetv2/tests/github_actions_output", exist_ok=True)
shutil.copy("nnunetv2/tests/example_data/example_ct_sm.nii.gz", "nnunetv2/tests/github_actions_output/example_ct_sm_0000.nii.gz")

# Run nnunet
subprocess.call(f"nnUNetv2_predict -i tests/nnunet_input_files -o tests/nnunet_input_files -d 300 -tr nnUNetTrainer -c 3d_fullres -f 0 -device cpu", shell=True)
subprocess.call(f"nnUNetv2_predict -i nnunetv2/tests/github_actions_output -o nnunetv2/tests/github_actions_output -d 300 -tr nnUNetTrainer -c 3d_fullres -f 0 -device cpu", shell=True)

# Check if output file exists
assert os.path.exists("tests/nnunet_input_files/example_ct_sm.nii.gz"), "A nnunet output file was not generated."
# Check if the nnunet segmentation is correct
img_gt = nib.load(f"nnunetv2/tests/example_data/example_ct_sm_T300_output.nii.gz").get_fdata()
img_pred = nib.load(f"nnunetv2/tests/github_actions_output/example_ct_sm.nii.gz").get_fdata()
dice = dice_score(img_gt, img_pred)
images_equal = dice > 0.99
assert images_equal, f"The nnunet segmentation is not correct (dice: {dice:.5f})."

# Clean up
shutil.rmtree("tests/nnunet_input_files")
shutil.rmtree("nnunetv2/tests/github_actions_output")
shutil.rmtree(Path.home() / "github_actions_nnunet")


Expand Down

0 comments on commit 507abd2

Please sign in to comment.