Skip to content

Commit

Permalink
Cleaning part 2 + DataGroup (#690)
Browse files Browse the repository at this point in the history
  • Loading branch information
camillebrianceau authored Dec 19, 2024
1 parent 1a7b08b commit c73b884
Show file tree
Hide file tree
Showing 63 changed files with 272 additions and 72 deletions.
13 changes: 6 additions & 7 deletions clinicadl/data/datasets/caps_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from clinicadl.utils.exceptions import ClinicaDLCAPSError, ClinicaDLTSVError
from clinicadl.utils.iotools.clinica_utils import create_subs_sess_list
from clinicadl.utils.loading import nifti_to_tensor, pt_to_tensor
from clinicadl.utils.typing import DataType, PathType

logger = getLogger("clinicadl.caps_dataset")

Expand Down Expand Up @@ -68,10 +69,10 @@ class CapsDataset(Dataset):

def __init__(
self,
caps_directory: Union[str, Path],
caps_directory: PathType,
preprocessing: Preprocessing = PreprocessingT1(),
transforms: Transforms = Transforms(),
data: Optional[Union[pd.DataFrame, str, Path]] = None,
data: Optional[DataType] = None,
label: Optional[str] = None,
masks: Optional[list[str]] = None,
):
Expand Down Expand Up @@ -207,9 +208,7 @@ def describe(self):
"extraction": self.extraction.model_dump(),
}

def _get_df_from_input(
self, data: Optional[Union[pd.DataFrame, Path, str]]
) -> pd.DataFrame:
def _get_df_from_input(self, data: Optional[DataType]) -> pd.DataFrame:
"""
Generates or validates the DataFrame from the input data.
Expand Down Expand Up @@ -248,7 +247,7 @@ def _get_df_from_input(

return df

def _check_data_instance(self, data: Optional[Union[pd.DataFrame, Path, str]]):
def _check_data_instance(self, data: Optional[DataType]):
if isinstance(data, str):
data = Path(data)

Expand Down Expand Up @@ -550,7 +549,7 @@ def train(self):
"""
self.eval_mode = False

def subset(self, data: Optional[Union[pd.DataFrame, Path]] = None) -> CapsDataset:
def subset(self, data: Optional[DataType] = None) -> CapsDataset:
df = self._check_data_instance(data)

common_rows = pd.merge(df, self.df, how="inner")
Expand Down
6 changes: 2 additions & 4 deletions clinicadl/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@
SESSION_ID = "session_id"


