Skip to content

Commit

Permalink
change classe CapsDatset
Browse files Browse the repository at this point in the history
  • Loading branch information
camillebrianceau committed Jun 5, 2024
1 parent 8d9d30d commit 3d81738
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 58 deletions.
26 changes: 16 additions & 10 deletions clinicadl/caps_dataset/caps_dataset_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,15 @@ def get_generate(generate: Union[str, GenerateType]):
raise ValueError(f"GenerateType {generate.value} is not available.")


class CapsDatasetBase(BaseModel):
data: DataConfig
modality: modality.ModalityConfig
preprocessing: preprocessing.PreprocessingConfig

# pydantic config
model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True)


def create_caps_dataset_config(
preprocessing: Union[str, Preprocessing], extract: Union[str, ExtractionMethod]
):
Expand All @@ -80,14 +89,11 @@ def create_caps_dataset_config(
except ClinicaDLArgumentError:
print("Invalid preprocessing configuration")

class CapsDatasetBase(
DataConfig,
get_modality(preprocessing_type),
get_preprocessing(extract_method),
):
# pydantic config
model_config = ConfigDict(
validate_assignment=True, arbitrary_types_allowed=True
)
class CapsDatasetConfig(CapsDatasetBase):
modality: get_modality(preprocessing_type)
preprocessing: get_preprocessing(extract_method)

def __init__(self, **kwargs):
super().__init__(data=kwargs, modality=kwargs, preprocessing=kwargs)

return CapsDatasetBase
return CapsDatasetConfig
13 changes: 8 additions & 5 deletions clinicadl/commandline/pipelines/generate/artifacts/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,14 @@ def cli(generated_caps_directory, n_proc, **kwargs):
commandline_to_json(
{
"output_dir": generated_caps_directory,
"caps_dir": caps_config.caps_directory,
"preprocessing": caps_config.preprocessing,
"caps_dir": caps_config.data.caps_directory,
"preprocessing": caps_config.preprocessing.preprocessing.value,
}
)

# Read DataFrame
data_df = load_and_check_tsv(
caps_config.data_tsv, caps_config.caps_dict, generated_caps_directory
caps_config.data.data_tsv, caps_config.data.caps_dict, generated_caps_directory
)
# data_df = extract_baseline(data_df)
# if caps_config.n_subjects > len(data_df):
Expand All @@ -98,7 +98,10 @@ def create_artifacts_image(data_idx: int) -> pd.DataFrame:
cohort = data_df.at[data_idx, "cohort"]
image_path = Path(
clinicadl_file_reader(
[participant_id], [session_id], caps_config.caps_dict[cohort], file_type
[participant_id],
[session_id],
caps_config.data.caps_dict[cohort],
file_type,
)[0][0]
)
from clinicadl.utils.read_utils import get_info_from_filename
Expand All @@ -115,7 +118,7 @@ def create_artifacts_image(data_idx: int) -> pd.DataFrame:
/ "subjects"
/ subject_name
/ session_name
/ caps_config.preprocessing.value
/ caps_config.preprocessing.preprocessing.value
)
artif_image_nii_dir.mkdir(parents=True, exist_ok=True)

