diff --git a/src/transformers/integrations/integration_utils.py b/src/transformers/integrations/integration_utils.py index 4b236b9155f158..0cc2685a55206f 100755 --- a/src/transformers/integrations/integration_utils.py +++ b/src/transformers/integrations/integration_utils.py @@ -208,7 +208,7 @@ def hp_params(trial): if is_optuna_available(): import optuna - if isinstance(trial, optuna.Trial): + if isinstance(trial, optuna.trial.BaseTrial): return trial.params if is_ray_tune_available(): if isinstance(trial, dict): @@ -230,7 +230,7 @@ def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> Be if trainer.args.process_index == 0: - def _objective(trial, checkpoint_dir=None): + def _objective(trial: optuna.Trial, checkpoint_dir=None): checkpoint = None if checkpoint_dir: for subdir in os.listdir(checkpoint_dir): @@ -240,10 +240,11 @@ def _objective(trial, checkpoint_dir=None): if trainer.args.world_size > 1: if trainer.args.parallel_mode != ParallelMode.DISTRIBUTED: raise RuntimeError("only support DDP optuna HPO for ParallelMode.DISTRIBUTED currently.") - trainer._hp_search_setup(trial) - args_main_rank_list = [pickle.dumps(trainer.args)] - torch.distributed.broadcast_object_list(args_main_rank_list, src=0) - trainer.train(resume_from_checkpoint=checkpoint) + trainer.hp_space(trial) + fixed_trial = optuna.trial.FixedTrial(trial.params, trial.number) + trial_main_rank_list = [fixed_trial] + torch.distributed.broadcast_object_list(trial_main_rank_list, src=0) + trainer.train(resume_from_checkpoint=checkpoint, trial=trial) else: trainer.train(resume_from_checkpoint=checkpoint, trial=trial) # If there hasn't been any evaluation during the training loop. @@ -268,15 +269,11 @@ def _objective(trial, checkpoint_dir=None): else: for i in range(n_trials): trainer.objective = None - args_main_rank_list = [None] + trial_main_rank_list = [None] if trainer.args.parallel_mode != ParallelMode.DISTRIBUTED: raise RuntimeError("only support DDP optuna HPO for ParallelMode.DISTRIBUTED currently.") - torch.distributed.broadcast_object_list(args_main_rank_list, src=0) - args = pickle.loads(bytes(args_main_rank_list[0])) - for key, value in asdict(args).items(): - if key != "local_rank": - setattr(trainer.args, key, value) - trainer.train(resume_from_checkpoint=None) + torch.distributed.broadcast_object_list(trial_main_rank_list, src=0) + trainer.train(resume_from_checkpoint=None, trial=trial_main_rank_list[0]) # If there hasn't been any evaluation during the training loop. if getattr(trainer, "objective", None) is None: metrics = trainer.evaluate() diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 129398e374be73..f2e0a90acddd16 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1725,6 +1725,9 @@ def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]): if self.is_deepspeed_enabled: if self.args.deepspeed is None: raise ValueError("For sweeps with deepspeed, `args.deepspeed` must be set") + + self.accelerator.free_memory() + # Rebuild the deepspeed config to reflect the updated training parameters from accelerate.utils import DeepSpeedPlugin @@ -1748,7 +1751,7 @@ def _report_to_hp_search(self, trial: Union["optuna.Trial", Dict[str, Any]], ste if self.hp_search_backend == HPSearchBackend.OPTUNA: import optuna - if not trial.study._is_multi_objective(): + if hasattr(trial, "study") and not trial.study._is_multi_objective(): trial.report(self.objective, step) if trial.should_prune(): self.callback_handler.on_train_end(self.args, self.state, self.control)