From 7adc1c28bc2fa580b85ce5d689fc0ddf2dc13113 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Thu, 30 May 2024 17:06:57 +0200 Subject: [PATCH] add sanity evaluation --- src/transformers/trainer.py | 25 ++++++++++++++++--------- src/transformers/training_args.py | 8 ++++++++ 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 49e780306611dd..342028a3a4f918 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2175,6 +2175,9 @@ def _inner_training_loop( grad_norm: Optional[float] = None self.control = self.callback_handler.on_train_begin(args, self.state, self.control) + if args.sanity_evaluation: + self._evaluate(trial, ignore_keys_for_eval) + total_batched_samples = 0 for epoch in range(epochs_trained, num_train_epochs): epoch_iterator = train_dataloader @@ -2723,6 +2726,18 @@ def _issue_warnings_after_load(self, load_result): f"There were unexpected keys in the checkpoint model loaded: {load_result.unexpected_keys}." ) + def _evaluate(self, trial, ignore_keys_for_eval): + metrics = self.evaluate(ignore_keys=ignore_keys_for_eval) + self._report_to_hp_search(trial, self.state.global_step, metrics) + + # Run delayed LR scheduler now that metrics are populated + if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): + metric_to_check = self.args.metric_for_best_model + if not metric_to_check.startswith("eval_"): + metric_to_check = f"eval_{metric_to_check}" + self.lr_scheduler.step(metrics[metric_to_check]) + return metrics + def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval): if self.control.should_log and self.state.global_step > self._globalstep_last_logged: if is_torch_xla_available(): @@ -2749,15 +2764,7 @@ def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, igno metrics = None if self.control.should_evaluate: - metrics = self.evaluate(ignore_keys=ignore_keys_for_eval) - self._report_to_hp_search(trial, self.state.global_step, metrics) - - # Run delayed LR scheduler now that metrics are populated - if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): - metric_to_check = self.args.metric_for_best_model - if not metric_to_check.startswith("eval_"): - metric_to_check = f"eval_{metric_to_check}" - self.lr_scheduler.step(metrics[metric_to_check]) + metrics = self._evaluate(self, trial, ignore_keys_for_eval) if self.control.should_save: self._save_checkpoint(model, trial, metrics=metrics) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index a97139a07ba938..6aa37d0ab1cdcf 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -771,6 +771,9 @@ class TrainingArguments: rather than saving all eval logits in memory. When set to `True`, you must pass a compute_metrics function that takes a boolean argument `compute_result`, which when passed `True`, will trigger the final global summary statistics from the batch-level summary statistics you've accumulated over the evaluation set. + + sanity_evaluation(`bool`, *optional*, defaults to `False`): + Whether or not to perform a sanity check to ensure that the validation steps works correctly. It will be performed before the training. """ framework = "pt" @@ -1454,6 +1457,11 @@ class TrainingArguments: metadata={"help": "Break eval metrics calculation into batches to save memory."}, ) + sanity_evaluation: bool = field( + default=False, + metadata={"help": "Sanity check for the evaluation step."}, + ) + def __post_init__(self): # Parse in args that could be `dict` sent in from the CLI as a string for field in _VALID_DICT_FIELDS: