Skip to content

Commit

Permalink
add sanity evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
SunMarc committed May 30, 2024
1 parent f5590de commit 7adc1c2
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 9 deletions.
25 changes: 16 additions & 9 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2175,6 +2175,9 @@ def _inner_training_loop(
grad_norm: Optional[float] = None
self.control = self.callback_handler.on_train_begin(args, self.state, self.control)

if args.sanity_evaluation:
self._evaluate(trial, ignore_keys_for_eval)

total_batched_samples = 0
for epoch in range(epochs_trained, num_train_epochs):
epoch_iterator = train_dataloader
Expand Down Expand Up @@ -2723,6 +2726,18 @@ def _issue_warnings_after_load(self, load_result):
f"There were unexpected keys in the checkpoint model loaded: {load_result.unexpected_keys}."
)

def _evaluate(self, trial, ignore_keys_for_eval):
metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
self._report_to_hp_search(trial, self.state.global_step, metrics)

# Run delayed LR scheduler now that metrics are populated
if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
metric_to_check = self.args.metric_for_best_model
if not metric_to_check.startswith("eval_"):
metric_to_check = f"eval_{metric_to_check}"
self.lr_scheduler.step(metrics[metric_to_check])
return metrics

def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval):
if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
if is_torch_xla_available():
Expand All @@ -2749,15 +2764,7 @@ def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, igno

metrics = None
if self.control.should_evaluate:
metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
self._report_to_hp_search(trial, self.state.global_step, metrics)

# Run delayed LR scheduler now that metrics are populated
if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
metric_to_check = self.args.metric_for_best_model
if not metric_to_check.startswith("eval_"):
metric_to_check = f"eval_{metric_to_check}"
self.lr_scheduler.step(metrics[metric_to_check])
metrics = self._evaluate(self, trial, ignore_keys_for_eval)

if self.control.should_save:
self._save_checkpoint(model, trial, metrics=metrics)
Expand Down
8 changes: 8 additions & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,6 +771,9 @@ class TrainingArguments:
rather than saving all eval logits in memory. When set to `True`, you must pass a compute_metrics function
that takes a boolean argument `compute_result`, which when passed `True`, will trigger the final global
summary statistics from the batch-level summary statistics you've accumulated over the evaluation set.
sanity_evaluation(`bool`, *optional*, defaults to `False`):
Whether or not to perform a sanity check to ensure that the validation steps works correctly. It will be performed before the training.
"""

framework = "pt"
Expand Down Expand Up @@ -1454,6 +1457,11 @@ class TrainingArguments:
metadata={"help": "Break eval metrics calculation into batches to save memory."},
)

sanity_evaluation: bool = field(
default=False,
metadata={"help": "Sanity check for the evaluation step."},
)

def __post_init__(self):
# Parse in args that could be `dict` sent in from the CLI as a string
for field in _VALID_DICT_FIELDS:
Expand Down

0 comments on commit 7adc1c2

Please sign in to comment.