Skip to content

Commit

Permalink
Change file_type dict for a class (#627)
Browse files Browse the repository at this point in the history
 create a FileType class instead of having a dictionary
  • Loading branch information
camillebrianceau authored Jun 24, 2024
1 parent 5625612 commit 2e5bfa0
Show file tree
Hide file tree
Showing 30 changed files with 329 additions and 406 deletions.
67 changes: 59 additions & 8 deletions clinicadl/caps_dataset/caps_dataset_config.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,27 @@
from typing import Optional, Union
from pathlib import Path
from typing import Optional, Tuple, Union

from pydantic import BaseModel, ConfigDict

from clinicadl.caps_dataset.data_config import DataConfig
from clinicadl.caps_dataset.dataloader_config import DataLoaderConfig
from clinicadl.caps_dataset.extraction import config as extraction
from clinicadl.caps_dataset.preprocessing import config as preprocessing
from clinicadl.caps_dataset.preprocessing.config import (
CustomPreprocessingConfig,
DTIPreprocessingConfig,
FlairPreprocessingConfig,
PETPreprocessingConfig,
PreprocessingConfig,
T1PreprocessingConfig,
)
from clinicadl.transforms.config import TransformsConfig
from clinicadl.utils.clinica_utils import (
FileType,
bids_nii,
dwi_dti,
linear_nii,
pet_linear_nii,
)
from clinicadl.utils.enum import ExtractionMethod, Preprocessing


Expand All @@ -25,15 +40,15 @@ def get_extraction(extract_method: ExtractionMethod):

def get_preprocessing(preprocessing_type: Preprocessing):
if preprocessing_type == Preprocessing.T1_LINEAR:
return preprocessing.T1PreprocessingConfig
return T1PreprocessingConfig
elif preprocessing_type == Preprocessing.PET_LINEAR:
return preprocessing.PETPreprocessingConfig
return PETPreprocessingConfig
elif preprocessing_type == Preprocessing.FLAIR_LINEAR:
return preprocessing.FlairPreprocessingConfig
return FlairPreprocessingConfig
elif preprocessing_type == Preprocessing.CUSTOM:
return preprocessing.CustomPreprocessingConfig
return CustomPreprocessingConfig
elif preprocessing_type == Preprocessing.DWI_DTI:
return preprocessing.DTIPreprocessingConfig
return DTIPreprocessingConfig
else:
raise ValueError(
f"Preprocessing {preprocessing_type.value} is not implemented."
Expand All @@ -52,7 +67,7 @@ class CapsDatasetConfig(BaseModel):
data: DataConfig
dataloader: DataLoaderConfig
extraction: extraction.ExtractionConfig
preprocessing: preprocessing.PreprocessingConfig
preprocessing: PreprocessingConfig
transforms: TransformsConfig

# pydantic config
Expand All @@ -74,3 +89,39 @@ def from_preprocessing_and_extraction_method(
extraction=get_extraction(ExtractionMethod(extraction))(**kwargs),
transforms=TransformsConfig(**kwargs),
)

def compute_folder_and_file_type(
self, from_bids: Optional[Path] = None
) -> Tuple[str, FileType]:
preprocessing = self.preprocessing.preprocessing
if from_bids is not None:
if isinstance(self.preprocessing, CustomPreprocessingConfig):
mod_subfolder = Preprocessing.CUSTOM.value
file_type = FileType(
pattern=f"*{self.preprocessing.custom_suffix}",
description="Custom suffix",
)
else:
mod_subfolder = preprocessing
file_type = bids_nii(self.preprocessing)

elif preprocessing not in Preprocessing:
raise NotImplementedError(
f"Extraction of preprocessing {preprocessing} is not implemented from CAPS directory."
)
else:
mod_subfolder = preprocessing.value.replace("-", "_")
if isinstance(self.preprocessing, T1PreprocessingConfig) or isinstance(
self.preprocessing, FlairPreprocessingConfig
):
file_type = linear_nii(self.preprocessing)
elif isinstance(self.preprocessing, PETPreprocessingConfig):
file_type = pet_linear_nii(self.preprocessing)
elif isinstance(self.preprocessing, DTIPreprocessingConfig):
file_type = dwi_dti(self.preprocessing)
elif isinstance(self.preprocessing, CustomPreprocessingConfig):
file_type = FileType(
pattern=f"*{self.preprocessing.custom_suffix}",
description="Custom suffix",
)
return mod_subfolder, file_type
65 changes: 25 additions & 40 deletions clinicadl/caps_dataset/caps_dataset_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Dict, Optional, Tuple
from typing import Optional, Tuple

from clinicadl.caps_dataset.caps_dataset_config import CapsDatasetConfig
from clinicadl.caps_dataset.preprocessing.config import (
Expand All @@ -9,63 +9,48 @@
PETPreprocessingConfig,
T1PreprocessingConfig,
)
from clinicadl.utils.enum import LinearModality, Preprocessing
from clinicadl.utils.clinica_utils import (
FileType,
bids_nii,
dwi_dti,
linear_nii,
pet_linear_nii,
)
from clinicadl.utils.enum import Preprocessing


def compute_folder_and_file_type(
config: CapsDatasetConfig, from_bids: Optional[Path] = None
) -> Tuple[str, Dict[str, str]]:
from clinicadl.utils.clinica_utils import (
bids_nii,
dwi_dti,
linear_nii,
pet_linear_nii,
)

) -> Tuple[str, FileType]:
preprocessing = config.preprocessing.preprocessing
if from_bids is not None:
if isinstance(config.preprocessing, CustomPreprocessingConfig):
mod_subfolder = Preprocessing.CUSTOM.value
file_type = {
"pattern": f"*{config.preprocessing.custom_suffix}",
"description": "Custom suffix",
}
file_type = FileType(
pattern=f"*{config.preprocessing.custom_suffix}",
description="Custom suffix",
)
else:
mod_subfolder = preprocessing
file_type = bids_nii(preprocessing)
file_type = bids_nii(config.preprocessing)

elif preprocessing not in Preprocessing:
raise NotImplementedError(
f"Extraction of preprocessing {preprocessing} is not implemented from CAPS directory."
)
else:
mod_subfolder = preprocessing.value.replace("-", "_")
if isinstance(config.preprocessing, T1PreprocessingConfig):
file_type = linear_nii(
LinearModality.T1W, config.extraction.use_uncropped_image
)

elif isinstance(config.preprocessing, FlairPreprocessingConfig):
file_type = linear_nii(
LinearModality.FLAIR, config.extraction.use_uncropped_image
)

if isinstance(config.preprocessing, T1PreprocessingConfig) or isinstance(
config.preprocessing, FlairPreprocessingConfig
):
file_type = linear_nii(config.preprocessing)
elif isinstance(config.preprocessing, PETPreprocessingConfig):
file_type = pet_linear_nii(
config.preprocessing.tracer,
config.preprocessing.suvr_reference_region,
config.extraction.use_uncropped_image,
)
file_type = pet_linear_nii(config.preprocessing)
elif isinstance(config.preprocessing, DTIPreprocessingConfig):
file_type = dwi_dti(
config.preprocessing.dti_measure,
config.preprocessing.dti_space,
)
file_type = dwi_dti(config.preprocessing)
elif isinstance(config.preprocessing, CustomPreprocessingConfig):
file_type = {
"pattern": f"*{config.preprocessing.custom_suffix}",
"description": "Custom suffix",
}
# custom_suffix["use_uncropped_image"] = None

file_type = FileType(
pattern=f"*{config.preprocessing.custom_suffix}",
description="Custom suffix",
)
return mod_subfolder, file_type
30 changes: 18 additions & 12 deletions clinicadl/caps_dataset/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,15 +137,18 @@ def _get_image_path(self, participant: str, session: str, cohort: str) -> Path:

# Try to find .nii.gz file
try:
file_type = self.preprocessing_dict["file_type"]
folder, file_type = self.config.compute_folder_and_file_type()

results = clinicadl_file_reader(
[participant], [session], self.config.data.caps_dict[cohort], file_type
[participant],
[session],
self.config.data.caps_dict[cohort],
file_type.model_dump(),
)
logger.debug(f"clinicadl_file_reader output: {results}")
filepath = Path(results[0][0])
image_filename = filepath.name.replace(".nii.gz", ".pt")

folder, _ = compute_folder_and_file_type(self.config)
image_dir = (
self.config.data.caps_dict[cohort]
/ "subjects"
Expand All @@ -158,10 +161,13 @@ def _get_image_path(self, participant: str, session: str, cohort: str) -> Path:
image_path = image_dir / image_filename
# Try to find .pt file
except ClinicaDLCAPSError:
file_type = self.preprocessing_dict["file_type"]
file_type["pattern"] = file_type["pattern"].replace(".nii.gz", ".pt")
folder, file_type = self.config.compute_folder_and_file_type()
file_type.pattern = file_type.pattern.replace(".nii.gz", ".pt")
results = clinicadl_file_reader(
[participant], [session], self.config.data.caps_dict[cohort], file_type
[participant],
[session],
self.config.data.caps_dict[cohort],
file_type.model_dump(),
)
filepath = results[0]
image_path = Path(filepath[0])
Expand Down Expand Up @@ -225,12 +231,12 @@ def _get_full_image(self) -> torch.Tensor:
image_path = self._get_image_path(participant_id, session_id, cohort)
image = torch.load(image_path)
except IndexError:
file_type = self.preprocessing_dict["file_type"]
file_type = self.config.extraction.file_type
results = clinicadl_file_reader(
[participant_id],
[session_id],
self.config.data.caps_dict[cohort],
file_type,
file_type.model_dump(),
)
image_nii = nib.loadsave.load(results[0])
image_np = image_nii.get_fdata()
Expand Down Expand Up @@ -741,7 +747,7 @@ def return_dataset(

if preprocessing_dict["mode"] == "image":
config.extraction.save_features = preprocessing_dict["prepare_dl"]
config.extraction.use_uncropped_image = preprocessing_dict[
config.preprocessing.use_uncropped_image = preprocessing_dict[
"use_uncropped_image"
]
return CapsDatasetImage(
Expand All @@ -755,7 +761,7 @@ def return_dataset(
config.extraction.patch_size = preprocessing_dict["patch_size"]
config.extraction.stride_size = preprocessing_dict["stride_size"]
config.extraction.save_features = preprocessing_dict["prepare_dl"]
config.extraction.use_uncropped_image = preprocessing_dict[
config.preprocessing.use_uncropped_image = preprocessing_dict[
"use_uncropped_image"
]
return CapsDatasetPatch(
Expand All @@ -770,7 +776,7 @@ def return_dataset(
config.extraction.roi_list = preprocessing_dict["roi_list"]
config.extraction.roi_uncrop_output = preprocessing_dict["uncropped_roi"]
config.extraction.save_features = preprocessing_dict["prepare_dl"]
config.extraction.use_uncropped_image = preprocessing_dict[
config.preprocessing.use_uncropped_image = preprocessing_dict[
"use_uncropped_image"
]
return CapsDatasetRoi(
Expand All @@ -795,7 +801,7 @@ def return_dataset(
else preprocessing_dict["num_slices"]
)
config.extraction.save_features = preprocessing_dict["prepare_dl"]
config.extraction.use_uncropped_image = preprocessing_dict[
config.preprocessing.use_uncropped_image = preprocessing_dict[
"use_uncropped_image"
]
return CapsDatasetSlice(
Expand Down
47 changes: 23 additions & 24 deletions clinicadl/caps_dataset/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def load_data_test(test_path: Path, diagnoses_list, baseline=True, multi_cohort=
# TODO: computes baseline sessions on-the-fly to manager TSV file case

if multi_cohort:
if not test_path.suffix == ".tsv":
if test_path.suffix != ".tsv":
raise ClinicaDLArgumentError(
"If multi_cohort is given, the TSV_DIRECTORY argument should be a path to a TSV file."
)
Expand Down Expand Up @@ -77,6 +77,27 @@ def load_data_test(test_path: Path, diagnoses_list, baseline=True, multi_cohort=
return test_df


def check_test_path(test_path: Path, baseline: bool = True) -> Path:
if baseline:
train_filename = "train_baseline.tsv"
label_filename = "labels_baseline.tsv"
else:
train_filename = "train.tsv"
label_filename = "labels.tsv"

if not (test_path.parent / train_filename).is_file():
if not (test_path.parent / label_filename).is_file():
raise ClinicaDLTSVError(
f"There is no {train_filename} nor {label_filename} in your folder {test_path.parents[0]} "
)
else:
test_path = test_path.parent / label_filename
else:
test_path = test_path.parent / train_filename

return test_path


def load_data_test_single(test_path: Path, diagnoses_list, baseline=True):
if test_path.suffix == ".tsv":
test_df = pd.read_csv(test_path, sep="\t")
Expand All @@ -91,29 +112,7 @@ def load_data_test_single(test_path: Path, diagnoses_list, baseline=True):
)
return test_df

test_df = pd.DataFrame()

if baseline:
if not (test_path.parent / "train_baseline.tsv").is_file():
if not (test_path.parent / "labels_baseline.tsv").is_file():
raise ClinicaDLTSVError(
f"There is no train_baseline.tsv nor labels_baseline.tsv in your folder {test_path.parents[0]} "
)
else:
test_path = test_path.parent / "labels_baseline.tsv"
else:
test_path = test_path.parent / "train_baseline.tsv"
else:
if not (test_path.parent / "train.tsv").is_file():
if not (test_path.parent / "labels.tsv").is_file():
raise ClinicaDLTSVError(
f"There is no train.tsv or labels.tsv in your folder {test_path.parent} "
)
else:
test_path = test_path.parent / "labels.tsv"
else:
test_path = test_path.parent / "train.tsv"

test_path = check_test_path(test_path=test_path, baseline=baseline)
test_df = pd.read_csv(test_path, sep="\t")
test_df = test_df[test_df.diagnosis.isin(diagnoses_list)]
test_df.reset_index(inplace=True, drop=True)
Expand Down
10 changes: 3 additions & 7 deletions clinicadl/caps_dataset/extraction/config.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from logging import getLogger
from pathlib import Path
from time import time
from typing import List, Optional, Tuple

from pydantic import BaseModel, ConfigDict, field_validator
from pydantic.types import NonNegativeInt

from clinicadl.utils.clinica_utils import FileType
from clinicadl.utils.enum import (
ExtractionMethod,
Preprocessing,
SliceDirection,
SliceMode,
)
Expand All @@ -18,14 +17,11 @@

class ExtractionConfig(BaseModel):
"""
Abstract config class for the validation procedure.
Abstract config class for the Extraction procedure.
"""

use_uncropped_image: bool = False
extract_method: ExtractionMethod
file_type: Optional[str] = None # Optional ??
file_type: Optional[FileType] = None
save_features: bool = False
extract_json: Optional[str] = None

Expand Down
Loading

0 comments on commit 2e5bfa0

Please sign in to comment.