def df_to_tsv(
name: str, results_path: Path, df: pd.DataFrame, baseline: bool = False
) -> None:
def df_to_tsv(tsv_path: Path, df: pd.DataFrame, baseline: bool = False) -> None:
"""
Write Dataframe into a TSV file and drop duplicates
Expand All @@ -46,7 +44,7 @@ def df_to_tsv(
subset=["participant_id", "session_id"], keep="first", inplace=True
)
# df = df[["participant_id", "session_id"]]
df.to_csv(results_path / name, sep="\t", index=False)
df.to_csv(tsv_path, sep="\t", index=False)


def tsv_to_df(tsv_path: Path) -> pd.DataFrame:
Expand Down
2 changes: 2 additions & 0 deletions clinicadl/dictionary/suffixes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
JSON = ".json"
LOG = ".log"
PT = ".pt"
PTH = ".pth"
TAR = ".tar"
NII = ".nii"
GZ = ".gz"
NII_GZ = NII + GZ
3 changes: 3 additions & 0 deletions clinicadl/dictionary/words.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@
CONFIG = "config"
COUNT = "count"
CUDA = "cuda"
DATA = "data"
DESCRIPTION = "description"
FOLD = "fold"
GROUPS = "groups"
ID = "id"
IMAGE = "image"
KFOLD = "k" + FOLD
LABEL = "label"
MAPS = "maps"
MEAN = "mean"
OBJECT = "object"
PARTICIPANT = "participant"
Expand Down
69 changes: 69 additions & 0 deletions clinicadl/experiment_manager/data_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import json
from pathlib import Path
from typing import List, Optional

import pandas as pd

from clinicadl.data.datasets import CapsDataset
from clinicadl.data.utils import df_to_tsv, tsv_to_df
from clinicadl.dictionary.suffixes import JSON, TSV
from clinicadl.dictionary.words import DATA, GROUPS, MAPS, SPLIT, TRAIN, VALIDATION
from clinicadl.utils.config import ClinicaDLConfig

TRAIN_VAL = [TRAIN, VALIDATION]


class DataGroup(ClinicaDLConfig):
maps_path: Path
name: str
split: Optional[int]

@property
def data_tsv(self) -> Path:
return self.group_split_dir / (DATA + TSV)

@property
def maps_json(self) -> Path:
return self.group_split_dir / (MAPS + JSON)

@property
def group_dir(self) -> Path:
return self.maps_path / GROUPS / self.name

@property
def group_split_dir(self) -> Path:
if self.name in TRAIN_VAL:
return self.group_dir / (SPLIT + "-" + str(self.split))
else:
return self.group_dir

@property
def df(self) -> pd.DataFrame:
return tsv_to_df(self.data_tsv)

@property
def caps_dir(self) -> Path:
dict_ = self._read_json()
return Path(dict_["cap_directory"])

def exists(self) -> bool:
return self.data_tsv.is_file() and self.maps_json.is_file()

def create(self, caps_dataset: CapsDataset):
df_to_tsv(self.data_tsv, caps_dataset.df)
self._write_json(caps_dataset)

def _write_json(self, caps_dataset: CapsDataset):
dict_ = {} # TODO: to complete
with self.maps_json.open(mode="w") as file:
json.dump(dict_, file)

def _read_json(self):
if self.maps_json.is_file():
with self.maps_json.open(mode="r") as file:
dict_ = json.load(file)
return dict_

raise FileNotFoundError(
f"Could not find the `maps.json` file for the data grou : {self.name} (in {self.maps_path})"
)
83 changes: 83 additions & 0 deletions clinicadl/experiment_manager/maps_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from pathlib import Path
from typing import Optional, Union

import pandas as pd
import torch

from clinicadl.data.datasets import CapsDataset
from clinicadl.data.utils import tsv_to_df
from clinicadl.dictionary.suffixes import PTH, TAR
from clinicadl.dictionary.words import BEST, GROUPS, PARTICIPANT_ID, SPLIT, TMP
from clinicadl.experiment_manager.data_group import DataGroup
from clinicadl.model import ClinicaDLModel
from clinicadl.splitter.split import Split
from clinicadl.utils.exceptions import (
ClinicaDLConfigurationError,
ClinicaDLDataLeakageError,
)
from clinicadl.utils.typing import PathType


class MapsReader:
def __init__(self, maps_path: PathType) -> None:
self.maps_path = Path(maps_path)

def _create_data_group(
self, name: str, caps_dataset: CapsDataset, split: Optional[int] = None
) -> DataGroup:
"""
Check that a data_group is not already written and writes the characteristics of the data group
(TSV file with a list of participant / session + JSON file containing the CAPS and the preprocessing).
"""
data_group = DataGroup(name=name, split=split, maps_path=self.maps_path)
if data_group.exists():
raise ClinicaDLConfigurationError(
f"Data group {data_group.name} already exists, please give another name to your data group"
)

data_group.create(caps_dataset)
return data_group

def _load_data_group(self, name: str, split: Optional[int] = None) -> DataGroup:
"""creates a new data_group."""
data_group = DataGroup(name=name, split=split, maps_path=self.maps_path)
if data_group.exists():
return data_group

raise ClinicaDLConfigurationError(
f"Could not find data group {data_group.name}"
)

def get_train_val_df(self):
"""Loads the train and validation data groups."""
path = self.maps_path / GROUPS / "train+validation.tsv"
return tsv_to_df(path)

def get_model(self) -> ClinicaDLModel:
return ClinicaDLModel() # type: ignore

def _write_network_weights(self):
"""TO COMPLETE"""
pass

def _write_optim_weights(self):
"""TO COMPLETE"""
pass

def write_tensor(self):
"""TO COMPLETE"""
pass

def split_path(self, split: int):
return self.maps_path / (SPLIT + "-" + str(split))

def optimizer_path(self, split: int, resume: bool = False) -> Path:
"""TO COMPLETE"""

return self.split_path(split) / TMP / ("optimizer" + PTH + TAR)

def checkpoint_path(self, split: int, resume: bool = False):
return self.split_path(split) / TMP / ("checkpoint" + PTH + TAR)

def model_path(self, split: int, metric: str):
return self.split_path(split) / (BEST + "-" + metric) / ("model" + PTH + TAR)
1 change: 1 addition & 0 deletions clinicadl/model/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .clinicadl_model import ClinicaDLModel
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ class PredictConfig(BaseModel):
data: DataConfig
validation: ValidationConfig
computational: ComputationalConfig
dataloader: DataLoaderConfig
split: SplitConfig
transforms: TransformsConfig

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1059,7 +1059,7 @@ def _compute_output_tensors(
Compute the output tensors and saves them in the MAPS.
Args:
dataset (clinicadl.data.datasets.caps_dataset.CapsDataset): wrapper of the data set.
dataset (clinicadl.data.datasets.CapsDataset): wrapper of the data set.
data_group (str): name of the data group used for the task.
split (int): split number.
selection_metrics (list[str]): metrics used for model selection.
Expand Down
File renamed without changes.
File renamed without changes.
69 changes: 64 additions & 5 deletions clinicadl/predictor/predictor.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,71 @@
from clinicadl.data.caps_dataset import CapsDataset
from clinicadl.experiment_manager.experiment_manager import ExperimentManager
from pathlib import Path
from typing import Optional

import pandas as pd
import torch

from clinicadl.data.datasets import CapsDataset
from clinicadl.data.utils import tsv_to_df
from clinicadl.dictionary.words import GROUPS, PARTICIPANT_ID
from clinicadl.experiment_manager import ExperimentManager
from clinicadl.experiment_manager.maps_reader import MapsReader
from clinicadl.model import ClinicaDLModel
from clinicadl.splitter.split import Split
from clinicadl.utils.exceptions import (
ClinicaDLConfigurationError,
ClinicaDLDataLeakageError,
)


class Predictor:
def __init__(self, manager: ExperimentManager):
def __init__(self, reader: MapsReader, model: ClinicaDLModel):
"""TO COMPLETE"""
pass
self.reader = reader
self.model = model

def predict(self, dataset_test: CapsDataset, split: int):
def predict(self, dataset_test: CapsDataset, split: Split):
"""TO COMPLETE"""
pass

def _check_leakage(self, dataset_test: CapsDataset):
"""Checks that no intersection exist between the participants used for training and those used for testing."""

df_train_val = self.reader.get_train_val_df()
df_test = dataset_test.df

participants_train = set(df_train_val[PARTICIPANT_ID].values)
participants_test = set(df_test[PARTICIPANT_ID].values)
intersection = participants_test & participants_train

if len(intersection) > 0:
raise ClinicaDLDataLeakageError(
"Your evaluation set contains participants who were already seen during "
"the training step. The list of common participants is the following: "
f"{intersection}."
)

def test(self):
"""Computes the predictions and evaluation metrics."""
pass

def _test_loader(self):
"""Launches the testing task on a dataset wrapped by a DataLoader and writes prediction TSV files."""
pass

def _compute_latent_tensor(self):
"""Compute the output tensors and saves them in the MAPS."""
pass

@torch.no_grad()
def _compute_output_nifti(self):
"""omputes the output nifti images and saves them in the MAPS."""
pass

@torch.no_grad()
def _compute_output_tensors(self):
"""Compute the output tensors and saves them in the MAPS."""
pass

def _ensemble_prediction(self):
"""Computes the results on the image-level."""
pass
5 changes: 3 additions & 2 deletions clinicadl/splitter/make_splits/kfold.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from clinicadl.splitter.splitter.kfold import KFoldConfig
from clinicadl.tsvtools.tsvtools_utils import extract_baseline
from clinicadl.utils.exceptions import ClinicaDLConfigurationError, ClinicaDLTSVError
from clinicadl.utils.typing import DataType, PathType


def _validate_stratification(
Expand Down Expand Up @@ -98,8 +99,8 @@ def preprocess_stratification(


def make_kfold(
data: Union[pd.DataFrame, Path, str],
output_dir: Optional[Union[Path, str]] = None,
data: DataType,
output_dir: Optional[PathType] = None,
subset_name: str = "validation",
valid_longitudinal: bool = False,
n_splits: PositiveInt = 5,
Expand Down
5 changes: 3 additions & 2 deletions clinicadl/splitter/make_splits/single_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from clinicadl.splitter.splitter.single_split import SingleSplitConfig
from clinicadl.tsvtools.tsvtools_utils import extract_baseline
from clinicadl.utils.exceptions import ClinicaDLConfigurationError, ClinicaDLTSVError
from clinicadl.utils.typing import DataType, PathType

logger = getLogger("clinicadl.splitter.single_split")

Expand Down Expand Up @@ -128,8 +129,8 @@ def _chi2_test(x_test: List[int], x_train: List[int]) -> float:


def make_split(
data: Union[pd.DataFrame, Path, str],
output_dir: Optional[Union[Path, str]] = None,
data: DataType,
output_dir: Optional[PathType] = None,
n_test: PositiveFloat = 100,
subset_name: str = "test",
p_categorical_threshold: float = 0.50,
Expand Down
Loading

0 comments on commit c73b884

Please sign in to comment.