Skip to content

Commit

Permalink
Dataset module for v2 (#685)
Browse files Browse the repository at this point in the history
* new dataset module
  • Loading branch information
camillebrianceau authored Dec 5, 2024
1 parent 0cf01a6 commit c28d4d5
Show file tree
Hide file tree
Showing 64 changed files with 3,522 additions and 2,441 deletions.
167 changes: 167 additions & 0 deletions clinicadl/API/dataset_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
from pathlib import Path

import torchio.transforms as transforms

from clinicadl.dataset.datasets.caps_dataset import CapsDataset
from clinicadl.dataset.datasets.concat import ConcatDataset
from clinicadl.dataset.preprocessing import (
BasePreprocessing,
PreprocessingFlair,
PreprocessingPET,
PreprocessingT1,
)
from clinicadl.experiment_manager.experiment_manager import ExperimentManager
from clinicadl.losses.config import CrossEntropyLossConfig
from clinicadl.model.clinicadl_model import ClinicaDLModel
from clinicadl.networks.factory import (
ConvEncoderOptions,
create_network_config,
get_network_from_config,
)
from clinicadl.splitter.kfold import KFolder
from clinicadl.splitter.split import get_single_split, split_tsv
from clinicadl.transforms.extraction import ROI, Image, Patch, Slice
from clinicadl.transforms.transforms import Transforms

sub_ses_t1 = Path("/Users/camille.brianceau/aramis/CLINICADL/caps/subjects_t1.tsv")
sub_ses_pet_45 = Path(
"/Users/camille.brianceau/aramis/CLINICADL/caps/subjects_pet_18FAV45.tsv"
)
sub_ses_flair = Path(
"/Users/camille.brianceau/aramis/CLINICADL/caps/subjects_flair.tsv"
)
sub_ses_pet_11 = Path(
"/Users/camille.brianceau/aramis/CLINICADL/caps/subjects_pet_11CPIB.tsv"
)

caps_directory = Path(
"/Users/camille.brianceau/aramis/CLINICADL/caps"
) # output of clinica pipelines

preprocessing_pet_45 = PreprocessingPET(tracer="18FAV45", suvr_reference_region="pons2")
preprocessing_pet_11 = PreprocessingPET(tracer="11CPIB", suvr_reference_region="pons2")

preprocessing_t1 = PreprocessingT1()
preprocessing_flair = PreprocessingFlair()


transforms_patch = Transforms(
object_augmentation=[transforms.Ghosting(2, 1, 0.1, 0.1)],
image_augmentation=[transforms.RandomMotion()],
extraction=Patch(patch_size=60),
image_transforms=[transforms.Blur((0.5, 0.6, 0.3))],
object_transforms=[transforms.RandomMotion()],
) # not mandatory

transforms_slice = Transforms(extraction=Slice())

transforms_roi = Transforms(
object_augmentation=[transforms.Ghosting(2, 1, 0.1, 0.1)],
object_transforms=[transforms.RandomMotion()],
extraction=ROI(
roi_list=["leftHippocampusBox", "rightHippocampusBox"],
roi_mask_location=Path(
"/Users/camille.brianceau/aramis/CLINICADL/caps/masks/tpl-MNI152NLin2009cSym"
),
roi_crop_input=True,
),
)

transforms_image = Transforms(
image_augmentation=[transforms.RandomMotion()],
extraction=Image(),
image_transforms=[transforms.Blur((0.5, 0.6, 0.3))],
)


print("Pet 45 and Patch ")
dataset_pet_45_patch = CapsDataset(
caps_directory=caps_directory,
data=sub_ses_pet_45,
preprocessing=preprocessing_pet_45,
transforms=transforms_patch,
)
dataset_pet_45_patch.prepare_data(n_proc=2)

print(dataset_pet_45_patch)
print(dataset_pet_45_patch.__len__())
print(dataset_pet_45_patch._get_meta_data(3))
print(dataset_pet_45_patch._get_meta_data(80))
# print(dataset_pet_45_patch._get_full_image())
print(dataset_pet_45_patch.__getitem__(80).elem_idx)
print(dataset_pet_45_patch.elem_per_image)

dataset_pet_45_patch.caps_reader._write_caps_json(
transforms_patch, preprocessing_pet_45, sub_ses_pet_45, name="tfsdklsqfh"
)


print("Pet 11 and ROI ")

dataset_pet_11_roi = CapsDataset(
caps_directory=caps_directory,
data=sub_ses_pet_11,
preprocessing=preprocessing_pet_11,
transforms=transforms_roi,
)
dataset_pet_11_roi.prepare_data(
n_proc=2
) # to extract the tensor of the PET file this time

print(dataset_pet_11_roi)
print(dataset_pet_11_roi.__len__())
print(dataset_pet_11_roi._get_meta_data(0))
print(dataset_pet_11_roi._get_meta_data(1))
# print(dataset_pet_11_roi._get_full_image())
print(dataset_pet_11_roi.__getitem__(1).elem_idx)
print(dataset_pet_11_roi.elem_per_image)


print("T1 and image ")

dataset_t1_image = CapsDataset(
caps_directory=caps_directory,
data=sub_ses_t1,
preprocessing=preprocessing_t1,
transforms=transforms_image,
)
dataset_t1_image.prepare_data(
n_proc=2
) # to extract the tensor of the PET file this time

print(dataset_t1_image)
print(dataset_t1_image.__len__())
print(dataset_t1_image._get_meta_data(3))
print(dataset_t1_image._get_meta_data(5))
# print(dataset_t1_image._get_full_image())
print(dataset_t1_image.__getitem__(5).elem_idx)
print(dataset_t1_image.elem_per_image)


print("Flair and slice ")

dataset_flair_slice = CapsDataset(
caps_directory=caps_directory,
data=sub_ses_flair,
preprocessing=preprocessing_flair,
transforms=transforms_slice,
)
dataset_flair_slice.prepare_data(
n_proc=2
) # to extract the tensor of the PET file this time

print(dataset_flair_slice)
print(dataset_flair_slice.__len__())
print(dataset_flair_slice._get_meta_data(3))
print(dataset_flair_slice._get_meta_data(80))
# print(dataset_flair_slice._get_full_image())
print(dataset_flair_slice.__getitem__(80).elem_idx)
print(dataset_flair_slice.elem_per_image)


lity_multi_extract = ConcatDataset(
[
dataset_t1,
dataset_pet,
]
) # 3 train.tsv en entrée qu'il faut concat et pareil pour les transforms à faire attention
2 changes: 1 addition & 1 deletion clinicadl/commandline/modules_options/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from clinicadl.config.config_utils import get_default_from_config_class as get_default
from clinicadl.config.config_utils import get_type_from_config_class as get_type
from clinicadl.dataset.data_config import DataConfig
from clinicadl.dataset.config.data import DataConfig

# Data
baseline = click.option(
Expand Down
34 changes: 17 additions & 17 deletions clinicadl/commandline/modules_options/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,17 @@

from clinicadl.config.config_utils import get_default_from_config_class as get_default
from clinicadl.config.config_utils import get_type_from_config_class as get_type
from clinicadl.dataset.config.preprocessing import (
CustomPreprocessingConfig,
DTIPreprocessingConfig,
PETPreprocessingConfig,
PreprocessingConfig,
from clinicadl.dataset.preprocessing import (
BasePreprocessing,
PreprocessingCustom,
PreprocessingDTI,
PreprocessingPET,
)

tracer = click.option(
"--tracer",
default=get_default("tracer", PETPreprocessingConfig),
type=get_type("tracer", PETPreprocessingConfig),
default=get_default("tracer", PreprocessingPET),
type=get_type("tracer", PreprocessingPET),
help=(
"Acquisition label if MODALITY is `pet-linear`. "
"Name of the tracer used for the PET acquisition (trc-<tracer>). "
Expand All @@ -22,8 +22,8 @@
suvr_reference_region = click.option(
"-suvr",
"--suvr_reference_region",
default=get_default("suvr_reference_region", PETPreprocessingConfig),
type=get_type("suvr_reference_region", PETPreprocessingConfig),
default=get_default("suvr_reference_region", PreprocessingPET),
type=get_type("suvr_reference_region", PreprocessingPET),
help=(
"Regions used for normalization if MODALITY is `pet-linear`. "
"Intensity normalization using the average PET uptake in reference regions resulting in a standardized uptake "
Expand All @@ -34,8 +34,8 @@
custom_suffix = click.option(
"-cn",
"--custom_suffix",
default=get_default("custom_suffix", CustomPreprocessingConfig),
type=get_type("custom_suffix", CustomPreprocessingConfig),
default=get_default("custom_suffix", PreprocessingCustom),
type=get_type("custom_suffix", PreprocessingCustom),
help=(
"Suffix of output files if MODALITY is `custom`. "
"Suffix to append to filenames, for instance "
Expand All @@ -46,21 +46,21 @@
dti_measure = click.option(
"--dti_measure",
"-dm",
type=get_type("dti_measure", DTIPreprocessingConfig),
type=get_type("dti_measure", PreprocessingDTI),
help="Possible DTI measures.",
default=get_default("dti_measure", DTIPreprocessingConfig),
default=get_default("dti_measure", PreprocessingDTI),
)
dti_space = click.option(
"--dti_space",
"-ds",
type=get_type("dti_space", DTIPreprocessingConfig),
type=get_type("dti_space", PreprocessingDTI),
help="Possible DTI space.",
default=get_default("dti_space", DTIPreprocessingConfig),
default=get_default("dti_space", PreprocessingDTI),
)
preprocessing = click.option(
"--preprocessing",
type=get_type("preprocessing", PreprocessingConfig),
default=get_default("preprocessing", PreprocessingConfig),
type=get_type("preprocessing", BasePreprocessing),
default=get_default("preprocessing", BasePreprocessing),
required=True,
help="Extraction used to generate synthetic data.",
show_default=True,
Expand Down
Loading

0 comments on commit c28d4d5

Please sign in to comment.