diff --git a/src/eva/metrics/core/module.py b/src/eva/metrics/core/module.py index f08207c7..d9c69b6b 100644 --- a/src/eva/metrics/core/module.py +++ b/src/eva/metrics/core/module.py @@ -22,9 +22,9 @@ def __init__( """Initializes the metrics for the Trainer. Args: - train: The training metric collection to be tracked. - val: The validation metric collection to be tracked. - test: The test metric collection to be tracked. + train: The training metric collection. + val: The validation metric collection. + test: The test metric collection. """ super().__init__() @@ -49,9 +49,9 @@ def from_metrics( """Initializes a metric module from a list of metrics. Args: - train: A list of metrics for the training process. - val: A list of metrics for the val process. - test: A list of metrics for the test process. + train: Metrics for the training stage. + val: Metrics for the validation stage. + test: Metrics for the test stage. separator: The separator between the group name of the metric and the metric itself. Defaults to `"/"`. """ @@ -89,7 +89,7 @@ def training_metrics(self) -> collection.MetricCollection: @property def validation_metrics(self) -> collection.MetricCollection: - """Returns the metrics of the val dataset.""" + """Returns the metrics of the validation dataset.""" return self._val @property @@ -104,7 +104,7 @@ def _create_collection_from_metrics( """Create a unique collection from metrics. Args: - metrics: A list of metrics. + metrics: The desired metrics. prefix: A prefix to added to the collection. Defaults to `None`. Returns: diff --git a/tests/eva/metrics/core/test_metric_module.py b/tests/eva/metrics/core/test_metric_module.py new file mode 100644 index 00000000..fd0940a9 --- /dev/null +++ b/tests/eva/metrics/core/test_metric_module.py @@ -0,0 +1,33 @@ +"""MetricModule tests.""" +from typing import List + +import pytest +import torchmetrics + +from eva.metrics import core + + +@pytest.mark.parametrize( + "schema, expected", + [ + (core.MetricsSchema(train=torchmetrics.Dice()), [1, 0, 0]), + (core.MetricsSchema(evaluation=torchmetrics.Dice()), [0, 1, 1]), + (core.MetricsSchema(common=torchmetrics.Dice()), [1, 1, 1]), + (core.MetricsSchema(train=torchmetrics.Dice(), evaluation=torchmetrics.Dice()), [1, 1, 1]), + ], +) +def test_metric_module(metric_module: core.MetricModule, expected: List[int]) -> None: + """Tests the MetricModule.""" + assert len(metric_module.training_metrics) == expected[0] + assert len(metric_module.validation_metrics) == expected[1] + assert len(metric_module.test_metrics) == expected[2] + # test that the metrics are copied and are not the same object + assert metric_module.training_metrics != metric_module.validation_metrics + assert metric_module.training_metrics != metric_module.test_metrics + assert metric_module.validation_metrics != metric_module.test_metrics + + +@pytest.fixture(scope="function") +def metric_module(schema: core.MetricsSchema) -> core.MetricModule: + """MetricModule fixture.""" + return core.MetricModule.from_schema(schema=schema) diff --git a/tests/eva/metrics/core/test_schemas.py b/tests/eva/metrics/core/test_schemas.py index e67fe913..b20ed0fc 100644 --- a/tests/eva/metrics/core/test_schemas.py +++ b/tests/eva/metrics/core/test_schemas.py @@ -9,9 +9,27 @@ @pytest.mark.parametrize( "common, train, evaluation, expected_train, expected_evaluation", [ - (torchmetrics.Accuracy("binary"), None, None, "BinaryAccuracy()", "BinaryAccuracy()"), - (None, torchmetrics.Accuracy("binary"), None, "BinaryAccuracy()", "None"), - (None, None, torchmetrics.Accuracy("binary"), "None", "BinaryAccuracy()"), + ( + torchmetrics.Accuracy("binary"), + None, + None, + "BinaryAccuracy()", + "BinaryAccuracy()", + ), + ( + None, + torchmetrics.Accuracy("binary"), + None, + "BinaryAccuracy()", + "None", + ), + ( + None, + None, + torchmetrics.Accuracy("binary"), + "None", + "BinaryAccuracy()", + ), ( torchmetrics.Accuracy("binary"), torchmetrics.Dice(),