diff --git a/clinicadl/utils/metric_module.py b/clinicadl/utils/metric_module.py index c02d3fdfc..9a7de6ec2 100644 --- a/clinicadl/utils/metric_module.py +++ b/clinicadl/utils/metric_module.py @@ -480,6 +480,12 @@ def compute_ap(y, y_pred, *args): return average_precision_score(y, y_pred, *args) + @staticmethod + def compute_roc_auc(y_pred, y, *args): + from monai.metrics.rocauc import compute_roc_auc + + return compute_roc_auc(y_pred, y, *args) + class RetainBest: """ diff --git a/clinicadl/utils/task_manager/classification.py b/clinicadl/utils/task_manager/classification.py index a4bd1f564..51a35ccf0 100644 --- a/clinicadl/utils/task_manager/classification.py +++ b/clinicadl/utils/task_manager/classification.py @@ -53,6 +53,7 @@ def evaluation_metrics(self): "MK", "LR_plus", "LR_minus", + "roc_auc", ] @property diff --git a/clinicadl/utils/task_manager/task_manager.py b/clinicadl/utils/task_manager/task_manager.py index 0f220a3a7..340d4daff 100644 --- a/clinicadl/utils/task_manager/task_manager.py +++ b/clinicadl/utils/task_manager/task_manager.py @@ -18,9 +18,11 @@ # TODO: add function to check that the output size of the network corresponds to what is expected to # perform the task class TaskManager: - def __init__(self, mode: str, n_classes: int = None): + def __init__(self, mode: str, n_classes: int = 2): self.mode = mode - self.metrics_module = MetricModule(self.evaluation_metrics, n_classes=n_classes) + self.metrics_module = MetricModule( + v=self.evaluation_metrics, n_classes=n_classes + ) @property @abstractmethod