From f18d2c2f5f2ebcdb98346aa727c25f389d69e6b0 Mon Sep 17 00:00:00 2001 From: "Sean (Seok-Won) Yi" Date: Sat, 19 Oct 2024 13:06:20 +0900 Subject: [PATCH] Added error for metric and changed default best_metric. --- src/transformers/trainer.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 652061a37716ad..4b31b992022e3e 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -419,6 +419,12 @@ def __init__( raise ValueError( f"You have set `args.eval_strategy` to {args.eval_strategy} but you didn't pass an `eval_dataset` to `Trainer`. Either set `args.eval_strategy` to `no` or pass an `eval_dataset`. " ) + if args.save_strategy == SaveStrategy.BEST or args.load_best_model_at_end: + if args.metric_for_best_model is None: + raise ValueError( + "`args.metric_for_best_model` must be provided when using 'best' save_strategy or if `args.load_best_model_at_end` is set to `True`." + ) + self.args = args self.compute_loss_func = compute_loss_func # Seed must be set before instantiating the model when using model @@ -3103,19 +3109,19 @@ def _determine_best_metric(self, metrics, trial): ) from exc operator = np.greater if self.args.greater_is_better else np.less - else: - metric_value = metrics["eval_loss"] - operator = np.less - if self.state.best_metric is None or operator(metric_value, self.state.best_metric): - run_dir = self._get_output_dir(trial=trial) - checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" - output_dir = os.path.join(run_dir, checkpoint_folder) + if self.state.best_metric is None: + self.state.best_metric = float("-inf") if self.args.greater_is_better else float("inf") + + if operator(metric_value, self.state.best_metric): + run_dir = self._get_output_dir(trial=trial) + checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" + output_dir = os.path.join(run_dir, checkpoint_folder) - self.state.best_metric = metric_value - self.state.best_model_checkpoint = output_dir + self.state.best_metric = metric_value + self.state.best_model_checkpoint = output_dir - new_best_metric = True + new_best_metric = True return new_best_metric