Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[integration] Update Ray Tune integration for Ray 2.7 #26499

Merged
merged 14 commits into from
Dec 9, 2023
Merged
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@
"pytest-timeout",
"pytest-xdist",
"python>=3.8.0",
"ray[tune]",
"ray[tune]>=2.7.0",
"regex!=2019.12.17",
"requests",
"rhoknp>=1.1.0,<1.3.1",
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
"pytest-timeout": "pytest-timeout",
"pytest-xdist": "pytest-xdist",
"python": "python>=3.8.0",
"ray[tune]": "ray[tune]",
"ray[tune]": "ray[tune]>=2.7.0",
"regex": "regex!=2019.12.17",
"requests": "requests",
"rhoknp": "rhoknp>=1.1.0,<1.3.1",
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/hyperparameter_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from .integrations import (
is_optuna_available,
is_ray_available,
is_ray_tune_available,
is_sigopt_available,
is_wandb_available,
run_hp_search_optuna,
Expand Down Expand Up @@ -81,7 +81,7 @@ class RayTuneBackend(HyperParamSearchBackendBase):

@staticmethod
def is_available():
return is_ray_available()
return is_ray_tune_available()

def run(self, trainer, n_trials: int, direction: str, **kwargs):
return run_hp_search_ray(trainer, n_trials, direction, **kwargs)
Expand Down
54 changes: 26 additions & 28 deletions src/transformers/integrations/integration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,9 @@ def _objective(trial, checkpoint_dir=None):

def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
import ray
import ray.train

def _objective(trial, local_trainer, checkpoint_dir=None):
def _objective(trial: dict, local_trainer):
try:
from transformers.utils.notebook import NotebookProgressCallback

Expand All @@ -246,19 +247,34 @@ def _objective(trial, local_trainer, checkpoint_dir=None):
except ModuleNotFoundError:
pass

checkpoint = None
if checkpoint_dir:
for subdir in os.listdir(checkpoint_dir):
if subdir.startswith(PREFIX_CHECKPOINT_DIR):
checkpoint = os.path.join(checkpoint_dir, subdir)
local_trainer.objective = None
local_trainer.train(resume_from_checkpoint=checkpoint, trial=trial)

checkpoint = ray.train.get_checkpoint()
if checkpoint:
# Upon trial resume, the local_trainer's objective gets reset to None.
# If `local_trainer.train` is a noop (training has already reached
# the target number of epochs/steps), then this would
# trigger an unnecessary extra checkpoint at the end of training.
# -> Set the objective to a dummy value upon resume as a workaround.
local_trainer.objective = "objective"

with checkpoint.as_directory() as checkpoint_dir:
checkpoint_path = next(Path(checkpoint_dir).glob(f"{PREFIX_CHECKPOINT_DIR}*")).as_posix()
local_trainer.train(resume_from_checkpoint=checkpoint_path, trial=trial)
else:
local_trainer.train(trial=trial)

# If there hasn't been any evaluation during the training loop.
if getattr(local_trainer, "objective", None) is None:
metrics = local_trainer.evaluate()
local_trainer.objective = local_trainer.compute_objective(metrics)
local_trainer._tune_save_checkpoint()
ray.tune.report(objective=local_trainer.objective, **metrics, done=True)

metrics.update({"objective": local_trainer.objective, "done": True})

with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
local_trainer._tune_save_checkpoint(checkpoint_dir=temp_checkpoint_dir)
checkpoint = ray.train.Checkpoint.from_directory(temp_checkpoint_dir)
ray.train.report(metrics, checkpoint=checkpoint)

if not trainer._memory_tracker.skip_memory_metrics:
from ..trainer_utils import TrainerMemoryTracker
Expand Down Expand Up @@ -296,28 +312,10 @@ def _objective(trial, local_trainer, checkpoint_dir=None):
from ray.tune import CLIReporter

kwargs["progress_reporter"] = CLIReporter(metric_columns=["objective"])
if "keep_checkpoints_num" in kwargs and kwargs["keep_checkpoints_num"] > 0:
# `keep_checkpoints_num=0` would disabled checkpointing
trainer.use_tune_checkpoints = True
if kwargs["keep_checkpoints_num"] > 1:
logger.warning(
f"Currently keeping {kwargs['keep_checkpoints_num']} checkpoints for each trial. "
"Checkpoints are usually huge, "
"consider setting `keep_checkpoints_num=1`."
)
Comment on lines -299 to -307
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously, you needed to set a # of checkpoints to keep in order to do checkpointing at all -- now this is just an optional arg


if "scheduler" in kwargs:
from ray.tune.schedulers import ASHAScheduler, HyperBandForBOHB, MedianStoppingRule, PopulationBasedTraining

# Check if checkpointing is enabled for PopulationBasedTraining
if isinstance(kwargs["scheduler"], PopulationBasedTraining):
if not trainer.use_tune_checkpoints:
logger.warning(
"You are using PopulationBasedTraining but you haven't enabled checkpointing. "
"This means your trials will train from scratch everytime they are exploiting "
"new configurations. Consider enabling checkpointing by passing "
"`keep_checkpoints_num=1` as an additional argument to `Trainer.hyperparameter_search`."
)

# Check for `do_eval` and `eval_during_training` for schedulers that require intermediate reporting.
if isinstance(
kwargs["scheduler"], (ASHAScheduler, MedianStoppingRule, HyperBandForBOHB, PopulationBasedTraining)
Expand Down
44 changes: 22 additions & 22 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import re
import shutil
import sys
import tempfile
import time
import warnings
from collections.abc import Mapping
Expand Down Expand Up @@ -595,7 +596,6 @@ def __init__(
# returned to 0 every time flos need to be logged
self.current_flos = 0
self.hp_search_backend = None
self.use_tune_checkpoints = False
default_label_names = find_labels(self.model.__class__)
self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
self.can_return_loss = can_return_loss(self.model.__class__)
Expand Down Expand Up @@ -1201,7 +1201,8 @@ def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
def _report_to_hp_search(self, trial: Union["optuna.Trial", Dict[str, Any]], step: int, metrics: Dict[str, float]):
if self.hp_search_backend is None or trial is None:
return
self.objective = self.compute_objective(metrics.copy())
metrics = metrics.copy()
self.objective = self.compute_objective(metrics)
Comment on lines +1204 to +1205
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No sure I understand this change

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, this is because the metrics get modified later with some extra keys, so I want to keep the original one in tact: https://github.com/huggingface/transformers/pull/26499/files/9fbd0ec915455293eb4a1730e6a473d5c11b1151#diff-ed55888e6665791fe92cc8fc0c499da54f4ace6738551cd9a2591881cda076deR1199

if self.hp_search_backend == HPSearchBackend.OPTUNA:
import optuna

Expand All @@ -1211,24 +1212,23 @@ def _report_to_hp_search(self, trial: Union["optuna.Trial", Dict[str, Any]], ste
self.callback_handler.on_train_end(self.args, self.state, self.control)
raise optuna.TrialPruned()
elif self.hp_search_backend == HPSearchBackend.RAY:
from ray import tune

if self.control.should_save:
self._tune_save_checkpoint()
tune.report(objective=self.objective, **metrics)

def _tune_save_checkpoint(self):
from ray import tune

if not self.use_tune_checkpoints:
return
with tune.checkpoint_dir(step=self.state.global_step) as checkpoint_dir:
output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
self.save_model(output_dir, _internal_call=True)
if self.args.should_save:
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
import ray.train

with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
checkpoint = None
if self.control.should_save:
self._tune_save_checkpoint(checkpoint_dir=temp_checkpoint_dir)
checkpoint = ray.train.Checkpoint.from_directory(temp_checkpoint_dir)
metrics["objective"] = self.objective
ray.train.report(metrics, checkpoint=checkpoint)

def _tune_save_checkpoint(self, checkpoint_dir: str):
output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
self.save_model(output_dir, _internal_call=True)
if self.args.should_save:
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))

def call_model_init(self, trial=None):
model_init_argcount = number_of_arguments(self.model_init)
Expand Down Expand Up @@ -1997,9 +1997,9 @@ def _get_output_dir(self, trial):
if self.hp_search_backend == HPSearchBackend.OPTUNA:
run_id = trial.number
elif self.hp_search_backend == HPSearchBackend.RAY:
from ray import tune
import ray.train

run_id = tune.get_trial_id()
run_id = ray.train.get_context().get_trial_id()
elif self.hp_search_backend == HPSearchBackend.SIGOPT:
run_id = trial.id
elif self.hp_search_backend == HPSearchBackend.WANDB:
Expand Down
Loading