Skip to content

Commit

Permalink
first try after rebase and new config
Browse files Browse the repository at this point in the history
  • Loading branch information
camillebrianceau committed Jun 5, 2024
1 parent a66e380 commit 1a2135b
Show file tree
Hide file tree
Showing 33 changed files with 423 additions and 756 deletions.
2 changes: 0 additions & 2 deletions clinicadl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from importlib.metadata import version

from .utils.maps_manager import MapsManager

__all__ = ["__version__", "MapsManager"]

__version__ = version("clinicadl")
62 changes: 19 additions & 43 deletions clinicadl/caps_dataset/caps_dataset_config.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,20 @@
import abc
from pathlib import Path
from typing import Any, Dict, Optional, Union
from typing import Union

import pandas as pd
from pydantic import BaseModel, computed_field
from pydantic import BaseModel, ConfigDict

from clinicadl.caps_dataset.data_config import ConfigDict, DataConfig
from clinicadl.caps_dataset.data_utils import check_multi_cohort_tsv
from clinicadl.caps_dataset.data_config import DataConfig
from clinicadl.caps_dataset.dataloader_config import DataLoaderConfig
from clinicadl.config.config import modality
from clinicadl.config.config.modality import (
CustomModalityConfig,
DTIModalityConfig,
FlairModalityConfig,
ModalityConfig,
PETModalityConfig,
T1ModalityConfig,
)
from clinicadl.generate import generate_config as generate_type
from clinicadl.generate.generate_config import GenerateConfig
from clinicadl.preprocessing import config as preprocessing
from clinicadl.utils.enum import ExtractionMethod, GenerateType, Preprocessing
from clinicadl.utils.exceptions import (
ClinicaDLArgumentError,
ClinicaDLTSVError,
DownloadError,
)


def get_preprocessing(extract_method: ExtractionMethod):
Expand All @@ -38,15 +35,15 @@ def get_modality(preprocessing: Preprocessing):
preprocessing == Preprocessing.T1_EXTENSIVE
or preprocessing == Preprocessing.T1_LINEAR
):
return modality.T1ModalityConfig
return T1ModalityConfig
elif preprocessing == Preprocessing.PET_LINEAR:
return modality.PETModalityConfig
return PETModalityConfig
elif preprocessing == Preprocessing.FLAIR_LINEAR:
return modality.FlairModalityConfig
return FlairModalityConfig
elif preprocessing == Preprocessing.CUSTOM:
return modality.CustomModalityConfig
return CustomModalityConfig
elif preprocessing == Preprocessing.DWI_DTI:
return modality.DTIModalityConfig
return DTIModalityConfig
else:
raise ValueError(f"Preprocessing {preprocessing.value} is not implemented.")

Expand All @@ -69,7 +66,8 @@ def get_generate(generate: Union[str, GenerateType]):

class CapsDatasetBase(BaseModel):
data: DataConfig
modality: modality.ModalityConfig
dataloader: DataLoaderConfig
modality: ModalityConfig
preprocessing: preprocessing.PreprocessingConfig

# pydantic config
Expand All @@ -86,29 +84,7 @@ def from_preprocessing_and_extraction_method(
):
return cls(
data=DataConfig(**kwargs),
dataloader=DataLoaderConfig(**kwargs),
modality=get_modality(Preprocessing(preprocessing_type))(**kwargs),
preprocessing=get_preprocessing(ExtractionMethod(extraction))(**kwargs),
)


# def create_caps_dataset_config(
# preprocessing: Union[str, Preprocessing], extract: Union[str, ExtractionMethod]
# ):
# try:
# preprocessing_type = Preprocessing(preprocessing)
# except ClinicaDLArgumentError:
# print("Invalid preprocessing configuration")

# try:
# extract_method = ExtractionMethod(extract)
# except ClinicaDLArgumentError:
# print("Invalid preprocessing configuration")

# class CapsDatasetConfig(CapsDatasetBase):
# modality: get_modality(preprocessing_type)
# preprocessing: get_preprocessing(extract_method)

# def __init__(self, **kwargs):
# super().__init__(data=kwargs, modality=kwargs, preprocessing=kwargs)

