-
Notifications
You must be signed in to change notification settings - Fork 55
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
New metrics module based on MONAI (#654)
* config classes for metrics * factory for config classes * factory to get metrics from monai * unittests
- Loading branch information
1 parent
048ca39
commit 17ef236
Showing
20 changed files
with
1,155 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .config import ImplementedMetrics, MetricConfig, create_metric_config | ||
from .factory import get_metric, loss_to_metric |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .base import MetricConfig | ||
from .enum import ImplementedMetrics | ||
from .factory import create_metric_config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import Union | ||
|
||
from pydantic import ( | ||
BaseModel, | ||
ConfigDict, | ||
computed_field, | ||
field_validator, | ||
) | ||
|
||
from clinicadl.utils.factories import DefaultFromLibrary | ||
|
||
from .enum import Reduction | ||
|
||
|
||
class MetricConfig(BaseModel, ABC): | ||
"""Base config class to configure metrics.""" | ||
|
||
reduction: Union[Reduction, DefaultFromLibrary] = DefaultFromLibrary.YES | ||
get_not_nans: bool = False | ||
include_background: Union[bool, DefaultFromLibrary] = DefaultFromLibrary.YES | ||
# pydantic config | ||
model_config = ConfigDict( | ||
validate_assignment=True, | ||
use_enum_values=True, | ||
validate_default=True, | ||
) | ||
|
||
@computed_field | ||
@property | ||
@abstractmethod | ||
def metric(self) -> str: | ||
"""The name of the metric.""" | ||
|
||
@field_validator("get_not_nans", mode="after") | ||
@classmethod | ||
def validator_get_not_nans(cls, v): | ||
assert not v, "get_not_nans not supported in ClinicaDL. Please set to False." | ||
|
||
return v |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import Type, Union | ||
|
||
from pydantic import computed_field | ||
|
||
from clinicadl.utils.factories import DefaultFromLibrary | ||
|
||
from .base import MetricConfig | ||
from .enum import Average, ConfusionMatrixMetric | ||
|
||
__all__ = [ | ||
"ROCAUCConfig", | ||
"create_confusion_matrix_config", | ||
] | ||
|
||
|
||
# TODO : AP is missing | ||
class ROCAUCConfig(MetricConfig): | ||
"Config class for ROC AUC." | ||
|
||
average: Union[Average, DefaultFromLibrary] = DefaultFromLibrary.YES | ||
|
||
@computed_field | ||
@property | ||
def metric(self) -> str: | ||
"""The name of the metric.""" | ||
return "ROCAUCMetric" | ||
|
||
|
||
class ConfusionMatrixMetricConfig(MetricConfig, ABC): | ||
"Config class for metrics derived from the confusion matrix." | ||
|
||
compute_sample: Union[bool, DefaultFromLibrary] = DefaultFromLibrary.YES | ||
|
||
@computed_field | ||
@property | ||
def metric(self) -> str: | ||
"""The name of the metric.""" | ||
return "ConfusionMatrixMetric" | ||
|
||
@computed_field | ||
@property | ||
@abstractmethod | ||
def metric_name(self) -> str: | ||
"""The name of the metric computed from the confusion matrix.""" | ||
|
||
|
||
def create_confusion_matrix_config( | ||
metric_name: ConfusionMatrixMetric, | ||
) -> Type[ConfusionMatrixMetricConfig]: | ||
""" | ||
Builds a config class for a specific metric computed from the confusion matrix." | ||
Parameters | ||
---------- | ||
metric_name : ConfusionMatrixMetric | ||
The metric name (e.g. 'f1 score', 'accuracy', etc.). | ||
Returns | ||
------- | ||
Type[ConfusionMatrixMetricConfig] | ||
The config class. | ||
""" | ||
|
||
class ConfusionMatrixMetricSubConfig(ConfusionMatrixMetricConfig): | ||
"A sub config class for a specific metric computed from the confusion matrix." | ||
|
||
@computed_field | ||
@property | ||
def metric_name(self) -> str: | ||
"""The name of the metric computed from the confusion matrix.""" | ||
return metric_name | ||
|
||
return ConfusionMatrixMetricSubConfig |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
from enum import Enum | ||
|
||
|
||
class ImplementedMetrics(str, Enum): | ||
"""Implemented metrics in ClinicaDL.""" | ||
|
||
LOSS = "Loss" | ||
|
||
RECALL = "Recall" | ||
SPECIFICITY = "Specificity" | ||
PRECISION = "Precision" | ||
NPV = "Negative Predictive Value" | ||
F1 = "F1 score" | ||
BALANCED_ACC = "Balanced Accuracy" | ||
ACC = "Accuracy" | ||
MARKEDNESS = "Markedness" | ||
MCC = "Matthews Correlation Coefficient" | ||
ROC_AUC = "ROCAUC" | ||
|
||
MSE = "MSE" | ||
MAE = "MAE" | ||
RMSE = "RMSE" | ||
PSNR = "PSNR" | ||
SSIM = "SSIM" | ||
MS_SSIM = "Multi-scale SSIM" | ||
|
||
DICE = "Dice" | ||
GENERALIZED_DICE = "Generalized Dice" | ||
IOU = "IoU" | ||
SURF_DIST = "Surface distance" | ||
HAUSDORFF = "Hausdorff distance" | ||
SURF_DICE = "Surface Dice" | ||
|
||
MMD = "MMD" | ||
|
||
@classmethod | ||
def _missing_(cls, value): | ||
raise ValueError( | ||
f"{value} is not implemented. Implemented metrics are: " | ||
+ ", ".join([repr(m.value) for m in cls]) | ||
) | ||
|
||
|
||
class Reduction(str, Enum): | ||
"""Supported reduction for the metrics.""" | ||
|
||
MEAN = "mean" | ||
SUM = "sum" | ||
|
||
|
||
class GeneralizedDiceScoreReduction(str, Enum): | ||
"""Supported reduction for GeneralizedDiceScore.""" | ||
|
||
MEAN = "mean_batch" | ||
SUM = "sum_batch" | ||
|
||
|
||
class Average(str, Enum): | ||
"""Supported averaging method for ROCAUCMetric.""" | ||
|
||
MACRO = "macro" | ||
WEIGHTED = "weighted" | ||
MICRO = "micro" | ||
|
||
|
||
class ConfusionMatrixMetric(str, Enum): | ||
"""Supported metrics related to confusion matrix (in the format accepted by MONAI).""" | ||
|
||
RECALL = "recall" | ||
SPECIFICITY = "specificity" | ||
PRECISION = "precision" | ||
NPV = "negative predictive value" | ||
F1 = "f1 score" | ||
BALANCED_ACC = "balanced accuracy" | ||
ACC = "accuracy" | ||
MARKEDNESS = "markedness" | ||
MCC = "matthews correlation coefficient" | ||
|
||
|
||
class DistanceMetric(str, Enum): | ||
"Supported distances." | ||
|
||
L2 = "euclidean" | ||
L1 = "taxicab" | ||
LINF = "chessboard" | ||
|
||
|
||
class Kernel(str, Enum): | ||
"Supported kernel for SSIMMetric." | ||
|
||
GAUSSIAN = "gaussian" | ||
UNIFORM = "uniform" | ||
|
||
|
||
class WeightType(str, Enum): | ||
"Supported weight types for GeneralizedDiceScore." | ||
|
||
SQUARE = "square" | ||
SIMPLE = "simple" | ||
UNIFORM = "uniform" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
from typing import Type, Union | ||
|
||
from .base import MetricConfig | ||
from .classification import * | ||
from .enum import ConfusionMatrixMetric, ImplementedMetrics | ||
from .generation import * | ||
from .reconstruction import * | ||
from .regression import * | ||
from .segmentation import * | ||
|
||
|
||
def create_metric_config( | ||
metric: Union[str, ImplementedMetrics], | ||
) -> Type[MetricConfig]: | ||
""" | ||
A factory function to create a config class suited for the metric. | ||
Parameters | ||
---------- | ||
metric : Union[str, ImplementedMetrics] | ||
The name of the metric. | ||
Returns | ||
------- | ||
Type[MetricConfig] | ||
The config class. | ||
Raises | ||
------ | ||
ValueError | ||
When `metric`does not correspond to any supported metric. | ||
ValueError | ||
When `metric` is `Loss`. | ||
""" | ||
metric = ImplementedMetrics(metric) | ||
if metric == ImplementedMetrics.LOSS: | ||
raise ValueError( | ||
"To use the loss as a metric, please use directly clinicadl.metrics.loss_to_metric." | ||
) | ||
|
||
# special cases | ||
if metric == ImplementedMetrics.MS_SSIM: | ||
return MultiScaleSSIMConfig | ||
if metric == ImplementedMetrics.MMD: | ||
return MMDMetricConfig | ||
|
||
try: | ||
metric = ConfusionMatrixMetric(metric.lower()) | ||
return create_confusion_matrix_config(metric) | ||
except ValueError: | ||
pass | ||
|
||
# "normal" cases: | ||
try: | ||
config = _get_config(metric) | ||
except KeyError: | ||
config = _get_config(metric.title().replace(" ", "")) | ||
|
||
return config | ||
|
||
|
||
def _get_config(name: str) -> Type[MetricConfig]: | ||
""" | ||
Tries to get a config class associated to the name. | ||
Parameters | ||
---------- | ||
name : str | ||
The name of the metric. | ||
Returns | ||
------- | ||
Type[MetricConfig] | ||
The config class. | ||
""" | ||
config_name = "".join([name, "Config"]) | ||
return globals()[config_name] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from pydantic import PositiveFloat, computed_field | ||
|
||
from .base import MetricConfig | ||
|
||
__all__ = ["MMDMetricConfig"] | ||
|
||
|
||
class MMDMetricConfig(MetricConfig): | ||
"Config class for MMD metric." | ||
|
||
kernel_bandwidth: PositiveFloat = 1.0 | ||
|
||
@computed_field | ||
@property | ||
def metric(self) -> str: | ||
"""The name of the metric.""" | ||
return "MMDMetric" |
Oops, something went wrong.