Skip to content

Commit

Permalink
test for some change in splitter
Browse files Browse the repository at this point in the history
  • Loading branch information
camillebrianceau committed Oct 10, 2024
1 parent 550ead2 commit feff9ec
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 135 deletions.
5 changes: 5 additions & 0 deletions clinicadl/predictor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,8 @@ def adapt_with_maps_manager_info(self, maps_manager: MapsManager):
size_reduction=maps_manager.size_reduction,
size_reduction_factor=maps_manager.size_reduction_factor,
)

if self.split.split is None and self.split.n_splits == 0:
from clinicadl.splitter.split_utils import find_splits

self.split.split = find_splits(self.maps_manager.maps_dir)
145 changes: 18 additions & 127 deletions clinicadl/predictor/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,77 +51,30 @@ def __init__(self, _config: Union[PredictConfig, InterpretConfig]) -> None:
from clinicadl.splitter.config import SplitterConfig
from clinicadl.splitter.splitter import Splitter

tmp = _config.data.model_dump(
self.maps_manager = MapsManager(_config.maps_manager.maps_dir)
self._config.adapt_with_maps_manager_info(self.maps_manager)

tmp = self._config.data.model_dump(
exclude=set(["preprocessing_dict", "mode", "caps_dict"])
)
tmp.update(_config.split.model_dump())
tmp.update(_config.validation.model_dump())
tmp.update(self._config.split.model_dump())
tmp.update(self._config.validation.model_dump())
self.splitter = Splitter(SplitterConfig(**tmp))
self.maps_manager = MapsManager(_config.maps_manager.maps_dir)
self._config.adapt_with_maps_manager_info(self.maps_manager)

def predict(
self,
label_code: Union[str, dict[str, int]] = "default",
):
"""Performs the prediction task on a subset of caps_directory defined in a TSV file.
Parameters
----------
data_group : str
name of the data group tested.
caps_directory : Path (optional, default=None)
path to the CAPS folder. For more information please refer to
[clinica documentation](https://aramislab.paris.inria.fr/clinica/docs/public/latest/CAPS/Introduction/).
Default will load the value of an existing data group
tsv_path : Path (optional, default=None)
path to a TSV file containing the list of participants and sessions to test.
Default will load the DataFrame of an existing data group
split_list : List[int] (optional, default=None)
list of splits to test. Default perform prediction on all splits available.
selection_metrics : List[str] (optional, default=None)
list of selection metrics to test.
Default performs the prediction on all selection metrics available.
multi_cohort : bool (optional, default=False)
If True considers that tsv_path is the path to a multi-cohort TSV.
diagnoses : List[str] (optional, default=())
List of diagnoses to load if tsv_path is a split_directory.
Default uses the same as in training step.
use_labels : bool (optional, default=True)
If True, the labels must exist in test meta-data and metrics are computed.
batch_size : int (optional, default=None)
If given, sets the value of batch_size, else use the same as in training step.
n_proc : int (optional, default=None)
If given, sets the value of num_workers, else use the same as in training step.
gpu : bool (optional, default=None)
If given, a new value for the device of the model will be computed.
amp : bool (optional, default=False)
If enabled, uses Automatic Mixed Precision (requires GPU usage).
overwrite : bool (optional, default=False)
If True erase the occurrences of data_group.
label : str (optional, default=None)
Target label used for training (if network_task in [`regression`, `classification`]).
label_code : Optional[Dict[str, int]] (optional, default="default")
dictionary linking the target values to a node number.
save_tensor : bool (optional, default=False)
If true, save the tensor predicted for reconstruction task
save_nifti : bool (optional, default=False)
If true, save the nifti associated to the prediction for reconstruction task.
save_latent_tensor : bool (optional, default=False)
If true, save the tensor from the latent space for reconstruction task.
skip_leak_check : bool (optional, default=False)
If true, skip the leak check (not recommended).
Examples
--------
>>> _input_
_output_
"""
"""Performs the prediction task on a subset of caps_directory defined in a TSV file."""

group_df = self._config.data.create_groupe_df()
self._check_data_group(group_df)
criterion = get_criterion(
self.maps_manager.network_task, self.maps_manager.loss
)

print(f"enter split iterato: {self.splitter.split_iterator()}")

for split in self.splitter.split_iterator():
logger.info(f"Prediction of split {split}")
group_df, group_parameters = self.get_group_info(
Expand All @@ -142,16 +95,19 @@ def predict(
)
else:
split_selection_metrics = self._config.validation.selection_metrics
print(f" split selection metrics : {split_selection_metrics}")
for selection in split_selection_metrics:
tsv_dir = (
self.maps_manager.maps_path
/ f"split-{split}"
/ f"best-{selection}"
/ self._config.maps_manager.data_group
)
print(f"tsv_dir: {tsv_dir}")
tsv_pattern = f"{self._config.maps_manager.data_group}*.tsv"
for tsv_file in tsv_dir.glob(tsv_pattern):
tsv_file.unlink()
print("out boucle")
self._config.data.check_label(self.maps_manager.label)
if self.maps_manager.multi_network:
for network in range(self.maps_manager.num_networks):
Expand Down Expand Up @@ -246,6 +202,7 @@ def _predict_single(
)
if self._config.maps_manager.save_tensor:
logger.debug("Saving tensors")
print("save_tensor")
self._compute_output_tensors(
maps_manager=self.maps_manager,
dataset=data_test,
Expand Down Expand Up @@ -424,59 +381,6 @@ def _compute_output_nifti(
def interpret(self):
"""Performs the interpretation task on a subset of caps_directory defined in a TSV file.
The mean interpretation is always saved, to save the individual interpretations set save_individual to True.
Parameters
----------
data_group : str
Name of the data group interpreted.
name : str
Name of the interpretation procedure.
method : str
Method used for extraction (ex: gradients, grad-cam...).
caps_directory : Path (optional, default=None)
Path to the CAPS folder. For more information please refer to
[clinica documentation](https://aramislab.paris.inria.fr/clinica/docs/public/latest/CAPS/Introduction/).
Default will load the value of an existing data group.
tsv_path : Path (optional, default=None)
Path to a TSV file containing the list of participants and sessions to test.
Default will load the DataFrame of an existing data group.
split_list : list[int] (optional, default=None)
List of splits to interpret. Default perform interpretation on all splits available.
selection_metrics : list[str] (optional, default=None)
List of selection metrics to interpret.
Default performs the interpretation on all selection metrics available.
multi_cohort : bool (optional, default=False)
If True considers that tsv_path is the path to a multi-cohort TSV.
diagnoses : list[str] (optional, default=())
List of diagnoses to load if tsv_path is a split_directory.
Default uses the same as in training step.
target_node : int (optional, default=0)
Node from which the interpretation is computed.
save_individual : bool (optional, default=False)
If True saves the individual map of each participant / session couple.
batch_size : int (optional, default=None)
If given, sets the value of batch_size, else use the same as in training step.
n_proc : int (optional, default=None)
If given, sets the value of num_workers, else use the same as in training step.
gpu : bool (optional, default=None)
If given, a new value for the device of the model will be computed.
amp : bool (optional, default=False)
If enabled, uses Automatic Mixed Precision (requires GPU usage).
overwrite : bool (optional, default=False)
If True erase the occurrences of data_group.
overwrite_name : bool (optional, default=False)
If True erase the occurrences of name.
level : int (optional, default=None)
Layer number in the convolutional part after which the feature map is chosen.
save_nifti : bool (optional, default=False)
If True, save the interpretation map in nifti format.
Raises
------
NotImplementedError
If the method is not implemented
NotImplementedError
If the interpretaion of multi network is asked
MAPSError
If the interpretation has already been determined.
"""
assert isinstance(self._config, InterpretConfig)

Expand Down Expand Up @@ -603,22 +507,6 @@ def _check_data_group(
Parameters
----------
data_group : str
name of the data group
caps_directory : str (optional, default=None)
input CAPS directory
df : pd.DataFrame (optional, default=None)
Table of participant_id / session_id of the data group
multi_cohort : bool (optional, default=False)
indicates if the input data comes from several CAPS
overwrite : bool (optional, default=False)
If True former definition of data group is erased
label : str (optional, default=None)
label name if applicable
split_list : list[int] (optional, default=None)
_description_
skip_leak_check : bool (optional, default=False)
_description_
Raises
------
Expand All @@ -636,13 +524,15 @@ def _check_data_group(
/ self._config.maps_manager.data_group
)
logger.debug(f"Group path {group_dir}")
print(f"group_dir: {group_dir}")
if group_dir.is_dir(): # Data group already exists
print("is dir")
if self._config.maps_manager.overwrite:
if self._config.maps_manager.data_group in ["train", "validation"]:
raise MAPSError("Cannot overwrite train or validation data group.")
else:
# if not split_list:
# split_list = self.maps_manager.find_splits()
if not self._config.split.split:
self._config.split.split = self.maps_manager.find_splits()
assert self._config.split
for split in self._config.split.split:
selection_metrics = find_selection_metrics(
Expand Down Expand Up @@ -1155,6 +1045,7 @@ def _test_loader(

if cluster.master:
# Replace here
print("before saving")
maps_manager._mode_level_to_tsv(
prediction_df,
metrics,
Expand Down
1 change: 1 addition & 0 deletions clinicadl/predictor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def get_prediction(
prediction_dir = (
maps_path / f"split-{split}" / f"best-{selection_metric}" / data_group
)
print(prediction_dir)
if not prediction_dir.is_dir():
raise MAPSError(
f"No prediction corresponding to data group {data_group} was found."
Expand Down
15 changes: 9 additions & 6 deletions clinicadl/splitter/splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class Splitter:
def __init__(
self,
config: SplitterConfig,
split_list: Optional[List[int]] = None,
# split_list: Optional[List[int]] = None,
):
"""_summary_
Expand All @@ -29,7 +29,7 @@ def __init__(
"""
self.config = config
self.split_list = split_list
# self.config.split.split = split_list

self.caps_dict = self.config.data.caps_dict # TODO : check if useful ?

Expand All @@ -38,10 +38,10 @@ def max_length(self) -> int:
return self.config.split.n_splits

def __len__(self):
if not self.split_list:
if not self.config.split.split:
return self.config.split.n_splits
else:
return len(self.split_list)
return len(self.config.split.split)

@property
def allowed_splits_list(self):
Expand Down Expand Up @@ -203,10 +203,13 @@ def _get_tsv_paths(self, cohort_path, *args) -> Tuple[Path, Path]:

def split_iterator(self):
"""Returns an iterable to iterate on all splits wanted."""
if not self.split_list:
print(self.config.split.split)
print(self.config.split.n_splits)
print(self.config.split.split)
if not self.config.split.split:
return range(self.config.split.n_splits)
else:
return self.split_list
return self.config.split.split

def _check_item(self, item):
if item not in self.allowed_splits_list:
Expand Down
2 changes: 1 addition & 1 deletion clinicadl/tmp_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def adapt_cross_val_with_maps_manager_info(
): # maps_manager is of type MapsManager but need to be in a MapsConfig type in the future
# TEMPORARY
if not self.split:
self.split = find_splits(maps_manager.maps_path, maps_manager.split_name)
self.split = find_splits(maps_manager.maps_path)
logger.debug(f"List of splits {self.split}")

def create_groupe_df(self):
Expand Down
2 changes: 1 addition & 1 deletion clinicadl/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def resume(self, splits: List[int]) -> None:
# TODO : check these two lines. Why do we need a split_manager?

splitter_config = SplitterConfig(**self.config.get_dict())
split_manager = Splitter(splitter_config, split_list=splits)
split_manager = Splitter(splitter_config)

split_iterator = split_manager.split_iterator()
###
Expand Down

0 comments on commit feff9ec

Please sign in to comment.