Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make it possible to only save best model (not last checkpoint) #35070

Closed
umarbutler opened this issue Dec 4, 2024 · 6 comments
Closed

Make it possible to only save best model (not last checkpoint) #35070

umarbutler opened this issue Dec 4, 2024 · 6 comments
Labels
Feature request Request for a new feature

Comments

@umarbutler
Copy link
Contributor

Feature request

If you have eval_strategy = 'steps', eval_steps = 1, save_strategy = 'steps', save_steps = 1, save_total_limit = 1 and load_best_model_at_end = True in your TrainingArguments then, instead of saving the current checkpoint at each save step only if it is better than the last best checkpoint, it will save the current checkpoint at each save step.

This can add up to a seriously large amount of wear on your SSD as shown below:
image

I have written terabytes of data to my disk in a month because I need to checkpoint a large model at excessively short intervals to grab the most accurate checkpoint.

I would like to request that a new flag be added like only_save_if_best = True that will only save a checkpoint if it is the best one yet instead of wasting disk write.

Motivation

See above.

Your contribution

N/A

@umarbutler umarbutler added the Feature request Request for a new feature label Dec 4, 2024
@umarbutler
Copy link
Contributor Author

After discovering how much wear was being done, I switched to writing my models to a drive that I was not using previously. As you can see 2TB was written in a couple days. The harddrive has not been used for anything else.
image

@umarbutler
Copy link
Contributor Author

umarbutler commented Dec 4, 2024

In order to make this change, you would add a only_save_best_model argument to TrainingArguments and then you would change Trainer._save_checkpoint to this:

    def _save_checkpoint(self, model, trial, metrics=None):
        # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
        # want to save except FullyShardedDDP.
        # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"

        # NOTE This stays above so we can set `self.state.best_model_checkpoint = output_dir`
        checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
        run_dir = self._get_output_dir(trial=trial)
        output_dir = os.path.join(run_dir, checkpoint_folder)
        
        # NOTE We moved this above the saving so that we can early exit if we don't want to save
        # Determine the new best metric / best model checkpoint
        if metrics is not None and self.args.metric_for_best_model is not None:
            metric_to_check = self.args.metric_for_best_model
            if not metric_to_check.startswith("eval_"):
                metric_to_check = f"eval_{metric_to_check}"
            try:
                metric_value = metrics[metric_to_check]
            except KeyError as exc:
                raise KeyError(
                    f"The `metric_for_best_model` training argument is set to '{metric_to_check}', "
                    f"which is not found in the evaluation metrics. "
                    f"The available evaluation metrics are: {list(metrics.keys())}. "
                    f"Please ensure that the `compute_metrics` function returns a dictionary that includes '{metric_to_check}' or "
                    f"consider changing the `metric_for_best_model` via the TrainingArguments."
                ) from exc

            operator = np.greater if self.args.greater_is_better else np.less
            if (
                self.state.best_metric is None
                or self.state.best_model_checkpoint is None
                or operator(metric_value, self.state.best_metric)
            ):
                self.state.best_metric = metric_value
                self.state.best_model_checkpoint = output_dir
            
            # Exit if this isn't the best model
            elif self.args.load_best_model_at_end and self.args.save_total_limit == 1:
                return

        # Save model checkpoint
        if self.hp_search_backend is None and trial is None:
            self.store_flos()

        self.save_model(output_dir, _internal_call=True)

        if not self.args.save_only_model:
            # Save optimizer and scheduler
            self._save_optimizer_and_scheduler(output_dir)
            # Save RNG state
            self._save_rng_state(output_dir)

        # Save the Trainer state
        if self.args.should_save:
            # Update `ExportableState` callbacks and `TrainerControl` state to where we are currently
            for cb in [
                cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
            ]:
                cb_name = cb.__class__.__name__
                cb_state = cb.state()
                if isinstance(self.state.stateful_callbacks[cb_name], list):
                    self.state.stateful_callbacks[cb_name].append(cb_state)
                else:
                    self.state.stateful_callbacks[cb_name] = cb_state
            self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))

        if self.args.push_to_hub:
            self._push_from_checkpoint(output_dir)

        # Maybe delete some older checkpoints.
        if self.args.should_save:
            # Solely rely on numerical checkpoint id for rotation.
            # mtime is not reliable especially on some fuse fs in cloud environments.
            self._rotate_checkpoints(use_mtime=False, output_dir=run_dir)

