Skip to content

Commit

Permalink
Fix hyperparameter search when optuna+deepseed (#34642)
Browse files Browse the repository at this point in the history
* Fix hyperparameter search when optuna+deepseed

* Adding free_memory to the search setup

---------

Co-authored-by: Corentin-Royer <[email protected]>
  • Loading branch information
corentin-ryr and Corentin-Royer authored Nov 20, 2024
1 parent 67890de commit bf42c3b
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 14 deletions.
23 changes: 10 additions & 13 deletions src/transformers/integrations/integration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def hp_params(trial):
if is_optuna_available():
import optuna

if isinstance(trial, optuna.Trial):
if isinstance(trial, optuna.trial.BaseTrial):
return trial.params
if is_ray_tune_available():
if isinstance(trial, dict):
Expand All @@ -230,7 +230,7 @@ def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> Be

if trainer.args.process_index == 0:

def _objective(trial, checkpoint_dir=None):
def _objective(trial: optuna.Trial, checkpoint_dir=None):
checkpoint = None
if checkpoint_dir:
for subdir in os.listdir(checkpoint_dir):
Expand All @@ -240,10 +240,11 @@ def _objective(trial, checkpoint_dir=None):
if trainer.args.world_size > 1:
if trainer.args.parallel_mode != ParallelMode.DISTRIBUTED:
raise RuntimeError("only support DDP optuna HPO for ParallelMode.DISTRIBUTED currently.")
trainer._hp_search_setup(trial)
args_main_rank_list = [pickle.dumps(trainer.args)]
torch.distributed.broadcast_object_list(args_main_rank_list, src=0)
trainer.train(resume_from_checkpoint=checkpoint)
trainer.hp_space(trial)
fixed_trial = optuna.trial.FixedTrial(trial.params, trial.number)
trial_main_rank_list = [fixed_trial]
torch.distributed.broadcast_object_list(trial_main_rank_list, src=0)
trainer.train(resume_from_checkpoint=checkpoint, trial=trial)
else:
trainer.train(resume_from_checkpoint=checkpoint, trial=trial)
# If there hasn't been any evaluation during the training loop.
Expand All @@ -268,15 +269,11 @@ def _objective(trial, checkpoint_dir=None):
else:
for i in range(n_trials):
trainer.objective = None
args_main_rank_list = [None]
trial_main_rank_list = [None]
if trainer.args.parallel_mode != ParallelMode.DISTRIBUTED:
raise RuntimeError("only support DDP optuna HPO for ParallelMode.DISTRIBUTED currently.")
torch.distributed.broadcast_object_list(args_main_rank_list, src=0)
args = pickle.loads(bytes(args_main_rank_list[0]))
for key, value in asdict(args).items():
if key != "local_rank":
setattr(trainer.args, key, value)
trainer.train(resume_from_checkpoint=None)
torch.distributed.broadcast_object_list(trial_main_rank_list, src=0)
trainer.train(resume_from_checkpoint=None, trial=trial_main_rank_list[0])
# If there hasn't been any evaluation during the training loop.
if getattr(trainer, "objective", None) is None:
metrics = trainer.evaluate()
Expand Down
5 changes: 4 additions & 1 deletion src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1725,6 +1725,9 @@ def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
if self.is_deepspeed_enabled:
if self.args.deepspeed is None:
raise ValueError("For sweeps with deepspeed, `args.deepspeed` must be set")

self.accelerator.free_memory()

# Rebuild the deepspeed config to reflect the updated training parameters
from accelerate.utils import DeepSpeedPlugin

Expand All @@ -1748,7 +1751,7 @@ def _report_to_hp_search(self, trial: Union["optuna.Trial", Dict[str, Any]], ste
if self.hp_search_backend == HPSearchBackend.OPTUNA:
import optuna

if not trial.study._is_multi_objective():
if hasattr(trial, "study") and not trial.study._is_multi_objective():
trial.report(self.objective, step)
if trial.should_prune():
self.callback_handler.on_train_end(self.args, self.state, self.control)
Expand Down

0 comments on commit bf42c3b

Please sign in to comment.