From e896c092db1d5b95e89b7ba56c50340406c0a575 Mon Sep 17 00:00:00 2001 From: Jeremy Dohmann Date: Fri, 29 Sep 2023 12:34:00 -0400 Subject: [PATCH] commit --- scripts/eval/eval.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/eval/eval.py b/scripts/eval/eval.py index 2e94ca99d2..09db0a870a 100644 --- a/scripts/eval/eval.py +++ b/scripts/eval/eval.py @@ -145,7 +145,7 @@ def evaluate_model( if eval_gauntlet_df is None and eval_gauntlet_callback is not None: eval_gauntlet_df = pd.DataFrame( - columns=['model_name', 'average'] + + columns=['model_name'] + [avg for avg in eval_gauntlet_callback.averages] + [t.name for t in eval_gauntlet_callback.categories]) load_path = model_cfg.get('load_path', None) @@ -322,7 +322,7 @@ def main(cfg: DictConfig): print( eval_gauntlet_df.sort_values( - list(eval_gauntlet_df.columns)[-1], ascending=False).to_markdown(index=False)) + next(eval_gauntlet_callback.averages.keys()), ascending=False).to_markdown(index=False)) print(f'Printing complete results for all models') assert models_df is not None print(models_df.to_markdown(index=False))