# return CapsDatasetConfig
68 changes: 68 additions & 0 deletions clinicadl/caps_dataset/caps_dataset_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from pathlib import Path
from typing import Dict, Optional, Tuple, Union

from pydantic import BaseModel, ConfigDict

from clinicadl.caps_dataset.caps_dataset_config import CapsDatasetConfig
from clinicadl.utils.enum import LinearModality, Preprocessing


def compute_folder_and_file_type(
config: CapsDatasetConfig, from_bids: Optional[Path] = None
) -> Tuple[str, Dict[str, str]]:
from clinicadl.utils.clinica_utils import (
bids_nii,
dwi_dti,
linear_nii,
pet_linear_nii,
)

preprocessing = Preprocessing(
config.preprocessing.preprocessing
) # replace("-", "_")
if from_bids is not None:
if preprocessing == Preprocessing.CUSTOM:
mod_subfolder = Preprocessing.CUSTOM.value
file_type = {
"pattern": f"*{config.modality.custom_suffix}",
"description": "Custom suffix",
}
else:
mod_subfolder = preprocessing
file_type = bids_nii(preprocessing)

elif preprocessing not in Preprocessing:
raise NotImplementedError(
f"Extraction of preprocessing {config.preprocessing.preprocessing.value} is not implemented from CAPS directory."
)
else:
mod_subfolder = preprocessing.value.replace("-", "_")
if preprocessing == Preprocessing.T1_LINEAR:
file_type = linear_nii(
LinearModality.T1W, config.preprocessing.use_uncropped_image
)

elif preprocessing == Preprocessing.FLAIR_LINEAR:
file_type = linear_nii(
LinearModality.FLAIR, config.preprocessing.use_uncropped_image
)

elif preprocessing == Preprocessing.PET_LINEAR:
file_type = pet_linear_nii(
config.modality.tracer,
config.modality.suvr_reference_region,
config.preprocessing.use_uncropped_image,
)
elif preprocessing == Preprocessing.DWI_DTI:
file_type = dwi_dti(
config.modality.dti_measure,
config.modality.dti_space,
)
elif preprocessing == Preprocessing.CUSTOM:
file_type = {
"pattern": f"*{config.modality.custom_suffix}",
"description": "Custom suffix",
}
# custom_suffix["use_uncropped_image"] = None

return mod_subfolder, file_type
117 changes: 92 additions & 25 deletions clinicadl/caps_dataset/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch
from torch.utils.data import Dataset

