-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[integration] Update Ray Tune integration for Ray 2.7 #26499
Changes from all commits
0af114d
d2f04cf
d4beffc
0b808c7
288a631
751dfe5
e5bb092
1d1e298
9069ddf
058a04a
9fbd0ec
e95673a
560f7d3
db4c491
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,6 +28,7 @@ | |
import re | ||
import shutil | ||
import sys | ||
import tempfile | ||
import time | ||
import warnings | ||
from collections.abc import Mapping | ||
|
@@ -595,7 +596,6 @@ def __init__( | |
# returned to 0 every time flos need to be logged | ||
self.current_flos = 0 | ||
self.hp_search_backend = None | ||
self.use_tune_checkpoints = False | ||
default_label_names = find_labels(self.model.__class__) | ||
self.label_names = default_label_names if self.args.label_names is None else self.args.label_names | ||
self.can_return_loss = can_return_loss(self.model.__class__) | ||
|
@@ -1201,7 +1201,8 @@ def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]): | |
def _report_to_hp_search(self, trial: Union["optuna.Trial", Dict[str, Any]], step: int, metrics: Dict[str, float]): | ||
if self.hp_search_backend is None or trial is None: | ||
return | ||
self.objective = self.compute_objective(metrics.copy()) | ||
metrics = metrics.copy() | ||
self.objective = self.compute_objective(metrics) | ||
Comment on lines
+1204
to
+1205
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No sure I understand this change There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, this is because the metrics get modified later with some extra keys, so I want to keep the original one in tact: https://github.com/huggingface/transformers/pull/26499/files/9fbd0ec915455293eb4a1730e6a473d5c11b1151#diff-ed55888e6665791fe92cc8fc0c499da54f4ace6738551cd9a2591881cda076deR1199 |
||
if self.hp_search_backend == HPSearchBackend.OPTUNA: | ||
import optuna | ||
|
||
|
@@ -1211,24 +1212,23 @@ def _report_to_hp_search(self, trial: Union["optuna.Trial", Dict[str, Any]], ste | |
self.callback_handler.on_train_end(self.args, self.state, self.control) | ||
raise optuna.TrialPruned() | ||
elif self.hp_search_backend == HPSearchBackend.RAY: | ||
from ray import tune | ||
|
||
if self.control.should_save: | ||
self._tune_save_checkpoint() | ||
tune.report(objective=self.objective, **metrics) | ||
|
||
def _tune_save_checkpoint(self): | ||
from ray import tune | ||
|
||
if not self.use_tune_checkpoints: | ||
return | ||
with tune.checkpoint_dir(step=self.state.global_step) as checkpoint_dir: | ||
output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}") | ||
self.save_model(output_dir, _internal_call=True) | ||
if self.args.should_save: | ||
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME)) | ||
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) | ||
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) | ||
import ray.train | ||
|
||
with tempfile.TemporaryDirectory() as temp_checkpoint_dir: | ||
checkpoint = None | ||
if self.control.should_save: | ||
self._tune_save_checkpoint(checkpoint_dir=temp_checkpoint_dir) | ||
checkpoint = ray.train.Checkpoint.from_directory(temp_checkpoint_dir) | ||
metrics["objective"] = self.objective | ||
ray.train.report(metrics, checkpoint=checkpoint) | ||
|
||
def _tune_save_checkpoint(self, checkpoint_dir: str): | ||
output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}") | ||
self.save_model(output_dir, _internal_call=True) | ||
if self.args.should_save: | ||
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME)) | ||
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) | ||
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) | ||
|
||
def call_model_init(self, trial=None): | ||
model_init_argcount = number_of_arguments(self.model_init) | ||
|
@@ -1997,9 +1997,9 @@ def _get_output_dir(self, trial): | |
if self.hp_search_backend == HPSearchBackend.OPTUNA: | ||
run_id = trial.number | ||
elif self.hp_search_backend == HPSearchBackend.RAY: | ||
from ray import tune | ||
import ray.train | ||
|
||
run_id = tune.get_trial_id() | ||
run_id = ray.train.get_context().get_trial_id() | ||
elif self.hp_search_backend == HPSearchBackend.SIGOPT: | ||
run_id = trial.id | ||
elif self.hp_search_backend == HPSearchBackend.WANDB: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Previously, you needed to set a # of checkpoints to keep in order to do checkpointing at all -- now this is just an optional arg