Skip to content

Commit

Permalink
some changes
Browse files Browse the repository at this point in the history
  • Loading branch information
camillebrianceau committed Jun 24, 2024
1 parent cfd2561 commit a957f8b
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 2 deletions.
6 changes: 6 additions & 0 deletions clinicadl/utils/metric_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,12 @@ def compute_ap(y, y_pred, *args):

return average_precision_score(y, y_pred, *args)

@staticmethod
def compute_roc_auc(y_pred, y, *args):
from monai.metrics.rocauc import compute_roc_auc

return compute_roc_auc(y_pred, y, *args)


class RetainBest:
"""
Expand Down
1 change: 1 addition & 0 deletions clinicadl/utils/task_manager/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def evaluation_metrics(self):
"MK",
"LR_plus",
"LR_minus",
"roc_auc",
]

@property
Expand Down
6 changes: 4 additions & 2 deletions clinicadl/utils/task_manager/task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
# TODO: add function to check that the output size of the network corresponds to what is expected to
# perform the task
class TaskManager:
def __init__(self, mode: str, n_classes: int = None):
def __init__(self, mode: str, n_classes: int = 2):
self.mode = mode
self.metrics_module = MetricModule(self.evaluation_metrics, n_classes=n_classes)
self.metrics_module = MetricModule(
v=self.evaluation_metrics, n_classes=n_classes
)

@property
@abstractmethod
Expand Down

0 comments on commit a957f8b

Please sign in to comment.