from clinicadl.caps_dataset.caps_dataset_utils import compute_folder_and_file_type
from clinicadl.prepare_data.prepare_data_config import (
PrepareDataConfig,
PrepareDataImageConfig,
Expand All @@ -19,7 +20,6 @@
)
from clinicadl.prepare_data.prepare_data_utils import (
compute_discarded_slices,
compute_folder_and_file_type,
extract_patch_path,
extract_patch_tensor,
extract_roi_path,
Expand Down Expand Up @@ -142,30 +142,6 @@ def domain_fn(self, target: Union[str, float, int]) -> Union[float, int]:
def __len__(self) -> int:
return len(self.df) * self.elem_per_image

@staticmethod
def create_caps_dict(caps_directory: Path, multi_cohort: bool) -> Dict[str, Path]:
from clinicadl.utils.clinica_utils import check_caps_folder

if multi_cohort:
if not caps_directory.suffix == ".tsv":
raise ClinicaDLArgumentError(
"If multi_cohort is True, the CAPS_DIRECTORY argument should be a path to a TSV file."
)
else:
caps_df = pd.read_csv(caps_directory, sep="\t")
check_multi_cohort_tsv(caps_df, "CAPS")
caps_dict = dict()
for idx in range(len(caps_df)):
cohort = caps_df.loc[idx, "cohort"]
caps_path = Path(caps_df.loc[idx, "path"])
check_caps_folder(caps_path)
caps_dict[cohort] = caps_path
else:
check_caps_folder(caps_directory)
caps_dict = {"single": caps_directory}

return caps_dict

def _get_image_path(self, participant: str, session: str, cohort: str) -> Path:
"""
Gets the path to the tensor image (*.pt)
Expand Down Expand Up @@ -804,3 +780,94 @@ def num_elem_per_image(self):
- self.discarded_slices[0]
- self.discarded_slices[1]
)


def return_dataset(
input_dir: Path,
data_df: pd.DataFrame,
preprocessing_dict: Dict[str, Any],
all_transformations: Optional[Callable],
label: str = None,
label_code: Dict[str, int] = None,
train_transformations: Optional[Callable] = None,
cnn_index: int = None,
label_presence: bool = True,
multi_cohort: bool = False,
) -> CapsDataset:
"""
Return appropriate Dataset according to given options.
Args:
input_dir: path to a directory containing a CAPS structure.
data_df: List subjects, sessions and diagnoses.
preprocessing_dict: preprocessing dict contained in the JSON file of prepare_data.
train_transformations: Optional transform to be applied during training only.
all_transformations: Optional transform to be applied during training and evaluation.
label: Name of the column in data_df containing the label.
label_code: label code that links the output node number to label value.
cnn_index: Index of the CNN in a multi-CNN paradigm (optional).
label_presence: If True the diagnosis will be extracted from the given DataFrame.
multi_cohort: If True caps_directory is the path to a TSV file linking cohort names and paths.
Returns:
the corresponding dataset.
"""
if cnn_index is not None and preprocessing_dict["mode"] == "image":
raise NotImplementedError(
f"Multi-CNN is not implemented for {preprocessing_dict['mode']} mode."
)

if preprocessing_dict["mode"] == "image":
return CapsDatasetImage(
input_dir,
data_df,
preprocessing_dict,
train_transformations=train_transformations,
all_transformations=all_transformations,
label_presence=label_presence,
label=label,
label_code=label_code,
multi_cohort=multi_cohort,
)
elif preprocessing_dict["mode"] == "patch":
return CapsDatasetPatch(
input_dir,
data_df,
preprocessing_dict,
train_transformations=train_transformations,
all_transformations=all_transformations,
patch_index=cnn_index,
label_presence=label_presence,
label=label,
label_code=label_code,
multi_cohort=multi_cohort,
)
elif preprocessing_dict["mode"] == "roi":
return CapsDatasetRoi(
input_dir,
data_df,
preprocessing_dict,
train_transformations=train_transformations,
all_transformations=all_transformations,
roi_index=cnn_index,
label_presence=label_presence,
label=label,
label_code=label_code,
multi_cohort=multi_cohort,
)
elif preprocessing_dict["mode"] == "slice":
return CapsDatasetSlice(
input_dir,
data_df,
preprocessing_dict,
train_transformations=train_transformations,
all_transformations=all_transformations,
slice_index=cnn_index,
label_presence=label_presence,
label=label,
label_code=label_code,
multi_cohort=multi_cohort,
)
else:
raise NotImplementedError(
f"Mode {preprocessing_dict['mode']} is not implemented."
)
16 changes: 1 addition & 15 deletions clinicadl/caps_dataset/data_config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import tarfile
from logging import getLogger
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, Union
Expand All @@ -8,19 +7,11 @@

from clinicadl.caps_dataset.data_utils import check_multi_cohort_tsv, load_data_test
from clinicadl.preprocessing.preprocessing import read_preprocessing
from clinicadl.utils.clinica_utils import (
RemoteFileStructure,
clinicadl_file_reader,
fetch_file,
)
from clinicadl.utils.enum import MaskChecksum, Mode, Pathology
from clinicadl.utils.enum import Mode
from clinicadl.utils.exceptions import (
ClinicaDLArgumentError,
ClinicaDLTSVError,
DownloadError,
)
from clinicadl.utils.maps_manager.maps_manager import MapsManager
from clinicadl.utils.read_utils import get_mask_checksum_and_filename

logger = getLogger("clinicadl.data_config")

Expand Down Expand Up @@ -52,11 +43,6 @@ def validator_diagnoses(cls, v):
return tuple(v)
return v # TODO : check if columns are in tsv

def adapt_data_with_maps_manager_info(self, maps_manager: MapsManager):
# TEMPORARY
if self.diagnoses is None or len(self.diagnoses) == 0:
self.diagnoses = maps_manager.diagnoses

def create_groupe_df(self):
group_df = None
if self.data_tsv is not None and self.data_tsv.is_file():
Expand Down
Loading

0 comments on commit 1a2135b

Please sign in to comment.