diff --git a/src/eva/core/models/modules/module.py b/src/eva/core/models/modules/module.py index 10eddeae..c46688a6 100644 --- a/src/eva/core/models/modules/module.py +++ b/src/eva/core/models/modules/module.py @@ -50,8 +50,13 @@ def default_postprocess(self) -> batch_postprocess.BatchPostProcess: @property def metrics_device(self) -> torch.device: """Returns the device by which the metrics should be calculated.""" - device = os.getenv("METRICS_DEVICE", "cpu") - return self.device if device is None else torch.device(device) + device = os.getenv("METRICS_DEVICE", None) + if device is not None: + return torch.device(device) + elif self.device == torch.device("mps"): + # mps seems to have compatibility issues with segmentation metrics + return torch.device("cpu") + return self.device @override def on_fit_start(self) -> None: