Skip to content

Commit

Permalink
some fixes and enhancements
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasGensollen committed Jul 24, 2024
1 parent 4420d3c commit 6da3b92
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 120 deletions.
164 changes: 78 additions & 86 deletions clinica/iotools/utils/pipeline_handling.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
"""Methods to find information in the different pipelines of Clinica."""
from enum import Enum
from functools import partial, reduce
from os import PathLike
from pathlib import Path
from typing import Callable, List, Optional, Tuple, Union
from typing import Callable, Iterable, List, Optional, Tuple, Union

import pandas as pd

__all__ = [
"PipelineNameForMetricExtraction",
"pipeline_metric_extractor_factory",
]


class PipelineNameForMetricExtraction(str, Enum):
"""Pipelines for which a metric extractor has been implemented."""
Expand All @@ -19,22 +23,22 @@ class PipelineNameForMetricExtraction(str, Enum):


def _extract_metrics_from_pipeline(
caps_dir: PathLike,
caps_dir: Path,
df: pd.DataFrame,
metrics: List[str],
metrics: Iterable[str],
pipeline: PipelineNameForMetricExtraction,
atlas_selection: Optional[List[str]] = None,
group_selection: Optional[List[str]] = None,
atlas_selection: Optional[Iterable[str]] = None,
group_selection: Optional[Iterable[str]] = None,
pvc_restriction: Optional[bool] = None,
tracers_selection: Optional[List[str]] = None,
tracers_selection: Optional[Iterable[str]] = None,
) -> Tuple[pd.DataFrame, pd.DataFrame]:
"""Extract and merge the data of the provided pipeline into the
merged dataframe already containing the BIDS information.
Parameters
----------
caps_dir : PathLike
Path to the CAPS directory.
caps_dir : Path
The path to the CAPS directory.
df : pd.DataFrame
DataFrame containing the BIDS information.
Expand Down Expand Up @@ -72,7 +76,6 @@ def _extract_metrics_from_pipeline(
summary_df : pd.DataFrame
Summary DataFrame generated by function `generate_summary`.
"""
caps_dir = Path(caps_dir)
if df.index.names != ["participant_id", "session_id"]:
try:
df.set_index(
Expand Down Expand Up @@ -111,35 +114,35 @@ def _extract_metrics_from_pipeline(
return final_df, summary_df


extract_metrics_from_dwi_dti = partial(
_extract_metrics_from_dwi_dti = partial(
_extract_metrics_from_pipeline,
metrics=["FA_statistics", "MD_statistics", "RD_statistics", "AD_statistics"],
pipeline=PipelineNameForMetricExtraction.DWI_DTI,
group_selection=[""],
)


extract_metrics_from_t1_freesurfer_longitudinal = partial(
_extract_metrics_from_t1_freesurfer_longitudinal = partial(
_extract_metrics_from_pipeline,
metrics=["volume", "thickness", "segmentationVolumes"],
pipeline=PipelineNameForMetricExtraction.T1_FREESURFER_LONGI,
group_selection=[""],
)

extract_metrics_from_t1_freesurfer = partial(
_extract_metrics_from_t1_freesurfer = partial(
_extract_metrics_from_pipeline,
metrics=["thickness", "segmentationVolumes"],
pipeline=PipelineNameForMetricExtraction.T1_FREESURFER,
group_selection=[""],
)

extract_metrics_from_t1_volume = partial(
_extract_metrics_from_t1_volume = partial(
_extract_metrics_from_pipeline,
metrics=["statistics"],
pipeline=PipelineNameForMetricExtraction.T1_VOLUME,
)

extract_metrics_from_pet_volume = partial(
_extract_metrics_from_pet_volume = partial(
_extract_metrics_from_pipeline,
metrics=["statistics"],
pipeline=PipelineNameForMetricExtraction.PET_VOLUME,
Expand All @@ -150,23 +153,24 @@ def pipeline_metric_extractor_factory(
name: Union[str, PipelineNameForMetricExtraction],
) -> Callable:
"""Factory returning a metric extractor given its name."""
if isinstance(name, str):
name = PipelineNameForMetricExtraction(name)
name = PipelineNameForMetricExtraction(name)
if name == PipelineNameForMetricExtraction.T1_VOLUME:
return extract_metrics_from_t1_volume
return _extract_metrics_from_t1_volume
if name == PipelineNameForMetricExtraction.PET_VOLUME:
return extract_metrics_from_pet_volume
return _extract_metrics_from_pet_volume
if name == PipelineNameForMetricExtraction.T1_FREESURFER:
return extract_metrics_from_t1_freesurfer
return _extract_metrics_from_t1_freesurfer
if name == PipelineNameForMetricExtraction.T1_FREESURFER_LONGI:
return extract_metrics_from_t1_freesurfer_longitudinal
return _extract_metrics_from_t1_freesurfer_longitudinal
if name == PipelineNameForMetricExtraction.DWI_DTI:
return extract_metrics_from_dwi_dti
return _extract_metrics_from_dwi_dti


def _check_group_selection(
caps_dir: Path, group_selection: Optional[List[str]] = None
) -> List[str]:
caps_dir: Path, group_selection: Optional[Iterable[str]] = None
) -> Iterable[str]:
if group_selection == [""]:
return group_selection
if group_selection is None:
return [f.name for f in (caps_dir / "groups").iterdir()]
return [f"group-{group}" for group in group_selection]
Expand All @@ -176,46 +180,45 @@ def _get_records(
caps_dir: Path,
df: pd.DataFrame,
pipeline: PipelineNameForMetricExtraction,
metrics: List[str],
group_selection: List[str],
atlas_selection: Optional[List[str]] = None,
metrics: Iterable[str],
group_selection: Iterable[str],
atlas_selection: Optional[Iterable[str]] = None,
pvc_restriction: Optional[bool] = None,
tracers_selection: Optional[List[str]] = None,
tracers_selection: Optional[Iterable[str]] = None,
) -> List[dict]:
"""Returns a list of dictionaries corresponding to the dataframe rows of the pipeline dataframe."""
from clinica.utils.stream import cprint

subjects_dir = caps_dir / "subjects"
records = []
subjects_dir = caps_dir / "subjects"
for participant_id, session_id in df.index.values:
mod_path = _get_modality_path(
subjects_dir / participant_id / session_id, pipeline
)
if mod_path is None:
if (
mod_path := _get_modality_path(
subjects_dir / participant_id / session_id, pipeline
)
) is None or not mod_path.exists():
cprint(
f"Could not find a longitudinal dataset for participant {participant_id} {session_id}",
lvl="warning",
)
continue
if not mod_path.exists():
continue
group_paths = [
for group_path in (
mod_path / group for group in group_selection if (mod_path / group).exists()
]
for group_path in group_paths:
for metric in metrics:
records.append(
_get_single_record(
group_path,
metric,
participant_id,
session_id,
pipeline,
atlas_selection,
pvc_restriction,
tracers_selection,
)
):
records_of_group = [
_get_single_record(
group_path,
metric,
participant_id,
session_id,
pipeline,
atlas_selection,
pvc_restriction,
tracers_selection,
)
for metric in metrics
]
records.append(reduce(lambda x, y: {**x, **y}, records_of_group))
return records


Expand All @@ -227,8 +230,7 @@ def _get_modality_path(
return ses_path / "dwi" / "dti_based_processing" / "atlas_statistics"
if pipeline == PipelineNameForMetricExtraction.T1_FREESURFER_LONGI:
mod_path = ses_path / "t1"
long_ids = list(mod_path.glob("long*"))
if len(long_ids) == 0:
if len(long_ids := list(mod_path.glob("long*"))) == 0:
return None
return (
mod_path
Expand All @@ -250,15 +252,16 @@ def _get_single_record(
participant_id: str,
session_id: str,
pipeline: PipelineNameForMetricExtraction,
atlas_selection: Optional[List[str]] = None,
atlas_selection: Optional[Iterable[str]] = None,
pvc_restriction: Optional[bool] = None,
tracers_selection: Optional[List[str]] = None,
tracers_selection: Optional[Iterable[str]] = None,
) -> dict:
"""Get a single record (dataframe row) for a given participant, session, group, and metric."""
atlas_paths = _get_atlas_paths(group_path, participant_id, session_id, metric)
atlases = [
atlas_path
for atlas_path in atlas_paths
for atlas_path in _get_atlas_paths(
group_path, participant_id, session_id, metric
)
if not _skip_atlas(
atlas_path, pipeline, pvc_restriction, tracers_selection, atlas_selection
)
Expand Down Expand Up @@ -301,8 +304,8 @@ def _skip_atlas(
atlas_path: Path,
pipeline: PipelineNameForMetricExtraction,
pvc_restriction: Optional[bool] = None,
tracers_selection: Optional[List[str]] = None,
atlas_selection: Optional[List[str]] = None,
tracers_selection: Optional[Iterable[str]] = None,
atlas_selection: Optional[Iterable[str]] = None,
) -> bool:
"""Returns whether the atlas provided through its path should be skipped or not."""
if not atlas_path.exists():
Expand All @@ -318,7 +321,7 @@ def _skip_atlas_based_on_pipeline(
atlas_path: Path,
pipeline: PipelineNameForMetricExtraction,
pvc_restriction: Optional[bool] = None,
tracers_selection: Optional[List[str]] = None,
tracers_selection: Optional[Iterable[str]] = None,
) -> bool:
"""Returns True if the atlas provided through its path should be skipped based on the pipeline name."""
if pipeline == PipelineNameForMetricExtraction.T1_FREESURFER_LONGI:
Expand All @@ -344,14 +347,13 @@ def _skip_atlas_based_on_pipeline(
def _skip_atlas_based_on_selection(
atlas_path: Path,
pipeline: PipelineNameForMetricExtraction,
atlas_selection: Optional[List[str]] = None,
atlas_selection: Optional[Iterable[str]] = None,
) -> bool:
"""Returns True if the atlas provided through its path should be skipped based on the user-provided selection."""
# try:
atlas_name = _get_atlas_name(atlas_path, pipeline)
# except ValueError:
# return True
return atlas_selection is not None and atlas_name not in atlas_selection
return (
atlas_selection is not None
and _get_atlas_name(atlas_path, pipeline) not in atlas_selection
)


def _get_records_for_atlas(
Expand All @@ -362,7 +364,7 @@ def _get_records_for_atlas(
) -> dict:
atlas_df = pd.read_csv(atlas_path, sep="\t")
label_list = _get_label_list(atlas_path, metric, pipeline, group)
key = "label_value" if "freesurfer" in pipeline else "mean_scalar"
key = "label_value" if "freesurfer" in pipeline.value else "mean_scalar"
values = atlas_df[key].to_numpy()
return {label: value for label, value in zip(label_list, values)}

Expand All @@ -373,15 +375,13 @@ def _get_label_list(
"""Returns the list of labels to use in the session df depending on the
pipeline, the atlas, and the metric considered.
"""
from clinica.iotools.converter_utils import (
replace_sequence_chars,
)
from clinica.iotools.converter_utils import replace_sequence_chars

atlas_name = _get_atlas_name(atlas_path, pipeline)
atlas_df = pd.read_csv(atlas_path, sep="\t")
atlas_name = _get_atlas_name(atlas_path, pipeline)
if pipeline == PipelineNameForMetricExtraction.T1_FREESURFER:
return [
f"t1-freesurfer_atlas-{atlas_name}_ROI-{replace_sequence_chars(roi_name)}_thickness"
f"t1-freesurfer_{atlas_name}_ROI-{replace_sequence_chars(roi_name)}_{'volume' if metric == 'segmentationVolumes' else metric}"
for roi_name in atlas_df.label_name.values
]
if pipeline in (
Expand All @@ -395,18 +395,15 @@ def _get_label_list(
if "pvc-rbv" in str(atlas_path):
additional_desc += f"_pvc-rbv"
return [
f"{pipeline}_{group}_atlas-{atlas_name}{additional_desc}_ROI-{replace_sequence_chars(roi_name)}_intensity"
f"{pipeline}_{group}_{atlas_name}{additional_desc}_ROI-{replace_sequence_chars(roi_name)}_intensity"
for roi_name in atlas_df.label_name.values
]
if pipeline == PipelineNameForMetricExtraction.DWI_DTI:
prefix = "dwi-dti_"
metric = metric.rstrip("_statistics")
else:
prefix = "t1-fs-long_"
return [
prefix + metric + "_atlas-" + atlas_name + "_" + x
for x in atlas_df.label_name.values
]
return [prefix + metric + atlas_name + "_" + x for x in atlas_df.label_name.values]


def _get_atlas_name(atlas_path: Path, pipeline: PipelineNameForMetricExtraction) -> str:
Expand All @@ -428,9 +425,12 @@ def _get_atlas_name(atlas_path: Path, pipeline: PipelineNameForMetricExtraction)
def _infer_atlas_name(splitter: str, atlas_path: Path) -> str:
try:
assert splitter in atlas_path.stem
return atlas_path.stem.split(splitter)[-1].split("_")[0]
return f"atlas-{atlas_path.stem.split(splitter)[-1].split('_')[0]}"
except Exception:
raise ValueError(f"Unable to infer the atlas name from {atlas_path}.")
if "segmentationVolumes" in atlas_path.stem:
return "segmentation-volumes"
else:
raise ValueError(f"Unable to infer the atlas name from {atlas_path}.")


def _generate_summary(
Expand Down Expand Up @@ -515,11 +515,3 @@ def _generate_summary(

summary_df = summary_df.replace("_", "n/a")
return summary_df


class DatasetError(Exception):
def __init__(self, name):
self.name = name

def __str__(self):
return repr("Bad format for the sessions: " + self.name)
Loading

0 comments on commit 6da3b92

Please sign in to comment.