-
Notifications
You must be signed in to change notification settings - Fork 55
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
1a7b08b
commit c73b884
Showing
63 changed files
with
272 additions
and
72 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,8 @@ | |
JSON = ".json" | ||
LOG = ".log" | ||
PT = ".pt" | ||
PTH = ".pth" | ||
TAR = ".tar" | ||
NII = ".nii" | ||
GZ = ".gz" | ||
NII_GZ = NII + GZ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import json | ||
from pathlib import Path | ||
from typing import List, Optional | ||
|
||
import pandas as pd | ||
|
||
from clinicadl.data.datasets import CapsDataset | ||
from clinicadl.data.utils import df_to_tsv, tsv_to_df | ||
from clinicadl.dictionary.suffixes import JSON, TSV | ||
from clinicadl.dictionary.words import DATA, GROUPS, MAPS, SPLIT, TRAIN, VALIDATION | ||
from clinicadl.utils.config import ClinicaDLConfig | ||
|
||
TRAIN_VAL = [TRAIN, VALIDATION] | ||
|
||
|
||
class DataGroup(ClinicaDLConfig): | ||
maps_path: Path | ||
name: str | ||
split: Optional[int] | ||
|
||
@property | ||
def data_tsv(self) -> Path: | ||
return self.group_split_dir / (DATA + TSV) | ||
|
||
@property | ||
def maps_json(self) -> Path: | ||
return self.group_split_dir / (MAPS + JSON) | ||
|
||
@property | ||
def group_dir(self) -> Path: | ||
return self.maps_path / GROUPS / self.name | ||
|
||
@property | ||
def group_split_dir(self) -> Path: | ||
if self.name in TRAIN_VAL: | ||
return self.group_dir / (SPLIT + "-" + str(self.split)) | ||
else: | ||
return self.group_dir | ||
|
||
@property | ||
def df(self) -> pd.DataFrame: | ||
return tsv_to_df(self.data_tsv) | ||
|
||
@property | ||
def caps_dir(self) -> Path: | ||
dict_ = self._read_json() | ||
return Path(dict_["cap_directory"]) | ||
|
||
def exists(self) -> bool: | ||
return self.data_tsv.is_file() and self.maps_json.is_file() | ||
|
||
def create(self, caps_dataset: CapsDataset): | ||
df_to_tsv(self.data_tsv, caps_dataset.df) | ||
self._write_json(caps_dataset) | ||
|
||
def _write_json(self, caps_dataset: CapsDataset): | ||
dict_ = {} # TODO: to complete | ||
with self.maps_json.open(mode="w") as file: | ||
json.dump(dict_, file) | ||
|
||
def _read_json(self): | ||
if self.maps_json.is_file(): | ||
with self.maps_json.open(mode="r") as file: | ||
dict_ = json.load(file) | ||
return dict_ | ||
|
||
raise FileNotFoundError( | ||
f"Could not find the `maps.json` file for the data grou : {self.name} (in {self.maps_path})" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
from pathlib import Path | ||
from typing import Optional, Union | ||
|
||
import pandas as pd | ||
import torch | ||
|
||
from clinicadl.data.datasets import CapsDataset | ||
from clinicadl.data.utils import tsv_to_df | ||
from clinicadl.dictionary.suffixes import PTH, TAR | ||
from clinicadl.dictionary.words import BEST, GROUPS, PARTICIPANT_ID, SPLIT, TMP | ||
from clinicadl.experiment_manager.data_group import DataGroup | ||
from clinicadl.model import ClinicaDLModel | ||
from clinicadl.splitter.split import Split | ||
from clinicadl.utils.exceptions import ( | ||
ClinicaDLConfigurationError, | ||
ClinicaDLDataLeakageError, | ||
) | ||
from clinicadl.utils.typing import PathType | ||
|
||
|
||
class MapsReader: | ||
def __init__(self, maps_path: PathType) -> None: | ||
self.maps_path = Path(maps_path) | ||
|
||
def _create_data_group( | ||
self, name: str, caps_dataset: CapsDataset, split: Optional[int] = None | ||
) -> DataGroup: | ||
""" | ||
Check that a data_group is not already written and writes the characteristics of the data group | ||
(TSV file with a list of participant / session + JSON file containing the CAPS and the preprocessing). | ||
""" | ||
data_group = DataGroup(name=name, split=split, maps_path=self.maps_path) | ||
if data_group.exists(): | ||
raise ClinicaDLConfigurationError( | ||
f"Data group {data_group.name} already exists, please give another name to your data group" | ||
) | ||
|
||
data_group.create(caps_dataset) | ||
return data_group | ||
|
||
def _load_data_group(self, name: str, split: Optional[int] = None) -> DataGroup: | ||
"""creates a new data_group.""" | ||
data_group = DataGroup(name=name, split=split, maps_path=self.maps_path) | ||
if data_group.exists(): | ||
return data_group | ||
|
||
raise ClinicaDLConfigurationError( | ||
f"Could not find data group {data_group.name}" | ||
) | ||
|
||
def get_train_val_df(self): | ||
"""Loads the train and validation data groups.""" | ||
path = self.maps_path / GROUPS / "train+validation.tsv" | ||
return tsv_to_df(path) | ||
|
||
def get_model(self) -> ClinicaDLModel: | ||
return ClinicaDLModel() # type: ignore | ||
|
||
def _write_network_weights(self): | ||
"""TO COMPLETE""" | ||
pass | ||
|
||
def _write_optim_weights(self): | ||
"""TO COMPLETE""" | ||
pass | ||
|
||
def write_tensor(self): | ||
"""TO COMPLETE""" | ||
pass | ||
|
||
def split_path(self, split: int): | ||
return self.maps_path / (SPLIT + "-" + str(split)) | ||
|
||
def optimizer_path(self, split: int, resume: bool = False) -> Path: | ||
"""TO COMPLETE""" | ||
|
||
return self.split_path(split) / TMP / ("optimizer" + PTH + TAR) | ||
|
||
def checkpoint_path(self, split: int, resume: bool = False): | ||
return self.split_path(split) / TMP / ("checkpoint" + PTH + TAR) | ||
|
||
def model_path(self, split: int, metric: str): | ||
return self.split_path(split) / (BEST + "-" + metric) / ("model" + PTH + TAR) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .clinicadl_model import ClinicaDLModel |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,71 @@ | ||
from clinicadl.data.caps_dataset import CapsDataset | ||
from clinicadl.experiment_manager.experiment_manager import ExperimentManager | ||
from pathlib import Path | ||
from typing import Optional | ||
|
||
import pandas as pd | ||
import torch | ||
|
||
from clinicadl.data.datasets import CapsDataset | ||
from clinicadl.data.utils import tsv_to_df | ||
from clinicadl.dictionary.words import GROUPS, PARTICIPANT_ID | ||
from clinicadl.experiment_manager import ExperimentManager | ||
from clinicadl.experiment_manager.maps_reader import MapsReader | ||
from clinicadl.model import ClinicaDLModel | ||
from clinicadl.splitter.split import Split | ||
from clinicadl.utils.exceptions import ( | ||
ClinicaDLConfigurationError, | ||
ClinicaDLDataLeakageError, | ||
) | ||
|
||
|
||
class Predictor: | ||
def __init__(self, manager: ExperimentManager): | ||
def __init__(self, reader: MapsReader, model: ClinicaDLModel): | ||
"""TO COMPLETE""" | ||
pass | ||
self.reader = reader | ||
self.model = model | ||
|
||
def predict(self, dataset_test: CapsDataset, split: int): | ||
def predict(self, dataset_test: CapsDataset, split: Split): | ||
"""TO COMPLETE""" | ||
pass | ||
|
||
def _check_leakage(self, dataset_test: CapsDataset): | ||
"""Checks that no intersection exist between the participants used for training and those used for testing.""" | ||
|
||
df_train_val = self.reader.get_train_val_df() | ||
df_test = dataset_test.df | ||
|
||
participants_train = set(df_train_val[PARTICIPANT_ID].values) | ||
participants_test = set(df_test[PARTICIPANT_ID].values) | ||
intersection = participants_test & participants_train | ||
|
||
if len(intersection) > 0: | ||
raise ClinicaDLDataLeakageError( | ||
"Your evaluation set contains participants who were already seen during " | ||
"the training step. The list of common participants is the following: " | ||
f"{intersection}." | ||
) | ||
|
||
def test(self): | ||
"""Computes the predictions and evaluation metrics.""" | ||
pass | ||
|
||
def _test_loader(self): | ||
"""Launches the testing task on a dataset wrapped by a DataLoader and writes prediction TSV files.""" | ||
pass | ||
|
||
def _compute_latent_tensor(self): | ||
"""Compute the output tensors and saves them in the MAPS.""" | ||
pass | ||
|
||
@torch.no_grad() | ||
def _compute_output_nifti(self): | ||
"""omputes the output nifti images and saves them in the MAPS.""" | ||
pass | ||
|
||
@torch.no_grad() | ||
def _compute_output_tensors(self): | ||
"""Compute the output tensors and saves them in the MAPS.""" | ||
pass | ||
|
||
def _ensemble_prediction(self): | ||
"""Computes the results on the image-level.""" | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.