diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 64cb5c6bd4ddbb..4315e54a42fc2e 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -117,9 +117,9 @@ EvalPrediction, HPSearchBackend, HubStrategy, - IntervalStrategy, PredictionOutput, RemoveColumnsCollator, + SaveStrategy, TrainerMemoryTracker, TrainOutput, check_target_module_exists, @@ -419,6 +419,12 @@ def __init__( raise ValueError( f"You have set `args.eval_strategy` to {args.eval_strategy} but you didn't pass an `eval_dataset` to `Trainer`. Either set `args.eval_strategy` to `no` or pass an `eval_dataset`. " ) + if args.save_strategy == SaveStrategy.BEST or args.load_best_model_at_end: + if args.metric_for_best_model is None: + raise ValueError( + "`args.metric_for_best_model` must be provided when using 'best' save_strategy or if `args.load_best_model_at_end` is set to `True`." + ) + self.args = args self.compute_loss_func = compute_loss_func # Seed must be set before instantiating the model when using model @@ -2998,9 +3004,13 @@ def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, igno metrics = None if self.control.should_evaluate: metrics = self._evaluate(trial, ignore_keys_for_eval) + is_new_best_metric = self._determine_best_metric(metrics=metrics, trial=trial) + + if self.args.save_strategy == SaveStrategy.BEST: + self.control.should_save = is_new_best_metric if self.control.should_save: - self._save_checkpoint(model, trial, metrics=metrics) + self._save_checkpoint(model, trial) self.control = self.callback_handler.on_save(self.args, self.state, self.control) def _load_rng_state(self, checkpoint): @@ -3077,7 +3087,48 @@ def _load_rng_state(self, checkpoint): "\nThis won't yield the same results as if the training had not been interrupted." ) - def _save_checkpoint(self, model, trial, metrics=None): + def _determine_best_metric(self, metrics, trial): + """ + Determine if the model should be saved based on the evaluation metrics. + If args.metric_for_best_model is not set, the loss is used. + + Returns: + bool: True if a new best metric was found, else False + """ + is_new_best_metric = False + + if self.args.metric_for_best_model is not None: + metric_to_check = self.args.metric_for_best_model + + if not metric_to_check.startswith("eval_"): + metric_to_check = f"eval_{metric_to_check}" + + try: + metric_value = metrics[metric_to_check] + except KeyError as exc: + raise KeyError( + f"The `metric_for_best_model` training argument is set to '{metric_to_check}', which is not found in the evaluation metrics. " + f"The available evaluation metrics are: {list(metrics.keys())}. Consider changing the `metric_for_best_model` via the TrainingArguments." + ) from exc + + operator = np.greater if self.args.greater_is_better else np.less + + if self.state.best_metric is None: + self.state.best_metric = float("-inf") if self.args.greater_is_better else float("inf") + + if operator(metric_value, self.state.best_metric): + run_dir = self._get_output_dir(trial=trial) + checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" + output_dir = os.path.join(run_dir, checkpoint_folder) + + self.state.best_metric = metric_value + self.state.best_model_checkpoint = output_dir + + is_new_best_metric = True + + return is_new_best_metric + + def _save_checkpoint(self, model, trial): # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we # want to save except FullyShardedDDP. # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model" @@ -3098,31 +3149,6 @@ def _save_checkpoint(self, model, trial, metrics=None): # Save RNG state self._save_rng_state(output_dir) - # Determine the new best metric / best model checkpoint - if metrics is not None and self.args.metric_for_best_model is not None: - metric_to_check = self.args.metric_for_best_model - if not metric_to_check.startswith("eval_"): - metric_to_check = f"eval_{metric_to_check}" - try: - metric_value = metrics[metric_to_check] - except KeyError as exc: - raise KeyError( - f"The `metric_for_best_model` training argument is set to '{metric_to_check}', " - f"which is not found in the evaluation metrics. " - f"The available evaluation metrics are: {list(metrics.keys())}. " - f"Please ensure that the `compute_metrics` function returns a dictionary that includes '{metric_to_check}' or " - f"consider changing the `metric_for_best_model` via the TrainingArguments." - ) from exc - - operator = np.greater if self.args.greater_is_better else np.less - if ( - self.state.best_metric is None - or self.state.best_model_checkpoint is None - or operator(metric_value, self.state.best_metric) - ): - self.state.best_metric = metric_value - self.state.best_model_checkpoint = output_dir - # Save the Trainer state if self.args.should_save: # Update `ExportableState` callbacks and `TrainerControl` state to where we are currently @@ -4543,7 +4569,7 @@ def _push_from_checkpoint(self, checkpoint_folder): # Same for the training arguments torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) - if self.args.save_strategy == IntervalStrategy.STEPS: + if self.args.save_strategy == SaveStrategy.STEPS: commit_message = f"Training in progress, step {self.state.global_step}" else: commit_message = f"Training in progress, epoch {int(self.state.epoch)}" diff --git a/src/transformers/trainer_callback.py b/src/transformers/trainer_callback.py index 405874acf8f4c4..ce9f2a26732c2e 100644 --- a/src/transformers/trainer_callback.py +++ b/src/transformers/trainer_callback.py @@ -24,7 +24,7 @@ import numpy as np from tqdm.auto import tqdm -from .trainer_utils import IntervalStrategy, has_length +from .trainer_utils import IntervalStrategy, SaveStrategy, has_length from .training_args import TrainingArguments from .utils import logging @@ -555,7 +555,7 @@ def on_step_end(self, args: TrainingArguments, state: TrainerState, control: Tra # Save if ( - args.save_strategy == IntervalStrategy.STEPS + args.save_strategy == SaveStrategy.STEPS and state.save_steps > 0 and state.global_step % state.save_steps == 0 ): @@ -565,7 +565,7 @@ def on_step_end(self, args: TrainingArguments, state: TrainerState, control: Tra if state.global_step >= state.max_steps: control.should_training_stop = True # Save the model at the end if we have a save strategy - if args.save_strategy != IntervalStrategy.NO: + if args.save_strategy not in [SaveStrategy.NO, SaveStrategy.BEST]: control.should_save = True return control @@ -580,7 +580,7 @@ def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: Tr control.should_evaluate = True # Save - if args.save_strategy == IntervalStrategy.EPOCH: + if args.save_strategy == SaveStrategy.EPOCH: control.should_save = True return control diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index 02c298cf7d2e65..42088cd730628d 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -227,6 +227,13 @@ class IntervalStrategy(ExplicitEnum): EPOCH = "epoch" +class SaveStrategy(ExplicitEnum): + NO = "no" + STEPS = "steps" + EPOCH = "epoch" + BEST = "best" + + class EvaluationStrategy(ExplicitEnum): NO = "no" STEPS = "steps" diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 485610dd9baa28..c98e8bc41b924d 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -33,6 +33,7 @@ FSDPOption, HubStrategy, IntervalStrategy, + SaveStrategy, SchedulerType, ) from .utils import ( @@ -349,12 +350,13 @@ class TrainingArguments: - save_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"steps"`): + save_strategy (`str` or [`~trainer_utils.SaveStrategy`], *optional*, defaults to `"steps"`): The checkpoint save strategy to adopt during training. Possible values are: - `"no"`: No save is done during training. - `"epoch"`: Save is done at the end of each epoch. - `"steps"`: Save is done every `save_steps`. + - `"best"`: Save is done whenever a new `best_metric` is achieved. If `"epoch"` or `"steps"` is chosen, saving will also be performed at the very end of training, always. @@ -962,7 +964,7 @@ class TrainingArguments: }, ) logging_nan_inf_filter: bool = field(default=True, metadata={"help": "Filter nan and inf losses for logging."}) - save_strategy: Union[IntervalStrategy, str] = field( + save_strategy: Union[SaveStrategy, str] = field( default="steps", metadata={"help": "The checkpoint save strategy to use."}, ) @@ -1580,7 +1582,7 @@ def __post_init__(self): self.eval_strategy = IntervalStrategy(self.eval_strategy) self.logging_strategy = IntervalStrategy(self.logging_strategy) - self.save_strategy = IntervalStrategy(self.save_strategy) + self.save_strategy = SaveStrategy(self.save_strategy) self.hub_strategy = HubStrategy(self.hub_strategy) self.lr_scheduler_type = SchedulerType(self.lr_scheduler_type) @@ -1616,7 +1618,7 @@ def __post_init__(self): if self.eval_steps != int(self.eval_steps): raise ValueError(f"--eval_steps must be an integer if bigger than 1: {self.eval_steps}") self.eval_steps = int(self.eval_steps) - if self.save_strategy == IntervalStrategy.STEPS and self.save_steps > 1: + if self.save_strategy == SaveStrategy.STEPS and self.save_steps > 1: if self.save_steps != int(self.save_steps): raise ValueError(f"--save_steps must be an integer if bigger than 1: {self.save_steps}") self.save_steps = int(self.save_steps) @@ -2750,8 +2752,8 @@ def set_save( 100 ``` """ - self.save_strategy = IntervalStrategy(strategy) - if self.save_strategy == IntervalStrategy.STEPS and steps == 0: + self.save_strategy = SaveStrategy(strategy) + if self.save_strategy == SaveStrategy.STEPS and steps == 0: raise ValueError("Setting `strategy` as 'steps' requires a positive value for `steps`.") self.save_steps = steps self.save_total_limit = total_limit diff --git a/src/transformers/training_args_tf.py b/src/transformers/training_args_tf.py index 9df53c3f1d6161..3716a78879d501 100644 --- a/src/transformers/training_args_tf.py +++ b/src/transformers/training_args_tf.py @@ -114,7 +114,7 @@ class TFTrainingArguments(TrainingArguments): Whether to log and evaluate the first `global_step` or not. logging_steps (`int`, *optional*, defaults to 500): Number of update steps between two logs if `logging_strategy="steps"`. - save_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"steps"`): + save_strategy (`str` or [`~trainer_utils.SaveStrategy`], *optional*, defaults to `"steps"`): The checkpoint save strategy to adopt during training. Possible values are: - `"no"`: No save is done during training. diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 5c03355785d2b5..b6fe807fa4961a 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -4041,6 +4041,89 @@ def test_trainer_saves_processor(self): reloaded_tokenizer(test_sentence, padding="max_length").input_ids, ) + def test_save_best_checkpoint(self): + freq = int(64 / self.batch_size) + total = int(self.n_epochs * 64 / self.batch_size) + + # Case 1: args.metric_for_best_model == "accuracy". + with tempfile.TemporaryDirectory() as tmpdir: + trainer = get_regression_trainer( + a=1.5, + b=2.5, + output_dir=tmpdir, + learning_rate=0.1, + eval_strategy="epoch", + save_strategy="best", + metric_for_best_model="accuracy", + compute_metrics=AlmostAccuracy(), + ) + self.assertTrue(trainer.args.metric_for_best_model == "accuracy") + + with patch.object( + trainer, + "_evaluate", + side_effect=[ + {"eval_loss": 0.03, "eval_accuracy": 0.60, "epoch": 1.0}, + {"eval_loss": 0.02, "eval_accuracy": 0.65, "epoch": 2.0}, + {"eval_loss": 0.01, "eval_accuracy": 0.64, "epoch": 3.0}, + ], + ): + trainer.train() + + self.assertEqual(len(os.listdir(tmpdir)), 2) + self.check_saved_checkpoints( + output_dir=tmpdir, + freq=freq, + total=total, + ) + + # Case 2: args.metric_for_best_model == "loss". + with tempfile.TemporaryDirectory() as tmpdir: + trainer = get_regression_trainer( + a=1.5, + b=2.5, + output_dir=tmpdir, + learning_rate=0.1, + eval_strategy="epoch", + save_strategy="best", + metric_for_best_model="loss", + compute_metrics=AlmostAccuracy(), + ) + self.assertTrue(trainer.args.metric_for_best_model == "loss") + + with patch.object( + trainer, + "_evaluate", + side_effect=[ + {"eval_loss": 0.03, "eval_accuracy": 0.60, "epoch": 1.0}, + {"eval_loss": 0.02, "eval_accuracy": 0.65, "epoch": 2.0}, + {"eval_loss": 0.03, "eval_accuracy": 0.66, "epoch": 3.0}, + ], + ): + trainer.train() + + self.assertEqual(len(os.listdir(tmpdir)), 2) + self.check_saved_checkpoints( + output_dir=tmpdir, + freq=freq, + total=total, + ) + + # Case 3: Metric name not provided; throw error. + with tempfile.TemporaryDirectory() as tmpdir: + with self.assertRaises(ValueError) as context: + trainer = get_regression_trainer( + a=1.5, + b=2.5, + output_dir=tmpdir, + learning_rate=0.1, + eval_strategy="epoch", + save_strategy="best", + compute_metrics=AlmostAccuracy(), + ) + + self.assertIn("`args.metric_for_best_model` must be provided", str(context.exception)) + @require_torch @is_staging_test