Skip to content

Commit

Permalink
DataGroup
Browse files Browse the repository at this point in the history
  • Loading branch information
camillebrianceau committed Dec 18, 2024
1 parent 497ba0a commit ddf5fdc
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 45 deletions.
4 changes: 2 additions & 2 deletions clinicadl/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class CapsDatasetSample(BaseModel):
model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True)


def df_to_tsv(name: str, results_path: Path, df, baseline: bool = False) -> None:
def df_to_tsv(tsv_path: Path, df, baseline: bool = False) -> None:
"""
Write Dataframe into a TSV file and drop duplicates
Expand All @@ -77,7 +77,7 @@ def df_to_tsv(name: str, results_path: Path, df, baseline: bool = False) -> None
subset=["participant_id", "session_id"], keep="first", inplace=True
)
# df = df[["participant_id", "session_id"]]
df.to_csv(results_path / name, sep="\t", index=False)
df.to_csv(tsv_path, sep="\t", index=False)


def tsv_to_df(tsv_path: Path) -> pd.DataFrame:
Expand Down
56 changes: 41 additions & 15 deletions clinicadl/experiment_manager/data_group.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
import json
from pathlib import Path
from typing import List, Optional

import pandas as pd

from typing import List, Optional
from pathlib import Path

from clinicadl.data.utils import tsv_to_df
from clinicadl.data.datasets import CapsDataset
from clinicadl.data.utils import df_to_tsv, tsv_to_df
from clinicadl.dictionary.suffixes import JSON, TSV
from clinicadl.dictionary.words import DATA, GROUPS, MAPS, SPLIT, TRAIN, VALIDATION
from clinicadl.splitter.split import Split
from clinicadl.utils.config import ClinicaDLConfig

from clinicadl.dictionary.words import GROUPS, DATA, MAPS, TRAIN, VALIDATION, SPLIT
from clinicadl.dictionary.suffixes import TSV, JSON



TRAIN_VAL = [TRAIN, VALIDATION]


Expand All @@ -22,23 +20,51 @@ class DataGroup(ClinicaDLConfig):
split: Optional[int]

@property
def data_tsv(self)-> Path:
def data_tsv(self) -> Path:
return self.group_split_dir / (DATA + TSV)

@property
def maps_json(self)-> Path:
def maps_json(self) -> Path:
return self.group_split_dir / (MAPS + JSON)

@property
def group_dir(self) -> Path:
return self.maps_path / GROUPS / self.name

@property
def group_split_dir(self) -> Path:
if self.name in TRAIN_VAL:
return self.group_dir / (SPLIT + "-" + str(self.split))
else:
return self.group_dir

def get_data(self)-> pd.DataFrame:
@property
def df(self) -> pd.DataFrame:
return tsv_to_df(self.data_tsv)

@property
def caps_dir(self) -> Path:
dict_ = self._read_json()
return Path(dict_["cap_directory"])

def exists(self) -> bool:
return self.data_tsv.is_file() and self.maps_json.is_file()

def create(self, caps_dataset: CapsDataset):
df_to_tsv(self.data_tsv, caps_dataset.df)
self._write_json(caps_dataset)

def _write_json(self, caps_dataset: CapsDataset):
dict_ = {} # TODO: to complete
with self.maps_json.open(mode="w") as file:
json.dump(dict_, file)

def _read_json(self):
if self.maps_json.is_file():
with self.maps_json.open(mode="r") as file:
dict_ = json.load(file)
return dict_

raise FileNotFoundError(
f"Could not find the `maps.json` file for the data grou : {self.name} (in {self.maps_path})"
)
67 changes: 39 additions & 28 deletions clinicadl/predictor/predictor.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,55 @@
from pathlib import Path
from typing import Optional

import pandas as pd
import torch

from clinicadl.data.datasets import CapsDataset
from clinicadl.data.utils import tsv_to_df
from clinicadl.dictionary.words import GROUPS, PARTICIPANT_ID
from clinicadl.experiment_manager import ExperimentManager
from clinicadl.experiment_manager.data_group import DataGroup
from clinicadl.model import ClinicaDLModel
from clinicadl.splitter.split import Split
from clinicadl.utils.exceptions import ClinicaDLDataLeakageError
from clinicadl.utils.exceptions import (
ClinicaDLConfigurationError,
ClinicaDLDataLeakageError,
)


class MapsReader:
def _check_data_group(self, data_group: str) -> bool:
"""Check if a data group is already available if other arguments are None."""
return True
maps_path: Path

def _create_data_group(self, data_group: str):
"""creates a new data_group."""
pass
def _create_data_group(
self, name: str, caps_dataset: CapsDataset, split: Optional[int] = None
) -> DataGroup:
"""
Check that a data_group is not already written and writes the characteristics of the data group
(TSV file with a list of participant / session + JSON file containing the CAPS and the preprocessing).
"""
data_group = DataGroup(name=name, split=split, maps_path=self.maps_path)
if data_group.exists():
raise ClinicaDLConfigurationError(
f"Data group {data_group.name} already exists, please give another name to your data group"
)

def get_group_info(self, data_group: str, split: Optional[Split]):
"""Gets information from corresponding data group
(list of participant_id / session_id + configuration parameters).
split is only needed if data_group is train or validation."""
pass
data_group.create(caps_dataset)
return data_group

def get_group_df(
self, data_group: str, split: Optional[Split] = None
) -> pd.DataFrame:
"""Gets information from corresponding data group
(list of participant_id / session_id).
split is only needed if data_group is train or validation."""
def _load_data_group(self, name: str, split: Optional[int] = None) -> DataGroup:
"""creates a new data_group."""
data_group = DataGroup(name=name, split=split, maps_path=self.maps_path)
if data_group.exists():
return data_group

raise ClinicaDLConfigurationError(
f"Could not find data group {data_group.name}"
)

return pd.DataFrame()
def get_train_val_df(self):
"""Loads the train and validation data groups."""
path = self.maps_path / GROUPS / "train+validation.tsv"
return tsv_to_df(path)


class Predictor:
Expand All @@ -48,11 +65,11 @@ def predict(self, dataset_test: CapsDataset, split: Split):
def _check_leakage(self, dataset_test: CapsDataset):
"""Checks that no intersection exist between the participants used for training and those used for testing."""

df_train = self.reader.get_group_df("train+validation")
df_train_val = self.reader.get_train_val_df()
df_test = dataset_test.df

participants_train = set(df_train.participant_id.values)
participants_test = set(df_test.participant_id.values)
participants_train = set(df_train_val[PARTICIPANT_ID].values)
participants_test = set(df_test[PARTICIPANT_ID].values)
intersection = participants_test & participants_train

if len(intersection) > 0:
Expand All @@ -62,12 +79,6 @@ def _check_leakage(self, dataset_test: CapsDataset):
f"{intersection}."
)

def _write_data_group(self):
"""Check that a data_group is not already written and writes the characteristics of the data group
(TSV file with a list of participant / session + JSON file containing the CAPS and the preprocessing).
"""
pass

def test(self):
"""Computes the predictions and evaluation metrics."""
pass
Expand Down

0 comments on commit ddf5fdc

Please sign in to comment.