Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix _save_checkpoint for online methods #2288

Merged
merged 5 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ jobs:
python -m pip install --upgrade pip
# install PEFT & transformers from source
pip install -U git+https://github.com/huggingface/peft.git
pip install -U git+https://github.com/huggingface/transformers.git
pip install -U git+https://github.com/huggingface/transformers.git@f339042b0b8bdc0b57a70d37f67cafbea960a2ab
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is another recent breaking change. I set this commit to make the CI run anyway. I will fix the new breaking change in a followup PR

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not for this PR. But do you think we should add a CI workflow that tests against our lowest supported version of transformers? This may have picked up the issue with the examples being broken on main when processing_class was added?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. For ref, #2298 aims to make tests clearer/cleaner. I'll extend the work with another PR that implements your suggestion.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pip doesn't have an option to prefer the min version yet pypa/pip#8085
Let's do it manually

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is another recent breaking change. I set this commit to make the CI run anyway. I will fix the new breaking change in a followup PR

This other breaking change is solved here: #2302

Copy link
Member Author

@qgallouedec qgallouedec Oct 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a CI workflow that tests against our lowest supported version of transformers

Done in #2303

# cpu version of pytorch
pip install ".[test, diffusers]"
- name: Test with pytest
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ jobs:
python -m pip install --upgrade pip
# install PEFT & transformers from source
pip install -U git+https://github.com/huggingface/peft.git
pip install -U git+https://github.com/huggingface/transformers.git
pip install -U git+https://github.com/huggingface/transformers.git@f339042b0b8bdc0b57a70d37f67cafbea960a2ab
# cpu version of pytorch
pip install ".[test, diffusers]"
- name: Test with pytest
Expand All @@ -82,7 +82,7 @@ jobs:
run: |
python -m pip install --upgrade pip
# install transformers from source
pip install -U git+https://github.com/huggingface/transformers.git
pip install -U git+https://github.com/huggingface/transformers.git@f339042b0b8bdc0b57a70d37f67cafbea960a2ab
# cpu version of pytorch
pip install .[test]
- name: Test with pytest
Expand Down
51 changes: 49 additions & 2 deletions trl/trainer/online_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import datasets
import jinja2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand All @@ -40,7 +41,7 @@
is_apex_available,
is_wandb_available,
)
from transformers.trainer_utils import EvalPrediction, seed_worker
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, seed_worker
from transformers.training_args import OptimizerNames
from transformers.utils import is_peft_available, is_sagemaker_mp_enabled, logging

Expand Down Expand Up @@ -614,11 +615,57 @@ def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, igno
metrics = None
if self.control.should_evaluate:
metrics = self._evaluate(trial, ignore_keys_for_eval)
is_new_best_metric = self._determine_best_metric(metrics=metrics, trial=trial)

if self.args.save_strategy == "best":
self.control.should_save = is_new_best_metric

if self.control.should_save:
self._save_checkpoint(model, trial, metrics=metrics)
self._save_checkpoint(model, trial)
self.control = self.callback_handler.on_save(self.args, self.state, self.control)

# Copy-pasted from transformers.Trainer to maintain compatibility with earlier versions.
# This can be removed once the minimum transformers version is updated to 4.47.
# Refer to https://github.com/huggingface/trl/pull/2288 for more details.
def _determine_best_metric(self, metrics, trial):
"""
Determine if the model should be saved based on the evaluation metrics.
If args.metric_for_best_model is not set, the loss is used.
Returns:
bool: True if a new best metric was found, else False
"""
is_new_best_metric = False

if self.args.metric_for_best_model is not None:
metric_to_check = self.args.metric_for_best_model

if not metric_to_check.startswith("eval_"):
metric_to_check = f"eval_{metric_to_check}"

try:
metric_value = metrics[metric_to_check]
except KeyError as exc:
raise KeyError(
f"The `metric_for_best_model` training argument is set to '{metric_to_check}', which is not found in the evaluation metrics. "
f"The available evaluation metrics are: {list(metrics.keys())}. Consider changing the `metric_for_best_model` via the TrainingArguments."
) from exc

operator = np.greater if self.args.greater_is_better else np.less

if self.state.best_metric is None:
self.state.best_metric = float("-inf") if self.args.greater_is_better else float("inf")

if operator(metric_value, self.state.best_metric):
run_dir = self._get_output_dir(trial=trial)
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
output_dir = os.path.join(run_dir, checkpoint_folder)
self.state.best_metric = metric_value
self.state.best_model_checkpoint = output_dir

is_new_best_metric = True

return is_new_best_metric

def create_model_card(
self,
model_name: Optional[str] = None,
Expand Down
2 changes: 1 addition & 1 deletion trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ def repeat_generator():
self.lr_scheduler.step()
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
if self.control.should_save:
self._save_checkpoint(model, trial=None, metrics=metrics)
self._save_checkpoint(model, trial=None)
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
del kl, mean_kl, mean_entropy, mean_non_score_reward, scores, metrics, non_score_reward
torch.cuda.empty_cache()
Expand Down
2 changes: 1 addition & 1 deletion trl/trainer/rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,7 @@ def repeat_generator():
self.lr_scheduler.step()
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
if self.control.should_save:
self._save_checkpoint(model, trial=None, metrics=metrics)
self._save_checkpoint(model, trial=None)
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
torch.cuda.empty_cache()
gc.collect()
Expand Down
Loading