Skip to content

Commit

Permalink
fix unit tests and add some more
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasGensollen committed Oct 2, 2024
1 parent 238601a commit f3c8405
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 20 deletions.
1 change: 1 addition & 0 deletions clinica/utils/input_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ def query_factory(name: Union[str, QueryName], *args, **kwargs) -> Query:


def get_dwi_file(filetype: Union[str, DWIFileType]) -> Query:
"""Return the query to get DWI files (nii, json, bvec, bval)."""
filetype = DWIFileType(filetype)
return Query(
f"dwi/sub-*_ses-*_dwi.{filetype.value}*", f"DWI {filetype.value} files.", ""
Expand Down
151 changes: 131 additions & 20 deletions test/unittests/utils/test_input_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest

from clinica.pipelines.dwi.dti.utils import DTIBasedMeasure
from clinica.utils.input_files import query_factory
from clinica.utils.input_files import Query, query_factory
from clinica.utils.pet import ReconstructionMethod, Tracer


Expand Down Expand Up @@ -44,13 +44,49 @@ def toy_func_3(x, y=2, z=3):
@pytest.mark.parametrize(
"query_name,expected_pattern,expected_description,expected_pipelines",
[
("T1W", "sub-*_ses-*_t1w.nii*", "T1w MRI", ()),
("T2W", "sub-*_ses-*_flair.nii*", "FLAIR T2w MRI", ()),
("T1W", "sub-*_ses-*_t1w.nii*", "T1w MRI", ""),
("T2W", "sub-*_ses-*_flair.nii*", "FLAIR T2w MRI", ""),
(
"T1_FS_WM",
"t1/freesurfer_cross_sectional/sub-*_ses-*/mri/wm.seg.mgz",
"segmentation of white matter (mri/wm.seg.mgz).",
("t1-freesurfer",),
"t1-freesurfer",
),
(
"T1_FS_BRAIN",
"t1/freesurfer_cross_sectional/sub-*_ses-*/mri/brain.mgz",
"extracted brain from T1w MRI (mri/brain.mgz).",
"t1-freesurfer",
),
(
"T1_FS_ORIG_NU",
"t1/freesurfer_cross_sectional/sub-*_ses-*/mri/orig_nu.mgz",
(
"intensity normalized volume generated after correction for"
" non-uniformity in FreeSurfer (mri/orig_nu.mgz)."
),
"t1-freesurfer",
),
(
"T1_FS_LONG_ORIG_NU",
"t1/long-*/freesurfer_longitudinal/sub-*_ses-*.long.sub-*_*/mri/orig_nu.mgz",
(
"intensity normalized volume generated after correction for non-uniformity "
"in FreeSurfer (orig_nu.mgz) in longitudinal"
),
"t1-freesurfer and t1-freesurfer longitudinal",
),
(
"T1W_TO_MNI_TRANSFORM",
"*space-MNI152NLin2009cSym_res-1x1x1_affine.mat",
"Transformation matrix from T1W image to MNI space using t1-linear pipeline",
"t1-linear",
),
(
"DWI_PREPROC_BRAINMASK",
"dwi/preprocessing/sub-*_ses-*_space-*_brainmask.nii*",
"b0 brainmask",
"dwi-preprocessing-using-t1 or dwi-preprocessing-using-fieldmap",
),
],
)
Expand All @@ -64,27 +100,99 @@ def test_query_factory(
assert query.needed_pipeline == expected_pipelines


@pytest.mark.parametrize(
"filetype,expected_pattern,expected_description,expected_pipelines",
[
("nii", "dwi/sub-*_ses-*_dwi.nii*", "DWI nii files.", ""),
("json", "dwi/sub-*_ses-*_dwi.json*", "DWI json files.", ""),
("bvec", "dwi/sub-*_ses-*_dwi.bvec*", "DWI bvec files.", ""),
("bval", "dwi/sub-*_ses-*_dwi.bval*", "DWI bval files.", ""),
],
)
def test_get_dwi_file(
filetype: str,
expected_pattern: str,
expected_description: str,
expected_pipelines: str,
):
from clinica.utils.input_files import get_dwi_file

query = get_dwi_file(filetype)

assert query.pattern == expected_pattern
assert query.description == expected_description
assert query.needed_pipeline == expected_pipelines


@pytest.mark.parametrize(
"filetype,expected_pattern,expected_description,expected_pipelines",
[
(
"nii",
"dwi/preprocessing/sub-*_ses-*_space-*_desc-preproc_dwi.nii*",
"preprocessed nii files",
"dwi-preprocessing-using-t1 or dwi-preprocessing-using-fieldmap",
),
(
"json",
"dwi/preprocessing/sub-*_ses-*_space-*_desc-preproc_dwi.json*",
"preprocessed json files",
"dwi-preprocessing-using-t1 or dwi-preprocessing-using-fieldmap",
),
(
"bvec",
"dwi/preprocessing/sub-*_ses-*_space-*_desc-preproc_dwi.bvec*",
"preprocessed bvec files",
"dwi-preprocessing-using-t1 or dwi-preprocessing-using-fieldmap",
),
(
"bval",
"dwi/preprocessing/sub-*_ses-*_space-*_desc-preproc_dwi.bval*",
"preprocessed bval files",
"dwi-preprocessing-using-t1 or dwi-preprocessing-using-fieldmap",
),
],
)
def test_get_dwi_preprocessed_file(
filetype: str,
expected_pattern: str,
expected_description: str,
expected_pipelines: str,
):
from clinica.utils.input_files import get_dwi_preprocessed_file

query = get_dwi_preprocessed_file(filetype)

assert query.pattern == expected_pattern
assert query.description == expected_description
assert query.needed_pipeline == expected_pipelines


def test_bids_pet_nii_empty():
from clinica.utils.input_files import bids_pet_nii

assert bids_pet_nii() == {
"pattern": Path("pet") / "*_pet.nii*",
"description": "PET data",
}
query = bids_pet_nii()

assert query.pattern == str(Path("pet") / "*_pet.nii*")
assert query.description == "PET data"


@pytest.fixture
def expected_bids_pet_query(tracer, reconstruction):
return {
"pattern": Path("pet")
/ f"*_trc-{tracer.value}_rec-{reconstruction.value}_pet.nii*",
"description": f"PET data with {tracer.value} tracer and reconstruction method {reconstruction.value}",
}
def expected_bids_pet_query(
tracer: Tracer, reconstruction: ReconstructionMethod
) -> Query:
return Query(
str(Path("pet") / f"*_trc-{tracer.value}_rec-{reconstruction.value}_pet.nii*"),
f"PET data with {tracer.value} tracer and reconstruction method {reconstruction.value}",
"",
)


@pytest.mark.parametrize("tracer", Tracer)
@pytest.mark.parametrize("reconstruction", ReconstructionMethod)
def test_bids_pet_nii(tracer, reconstruction, expected_bids_pet_query):
def test_bids_pet_nii(
tracer: Tracer, reconstruction: ReconstructionMethod, expected_bids_pet_query: Query
):
from clinica.utils.input_files import bids_pet_nii

assert bids_pet_nii(tracer, reconstruction) == expected_bids_pet_query
Expand All @@ -96,11 +204,14 @@ def test_dwi_dti_query(dti_measure, space):
from clinica.utils.input_files import dwi_dti

space = space or "*"
assert dwi_dti(dti_measure, space=space) == {
"pattern": f"dwi/dti_based_processing/*/*_space-{space}_{dti_measure.value}.nii.gz",
"description": f"DTI-based {dti_measure.value} in space {space}.",
"needed_pipeline": "dwi_dti",
}
query = dwi_dti(dti_measure, space=space)

assert (
query.pattern
== f"dwi/dti_based_processing/*/*_space-{space}_{dti_measure.value}.nii.gz"
)
assert query.description == f"DTI-based {dti_measure.value} in space {space}."
assert query.needed_pipeline == "dwi_dti"


def test_dwi_dti_query_error():
Expand Down

0 comments on commit f3c8405

Please sign in to comment.