Expand Down
22 changes: 11 additions & 11 deletions clinicadl/commandline/pipelines/generate/hypometabolic/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ def cli(generated_caps_directory, n_proc, **kwargs):
commandline_to_json(
{
"output_dir": generated_caps_directory,
"caps_dir": caps_config.caps_directory,
"preprocessing": caps_config.preprocessing.value,
"n_subjects": caps_config.n_subjects,
"caps_dir": caps_config.data.caps_directory,
"preprocessing": caps_config.preprocessing.preprocessing.value,
"n_subjects": caps_config.data.n_subjects,
"n_proc": n_proc,
"pathology": generate_config.pathology.value,
"anomaly_degree": generate_config.anomaly_degree,
Expand All @@ -75,12 +75,12 @@ def cli(generated_caps_directory, n_proc, **kwargs):

# Read DataFrame
data_df = load_and_check_tsv(
caps_config.data_tsv, caps_config.caps_dict, generated_caps_directory
caps_config.data.data_tsv, caps_config.data.caps_dict, generated_caps_directory
)
data_df = extract_baseline(data_df)
if caps_config.n_subjects > len(data_df):
if caps_config.data.n_subjects > len(data_df):
raise IndexError(
f"The number of subjects {caps_config.n_subjects} cannot be higher "
f"The number of subjects {caps_config.data.n_subjects} cannot be higher "
f"than the number of subjects in the baseline dataset of size {len(data_df)}"
f"Please add the '--n_subjects' option and re-run the command."
)
Expand All @@ -97,10 +97,10 @@ def cli(generated_caps_directory, n_proc, **kwargs):

# Output tsv file
participants = [
data_df.at[i, "participant_id"] for i in range(caps_config.n_subjects)
data_df.at[i, "participant_id"] for i in range(caps_config.data.n_subjects)
]
sessions = [data_df.at[i, "session_id"] for i in range(caps_config.n_subjects)]
cohort = caps_config.caps_directory
sessions = [data_df.at[i, "session_id"] for i in range(caps_config.data.n_subjects)]
cohort = caps_config.data.caps_directory

images_paths = clinicadl_file_reader(participants, sessions, cohort, file_type)[0]
image_nii = nib.loadsave.load(images_paths[0])
Expand All @@ -124,7 +124,7 @@ def generate_hypometabolic_image(
/ "subjects"
/ participants[subject_id]
/ sessions[subject_id]
/ caps_config.preprocessing.value
/ caps_config.preprocessing.preprocessing.value
)
hypo_image_nii_filename = f"{input_filename}pat-{generate_config.pathology.value}_deg-{int(generate_config.anomaly_degree)}_pet.nii.gz"
hypo_image_nii_dir.mkdir(parents=True, exist_ok=True)
Expand All @@ -146,7 +146,7 @@ def generate_hypometabolic_image(

results_list = Parallel(n_jobs=n_proc)(
delayed(generate_hypometabolic_image)(subject_id)
for subject_id in range(caps_config.n_subjects)
for subject_id in range(caps_config.data.n_subjects)
)
output_df = pd.DataFrame()
for result_df in results_list:
Expand Down
26 changes: 13 additions & 13 deletions clinicadl/commandline/pipelines/generate/random/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ def cli(generated_caps_directory, n_proc, **kwargs):
commandline_to_json(
{
"output_dir": generated_caps_directory,
"caps_dir": caps_config.caps_directory,
"preprocessing": caps_config.preprocessing.value,
"n_subjects": caps_config.n_subjects,
"caps_dir": caps_config.data.caps_directory,
"preprocessing": caps_config.preprocessing.preprocessing.value,
"n_subjects": caps_config.data.n_subjects,
"n_proc": n_proc,
"mean": generate_config.mean,
"sigma": generate_config.sigma,
Expand All @@ -73,15 +73,15 @@ def cli(generated_caps_directory, n_proc, **kwargs):

# Read DataFrame
data_df = load_and_check_tsv(
caps_config.data_tsv,
caps_config.caps_dict,
caps_config.data.data_tsv,
caps_config.data.caps_dict,
generated_caps_directory,
)

data_df = extract_baseline(data_df)
if caps_config.n_subjects > len(data_df):
if caps_config.data.n_subjects > len(data_df):
raise IndexError(
f"The number of subjects {caps_config.n_subjects} cannot be higher "
f"The number of subjects {caps_config.data.n_subjects} cannot be higher "
f"than the number of subjects in the baseline dataset of size {len(data_df)}"
)

Expand All @@ -96,19 +96,19 @@ def cli(generated_caps_directory, n_proc, **kwargs):
session_id = data_df.at[0, "session_id"]
cohort = data_df.at[0, "cohort"]
image_paths = clinicadl_file_reader(
[participant_id], [session_id], caps_config.caps_dict[cohort], file_type
[participant_id], [session_id], caps_config.data.caps_dict[cohort], file_type
)
image_nii = nib.loadsave.load(image_paths[0][0])
# assert isinstance(image_nii, nib.nifti1.Nifti1Image)
image = image_nii.get_fdata()
output_df = pd.DataFrame(
{
"participant_id": [
f"sub-RAND{i}" for i in range(2 * caps_config.n_subjects)
f"sub-RAND{i}" for i in range(2 * caps_config.data.n_subjects)
],
"session_id": [SESSION_ID] * 2 * caps_config.n_subjects,
"diagnosis": ["AD"] * caps_config.n_subjects
+ ["CN"] * caps_config.n_subjects,
"session_id": [SESSION_ID] * 2 * caps_config.data.n_subjects,
"diagnosis": ["AD"] * caps_config.data.n_subjects
+ ["CN"] * caps_config.data.n_subjects,
"age_bl": AGE_BL_DEFAULT,
"sex": SEX_DEFAULT,
}
Expand Down Expand Up @@ -142,7 +142,7 @@ def create_random_image(subject_id: int) -> None:

Parallel(n_jobs=n_proc)(
delayed(create_random_image)(subject_id)
for subject_id in range(2 * caps_config.n_subjects)
for subject_id in range(2 * caps_config.data.n_subjects)
)
write_missing_mods(generated_caps_directory, output_df)
logger.info(f"Random dataset was generated at {generated_caps_directory}")
Expand Down
27 changes: 15 additions & 12 deletions clinicadl/commandline/pipelines/generate/trivial/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,20 +62,20 @@ def cli(generated_caps_directory, n_proc, **kwargs):
commandline_to_json(
{
"output_dir": generated_caps_directory,
"caps_dir": caps_config.caps_directory,
"preprocessing": caps_config.preprocessing,
"caps_dir": caps_config.data.caps_directory,
"preprocessing": caps_config.preprocessing.preprocessing.value,
}
)
# Read DataFrame
data_df = load_and_check_tsv(
caps_config.data_tsv,
caps_config.caps_dict,
caps_config.data.data_tsv,
caps_config.data.caps_dict,
generated_caps_directory,
)
data_df = extract_baseline(data_df)
if caps_config.n_subjects > len(data_df):
if caps_config.data.n_subjects > len(data_df):
raise IndexError(
f"The number of subjects {caps_config.n_subjects} cannot be higher "
f"The number of subjects {caps_config.data.n_subjects} cannot be higher "
f"than the number of subjects in the baseline dataset of size {len(data_df)}"
)

Expand All @@ -97,7 +97,10 @@ def create_trivial_image(subject_id: int) -> pd.DataFrame:
cohort = data_df.at[data_idx, "cohort"]
image_path = Path(
clinicadl_file_reader(
[participant_id], [session_id], caps_config.caps_dict[cohort], file_type
[participant_id],
[session_id],
caps_config.data.caps_dict[cohort],
file_type,
)[0][0]
)

Expand All @@ -113,13 +116,13 @@ def create_trivial_image(subject_id: int) -> pd.DataFrame:
/ "subjects"
/ f"sub-TRIV{subject_id}"
/ session_id
/ caps_config.preprocessing.value
/ caps_config.preprocessing.preprocessing.value
)
trivial_image_nii_dir.mkdir(parents=True, exist_ok=True)

if caps_config.mask_path is None:
caps_config.mask_path = get_mask_path()
path_to_mask = caps_config.mask_path / f"mask-{label + 1}.nii"
if caps_config.data.mask_path is None:
caps_config.data.mask_path = get_mask_path()
path_to_mask = caps_config.data.mask_path / f"mask-{label + 1}.nii"
print(path_to_mask)
if path_to_mask.is_file():
atlas_to_mask = nib.loadsave.load(path_to_mask).get_fdata()
Expand Down Expand Up @@ -153,7 +156,7 @@ def create_trivial_image(subject_id: int) -> pd.DataFrame:

results_df = Parallel(n_jobs=n_proc)(
delayed(create_trivial_image)(subject_id)
for subject_id in range(2 * caps_config.n_subjects)
for subject_id in range(2 * caps_config.data.n_subjects)
)
output_df = pd.DataFrame()
for result in results_df:
Expand Down
23 changes: 16 additions & 7 deletions clinicadl/generate/generate_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
from scipy.ndimage import gaussian_filter
from skimage.draw import ellipse

from clinicadl.caps_dataset.caps_dataset_config import CapsDatasetBase
from clinicadl.caps_dataset.data_utils import check_multi_cohort_tsv
from clinicadl.config.config.modality import PETModalityConfig
from clinicadl.utils.clinica_utils import (
create_subs_sess_list,
linear_nii,
Expand All @@ -32,22 +34,29 @@
)


def find_file_type(config: BaseModel) -> Dict[str, str]:
def find_file_type(config: CapsDatasetBase) -> Dict[str, str]:
# preprocessing = Preprocessing(preprocessing)
if config.preprocessing == Preprocessing.T1_LINEAR:
file_type = linear_nii(LinearModality.T1W, config.use_uncropped_image)
elif config.preprocessing == Preprocessing.PET_LINEAR:
if config.tracer is None or config.suvr_reference_region is None:
if config.preprocessing.preprocessing == Preprocessing.T1_LINEAR:
file_type = linear_nii(
LinearModality.T1W, config.preprocessing.use_uncropped_image
)
elif isinstance(config.modality, PETModalityConfig):
if (
config.modality.tracer is None
or config.modality.suvr_reference_region is None
):
raise ClinicaDLArgumentError(
"`tracer` and `suvr_reference_region` must be defined "
"when using `pet-linear` preprocessing."
)
file_type = pet_linear_nii(
config.tracer, config.suvr_reference_region, config.use_uncropped_image
config.modality.tracer,
config.modality.suvr_reference_region,
config.preprocessing.use_uncropped_image,
)
else:
raise NotImplementedError(
f"Generation of synthetic data is not implemented for preprocessing {config.preprocessing.value}"
f"Generation of synthetic data is not implemented for preprocessing {config.preprocessing.preprocessing.value}"
)

return file_type
Expand Down

0 comments on commit 3d81738

Please sign in to comment.