diff --git a/src/eva/core/trainers/_recorder.py b/src/eva/core/trainers/_recorder.py index 5c4f1e12..8e85ed99 100644 --- a/src/eva/core/trainers/_recorder.py +++ b/src/eva/core/trainers/_recorder.py @@ -189,18 +189,18 @@ def _print_table(metrics_dict: Dict[str, SESSION_STATISTICS], stage: str, datase title=f"\n{stage.capitalize()} Dataset {dataset_idx}", title_style="bold" ) metrics_table.add_column("Metric", style="cyan") - metrics_table.add_column("Mean", justify="right", style="magenta") - metrics_table.add_column("Stdev", justify="right", style="magenta") + metrics_table.add_column("Mean", style="magenta") + metrics_table.add_column("Stdev", style="magenta") + metrics_table.add_column("All", style="magenta") n_runs = len(metrics_dict[next(iter(metrics_dict))]["values"]) - for i in range(n_runs): - metrics_table.add_column(f"Run {i}", justify="right", style="magenta") - for metric_name, metric_dict in metrics_dict.items(): - row = [metric_name, metric_dict["mean"], metric_dict["stdev"]] + [ - metric_dict["values"][i] for i in range(n_runs) + row = [ + metric_name, + f'{metric_dict["mean"]:.3f}', + f'{metric_dict["stdev"]:.3f}', + ", ".join(f'{metric_dict["values"][i]:.3f}' for i in range(n_runs)), ] - row = [str(entry) for entry in row] metrics_table.add_row(*row) console = rich_console.Console()