Skip to content

Commit

Permalink
move lr scheduler config
Browse files Browse the repository at this point in the history
  • Loading branch information
camillebrianceau committed Sep 5, 2024
1 parent 7b97df5 commit 5a732f0
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 4 deletions.
3 changes: 2 additions & 1 deletion clinicadl/maps_manager/maps_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,8 @@ def _check_args(self, parameters):
self.n_classes = output_size(self.network_task, None, train_df, self.label)
else:
self.n_classes = None

print(self.network_task)
print(evaluation_metrics(network_task=self.network_task))
self.metrics_module = MetricModule(
evaluation_metrics(network_task=self.network_task),
n_classes=self.n_classes,
Expand Down
1 change: 1 addition & 0 deletions clinicadl/metrics/metric_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(self, metrics, n_classes=2):
for method_name in dir(MetricModule)
if callable(getattr(MetricModule, method_name))
]
print(metrics)
self.metrics = dict()
for metric in metrics:
if f"compute_{metric.lower()}" in list_fn:
Expand Down
14 changes: 11 additions & 3 deletions clinicadl/trainer/tasks_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,11 +192,19 @@ def evaluation_metrics(network_task: Union[str, Task]):
"""
network_task = Task(network_task)
if network_task == Task.CLASSIFICATION:
return [e.value for e in ClassificationMetric].remove("loss")
x = [e.value for e in ClassificationMetric]
x.remove("loss")
return x
elif network_task == Task.REGRESSION:
return [e.value for e in RegressionMetric].remove("loss")
x = [e.value for e in RegressionMetric]
x.remove("loss")
return x
elif network_task == Task.RECONSTRUCTION:
return [e.value for e in ReconstructionMetric].remove("loss")
x = [e.value for e in ReconstructionMetric]
x.remove("loss")
return x
else:
raise ValueError("Unknown network task")


def test(
Expand Down

0 comments on commit 5a732f0

Please sign in to comment.