I haven't done a PR because I'm not comfortable with making design choices about the logical dependencies of only_save_best_model (eg, load_best_model_at_end only works if other flags are set a certain way).

I've monkey patched my build as a hot fix for now.

@Rocketknight1
Copy link
Member

cc @muellerzr @SunMarc

@umarbutler
Copy link
Contributor Author

Are there any updates on this?

@umarbutler
Copy link
Contributor Author

This is my hotfix, it overwrites Trainer._save_checkpoint().

import os
import transformers.trainer
from transformers import Trainer

# Monkey patch `Trainer._save_checkpoint()` to only save a checkpoint if it is the best one yet if `self.args.load_best_model_at_end and self.args.save_total_limit == 1`.
def _save_checkpoint(self: Trainer, model, trial, metrics=None):
    # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
    # want to save except FullyShardedDDP.
    # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"

    # NOTE This stays above so we can set `self.state.best_model_checkpoint = output_dir`
    checkpoint_folder = f"{transformers.trainer.PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
    run_dir = self._get_output_dir(trial=trial)
    output_dir = os.path.join(run_dir, checkpoint_folder)
    
    # NOTE We moved this above the saving so that we can early exit if we don't want to save
    # Determine the new best metric / best model checkpoint
    if metrics is not None and self.args.metric_for_best_model is not None:
        metric_to_check = self.args.metric_for_best_model
        if not metric_to_check.startswith("eval_"):
            metric_to_check = f"eval_{metric_to_check}"
        try:
            metric_value = metrics[metric_to_check]
        except KeyError as exc:
            raise KeyError(
                f"The `metric_for_best_model` training argument is set to '{metric_to_check}', "
                f"which is not found in the evaluation metrics. "
                f"The available evaluation metrics are: {list(metrics.keys())}. "
                f"Please ensure that the `compute_metrics` function returns a dictionary that includes '{metric_to_check}' or "
                f"consider changing the `metric_for_best_model` via the TrainingArguments."
            ) from exc

        operator = np.greater if self.args.greater_is_better else np.less
        if (
            self.state.best_metric is None
            or self.state.best_model_checkpoint is None
            or operator(metric_value, self.state.best_metric)
        ):
            self.state.best_metric = metric_value
            self.state.best_model_checkpoint = output_dir
        
        # Exit if this isn't the best model
        elif self.args.load_best_model_at_end and self.args.save_total_limit == 1:
            return

    # Save model checkpoint
    if self.hp_search_backend is None and trial is None:
        self.store_flos()

    self.save_model(output_dir, _internal_call=True)

    if not self.args.save_only_model:
        # Save optimizer and scheduler
        self._save_optimizer_and_scheduler(output_dir)
        # Save RNG state
        self._save_rng_state(output_dir)

    # Save the Trainer state
    if self.args.should_save:
        # Update `ExportableState` callbacks and `TrainerControl` state to where we are currently
        for cb in [
            cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, transformers.trainer.ExportableState)
        ]:
            cb_name = cb.__class__.__name__
            cb_state = cb.state()
            if isinstance(self.state.stateful_callbacks[cb_name], list):
                self.state.stateful_callbacks[cb_name].append(cb_state)
            else:
                self.state.stateful_callbacks[cb_name] = cb_state
        self.state.save_to_json(os.path.join(output_dir, transformers.trainer.TRAINER_STATE_NAME))

    if self.args.push_to_hub:
        self._push_from_checkpoint(output_dir)

    # Maybe delete some older checkpoints.
    if self.args.should_save:
        # Solely rely on numerical checkpoint id for rotation.
        # mtime is not reliable especially on some fuse fs in cloud environments.
        self._rotate_checkpoints(use_mtime=False, output_dir=run_dir)

Trainer._save_checkpoint = _save_checkpoint

@umarbutler
Copy link
Contributor Author

The above hotfix is broken on the main branch because _save_checkpoint() is no longer passed metrics.

However, it turns out that this problem can be fixed by setting save_strategy to best thanks to #31817!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Feature request Request for a new feature
Projects
None yet
Development

No branches or pull requests

2 participants