Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dataset module for v2 #685

Merged
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 167 additions & 0 deletions clinicadl/API/dataset_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
from pathlib import Path

import torchio.transforms as transforms

from clinicadl.dataset.datasets.caps_dataset import CapsDataset
from clinicadl.dataset.datasets.concat import ConcatDataset
from clinicadl.dataset.preprocessing import (
BasePreprocessing,
PreprocessingFlair,
PreprocessingPET,
PreprocessingT1,
)
from clinicadl.experiment_manager.experiment_manager import ExperimentManager
from clinicadl.losses.config import CrossEntropyLossConfig
from clinicadl.model.clinicadl_model import ClinicaDLModel
from clinicadl.networks.factory import (
ConvEncoderOptions,
create_network_config,
get_network_from_config,
)
from clinicadl.splitter.kfold import KFolder
from clinicadl.splitter.split import get_single_split, split_tsv
from clinicadl.transforms.extraction import ROI, Image, Patch, Slice
from clinicadl.transforms.transforms import Transforms

sub_ses_t1 = Path("/Users/camille.brianceau/aramis/CLINICADL/caps/subjects_t1.tsv")
sub_ses_pet_45 = Path(
"/Users/camille.brianceau/aramis/CLINICADL/caps/subjects_pet_18FAV45.tsv"
)
sub_ses_flair = Path(
"/Users/camille.brianceau/aramis/CLINICADL/caps/subjects_flair.tsv"
)
sub_ses_pet_11 = Path(
"/Users/camille.brianceau/aramis/CLINICADL/caps/subjects_pet_11CPIB.tsv"
)

caps_directory = Path(
"/Users/camille.brianceau/aramis/CLINICADL/caps"
) # output of clinica pipelines

preprocessing_pet_45 = PreprocessingPET(tracer="18FAV45", suvr_reference_region="pons2")
preprocessing_pet_11 = PreprocessingPET(tracer="11CPIB", suvr_reference_region="pons2")

preprocessing_t1 = PreprocessingT1()
preprocessing_flair = PreprocessingFlair()


transforms_patch = Transforms(
object_augmentation=[transforms.Ghosting(2, 1, 0.1, 0.1)],
image_augmentation=[transforms.RandomMotion()],
extraction=Patch(patch_size=60),
image_transforms=[transforms.Blur((0.5, 0.6, 0.3))],
object_transforms=[transforms.RandomMotion()],
) # not mandatory

transforms_slice = Transforms(extraction=Slice())

transforms_roi = Transforms(
object_augmentation=[transforms.Ghosting(2, 1, 0.1, 0.1)],
object_transforms=[transforms.RandomMotion()],
extraction=ROI(
roi_list=["leftHippocampusBox", "rightHippocampusBox"],
roi_mask_location=Path(
"/Users/camille.brianceau/aramis/CLINICADL/caps/masks/tpl-MNI152NLin2009cSym"
),
roi_crop_input=True,
),
)

transforms_image = Transforms(
image_augmentation=[transforms.RandomMotion()],
extraction=Image(),
image_transforms=[transforms.Blur((0.5, 0.6, 0.3))],
)


print("Pet 45 and Patch ")
dataset_pet_45_patch = CapsDataset(
caps_directory=caps_directory,
data=sub_ses_pet_45,
preprocessing=preprocessing_pet_45,
transforms=transforms_patch,
)
dataset_pet_45_patch.prepare_data(n_proc=2)

print(dataset_pet_45_patch)
print(dataset_pet_45_patch.__len__())
print(dataset_pet_45_patch._get_meta_data(3))
print(dataset_pet_45_patch._get_meta_data(80))
# print(dataset_pet_45_patch._get_full_image())
print(dataset_pet_45_patch.__getitem__(80).elem_idx)
print(dataset_pet_45_patch.elem_per_image)

dataset_pet_45_patch.caps_reader._write_caps_json(
transforms_patch, preprocessing_pet_45, sub_ses_pet_45, name="tfsdklsqfh"
)


print("Pet 11 and ROI ")

dataset_pet_11_roi = CapsDataset(
caps_directory=caps_directory,
data=sub_ses_pet_11,
preprocessing=preprocessing_pet_11,
transforms=transforms_roi,
)
dataset_pet_11_roi.prepare_data(
n_proc=2
) # to extract the tensor of the PET file this time

print(dataset_pet_11_roi)
print(dataset_pet_11_roi.__len__())
print(dataset_pet_11_roi._get_meta_data(0))
print(dataset_pet_11_roi._get_meta_data(1))
# print(dataset_pet_11_roi._get_full_image())
print(dataset_pet_11_roi.__getitem__(1).elem_idx)
print(dataset_pet_11_roi.elem_per_image)


print("T1 and image ")

