diff --git a/clinicadl/monai_metrics/__init__.py b/clinicadl/monai_metrics/__init__.py new file mode 100644 index 000000000..abfdac61a --- /dev/null +++ b/clinicadl/monai_metrics/__init__.py @@ -0,0 +1,2 @@ +from .config import ImplementedMetrics, MetricConfig, create_metric_config +from .factory import get_metric, loss_to_metric diff --git a/clinicadl/monai_metrics/config/__init__.py b/clinicadl/monai_metrics/config/__init__.py new file mode 100644 index 000000000..8e2303238 --- /dev/null +++ b/clinicadl/monai_metrics/config/__init__.py @@ -0,0 +1,3 @@ +from .base import MetricConfig +from .enum import ImplementedMetrics +from .factory import create_metric_config diff --git a/clinicadl/monai_metrics/config/base.py b/clinicadl/monai_metrics/config/base.py new file mode 100644 index 000000000..47c537663 --- /dev/null +++ b/clinicadl/monai_metrics/config/base.py @@ -0,0 +1,40 @@ +from abc import ABC, abstractmethod +from typing import Union + +from pydantic import ( + BaseModel, + ConfigDict, + computed_field, + field_validator, +) + +from clinicadl.utils.factories import DefaultFromLibrary + +from .enum import Reduction + + +class MetricConfig(BaseModel, ABC): + """Base config class to configure metrics.""" + + reduction: Union[Reduction, DefaultFromLibrary] = DefaultFromLibrary.YES + get_not_nans: bool = False + include_background: Union[bool, DefaultFromLibrary] = DefaultFromLibrary.YES + # pydantic config + model_config = ConfigDict( + validate_assignment=True, + use_enum_values=True, + validate_default=True, + ) + + @computed_field + @property + @abstractmethod + def metric(self) -> str: + """The name of the metric.""" + + @field_validator("get_not_nans", mode="after") + @classmethod + def validator_get_not_nans(cls, v): + assert not v, "get_not_nans not supported in ClinicaDL. Please set to False." + + return v diff --git a/clinicadl/monai_metrics/config/classification.py b/clinicadl/monai_metrics/config/classification.py new file mode 100644 index 000000000..8a6cb75cc --- /dev/null +++ b/clinicadl/monai_metrics/config/classification.py @@ -0,0 +1,74 @@ +from abc import ABC, abstractmethod +from typing import Type, Union + +from pydantic import computed_field + +from clinicadl.utils.factories import DefaultFromLibrary + +from .base import MetricConfig +from .enum import Average, ConfusionMatrixMetric + +__all__ = [ + "ROCAUCConfig", + "create_confusion_matrix_config", +] + + +# TODO : AP is missing +class ROCAUCConfig(MetricConfig): + "Config class for ROC AUC." + + average: Union[Average, DefaultFromLibrary] = DefaultFromLibrary.YES + + @computed_field + @property + def metric(self) -> str: + """The name of the metric.""" + return "ROCAUCMetric" + + +class ConfusionMatrixMetricConfig(MetricConfig, ABC): + "Config class for metrics derived from the confusion matrix." + + compute_sample: Union[bool, DefaultFromLibrary] = DefaultFromLibrary.YES + + @computed_field + @property + def metric(self) -> str: + """The name of the metric.""" + return "ConfusionMatrixMetric" + + @computed_field + @property + @abstractmethod + def metric_name(self) -> str: + """The name of the metric computed from the confusion matrix.""" + + +def create_confusion_matrix_config( + metric_name: ConfusionMatrixMetric, +) -> Type[ConfusionMatrixMetricConfig]: + """ + Builds a config class for a specific metric computed from the confusion matrix." + + Parameters + ---------- + metric_name : ConfusionMatrixMetric + The metric name (e.g. 'f1 score', 'accuracy', etc.). + + Returns + ------- + Type[ConfusionMatrixMetricConfig] + The config class. + """ + + class ConfusionMatrixMetricSubConfig(ConfusionMatrixMetricConfig): + "A sub config class for a specific metric computed from the confusion matrix." + + @computed_field + @property + def metric_name(self) -> str: + """The name of the metric computed from the confusion matrix.""" + return metric_name + + return ConfusionMatrixMetricSubConfig diff --git a/clinicadl/monai_metrics/config/enum.py b/clinicadl/monai_metrics/config/enum.py new file mode 100644 index 000000000..c0153ad99 --- /dev/null +++ b/clinicadl/monai_metrics/config/enum.py @@ -0,0 +1,100 @@ +from enum import Enum + + +class ImplementedMetrics(str, Enum): + """Implemented metrics in ClinicaDL.""" + + LOSS = "Loss" + + RECALL = "Recall" + SPECIFICITY = "Specificity" + PRECISION = "Precision" + NPV = "Negative Predictive Value" + F1 = "F1 score" + BALANCED_ACC = "Balanced Accuracy" + ACC = "Accuracy" + MARKEDNESS = "Markedness" + MCC = "Matthews Correlation Coefficient" + ROC_AUC = "ROCAUC" + + MSE = "MSE" + MAE = "MAE" + RMSE = "RMSE" + PSNR = "PSNR" + SSIM = "SSIM" + MS_SSIM = "Multi-scale SSIM" + + DICE = "Dice" + GENERALIZED_DICE = "Generalized Dice" + IOU = "IoU" + SURF_DIST = "Surface distance" + HAUSDORFF = "Hausdorff distance" + SURF_DICE = "Surface Dice" + + MMD = "MMD" + + @classmethod + def _missing_(cls, value): + raise ValueError( + f"{value} is not implemented. Implemented metrics are: " + + ", ".join([repr(m.value) for m in cls]) + ) + + +class Reduction(str, Enum): + """Supported reduction for the metrics.""" + + MEAN = "mean" + SUM = "sum" + + +class GeneralizedDiceScoreReduction(str, Enum): + """Supported reduction for GeneralizedDiceScore.""" + + MEAN = "mean_batch" + SUM = "sum_batch" + + +class Average(str, Enum): + """Supported averaging method for ROCAUCMetric.""" + + MACRO = "macro" + WEIGHTED = "weighted" + MICRO = "micro" + + +class ConfusionMatrixMetric(str, Enum): + """Supported metrics related to confusion matrix (in the format accepted by MONAI).""" + + RECALL = "recall" + SPECIFICITY = "specificity" + PRECISION = "precision" + NPV = "negative predictive value" + F1 = "f1 score" + BALANCED_ACC = "balanced accuracy" + ACC = "accuracy" + MARKEDNESS = "markedness" + MCC = "matthews correlation coefficient" + + +class DistanceMetric(str, Enum): + "Supported distances." + + L2 = "euclidean" + L1 = "taxicab" + LINF = "chessboard" + + +class Kernel(str, Enum): + "Supported kernel for SSIMMetric." + + GAUSSIAN = "gaussian" + UNIFORM = "uniform" + + +class WeightType(str, Enum): + "Supported weight types for GeneralizedDiceScore." + + SQUARE = "square" + SIMPLE = "simple" + UNIFORM = "uniform" diff --git a/clinicadl/monai_metrics/config/factory.py b/clinicadl/monai_metrics/config/factory.py new file mode 100644 index 000000000..235e40d0b --- /dev/null +++ b/clinicadl/monai_metrics/config/factory.py @@ -0,0 +1,77 @@ +from typing import Type, Union + +from .base import MetricConfig +from .classification import * +from .enum import ConfusionMatrixMetric, ImplementedMetrics +from .generation import * +from .reconstruction import * +from .regression import * +from .segmentation import * + + +def create_metric_config( + metric: Union[str, ImplementedMetrics], +) -> Type[MetricConfig]: + """ + A factory function to create a config class suited for the metric. + + Parameters + ---------- + metric : Union[str, ImplementedMetrics] + The name of the metric. + + Returns + ------- + Type[MetricConfig] + The config class. + + Raises + ------ + ValueError + When `metric`does not correspond to any supported metric. + ValueError + When `metric` is `Loss`. + """ + metric = ImplementedMetrics(metric) + if metric == ImplementedMetrics.LOSS: + raise ValueError( + "To use the loss as a metric, please use directly clinicadl.metrics.loss_to_metric." + ) + + # special cases + if metric == ImplementedMetrics.MS_SSIM: + return MultiScaleSSIMConfig + if metric == ImplementedMetrics.MMD: + return MMDMetricConfig + + try: + metric = ConfusionMatrixMetric(metric.lower()) + return create_confusion_matrix_config(metric) + except ValueError: + pass + + # "normal" cases: + try: + config = _get_config(metric) + except KeyError: + config = _get_config(metric.title().replace(" ", "")) + + return config + + +def _get_config(name: str) -> Type[MetricConfig]: + """ + Tries to get a config class associated to the name. + + Parameters + ---------- + name : str + The name of the metric. + + Returns + ------- + Type[MetricConfig] + The config class. + """ + config_name = "".join([name, "Config"]) + return globals()[config_name] diff --git a/clinicadl/monai_metrics/config/generation.py b/clinicadl/monai_metrics/config/generation.py new file mode 100644 index 000000000..d3767bbe1 --- /dev/null +++ b/clinicadl/monai_metrics/config/generation.py @@ -0,0 +1,17 @@ +from pydantic import PositiveFloat, computed_field + +from .base import MetricConfig + +__all__ = ["MMDMetricConfig"] + + +class MMDMetricConfig(MetricConfig): + "Config class for MMD metric." + + kernel_bandwidth: PositiveFloat = 1.0 + + @computed_field + @property + def metric(self) -> str: + """The name of the metric.""" + return "MMDMetric" diff --git a/clinicadl/monai_metrics/config/reconstruction.py b/clinicadl/monai_metrics/config/reconstruction.py new file mode 100644 index 000000000..1e46e0976 --- /dev/null +++ b/clinicadl/monai_metrics/config/reconstruction.py @@ -0,0 +1,92 @@ +from typing import Tuple, Union + +from pydantic import ( + NonNegativeFloat, + PositiveFloat, + PositiveInt, + computed_field, + field_validator, + model_validator, +) + +from clinicadl.utils.factories import DefaultFromLibrary + +from .base import MetricConfig +from .enum import Kernel + +__all__ = [ + "PSNRConfig", + "SSIMConfig", + "MultiScaleSSIMConfig", +] + + +class PSNRConfig(MetricConfig): + "Config class for PSNR." + + max_val: PositiveFloat + + @computed_field + @property + def metric(self) -> str: + """The name of the metric.""" + return "PSNRMetric" + + +class SSIMConfig(MetricConfig): + "Config class for SSIM." + + spatial_dims: PositiveInt + data_range: Union[PositiveFloat, DefaultFromLibrary] = DefaultFromLibrary.YES + kernel_type: Union[Kernel, DefaultFromLibrary] = DefaultFromLibrary.YES + win_size: Union[ + PositiveInt, Tuple[PositiveInt, ...], DefaultFromLibrary + ] = DefaultFromLibrary.YES + kernel_sigma: Union[ + PositiveFloat, Tuple[PositiveFloat, ...], DefaultFromLibrary + ] = DefaultFromLibrary.YES + k1: Union[NonNegativeFloat, DefaultFromLibrary] = DefaultFromLibrary.YES + k2: Union[NonNegativeFloat, DefaultFromLibrary] = DefaultFromLibrary.YES + + @computed_field + @property + def metric(self) -> str: + """The name of the metric.""" + return "SSIMMetric" + + @field_validator("spatial_dims", mode="after") + @classmethod + def validator_spatial_dims(cls, v): + assert v == 2 or v == 3, f"spatial_dims must be 2 or 3. You passed: {v}." + + return v + + @model_validator(mode="after") + def dimension_validator(self): + """Checks coherence between fields.""" + self._check_spatial_dim("win_size") + self._check_spatial_dim("kernel_sigma") + + return self + + def _check_spatial_dim(self, attribute: str) -> None: + """Checks that the dimensionality of an attribute is consistent with self.spatial_dims.""" + value = getattr(self, attribute) + if isinstance(value, tuple): + assert ( + len(value) == self.spatial_dims + ), f"If you pass a sequence for {attribute}, it must be of size {self.spatial_dims}. You passed: {value}." + + +class MultiScaleSSIMConfig(SSIMConfig): + "Config class for multi-scale SSIM." + + weights: Union[ + Tuple[PositiveFloat, ...], DefaultFromLibrary + ] = DefaultFromLibrary.YES + + @computed_field + @property + def metric(self) -> str: + """The name of the metric.""" + return "MultiScaleSSIMMetric" diff --git a/clinicadl/monai_metrics/config/regression.py b/clinicadl/monai_metrics/config/regression.py new file mode 100644 index 000000000..039c3e888 --- /dev/null +++ b/clinicadl/monai_metrics/config/regression.py @@ -0,0 +1,40 @@ +from pydantic import computed_field + +from .base import MetricConfig + +__all__ = [ + "MSEConfig", + "MAEConfig", + "RMSEConfig", +] + + +# TODO : R2 missing +class MSEConfig(MetricConfig): + "Config class for MSE." + + @computed_field + @property + def metric(self) -> str: + """The name of the metric.""" + return "MSEMetric" + + +class MAEConfig(MetricConfig): + "Config class for MAE." + + @computed_field + @property + def metric(self) -> str: + """The name of the metric.""" + return "MAEMetric" + + +class RMSEConfig(MetricConfig): + "Config class for RMSE." + + @computed_field + @property + def metric(self) -> str: + """The name of the metric.""" + return "RMSEMetric" diff --git a/clinicadl/monai_metrics/config/segmentation.py b/clinicadl/monai_metrics/config/segmentation.py new file mode 100644 index 000000000..90acf7cab --- /dev/null +++ b/clinicadl/monai_metrics/config/segmentation.py @@ -0,0 +1,123 @@ +from typing import Optional, Tuple, Union + +from pydantic import NonNegativeFloat, PositiveInt, computed_field, field_validator + +from clinicadl.utils.factories import DefaultFromLibrary + +from .base import MetricConfig +from .enum import DistanceMetric, GeneralizedDiceScoreReduction, WeightType + +__all__ = [ + "DiceConfig", + "IoUConfig", + "GeneralizedDiceConfig", + "SurfaceDistanceConfig", + "HausdorffDistanceConfig", + "SurfaceDiceConfig", +] + + +class SegmentationMetricConfig(MetricConfig): + """Base config class for segmentation metrics.""" + + ignore_empty: Union[bool, DefaultFromLibrary] = DefaultFromLibrary.YES + + +class DiceConfig(SegmentationMetricConfig): + """Config class for Dice score.""" + + num_classes: Union[ + Optional[PositiveInt], DefaultFromLibrary + ] = DefaultFromLibrary.YES + return_with_label: bool = False + + @computed_field + @property + def metric(self) -> str: + """The name of the metric.""" + return "DiceMetric" + + @field_validator("return_with_label", mode="after") + @classmethod + def validator_return_with_label(cls, v): + assert ( + not v + ), "return_with_label not supported in ClinicaDL. Please set to False." + + return v + + +class IoUConfig(SegmentationMetricConfig): + """Config class for IoU metric.""" + + @computed_field + @property + def metric(self) -> str: + """The name of the metric.""" + return "MeanIoU" + + +class GeneralizedDiceConfig(MetricConfig): + """Config class for generalized Dice score.""" + + reduction: Union[ + GeneralizedDiceScoreReduction, DefaultFromLibrary + ] = DefaultFromLibrary.YES + weight_type: Union[WeightType, DefaultFromLibrary] = DefaultFromLibrary.YES + + @computed_field + @property + def metric(self) -> str: + """The name of the metric.""" + return "GeneralizedDiceScore" + + +class SurfaceDistanceConfig(MetricConfig): + """Config class for Surface Distance metric.""" + + distance_metric: Union[DistanceMetric, DefaultFromLibrary] = DefaultFromLibrary.YES + symmetric: Union[bool, DefaultFromLibrary] = DefaultFromLibrary.YES + + @computed_field + @property + def metric(self) -> str: + """The name of the metric.""" + return "SurfaceDistanceMetric" + + +class HausdorffDistanceConfig(SurfaceDistanceConfig): + """Config class for Hausdorff distance.""" + + percentile: Union[ + Optional[NonNegativeFloat], DefaultFromLibrary + ] = DefaultFromLibrary.YES + directed: Union[bool, DefaultFromLibrary] = DefaultFromLibrary.YES + + @computed_field + @property + def metric(self) -> str: + """The name of the metric.""" + return "HausdorffDistanceMetric" + + @field_validator("percentile", mode="after") + @classmethod + def validator_return_with_label(cls, v): + if isinstance(v, float): + assert ( + 0 <= v <= 100 + ), f"percentile must be between 0 and 100. You passed: {v}." + + return v + + +class SurfaceDiceConfig(SurfaceDistanceConfig): + """Config class for (normalized) surface Dice score.""" + + class_thresholds: Tuple[NonNegativeFloat, ...] + use_subvoxels: Union[bool, DefaultFromLibrary] = DefaultFromLibrary.YES + + @computed_field + @property + def metric(self) -> str: + """The name of the metric.""" + return "SurfaceDiceMetric" diff --git a/clinicadl/monai_metrics/factory.py b/clinicadl/monai_metrics/factory.py new file mode 100644 index 000000000..a5c8d5eda --- /dev/null +++ b/clinicadl/monai_metrics/factory.py @@ -0,0 +1,77 @@ +from typing import Optional, Tuple, Union + +import monai.metrics as metrics + +from clinicadl.losses.utils import Loss +from clinicadl.utils.factories import DefaultFromLibrary, get_args_and_defaults + +from .config.base import MetricConfig +from .config.enum import Reduction + + +def get_metric(config: MetricConfig) -> Tuple[metrics.Metric, MetricConfig]: + """ + Factory function to get a metric from MONAI. + + Parameters + ---------- + config : MetricConfig + The config class with the parameters of the metric. + + Returns + ------- + metrics.Metric + The Metric object. + MetricConfig + The updated config class: the arguments set to default will be updated + with their effective values (the default values from the library). + Useful for reproducibility. + """ + metric_class = getattr(metrics, config.metric) + expected_args, config_dict = get_args_and_defaults(metric_class.__init__) + for arg, value in config.model_dump().items(): + if arg in expected_args and value != DefaultFromLibrary.YES: + config_dict[arg] = value + + metric = metric_class(**config_dict) + updated_config = config.model_copy(update=config_dict) + + return metric, updated_config + + +def loss_to_metric( + loss_fn: Loss, + reduction: Optional[Union[str, Reduction]] = None, +) -> metrics.LossMetric: + """ + Converts a loss function to a metric object. + + Parameters + ---------- + loss_fn : Loss + A callable function that takes y_pred and optionally y as input (in the “batch-first” format), returns a 1-item tensor. + loss_fn can also be a PyTorch loss object. + reduction : Optional[Union[str, Reduction]] (optional, default=None) + Defines mode of reduction. If not passed, the reduction method of the loss function will be used (if it exists). + + Returns + ------- + metrics.LossMetric + The loss function wrapped in a metric object. + + Raises + ------ + ValueError + If the user didn't pass a reduction method, and the loss function doesn't have an attribute 'reduction'. + """ + if reduction is None: + try: + checked_reduction = loss_fn.reduction + except AttributeError as exc: + raise ValueError( + "If the loss function doesn't have an attribute 'reduction', you must pass a reduction method." + ) from exc + else: + checked_reduction = Reduction(reduction) + + return metrics.LossMetric(loss_fn, reduction=checked_reduction) diff --git a/tests/unittests/monai_metrics/__init__.py b/tests/unittests/monai_metrics/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unittests/monai_metrics/config/__init__.py b/tests/unittests/monai_metrics/config/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unittests/monai_metrics/config/test_classification.py b/tests/unittests/monai_metrics/config/test_classification.py new file mode 100644 index 000000000..e4192a254 --- /dev/null +++ b/tests/unittests/monai_metrics/config/test_classification.py @@ -0,0 +1,60 @@ +import pytest +from pydantic import ValidationError + +from clinicadl.monai_metrics.config.classification import ( + ROCAUCConfig, + create_confusion_matrix_config, +) +from clinicadl.monai_metrics.config.enum import ConfusionMatrixMetric + + +# ROCAUC +def test_fails_validations_rocauc(): + with pytest.raises(ValidationError): + ROCAUCConfig(average="abc") + + +def test_ROCAUCConfig(): + config = ROCAUCConfig( + average="macro", + ) + assert config.metric == "ROCAUCMetric" + assert config.average == "macro" + + +# Confusion Matrix +@pytest.mark.parametrize( + "bad_inputs", + [ + {"reduction": "abc"}, + {"get_not_nans": True}, + ], +) +def test_fails_validations_cmatrix(bad_inputs): + for m in ConfusionMatrixMetric: + config_class = create_confusion_matrix_config(m.value) + with pytest.raises(ValidationError): + config_class(**bad_inputs) + + +def test_passes_validations_cmatrix(): + for m in ConfusionMatrixMetric: + config_class = create_confusion_matrix_config(m.value) + config_class( + reduction="mean", + get_not_nans=False, + compute_sample=False, + ) + + +def test_ConfusionMatrixMetricConfig(): + for m in ConfusionMatrixMetric: + config_class = create_confusion_matrix_config(m.value) + config = config_class( + reduction="sum", + ) + assert config.metric == "ConfusionMatrixMetric" + assert config.reduction == "sum" + assert config.metric_name == m.value + assert config.include_background == "DefaultFromLibrary" + assert not config.get_not_nans diff --git a/tests/unittests/monai_metrics/config/test_factory.py b/tests/unittests/monai_metrics/config/test_factory.py new file mode 100644 index 000000000..5eed6f459 --- /dev/null +++ b/tests/unittests/monai_metrics/config/test_factory.py @@ -0,0 +1,37 @@ +import pytest + +from clinicadl.monai_metrics.config import ImplementedMetrics, create_metric_config + + +def test_create_training_config(): + for metric in [e.value for e in ImplementedMetrics]: + if metric == "Loss": + with pytest.raises(ValueError): + create_metric_config(metric) + else: + create_metric_config(metric) + + config_class = create_metric_config("Hausdorff distance") + config = config_class( + include_background=True, + distance_metric="taxicab", + reduction="sum", + percentile=50, + ) + assert config.metric == "HausdorffDistanceMetric" + assert config.include_background + assert config.distance_metric == "taxicab" + assert config.reduction == "sum" + assert config.percentile == 50 + assert config.directed == "DefaultFromLibrary" + assert not config.get_not_nans + + config_class = create_metric_config("F1 score") + config = config_class( + include_background=True, + compute_sample=True, + ) + assert config.metric == "ConfusionMatrixMetric" + assert config.include_background + assert config.compute_sample + assert config.metric_name == "f1 score" diff --git a/tests/unittests/monai_metrics/config/test_generation.py b/tests/unittests/monai_metrics/config/test_generation.py new file mode 100644 index 000000000..4e1691567 --- /dev/null +++ b/tests/unittests/monai_metrics/config/test_generation.py @@ -0,0 +1,17 @@ +import pytest +from pydantic import ValidationError + +from clinicadl.monai_metrics.config.generation import MMDMetricConfig + + +def test_fails_validation(): + with pytest.raises(ValidationError): + MMDMetricConfig(kernel_bandwidth=0) + + +def test_MMDMetricConfig(): + config = MMDMetricConfig( + kernel_bandwidth=2.0, + ) + assert config.metric == "MMDMetric" + assert config.kernel_bandwidth == 2.0 diff --git a/tests/unittests/monai_metrics/config/test_reconstruction.py b/tests/unittests/monai_metrics/config/test_reconstruction.py new file mode 100644 index 000000000..521c2717e --- /dev/null +++ b/tests/unittests/monai_metrics/config/test_reconstruction.py @@ -0,0 +1,121 @@ +import pytest +from pydantic import ValidationError + +from clinicadl.monai_metrics.config.reconstruction import ( + MultiScaleSSIMConfig, + PSNRConfig, + SSIMConfig, +) + + +# PSNR # +@pytest.mark.parametrize( + "bad_inputs", + [ + {"max_val": 255, "reduction": "abc"}, + {"max_val": 255, "get_not_nans": True}, + {"max_val": 0}, + ], +) +def test_fails_validation_psnr(bad_inputs): + with pytest.raises(ValidationError): + PSNRConfig(**bad_inputs) + + +@pytest.mark.parametrize( + "good_inputs", + [ + {"max_val": 255, "reduction": "sum"}, + {"max_val": 255, "reduction": "mean"}, + {"max_val": 255, "get_not_nans": False}, + ], +) +def test_passes_validations_psnr(good_inputs): + PSNRConfig(**good_inputs) + + +def test_PSNRConfig(): + config = PSNRConfig( + max_val=7, + reduction="sum", + ) + assert config.metric == "PSNRMetric" + assert config.max_val == 7 + assert config.reduction == "sum" + assert not config.get_not_nans + + +# SSIM # +@pytest.mark.parametrize( + "bad_inputs", + [ + {"spatial_dims": 1}, + {"spatial_dims": 2, "data_range": 0}, + {"spatial_dims": 2, "kernel_type": "abc"}, + {"spatial_dims": 2, "win_size": 0}, + {"spatial_dims": 2, "win_size": (1, 2, 3)}, + {"spatial_dims": 2, "kernel_sigma": 0}, + {"spatial_dims": 2, "kernel_sigma": (1.0, 2.0, 3.0)}, + {"spatial_dims": 2, "k1": -1.0}, + {"spatial_dims": 2, "k2": -0.01}, + ], +) +def test_fails_validations(bad_inputs): + with pytest.raises(ValidationError): + SSIMConfig(**bad_inputs) + with pytest.raises(ValidationError): + MultiScaleSSIMConfig(**bad_inputs) + + +def test_fails_validation_msssim(): + with pytest.raises(ValidationError): + MultiScaleSSIMConfig(spatial_dims=2, weights=(0.0, 1.0)) + with pytest.raises(ValidationError): + MultiScaleSSIMConfig(spatial_dims=2, weights=1.0) + + +@pytest.mark.parametrize( + "good_inputs", + [ + { + "spatial_dims": 2, + "data_range": 1, + "kernel_type": "gaussian", + "win_size": 10, + "kernel_sigma": 1.0, + "k1": 1.0, + "k2": 1.0, + "weights": [1.0, 2.0], + }, + {"spatial_dims": 2, "win_size": (1, 2), "kernel_sigma": (1.0, 2.0)}, + ], +) +def test_passes_validations(good_inputs): + MultiScaleSSIMConfig(**good_inputs) + SSIMConfig(**good_inputs) + + +def test_SSIMConfig(): + config = SSIMConfig( + spatial_dims=2, + reduction="sum", + k1=1.0, + ) + assert config.metric == "SSIMMetric" + assert config.reduction == "sum" + assert config.spatial_dims == 2 + assert config.k1 == 1.0 + assert config.k2 == "DefaultFromLibrary" + + +def test_MultiScaleSSIMMetric(): + config = MultiScaleSSIMConfig( + spatial_dims=2, reduction="sum", k1=1.0, weights=[1.0], win_size=10 + ) + assert config.metric == "MultiScaleSSIMMetric" + assert config.reduction == "sum" + assert config.spatial_dims == 2 + assert config.win_size == 10 + assert config.k1 == 1.0 + assert config.k2 == "DefaultFromLibrary" + assert config.weights == (1.0,) diff --git a/tests/unittests/monai_metrics/config/test_regression.py b/tests/unittests/monai_metrics/config/test_regression.py new file mode 100644 index 000000000..f95f20b6a --- /dev/null +++ b/tests/unittests/monai_metrics/config/test_regression.py @@ -0,0 +1,65 @@ +import pytest +from pydantic import ValidationError + +from clinicadl.monai_metrics.config.regression import ( + MAEConfig, + MSEConfig, + RMSEConfig, +) + + +@pytest.mark.parametrize( + "bad_inputs", + [ + {"reduction": "abc"}, + {"get_not_nans": True}, + ], +) +def test_fails_validations(bad_inputs): + with pytest.raises(ValidationError): + MAEConfig(**bad_inputs) + with pytest.raises(ValidationError): + MSEConfig(**bad_inputs) + with pytest.raises(ValidationError): + RMSEConfig(**bad_inputs) + + +@pytest.mark.parametrize( + "good_inputs", + [ + {"reduction": "sum"}, + {"reduction": "mean"}, + {"get_not_nans": False}, + ], +) +def test_passes_validations(good_inputs): + MAEConfig(**good_inputs) + MSEConfig(**good_inputs) + RMSEConfig(**good_inputs) + + +def test_MAEConfig(): + config = MAEConfig( + reduction="sum", + ) + assert config.metric == "MAEMetric" + assert config.reduction == "sum" + assert not config.get_not_nans + + +def test_MSEConfig(): + config = MSEConfig( + reduction="sum", + ) + assert config.metric == "MSEMetric" + assert config.reduction == "sum" + assert not config.get_not_nans + + +def test_RMSEConfig(): + config = RMSEConfig( + reduction="sum", + ) + assert config.metric == "RMSEMetric" + assert config.reduction == "sum" + assert not config.get_not_nans diff --git a/tests/unittests/monai_metrics/config/test_segmentation.py b/tests/unittests/monai_metrics/config/test_segmentation.py new file mode 100644 index 000000000..537f289c9 --- /dev/null +++ b/tests/unittests/monai_metrics/config/test_segmentation.py @@ -0,0 +1,132 @@ +import pytest +from pydantic import ValidationError + +from clinicadl.monai_metrics.config.segmentation import ( + DiceConfig, + GeneralizedDiceConfig, + HausdorffDistanceConfig, + IoUConfig, + SurfaceDiceConfig, + SurfaceDistanceConfig, +) + + +@pytest.mark.parametrize( + "bad_inputs", + [ + {"class_thresholds": [0.1], "reduction": "abc"}, + {"class_thresholds": [0.1], "get_not_nans": True}, + ], +) +def test_fails_validation(bad_inputs): + with pytest.raises(ValidationError): + DiceConfig(**bad_inputs) + with pytest.raises(ValidationError): + IoUConfig(**bad_inputs) + with pytest.raises(ValidationError): + SurfaceDistanceConfig(**bad_inputs) + + +def test_fails_validation_dice(): + with pytest.raises(ValidationError): + DiceConfig(return_with_label=True) + with pytest.raises(ValidationError): + DiceConfig(num_classes=0) + + +def test_fails_validation_gen_dice(): + with pytest.raises(ValidationError): + GeneralizedDiceConfig(reduction="mean") + with pytest.raises(ValidationError): + GeneralizedDiceConfig(weight_type="abc") + + +def test_fails_validation_surface_dist(): + with pytest.raises(ValidationError): + SurfaceDistanceConfig(distance_metric="abc") + + +def test_fails_validation_haussdorf(): + with pytest.raises(ValidationError): + HausdorffDistanceConfig(percentile=-1) + + +def test_fails_validation_surface_dice(): + with pytest.raises(ValidationError): + SurfaceDiceConfig(class_thresholds=0.1) + + +def test_DiceConfig(): + config = DiceConfig( + num_classes=3, + include_background=False, + reduction="mean", + ) + assert config.metric == "DiceMetric" + assert config.num_classes == 3 + assert not config.include_background + assert config.reduction == "mean" + assert config.ignore_empty == "DefaultFromLibrary" + assert not config.get_not_nans + assert not config.return_with_label + + +def test_IoUConfig(): + config = IoUConfig( + num_classes=3, + include_background=False, + reduction="mean", + ) + assert config.metric == "MeanIoU" + assert not config.include_background + assert config.reduction == "mean" + assert config.ignore_empty == "DefaultFromLibrary" + assert not config.get_not_nans + + +def test_GeneralizedDiceConfig(): + config = GeneralizedDiceConfig( + weight_type="square", + reduction="mean_batch", + ) + assert config.metric == "GeneralizedDiceScore" + assert config.weight_type == "square" + assert config.include_background == "DefaultFromLibrary" + assert config.reduction == "mean_batch" + + +def test_SurfaceDistanceConfig(): + config = SurfaceDistanceConfig( + symmetric=True, + distance_metric="taxicab", + ) + assert config.metric == "SurfaceDistanceMetric" + assert config.symmetric + assert config.distance_metric == "taxicab" + assert config.reduction == "DefaultFromLibrary" + assert config.include_background == "DefaultFromLibrary" + + +def test_HausdorffDistanceConfig(): + config = HausdorffDistanceConfig( + percentile=50, + directed=True, + ) + assert config.metric == "HausdorffDistanceMetric" + assert config.percentile == 50 + assert config.directed + assert config.distance_metric == "DefaultFromLibrary" + assert config.include_background == "DefaultFromLibrary" + assert not config.get_not_nans + + +def test_SurfaceDiceConfig(): + config = SurfaceDiceConfig( + use_subvoxels=True, class_thresholds=[0.1, 100], distance_metric="chessboard" + ) + assert config.metric == "SurfaceDiceMetric" + assert config.class_thresholds == (0.1, 100) + assert config.use_subvoxels + assert config.distance_metric == "chessboard" + assert config.include_background == "DefaultFromLibrary" + assert not config.get_not_nans diff --git a/tests/unittests/monai_metrics/test_factory.py b/tests/unittests/monai_metrics/test_factory.py new file mode 100644 index 000000000..5d265e416 --- /dev/null +++ b/tests/unittests/monai_metrics/test_factory.py @@ -0,0 +1,78 @@ +import pytest +from torch import Size, Tensor +from torch.nn import MSELoss + + +def test_get_metric(): + from monai.metrics import SSIMMetric + + from clinicadl.monai_metrics import get_metric + from clinicadl.monai_metrics.config import ImplementedMetrics, create_metric_config + + for metric_name in [e.value for e in ImplementedMetrics if e != "Loss"]: + if ( + metric_name == ImplementedMetrics.SSIM + or metric_name == ImplementedMetrics.MS_SSIM + ): + params = {"spatial_dims": 3, "kernel_sigma": 13.0} + elif metric_name == ImplementedMetrics.PSNR: + params = {"max_val": 3} + elif metric_name == ImplementedMetrics.SURF_DICE: + params = {"class_thresholds": [0.1, 0.2]} + else: + params = {} + config = create_metric_config(metric_name)(**params) + + metric, updated_config = get_metric(config) + + if metric_name == "SSIM": + assert isinstance(metric, SSIMMetric) + assert metric.spatial_dims == 3 + assert metric.data_range == 1.0 + assert metric.kernel_type == "gaussian" + assert metric.kernel_sigma == (13.0, 13.0, 13.0) + + assert updated_config.metric == "SSIMMetric" + assert updated_config.spatial_dims == 3 + assert updated_config.data_range == 1.0 + assert updated_config.kernel_type == "gaussian" + assert updated_config.kernel_sigma == 13.0 + assert updated_config.k1 == 0.01 + + +@pytest.mark.skip() +def loss_fn(y_pred: Tensor, y_true: Tensor) -> Tensor: + return ((y_pred - y_true) ** 2).sum() + + +@pytest.mark.skip() +def loss_fn_bis(y_pred: Tensor) -> Tensor: + return (y_pred**2).sum() + + +def test_loss_to_metric(): + from torch import randn + + from clinicadl.monai_metrics import loss_to_metric + + y_pred = randn(10, 5, 5) + y_true = randn(10, 5, 5) + + with pytest.raises(ValueError): + loss_to_metric(loss_fn) + + metric = loss_to_metric(MSELoss(reduction="sum"), reduction="mean") + assert metric.reduction == "mean" + assert metric(y_pred, y_true).shape == Size((1, 1)) + + metric = loss_to_metric(MSELoss(reduction="sum")) + assert metric.reduction == "sum" + assert metric(y_pred, y_true).shape == Size((1, 1)) + + metric = loss_to_metric(loss_fn, reduction="sum") + assert metric.reduction == "sum" + assert metric(y_pred, y_true).shape == Size((1, 1)) + + metric = loss_to_metric(loss_fn_bis, reduction="sum") + assert metric.reduction == "sum" + assert metric(y_pred).shape == Size((1, 1))