Skip to content

Commit

Permalink
Added error for metric and changed default best_metric.
Browse files Browse the repository at this point in the history
  • Loading branch information
seanswyi committed Oct 25, 2024
1 parent 5ee5e5a commit f18d2c2
Showing 1 changed file with 16 additions and 10 deletions.
26 changes: 16 additions & 10 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,12 @@ def __init__(
raise ValueError(
f"You have set `args.eval_strategy` to {args.eval_strategy} but you didn't pass an `eval_dataset` to `Trainer`. Either set `args.eval_strategy` to `no` or pass an `eval_dataset`. "
)
if args.save_strategy == SaveStrategy.BEST or args.load_best_model_at_end:
if args.metric_for_best_model is None:
raise ValueError(
"`args.metric_for_best_model` must be provided when using 'best' save_strategy or if `args.load_best_model_at_end` is set to `True`."
)

self.args = args
self.compute_loss_func = compute_loss_func
# Seed must be set before instantiating the model when using model
Expand Down Expand Up @@ -3103,19 +3109,19 @@ def _determine_best_metric(self, metrics, trial):
) from exc

operator = np.greater if self.args.greater_is_better else np.less
else:
metric_value = metrics["eval_loss"]
operator = np.less

if self.state.best_metric is None or operator(metric_value, self.state.best_metric):
run_dir = self._get_output_dir(trial=trial)
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
output_dir = os.path.join(run_dir, checkpoint_folder)
if self.state.best_metric is None:
self.state.best_metric = float("-inf") if self.args.greater_is_better else float("inf")

if operator(metric_value, self.state.best_metric):
run_dir = self._get_output_dir(trial=trial)
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
output_dir = os.path.join(run_dir, checkpoint_folder)

self.state.best_metric = metric_value
self.state.best_model_checkpoint = output_dir
self.state.best_metric = metric_value
self.state.best_model_checkpoint = output_dir

new_best_metric = True
new_best_metric = True

return new_best_metric

Expand Down

0 comments on commit f18d2c2

Please sign in to comment.