diff --git a/src/eva/core/models/modules/module.py b/src/eva/core/models/modules/module.py index c46688a6..55f3f279 100644 --- a/src/eva/core/models/modules/module.py +++ b/src/eva/core/models/modules/module.py @@ -53,7 +53,7 @@ def metrics_device(self) -> torch.device: device = os.getenv("METRICS_DEVICE", None) if device is not None: return torch.device(device) - elif self.device == torch.device("mps"): + elif self.device.type == "mps": # mps seems to have compatibility issues with segmentation metrics return torch.device("cpu") return self.device