Skip to content

Commit

Permalink
New option called "best" for args.save_strategy. (huggingface#31817)
Browse files Browse the repository at this point in the history
* Add _determine_best_metric and new saving logic.

1. Logic to determine the best logic was separated out from
`_save_checkpoint`.
2. In `_maybe_log_save_evaluate`, whether or not a new best metric was
achieved is determined after each evaluation, and if the save strategy
is "best' then the TrainerControl is updated accordingly.

* Added SaveStrategy.

Same as IntervalStrategy, but with a new attribute called BEST.

* IntervalStrategy -> SaveStrategy

* IntervalStratgy -> SaveStrategy for save_strat.

* Interval -> Save in docstring.

* Updated docstring for save_strategy.

* Added SaveStrategy and made according changes.

`save_strategy` previously followed `IntervalStrategy` but now follows
`SaveStrategy`.

Changes were made accordingly to the code and the docstring.

* Changes from `make fixup`.

* Removed redundant metrics argument.

* Added new test_save_best_checkpoint test.

1. Checks for both cases where `metric_for_best_model` is explicitly
provided and when it's not provided.
2. The first case should have two checkpoints saved, whereas the second
should have three saved.

* Changed should_training_end saving logic.

The Trainer saves a checkpoints at the end of training by default as
long as `save_strategy != SaveStrategy.NO`. This condition was modified
to include `SaveStrategy.BEST` because it would be counterintuitive that
we'd only want the best checkpoint to be saved but the last one is as
well.

* `args.metric_for_best_model` default to loss.

* Undo metric_for_best_model update.

* Remove checking metric_for_best_model.

* Added test cases for loss and no metric.

* Added error for metric and changed default best_metric.

* Removed unused import.

* `new_best_metric` -> `is_new_best_metric`

Co-authored-by: Arthur <[email protected]>

* Applied `is_new_best_metric` to all.

Changes were made for consistency and also to fix a potential bug.

---------

Co-authored-by: Arthur <[email protected]>
Co-authored-by: Zach Mueller <[email protected]>
  • Loading branch information
3 people authored and BernardZach committed Dec 5, 2024
1 parent 14716d0 commit 96dcad5
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 40 deletions.
84 changes: 55 additions & 29 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,9 @@
EvalPrediction,
HPSearchBackend,
HubStrategy,
IntervalStrategy,
PredictionOutput,
RemoveColumnsCollator,
SaveStrategy,
TrainerMemoryTracker,
TrainOutput,
check_target_module_exists,
Expand Down Expand Up @@ -419,6 +419,12 @@ def __init__(
raise ValueError(
f"You have set `args.eval_strategy` to {args.eval_strategy} but you didn't pass an `eval_dataset` to `Trainer`. Either set `args.eval_strategy` to `no` or pass an `eval_dataset`. "
)
if args.save_strategy == SaveStrategy.BEST or args.load_best_model_at_end:
if args.metric_for_best_model is None:
raise ValueError(
"`args.metric_for_best_model` must be provided when using 'best' save_strategy or if `args.load_best_model_at_end` is set to `True`."
)

self.args = args
self.compute_loss_func = compute_loss_func
# Seed must be set before instantiating the model when using model
Expand Down Expand Up @@ -2998,9 +3004,13 @@ def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, igno
metrics = None
if self.control.should_evaluate:
metrics = self._evaluate(trial, ignore_keys_for_eval)
is_new_best_metric = self._determine_best_metric(metrics=metrics, trial=trial)

if self.args.save_strategy == SaveStrategy.BEST:
self.control.should_save = is_new_best_metric

if self.control.should_save:
self._save_checkpoint(model, trial, metrics=metrics)
self._save_checkpoint(model, trial)
self.control = self.callback_handler.on_save(self.args, self.state, self.control)

def _load_rng_state(self, checkpoint):
Expand Down Expand Up @@ -3077,7 +3087,48 @@ def _load_rng_state(self, checkpoint):
"\nThis won't yield the same results as if the training had not been interrupted."
)

def _save_checkpoint(self, model, trial, metrics=None):
def _determine_best_metric(self, metrics, trial):
"""
Determine if the model should be saved based on the evaluation metrics.
If args.metric_for_best_model is not set, the loss is used.
Returns:
bool: True if a new best metric was found, else False
"""
is_new_best_metric = False

if self.args.metric_for_best_model is not None:
metric_to_check = self.args.metric_for_best_model

if not metric_to_check.startswith("eval_"):
metric_to_check = f"eval_{metric_to_check}"

try:
metric_value = metrics[metric_to_check]
except KeyError as exc:
raise KeyError(
f"The `metric_for_best_model` training argument is set to '{metric_to_check}', which is not found in the evaluation metrics. "
f"The available evaluation metrics are: {list(metrics.keys())}. Consider changing the `metric_for_best_model` via the TrainingArguments."
) from exc

operator = np.greater if self.args.greater_is_better else np.less

if self.state.best_metric is None:
self.state.best_metric = float("-inf") if self.args.greater_is_better else float("inf")

if operator(metric_value, self.state.best_metric):
run_dir = self._get_output_dir(trial=trial)
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
output_dir = os.path.join(run_dir, checkpoint_folder)

self.state.best_metric = metric_value
self.state.best_model_checkpoint = output_dir

is_new_best_metric = True

return is_new_best_metric

def _save_checkpoint(self, model, trial):
# In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
# want to save except FullyShardedDDP.
# assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
Expand All @@ -3098,31 +3149,6 @@ def _save_checkpoint(self, model, trial, metrics=None):
# Save RNG state
self._save_rng_state(output_dir)

# Determine the new best metric / best model checkpoint
if metrics is not None and self.args.metric_for_best_model is not None:
metric_to_check = self.args.metric_for_best_model
if not metric_to_check.startswith("eval_"):
metric_to_check = f"eval_{metric_to_check}"
try:
metric_value = metrics[metric_to_check]
except KeyError as exc:
raise KeyError(
f"The `metric_for_best_model` training argument is set to '{metric_to_check}', "
f"which is not found in the evaluation metrics. "
f"The available evaluation metrics are: {list(metrics.keys())}. "
f"Please ensure that the `compute_metrics` function returns a dictionary that includes '{metric_to_check}' or "
f"consider changing the `metric_for_best_model` via the TrainingArguments."
) from exc

operator = np.greater if self.args.greater_is_better else np.less
if (
self.state.best_metric is None
or self.state.best_model_checkpoint is None
or operator(metric_value, self.state.best_metric)
):
self.state.best_metric = metric_value
self.state.best_model_checkpoint = output_dir

# Save the Trainer state
if self.args.should_save:
# Update `ExportableState` callbacks and `TrainerControl` state to where we are currently
Expand Down Expand Up @@ -4543,7 +4569,7 @@ def _push_from_checkpoint(self, checkpoint_folder):
# Same for the training arguments
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

if self.args.save_strategy == IntervalStrategy.STEPS:
if self.args.save_strategy == SaveStrategy.STEPS:
commit_message = f"Training in progress, step {self.state.global_step}"
else:
commit_message = f"Training in progress, epoch {int(self.state.epoch)}"
Expand Down
8 changes: 4 additions & 4 deletions src/transformers/trainer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import numpy as np
from tqdm.auto import tqdm

from .trainer_utils import IntervalStrategy, has_length
from .trainer_utils import IntervalStrategy, SaveStrategy, has_length
from .training_args import TrainingArguments
from .utils import logging

Expand Down Expand Up @@ -555,7 +555,7 @@ def on_step_end(self, args: TrainingArguments, state: TrainerState, control: Tra

# Save
if (
args.save_strategy == IntervalStrategy.STEPS
args.save_strategy == SaveStrategy.STEPS
and state.save_steps > 0
and state.global_step % state.save_steps == 0
):
Expand All @@ -565,7 +565,7 @@ def on_step_end(self, args: TrainingArguments, state: TrainerState, control: Tra
if state.global_step >= state.max_steps:
control.should_training_stop = True
# Save the model at the end if we have a save strategy
if args.save_strategy != IntervalStrategy.NO:
if args.save_strategy not in [SaveStrategy.NO, SaveStrategy.BEST]:
control.should_save = True

return control
Expand All @@ -580,7 +580,7 @@ def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: Tr
control.should_evaluate = True

# Save
if args.save_strategy == IntervalStrategy.EPOCH:
if args.save_strategy == SaveStrategy.EPOCH:
control.should_save = True

return control
Expand Down
7 changes: 7 additions & 0 deletions src/transformers/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,13 @@ class IntervalStrategy(ExplicitEnum):
EPOCH = "epoch"


class SaveStrategy(ExplicitEnum):
NO = "no"
STEPS = "steps"
EPOCH = "epoch"
BEST = "best"


class EvaluationStrategy(ExplicitEnum):
NO = "no"
STEPS = "steps"
Expand Down
14 changes: 8 additions & 6 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
FSDPOption,
HubStrategy,
IntervalStrategy,
SaveStrategy,
SchedulerType,
)
from .utils import (
Expand Down Expand Up @@ -349,12 +350,13 @@ class TrainingArguments:
</Tip>
save_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"steps"`):
save_strategy (`str` or [`~trainer_utils.SaveStrategy`], *optional*, defaults to `"steps"`):
The checkpoint save strategy to adopt during training. Possible values are:
- `"no"`: No save is done during training.
- `"epoch"`: Save is done at the end of each epoch.
- `"steps"`: Save is done every `save_steps`.
- `"best"`: Save is done whenever a new `best_metric` is achieved.
If `"epoch"` or `"steps"` is chosen, saving will also be performed at the
very end of training, always.
Expand Down Expand Up @@ -962,7 +964,7 @@ class TrainingArguments:
},
)
logging_nan_inf_filter: bool = field(default=True, metadata={"help": "Filter nan and inf losses for logging."})
save_strategy: Union[IntervalStrategy, str] = field(
save_strategy: Union[SaveStrategy, str] = field(
default="steps",
metadata={"help": "The checkpoint save strategy to use."},
)
Expand Down Expand Up @@ -1580,7 +1582,7 @@ def __post_init__(self):

self.eval_strategy = IntervalStrategy(self.eval_strategy)
self.logging_strategy = IntervalStrategy(self.logging_strategy)
self.save_strategy = IntervalStrategy(self.save_strategy)
self.save_strategy = SaveStrategy(self.save_strategy)
self.hub_strategy = HubStrategy(self.hub_strategy)

self.lr_scheduler_type = SchedulerType(self.lr_scheduler_type)
Expand Down Expand Up @@ -1616,7 +1618,7 @@ def __post_init__(self):
if self.eval_steps != int(self.eval_steps):
raise ValueError(f"--eval_steps must be an integer if bigger than 1: {self.eval_steps}")
self.eval_steps = int(self.eval_steps)
if self.save_strategy == IntervalStrategy.STEPS and self.save_steps > 1:
if self.save_strategy == SaveStrategy.STEPS and self.save_steps > 1:
if self.save_steps != int(self.save_steps):
raise ValueError(f"--save_steps must be an integer if bigger than 1: {self.save_steps}")
self.save_steps = int(self.save_steps)
Expand Down Expand Up @@ -2750,8 +2752,8 @@ def set_save(
100
```
"""
self.save_strategy = IntervalStrategy(strategy)
if self.save_strategy == IntervalStrategy.STEPS and steps == 0:
self.save_strategy = SaveStrategy(strategy)
if self.save_strategy == SaveStrategy.STEPS and steps == 0:
raise ValueError("Setting `strategy` as 'steps' requires a positive value for `steps`.")
self.save_steps = steps
self.save_total_limit = total_limit
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/training_args_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class TFTrainingArguments(TrainingArguments):
Whether to log and evaluate the first `global_step` or not.
logging_steps (`int`, *optional*, defaults to 500):
Number of update steps between two logs if `logging_strategy="steps"`.
save_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"steps"`):
save_strategy (`str` or [`~trainer_utils.SaveStrategy`], *optional*, defaults to `"steps"`):
The checkpoint save strategy to adopt during training. Possible values are:
- `"no"`: No save is done during training.
Expand Down
83 changes: 83 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4041,6 +4041,89 @@ def test_trainer_saves_processor(self):
reloaded_tokenizer(test_sentence, padding="max_length").input_ids,
)

def test_save_best_checkpoint(self):
freq = int(64 / self.batch_size)
total = int(self.n_epochs * 64 / self.batch_size)

# Case 1: args.metric_for_best_model == "accuracy".
with tempfile.TemporaryDirectory() as tmpdir:
trainer = get_regression_trainer(
a=1.5,
b=2.5,
output_dir=tmpdir,
learning_rate=0.1,
eval_strategy="epoch",
save_strategy="best",
metric_for_best_model="accuracy",
compute_metrics=AlmostAccuracy(),
)
self.assertTrue(trainer.args.metric_for_best_model == "accuracy")

with patch.object(
trainer,
"_evaluate",
side_effect=[
{"eval_loss": 0.03, "eval_accuracy": 0.60, "epoch": 1.0},
{"eval_loss": 0.02, "eval_accuracy": 0.65, "epoch": 2.0},
{"eval_loss": 0.01, "eval_accuracy": 0.64, "epoch": 3.0},
],
):
trainer.train()

self.assertEqual(len(os.listdir(tmpdir)), 2)
self.check_saved_checkpoints(
output_dir=tmpdir,
freq=freq,
total=total,
)

# Case 2: args.metric_for_best_model == "loss".
with tempfile.TemporaryDirectory() as tmpdir:
trainer = get_regression_trainer(
a=1.5,
b=2.5,
output_dir=tmpdir,
learning_rate=0.1,
eval_strategy="epoch",
save_strategy="best",
metric_for_best_model="loss",
compute_metrics=AlmostAccuracy(),
)
self.assertTrue(trainer.args.metric_for_best_model == "loss")

with patch.object(
trainer,
"_evaluate",
side_effect=[
{"eval_loss": 0.03, "eval_accuracy": 0.60, "epoch": 1.0},
{"eval_loss": 0.02, "eval_accuracy": 0.65, "epoch": 2.0},
{"eval_loss": 0.03, "eval_accuracy": 0.66, "epoch": 3.0},
],
):
trainer.train()

self.assertEqual(len(os.listdir(tmpdir)), 2)
self.check_saved_checkpoints(
output_dir=tmpdir,
freq=freq,
total=total,
)

# Case 3: Metric name not provided; throw error.
with tempfile.TemporaryDirectory() as tmpdir:
with self.assertRaises(ValueError) as context:
trainer = get_regression_trainer(
a=1.5,
b=2.5,
output_dir=tmpdir,
learning_rate=0.1,
eval_strategy="epoch",
save_strategy="best",
compute_metrics=AlmostAccuracy(),
)

self.assertIn("`args.metric_for_best_model` must be provided", str(context.exception))


@require_torch
@is_staging_test
Expand Down

0 comments on commit 96dcad5

Please sign in to comment.