Skip to content

Commit

Permalink
metric module back
Browse files Browse the repository at this point in the history
  • Loading branch information
camillebrianceau committed Sep 3, 2024
1 parent 50efcbd commit 2dd1b09
Showing 1 changed file with 77 additions and 69 deletions.
146 changes: 77 additions & 69 deletions clinicadl/metrics/metric_module.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from logging import getLogger
from typing import Callable, Dict, List, Union
from typing import Dict, List

import numpy as np
from scipy.stats import bootstrap
from sklearn.utils import resample

metric_optimum = {
"MAE": "min",
Expand Down Expand Up @@ -44,78 +44,86 @@ def __init__(self, metrics, n_classes=2):
f"The metric {metric} is not implemented in the module."
)

def apply(
self, y_: List[float], y_pred_: List[float], report_ci: bool
) -> Dict[str, Union[float, List[float]]]:
"""Calculate metrics based on the list of true and predicted labels."""
if not y_ or not y_pred_:
return {}

y = np.array(y_)
y_pred = np.array(y_pred_)
results = {}

for metric_name, metric_fn in self.metrics.items():
class_numbers = (
range(self.n_classes)
if "class_number" in metric_fn.__code__.co_varnames
else [0]
)
self._calculate_metric(
y, y_pred, metric_name, metric_fn, class_numbers, results, report_ci
)
def apply(self, y, y_pred, report_ci):
"""
This is a function to calculate the different metrics based on the list of true label and predicted label
return results
Args:
y (List): list of labels
y_pred (List): list of predictions
report_ci (bool) : If True confidence intervals are reported
Returns:
(Dict[str:float]) metrics results
"""
if y is not None and y_pred is not None:
results = dict()
y = np.array(y)
y_pred = np.array(y_pred)

def _calculate_metric(
self,
y: np.ndarray,
y_pred: np.ndarray,
metric_name: str,
metric_fn: Callable,
class_numbers: range,
results: Dict[str, Union[float, List[float]]],
report_ci: bool,
):
"""Helper function to calculate metrics and optionally confidence intervals."""
for class_number in class_numbers:
metric_result = metric_fn(y, y_pred, class_number)
metric_key = (
f"{metric_name}-{class_number}"
if len(class_numbers) > 1
else metric_name
)
if report_ci:
from scipy.stats import bootstrap

metric_names = ["Metrics"]
metric_values = ["Values"] # Collect metric values
lower_ci_values = ["Lower bound CI"] # Collect lower CI values
upper_ci_values = ["Upper bound CI"] # Collect upper CI values
se_values = ["SE"] # Collect standard error values

if report_ci and len(y) >= 2:
lower_ci, upper_ci, se = self._calculate_confidence_interval(
y, y_pred, metric_fn, class_number
for metric_key, metric_fn in self.metrics.items():
metric_args = list(metric_fn.__code__.co_varnames)

class_numbers = (
range(self.n_classes)
if "class_number" in metric_args and self.n_classes > 2
else [0]
)
results.setdefault("Metric_names", []).append(metric_key)
results.setdefault("Metric_values", []).append(metric_result)
results.setdefault("Lower_CI", []).append(lower_ci)
results.setdefault("Upper_CI", []).append(upper_ci)
results.setdefault("SE", []).append(se)
else:
results[metric_key] = metric_result

@staticmethod
def _calculate_confidence_interval(
y: np.ndarray, y_pred: np.ndarray, metric_fn: Callable, class_number: int
) -> tuple:
"""Calculate confidence intervals and standard error for a metric."""
res = bootstrap(
(y, y_pred),
lambda y, y_pred: metric_fn(y, y_pred, class_number),
n_resamples=3000,
confidence_level=0.95,
method="percentile",
paired=True,
)
return (
res.confidence_interval.low,
res.confidence_interval.high,
res.standard_error,
)
for class_number in class_numbers:
metric_result = metric_fn(y, y_pred, class_number)

# Compute confidence intervals only if there are at least two samples in the data.
if report_ci and len(y) >= 2:
res = bootstrap(
(y, y_pred),
lambda y, y_pred: metric_fn(y, y_pred, class_number),
n_resamples=3000,
confidence_level=0.95,
method="percentile",
paired=True,
)

lower_ci, upper_ci = res.confidence_interval
standard_error = res.standard_error

metric_values.append(metric_result)
lower_ci_values.append(lower_ci)
upper_ci_values.append(upper_ci)
se_values.append(standard_error)
metric_names.append(
f"{metric_key}-{class_number}"
if len(class_numbers) > 1
else f"{metric_key}"
)
else:
results[
(
f"{metric_key}-{class_number}"
if len(class_numbers) > 1
else f"{metric_key}"
)
] = metric_result

if report_ci:
# Construct the final results dictionary
results["Metric_names"] = metric_names
results["Metric_values"] = metric_values
results["Lower_CI"] = lower_ci_values
results["Upper_CI"] = upper_ci_values
results["SE"] = se_values
else:
results = dict()

return results

@staticmethod
def compute_mae(y, y_pred, *args):
Expand Down

0 comments on commit 2dd1b09

Please sign in to comment.