Skip to content

Commit

Permalink
New metrics module based on MONAI (#654)
Browse files Browse the repository at this point in the history
* config classes for metrics

* factory for config classes

* factory to get metrics from monai

* unittests
  • Loading branch information
thibaultdvx authored Sep 24, 2024
1 parent 048ca39 commit 17ef236
Show file tree
Hide file tree
Showing 20 changed files with 1,155 additions and 0 deletions.
2 changes: 2 additions & 0 deletions clinicadl/monai_metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .config import ImplementedMetrics, MetricConfig, create_metric_config
from .factory import get_metric, loss_to_metric
3 changes: 3 additions & 0 deletions clinicadl/monai_metrics/config/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .base import MetricConfig
from .enum import ImplementedMetrics
from .factory import create_metric_config
40 changes: 40 additions & 0 deletions clinicadl/monai_metrics/config/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from abc import ABC, abstractmethod
from typing import Union

from pydantic import (
BaseModel,
ConfigDict,
computed_field,
field_validator,
)

from clinicadl.utils.factories import DefaultFromLibrary

from .enum import Reduction


class MetricConfig(BaseModel, ABC):
"""Base config class to configure metrics."""

reduction: Union[Reduction, DefaultFromLibrary] = DefaultFromLibrary.YES
get_not_nans: bool = False
include_background: Union[bool, DefaultFromLibrary] = DefaultFromLibrary.YES
# pydantic config
model_config = ConfigDict(
validate_assignment=True,
use_enum_values=True,
validate_default=True,
)

@computed_field
@property
@abstractmethod
def metric(self) -> str:
"""The name of the metric."""

@field_validator("get_not_nans", mode="after")
@classmethod
def validator_get_not_nans(cls, v):
assert not v, "get_not_nans not supported in ClinicaDL. Please set to False."

return v
74 changes: 74 additions & 0 deletions clinicadl/monai_metrics/config/classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from abc import ABC, abstractmethod
from typing import Type, Union

from pydantic import computed_field

from clinicadl.utils.factories import DefaultFromLibrary

from .base import MetricConfig
from .enum import Average, ConfusionMatrixMetric

__all__ = [
"ROCAUCConfig",
"create_confusion_matrix_config",
]


# TODO : AP is missing
class ROCAUCConfig(MetricConfig):
"Config class for ROC AUC."

average: Union[Average, DefaultFromLibrary] = DefaultFromLibrary.YES

@computed_field
@property
def metric(self) -> str:
"""The name of the metric."""
return "ROCAUCMetric"


class ConfusionMatrixMetricConfig(MetricConfig, ABC):
"Config class for metrics derived from the confusion matrix."

compute_sample: Union[bool, DefaultFromLibrary] = DefaultFromLibrary.YES

@computed_field
@property
def metric(self) -> str:
"""The name of the metric."""
return "ConfusionMatrixMetric"

@computed_field
@property
@abstractmethod
def metric_name(self) -> str:
"""The name of the metric computed from the confusion matrix."""


def create_confusion_matrix_config(
metric_name: ConfusionMatrixMetric,
) -> Type[ConfusionMatrixMetricConfig]:
"""
Builds a config class for a specific metric computed from the confusion matrix."
Parameters
----------
metric_name : ConfusionMatrixMetric
The metric name (e.g. 'f1 score', 'accuracy', etc.).
Returns
-------
Type[ConfusionMatrixMetricConfig]
The config class.
"""

class ConfusionMatrixMetricSubConfig(ConfusionMatrixMetricConfig):
"A sub config class for a specific metric computed from the confusion matrix."

@computed_field
@property
def metric_name(self) -> str:
"""The name of the metric computed from the confusion matrix."""
return metric_name

return ConfusionMatrixMetricSubConfig
100 changes: 100 additions & 0 deletions clinicadl/monai_metrics/config/enum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from enum import Enum


class ImplementedMetrics(str, Enum):
"""Implemented metrics in ClinicaDL."""

LOSS = "Loss"

RECALL = "Recall"
SPECIFICITY = "Specificity"
PRECISION = "Precision"
NPV = "Negative Predictive Value"
F1 = "F1 score"
BALANCED_ACC = "Balanced Accuracy"
ACC = "Accuracy"
MARKEDNESS = "Markedness"
MCC = "Matthews Correlation Coefficient"
ROC_AUC = "ROCAUC"

MSE = "MSE"
MAE = "MAE"
RMSE = "RMSE"
PSNR = "PSNR"
SSIM = "SSIM"
MS_SSIM = "Multi-scale SSIM"

DICE = "Dice"
GENERALIZED_DICE = "Generalized Dice"
IOU = "IoU"
SURF_DIST = "Surface distance"
HAUSDORFF = "Hausdorff distance"
SURF_DICE = "Surface Dice"

MMD = "MMD"

@classmethod
def _missing_(cls, value):
raise ValueError(
f"{value} is not implemented. Implemented metrics are: "
+ ", ".join([repr(m.value) for m in cls])
)


class Reduction(str, Enum):
"""Supported reduction for the metrics."""

MEAN = "mean"
SUM = "sum"


class GeneralizedDiceScoreReduction(str, Enum):
"""Supported reduction for GeneralizedDiceScore."""

MEAN = "mean_batch"
SUM = "sum_batch"


class Average(str, Enum):
"""Supported averaging method for ROCAUCMetric."""

MACRO = "macro"
WEIGHTED = "weighted"
MICRO = "micro"


class ConfusionMatrixMetric(str, Enum):
"""Supported metrics related to confusion matrix (in the format accepted by MONAI)."""

RECALL = "recall"
SPECIFICITY = "specificity"
PRECISION = "precision"
NPV = "negative predictive value"
F1 = "f1 score"
BALANCED_ACC = "balanced accuracy"
ACC = "accuracy"
MARKEDNESS = "markedness"
MCC = "matthews correlation coefficient"


class DistanceMetric(str, Enum):
"Supported distances."

L2 = "euclidean"
L1 = "taxicab"
LINF = "chessboard"


class Kernel(str, Enum):
"Supported kernel for SSIMMetric."

GAUSSIAN = "gaussian"
UNIFORM = "uniform"


class WeightType(str, Enum):
"Supported weight types for GeneralizedDiceScore."

SQUARE = "square"
SIMPLE = "simple"
UNIFORM = "uniform"
77 changes: 77 additions & 0 deletions clinicadl/monai_metrics/config/factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from typing import Type, Union

from .base import MetricConfig
from .classification import *
from .enum import ConfusionMatrixMetric, ImplementedMetrics
from .generation import *
from .reconstruction import *
from .regression import *
from .segmentation import *


def create_metric_config(
metric: Union[str, ImplementedMetrics],
) -> Type[MetricConfig]:
"""
A factory function to create a config class suited for the metric.
Parameters
----------
metric : Union[str, ImplementedMetrics]
The name of the metric.
Returns
-------
Type[MetricConfig]
The config class.
Raises
------
ValueError
When `metric`does not correspond to any supported metric.
ValueError
When `metric` is `Loss`.
"""
metric = ImplementedMetrics(metric)
if metric == ImplementedMetrics.LOSS:
raise ValueError(
"To use the loss as a metric, please use directly clinicadl.metrics.loss_to_metric."
)

# special cases
if metric == ImplementedMetrics.MS_SSIM:
return MultiScaleSSIMConfig
if metric == ImplementedMetrics.MMD:
return MMDMetricConfig

try:
metric = ConfusionMatrixMetric(metric.lower())
return create_confusion_matrix_config(metric)
except ValueError:
pass

# "normal" cases:
try:
config = _get_config(metric)
except KeyError:
config = _get_config(metric.title().replace(" ", ""))

return config


def _get_config(name: str) -> Type[MetricConfig]:
"""
Tries to get a config class associated to the name.
Parameters
----------
name : str
The name of the metric.
Returns
-------
Type[MetricConfig]
The config class.
"""
config_name = "".join([name, "Config"])
return globals()[config_name]
17 changes: 17 additions & 0 deletions clinicadl/monai_metrics/config/generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from pydantic import PositiveFloat, computed_field

from .base import MetricConfig

__all__ = ["MMDMetricConfig"]


class MMDMetricConfig(MetricConfig):
"Config class for MMD metric."

kernel_bandwidth: PositiveFloat = 1.0

@computed_field
@property
def metric(self) -> str:
"""The name of the metric."""
return "MMDMetric"
Loading

0 comments on commit 17ef236

Please sign in to comment.