From 5a732f0e05e7e9b0d7bb2716e9b3b9acfdb0061c Mon Sep 17 00:00:00 2001 From: camillebrianceau Date: Thu, 5 Sep 2024 16:57:08 +0200 Subject: [PATCH] move lr scheduler config --- clinicadl/maps_manager/maps_manager.py | 3 ++- clinicadl/metrics/metric_module.py | 1 + clinicadl/trainer/tasks_utils.py | 14 +++++++++++--- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/clinicadl/maps_manager/maps_manager.py b/clinicadl/maps_manager/maps_manager.py index db77681ff..69f38e380 100644 --- a/clinicadl/maps_manager/maps_manager.py +++ b/clinicadl/maps_manager/maps_manager.py @@ -445,7 +445,8 @@ def _check_args(self, parameters): self.n_classes = output_size(self.network_task, None, train_df, self.label) else: self.n_classes = None - + print(self.network_task) + print(evaluation_metrics(network_task=self.network_task)) self.metrics_module = MetricModule( evaluation_metrics(network_task=self.network_task), n_classes=self.n_classes, diff --git a/clinicadl/metrics/metric_module.py b/clinicadl/metrics/metric_module.py index 319b4a639..69493620f 100644 --- a/clinicadl/metrics/metric_module.py +++ b/clinicadl/metrics/metric_module.py @@ -33,6 +33,7 @@ def __init__(self, metrics, n_classes=2): for method_name in dir(MetricModule) if callable(getattr(MetricModule, method_name)) ] + print(metrics) self.metrics = dict() for metric in metrics: if f"compute_{metric.lower()}" in list_fn: diff --git a/clinicadl/trainer/tasks_utils.py b/clinicadl/trainer/tasks_utils.py index 9407fd9b2..16e070e61 100644 --- a/clinicadl/trainer/tasks_utils.py +++ b/clinicadl/trainer/tasks_utils.py @@ -192,11 +192,19 @@ def evaluation_metrics(network_task: Union[str, Task]): """ network_task = Task(network_task) if network_task == Task.CLASSIFICATION: - return [e.value for e in ClassificationMetric].remove("loss") + x = [e.value for e in ClassificationMetric] + x.remove("loss") + return x elif network_task == Task.REGRESSION: - return [e.value for e in RegressionMetric].remove("loss") + x = [e.value for e in RegressionMetric] + x.remove("loss") + return x elif network_task == Task.RECONSTRUCTION: - return [e.value for e in ReconstructionMetric].remove("loss") + x = [e.value for e in ReconstructionMetric] + x.remove("loss") + return x + else: + raise ValueError("Unknown network task") def test(