Skip to content

Commit

Permalink
add metrics module tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ioangatop committed Jan 17, 2024
1 parent bc2c2ec commit 314334a
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 11 deletions.
16 changes: 8 additions & 8 deletions src/eva/metrics/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ def __init__(
"""Initializes the metrics for the Trainer.
Args:
train: The training metric collection to be tracked.
val: The validation metric collection to be tracked.
test: The test metric collection to be tracked.
train: The training metric collection.
val: The validation metric collection.
test: The test metric collection.
"""
super().__init__()

Expand All @@ -49,9 +49,9 @@ def from_metrics(
"""Initializes a metric module from a list of metrics.
Args:
train: A list of metrics for the training process.
val: A list of metrics for the val process.
test: A list of metrics for the test process.
train: Metrics for the training stage.
val: Metrics for the validation stage.
test: Metrics for the test stage.
separator: The separator between the group name of the metric
and the metric itself. Defaults to `"/"`.
"""
Expand Down Expand Up @@ -89,7 +89,7 @@ def training_metrics(self) -> collection.MetricCollection:

@property
def validation_metrics(self) -> collection.MetricCollection:
"""Returns the metrics of the val dataset."""
"""Returns the metrics of the validation dataset."""
return self._val

@property
Expand All @@ -104,7 +104,7 @@ def _create_collection_from_metrics(
"""Create a unique collection from metrics.
Args:
metrics: A list of metrics.
metrics: The desired metrics.
prefix: A prefix to added to the collection. Defaults to `None`.
Returns:
Expand Down
33 changes: 33 additions & 0 deletions tests/eva/metrics/core/test_metric_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""MetricModule tests."""
from typing import List

import pytest
import torchmetrics

from eva.metrics import core


@pytest.mark.parametrize(
"schema, expected",
[
(core.MetricsSchema(train=torchmetrics.Dice()), [1, 0, 0]),
(core.MetricsSchema(evaluation=torchmetrics.Dice()), [0, 1, 1]),
(core.MetricsSchema(common=torchmetrics.Dice()), [1, 1, 1]),
(core.MetricsSchema(train=torchmetrics.Dice(), evaluation=torchmetrics.Dice()), [1, 1, 1]),
],
)
def test_metric_module(metric_module: core.MetricModule, expected: List[int]) -> None:
"""Tests the MetricModule."""
assert len(metric_module.training_metrics) == expected[0]
assert len(metric_module.validation_metrics) == expected[1]
assert len(metric_module.test_metrics) == expected[2]
# test that the metrics are copied and are not the same object
assert metric_module.training_metrics != metric_module.validation_metrics
assert metric_module.training_metrics != metric_module.test_metrics
assert metric_module.validation_metrics != metric_module.test_metrics


@pytest.fixture(scope="function")
def metric_module(schema: core.MetricsSchema) -> core.MetricModule:
"""MetricModule fixture."""
return core.MetricModule.from_schema(schema=schema)
24 changes: 21 additions & 3 deletions tests/eva/metrics/core/test_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,27 @@
@pytest.mark.parametrize(
"common, train, evaluation, expected_train, expected_evaluation",
[
(torchmetrics.Accuracy("binary"), None, None, "BinaryAccuracy()", "BinaryAccuracy()"),
(None, torchmetrics.Accuracy("binary"), None, "BinaryAccuracy()", "None"),
(None, None, torchmetrics.Accuracy("binary"), "None", "BinaryAccuracy()"),
(
torchmetrics.Accuracy("binary"),
None,
None,
"BinaryAccuracy()",
"BinaryAccuracy()",
),
(
None,
torchmetrics.Accuracy("binary"),
None,
"BinaryAccuracy()",
"None",
),
(
None,
None,
torchmetrics.Accuracy("binary"),
"None",
"BinaryAccuracy()",
),
(
torchmetrics.Accuracy("binary"),
torchmetrics.Dice(),
Expand Down

0 comments on commit 314334a

Please sign in to comment.