dataset_t1_image = CapsDataset(
caps_directory=caps_directory,
data=sub_ses_t1,
preprocessing=preprocessing_t1,
transforms=transforms_image,
)
dataset_t1_image.prepare_data(
n_proc=2
) # to extract the tensor of the PET file this time

print(dataset_t1_image)
print(dataset_t1_image.__len__())
print(dataset_t1_image._get_meta_data(3))
print(dataset_t1_image._get_meta_data(5))
# print(dataset_t1_image._get_full_image())
print(dataset_t1_image.__getitem__(5).elem_idx)
print(dataset_t1_image.elem_per_image)


print("Flair and slice ")

dataset_flair_slice = CapsDataset(
caps_directory=caps_directory,
data=sub_ses_flair,
preprocessing=preprocessing_flair,
transforms=transforms_slice,
)
dataset_flair_slice.prepare_data(
n_proc=2
) # to extract the tensor of the PET file this time

print(dataset_flair_slice)
print(dataset_flair_slice.__len__())
print(dataset_flair_slice._get_meta_data(3))
print(dataset_flair_slice._get_meta_data(80))
# print(dataset_flair_slice._get_full_image())
print(dataset_flair_slice.__getitem__(80).elem_idx)
print(dataset_flair_slice.elem_per_image)


lity_multi_extract = ConcatDataset(
[
dataset_t1,
dataset_pet,
]
) # 3 train.tsv en entrée qu'il faut concat et pareil pour les transforms à faire attention
2 changes: 1 addition & 1 deletion clinicadl/commandline/modules_options/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from clinicadl.config.config_utils import get_default_from_config_class as get_default
from clinicadl.config.config_utils import get_type_from_config_class as get_type
from clinicadl.dataset.data_config import DataConfig
from clinicadl.dataset.config.data import DataConfig

# Data
baseline = click.option(
Expand Down
34 changes: 17 additions & 17 deletions clinicadl/commandline/modules_options/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,17 @@

from clinicadl.config.config_utils import get_default_from_config_class as get_default
from clinicadl.config.config_utils import get_type_from_config_class as get_type
from clinicadl.dataset.config.preprocessing import (
CustomPreprocessingConfig,
DTIPreprocessingConfig,
PETPreprocessingConfig,
PreprocessingConfig,
from clinicadl.dataset.preprocessing import (
BasePreprocessing,
PreprocessingCustom,
PreprocessingDTI,
PreprocessingPET,
)

tracer = click.option(
"--tracer",
default=get_default("tracer", PETPreprocessingConfig),
type=get_type("tracer", PETPreprocessingConfig),
default=get_default("tracer", PreprocessingPET),
type=get_type("tracer", PreprocessingPET),
help=(
"Acquisition label if MODALITY is `pet-linear`. "
"Name of the tracer used for the PET acquisition (trc-<tracer>). "
Expand All @@ -22,8 +22,8 @@
suvr_reference_region = click.option(
"-suvr",
"--suvr_reference_region",
default=get_default("suvr_reference_region", PETPreprocessingConfig),
type=get_type("suvr_reference_region", PETPreprocessingConfig),
default=get_default("suvr_reference_region", PreprocessingPET),
type=get_type("suvr_reference_region", PreprocessingPET),
help=(
"Regions used for normalization if MODALITY is `pet-linear`. "
"Intensity normalization using the average PET uptake in reference regions resulting in a standardized uptake "
Expand All @@ -34,8 +34,8 @@
custom_suffix = click.option(
"-cn",
"--custom_suffix",
default=get_default("custom_suffix", CustomPreprocessingConfig),
type=get_type("custom_suffix", CustomPreprocessingConfig),
default=get_default("custom_suffix", PreprocessingCustom),
type=get_type("custom_suffix", PreprocessingCustom),
help=(
"Suffix of output files if MODALITY is `custom`. "
"Suffix to append to filenames, for instance "
Expand All @@ -46,21 +46,21 @@
dti_measure = click.option(
"--dti_measure",
"-dm",
type=get_type("dti_measure", DTIPreprocessingConfig),
type=get_type("dti_measure", PreprocessingDTI),
help="Possible DTI measures.",
default=get_default("dti_measure", DTIPreprocessingConfig),
default=get_default("dti_measure", PreprocessingDTI),
)
dti_space = click.option(
"--dti_space",
"-ds",
type=get_type("dti_space", DTIPreprocessingConfig),
type=get_type("dti_space", PreprocessingDTI),
help="Possible DTI space.",
default=get_default("dti_space", DTIPreprocessingConfig),
default=get_default("dti_space", PreprocessingDTI),
)
preprocessing = click.option(
"--preprocessing",
type=get_type("preprocessing", PreprocessingConfig),
default=get_default("preprocessing", PreprocessingConfig),
type=get_type("preprocessing", BasePreprocessing),
default=get_default("preprocessing", BasePreprocessing),
required=True,
help="Extraction used to generate synthetic data.",
show_default=True,
Expand Down
Loading
Loading