From 36f1e21f6b764eb361e0cddf77974f1975435b24 Mon Sep 17 00:00:00 2001 From: Lam Hieu Date: Sun, 28 Jan 2024 15:16:28 +0700 Subject: [PATCH 1/5] fix: resolve deepspeed resume peft model issues --- src/transformers/integrations/deepspeed.py | 108 ++++++++++++++++----- src/transformers/trainer.py | 8 +- 2 files changed, 91 insertions(+), 25 deletions(-) diff --git a/src/transformers/integrations/deepspeed.py b/src/transformers/integrations/deepspeed.py index 92cc1a4b0e5947..6efc6f9cbeb2bd 100644 --- a/src/transformers/integrations/deepspeed.py +++ b/src/transformers/integrations/deepspeed.py @@ -125,7 +125,9 @@ def fill_match(self, ds_key_long, hf_val, hf_key=None, must_match=True): ds_val = config.get(ds_key) if ds_val is not None and ds_val != hf_val: - self.mismatches.append(f"- ds {ds_key_long}={ds_val} vs hf {hf_key}={hf_val}") + self.mismatches.append( + f"- ds {ds_key_long}={ds_val} vs hf {hf_key}={hf_val}" + ) fill_only = partialmethod(fill_match, must_match=False) @@ -136,26 +138,45 @@ def trainer_config_process(self, args, auto_find_batch_size=False): """ # DeepSpeed does: # train_batch_size = world_size * train_micro_batch_size_per_gpu * gradient_accumulation_steps - train_batch_size = args.world_size * args.per_device_train_batch_size * args.gradient_accumulation_steps + train_batch_size = ( + args.world_size + * args.per_device_train_batch_size + * args.gradient_accumulation_steps + ) self.fill_match( "train_micro_batch_size_per_gpu", args.per_device_train_batch_size, "per_device_train_batch_size", not auto_find_batch_size, ) - self.fill_match("gradient_accumulation_steps", args.gradient_accumulation_steps, "gradient_accumulation_steps") self.fill_match( - "train_batch_size", train_batch_size, "train_batch_size (calculated)", not auto_find_batch_size + "gradient_accumulation_steps", + args.gradient_accumulation_steps, + "gradient_accumulation_steps", + ) + self.fill_match( + "train_batch_size", + train_batch_size, + "train_batch_size (calculated)", + not auto_find_batch_size, ) self.fill_match("gradient_clipping", args.max_grad_norm, "max_grad_norm") self.fill_match("optimizer.params.lr", args.learning_rate, "learning_rate") - self.fill_match("optimizer.params.betas", [args.adam_beta1, args.adam_beta2], "adam_beta1+adam_beta2") + self.fill_match( + "optimizer.params.betas", + [args.adam_beta1, args.adam_beta2], + "adam_beta1+adam_beta2", + ) self.fill_match("optimizer.params.eps", args.adam_epsilon, "adam_epsilon") - self.fill_match("optimizer.params.weight_decay", args.weight_decay, "weight_decay") + self.fill_match( + "optimizer.params.weight_decay", args.weight_decay, "weight_decay" + ) self.fill_only("scheduler.params.warmup_min_lr", 0) # not a trainer arg - self.fill_match("scheduler.params.warmup_max_lr", args.learning_rate, "learning_rate") + self.fill_match( + "scheduler.params.warmup_max_lr", args.learning_rate, "learning_rate" + ) # total_num_steps - will get set in trainer_config_finalize # fp16 @@ -179,10 +200,14 @@ def trainer_config_process(self, args, auto_find_batch_size=False): # apex: delegates amp work to apex (which needs to be available), but it cannot be used with any # ZeRO features - self.fill_match("amp.enabled", fp16_backend == "apex", "fp16+fp16_backend(apex)") + self.fill_match( + "amp.enabled", fp16_backend == "apex", "fp16+fp16_backend(apex)" + ) self.fill_match("amp.opt_level", args.fp16_opt_level, "fp16_opt_level") - self.fill_match("bf16.enabled", (args.bf16 or args.bf16_full_eval), "bf16|bf16_full_eval") + self.fill_match( + "bf16.enabled", (args.bf16 or args.bf16_full_eval), "bf16|bf16_full_eval" + ) # deepspeed's default mode is fp16 unless there is a config that says differently if self.is_true("bf16.enabled"): @@ -222,15 +247,31 @@ def trainer_config_finalize(self, args, model, num_training_steps): "`auto` values for these keys with an integer value of your choice." ) - self.fill_only("zero_optimization.reduce_bucket_size", hidden_size * hidden_size) + self.fill_only( + "zero_optimization.reduce_bucket_size", hidden_size * hidden_size + ) if self.is_zero3(): # automatically assign the optimal config values based on model config - self.fill_only("zero_optimization.stage3_prefetch_bucket_size", 0.9 * hidden_size * hidden_size) - self.fill_only("zero_optimization.stage3_param_persistence_threshold", 10 * hidden_size) + self.fill_only( + "zero_optimization.stage3_prefetch_bucket_size", + 0.9 * hidden_size * hidden_size, + ) + self.fill_only( + "zero_optimization.stage3_param_persistence_threshold", + 10 * hidden_size, + ) # scheduler - self.fill_match("scheduler.params.total_num_steps", num_training_steps, "num_training_steps (calculated)") - self.fill_match("scheduler.params.warmup_num_steps", args.get_warmup_steps(num_training_steps), "warmup_steps") + self.fill_match( + "scheduler.params.total_num_steps", + num_training_steps, + "num_training_steps (calculated)", + ) + self.fill_match( + "scheduler.params.warmup_num_steps", + args.get_warmup_steps(num_training_steps), + "warmup_steps", + ) if len(self.mismatches) > 0: mismatches = "\n".join(self.mismatches) @@ -259,20 +300,28 @@ def unset_hf_deepspeed_config(): def is_deepspeed_zero3_enabled(): - if _hf_deepspeed_config_weak_ref is not None and _hf_deepspeed_config_weak_ref() is not None: + if ( + _hf_deepspeed_config_weak_ref is not None + and _hf_deepspeed_config_weak_ref() is not None + ): return _hf_deepspeed_config_weak_ref().is_zero3() else: return False def deepspeed_config(): - if _hf_deepspeed_config_weak_ref is not None and _hf_deepspeed_config_weak_ref() is not None: + if ( + _hf_deepspeed_config_weak_ref is not None + and _hf_deepspeed_config_weak_ref() is not None + ): return _hf_deepspeed_config_weak_ref().config else: return None -def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps, model_parameters): +def deepspeed_optim_sched( + trainer, hf_deepspeed_config, args, num_training_steps, model_parameters +): """ A convenience wrapper that deals with optimizer and lr scheduler configuration. """ @@ -323,9 +372,13 @@ def _lr_scheduler_callable(optimizer): num_training_steps=num_training_steps, ) - lr_scheduler = DummyScheduler(optimizer, lr_scheduler_callable=_lr_scheduler_callable) + lr_scheduler = DummyScheduler( + optimizer, lr_scheduler_callable=_lr_scheduler_callable + ) else: - lr_scheduler = trainer.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer) + lr_scheduler = trainer.create_scheduler( + num_training_steps=num_training_steps, optimizer=optimizer + ) return optimizer, lr_scheduler @@ -367,7 +420,9 @@ def deepspeed_init(trainer, num_training_steps, inference=False): if inference: # only Z3 makes sense for the inference if not hf_deepspeed_config.is_zero3(): - raise ValueError("ZeRO inference only makes sense with ZeRO Stage 3 - please adjust your config") + raise ValueError( + "ZeRO inference only makes sense with ZeRO Stage 3 - please adjust your config" + ) # in case the training config is re-used for inference hf_deepspeed_config.del_config_sub_tree("optimizer") @@ -387,7 +442,9 @@ def deepspeed_init(trainer, num_training_steps, inference=False): return optimizer, lr_scheduler -def deepspeed_load_checkpoint(deepspeed_engine, checkpoint_path): +def deepspeed_load_checkpoint( + deepspeed_engine, checkpoint_path, load_module_strict = True +): # it's possible that the user is trying to resume from model_path, which doesn't necessarily # contain a deepspeed checkpoint. e.g. examples just check if the dir exists and assume it's # a resume from a checkpoint and not just a local pretrained weight. So we check here if the @@ -400,9 +457,14 @@ def deepspeed_load_checkpoint(deepspeed_engine, checkpoint_path): logger.info(f"Attempting to resume from {checkpoint_path}") # this magically updates self.optimizer and self.lr_scheduler load_path, _ = deepspeed_engine.load_checkpoint( - checkpoint_path, load_optimizer_states=True, load_lr_scheduler_states=True + checkpoint_path, + load_module_strict=load_module_strict, + load_optimizer_states=True, + load_lr_scheduler_states=True, ) if load_path is None: - raise ValueError(f"[deepspeed] failed to resume from checkpoint {checkpoint_path}") + raise ValueError( + f"[deepspeed] failed to resume from checkpoint {checkpoint_path}" + ) else: raise ValueError(f"Can't find a valid checkpoint at {checkpoint_path}") diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 93113c64d6c1e8..ced74b5988fcec 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1719,7 +1719,9 @@ def _inner_training_loop( # ckpt loading if resume_from_checkpoint is not None: if self.is_deepspeed_enabled: - deepspeed_load_checkpoint(self.model_wrapped, resume_from_checkpoint) + deepspeed_load_checkpoint( + self.model_wrapped, resume_from_checkpoint, load_module_strict=not _is_peft_model(model) + ) elif is_sagemaker_mp_enabled() or self.is_fsdp_enabled: self._load_from_checkpoint(resume_from_checkpoint, self.model_wrapped) @@ -2179,7 +2181,9 @@ def _load_best_model(self): model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model if self.is_deepspeed_enabled: - deepspeed_load_checkpoint(self.model_wrapped, self.state.best_model_checkpoint) + deepspeed_load_checkpoint( + self.model_wrapped, self.state.best_model_checkpoint, load_module_strict=not _is_peft_model(model), + ) elif self.is_fsdp_enabled: load_result = load_fsdp_model( self.accelerator.state.fsdp_plugin, self.accelerator, model, self.state.best_model_checkpoint From 4db8f9060d56863c685f1a446fd69f2879067450 Mon Sep 17 00:00:00 2001 From: Lam Hieu Date: Sun, 28 Jan 2024 16:04:03 +0700 Subject: [PATCH 2/5] chore: update something --- src/transformers/integrations/deepspeed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/integrations/deepspeed.py b/src/transformers/integrations/deepspeed.py index 6efc6f9cbeb2bd..e513de214e2f3b 100644 --- a/src/transformers/integrations/deepspeed.py +++ b/src/transformers/integrations/deepspeed.py @@ -458,7 +458,7 @@ def deepspeed_load_checkpoint( # this magically updates self.optimizer and self.lr_scheduler load_path, _ = deepspeed_engine.load_checkpoint( checkpoint_path, - load_module_strict=load_module_strict, + load_module_strict=False, load_optimizer_states=True, load_lr_scheduler_states=True, ) From 00905fcec30bd3a14724e819f5895ff72f305e72 Mon Sep 17 00:00:00 2001 From: Lam Hieu Date: Sun, 28 Jan 2024 16:12:16 +0700 Subject: [PATCH 3/5] chore: update model instance pass into is peft model checks --- src/transformers/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index ced74b5988fcec..28d1f7b84b6c1c 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1720,7 +1720,7 @@ def _inner_training_loop( if resume_from_checkpoint is not None: if self.is_deepspeed_enabled: deepspeed_load_checkpoint( - self.model_wrapped, resume_from_checkpoint, load_module_strict=not _is_peft_model(model) + self.model_wrapped, resume_from_checkpoint, load_module_strict=not _is_peft_model(self.model) ) elif is_sagemaker_mp_enabled() or self.is_fsdp_enabled: self._load_from_checkpoint(resume_from_checkpoint, self.model_wrapped) @@ -2182,7 +2182,7 @@ def _load_best_model(self): model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model if self.is_deepspeed_enabled: deepspeed_load_checkpoint( - self.model_wrapped, self.state.best_model_checkpoint, load_module_strict=not _is_peft_model(model), + self.model_wrapped, self.state.best_model_checkpoint, load_module_strict=not _is_peft_model(self.model), ) elif self.is_fsdp_enabled: load_result = load_fsdp_model( From 0aa3bc9983897cf1f3d35ba760babed1786b51c2 Mon Sep 17 00:00:00 2001 From: Lam Hieu Date: Sun, 28 Jan 2024 16:12:45 +0700 Subject: [PATCH 4/5] chore: remove hard code value to tests --- src/transformers/integrations/deepspeed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/integrations/deepspeed.py b/src/transformers/integrations/deepspeed.py index e513de214e2f3b..6efc6f9cbeb2bd 100644 --- a/src/transformers/integrations/deepspeed.py +++ b/src/transformers/integrations/deepspeed.py @@ -458,7 +458,7 @@ def deepspeed_load_checkpoint( # this magically updates self.optimizer and self.lr_scheduler load_path, _ = deepspeed_engine.load_checkpoint( checkpoint_path, - load_module_strict=False, + load_module_strict=load_module_strict, load_optimizer_states=True, load_lr_scheduler_states=True, ) From 8d8e7a1bc1c17c25f3ded8e33ccb0f3b579df3d7 Mon Sep 17 00:00:00 2001 From: Lam Hieu Date: Sun, 28 Jan 2024 18:40:59 +0700 Subject: [PATCH 5/5] fix: format code --- src/transformers/integrations/deepspeed.py | 64 +++++----------------- src/transformers/trainer.py | 4 +- 2 files changed, 18 insertions(+), 50 deletions(-) diff --git a/src/transformers/integrations/deepspeed.py b/src/transformers/integrations/deepspeed.py index 6efc6f9cbeb2bd..07d0a5b5e37a57 100644 --- a/src/transformers/integrations/deepspeed.py +++ b/src/transformers/integrations/deepspeed.py @@ -125,9 +125,7 @@ def fill_match(self, ds_key_long, hf_val, hf_key=None, must_match=True): ds_val = config.get(ds_key) if ds_val is not None and ds_val != hf_val: - self.mismatches.append( - f"- ds {ds_key_long}={ds_val} vs hf {hf_key}={hf_val}" - ) + self.mismatches.append(f"- ds {ds_key_long}={ds_val} vs hf {hf_key}={hf_val}") fill_only = partialmethod(fill_match, must_match=False) @@ -138,11 +136,7 @@ def trainer_config_process(self, args, auto_find_batch_size=False): """ # DeepSpeed does: # train_batch_size = world_size * train_micro_batch_size_per_gpu * gradient_accumulation_steps - train_batch_size = ( - args.world_size - * args.per_device_train_batch_size - * args.gradient_accumulation_steps - ) + train_batch_size = args.world_size * args.per_device_train_batch_size * args.gradient_accumulation_steps self.fill_match( "train_micro_batch_size_per_gpu", args.per_device_train_batch_size, @@ -169,14 +163,10 @@ def trainer_config_process(self, args, auto_find_batch_size=False): "adam_beta1+adam_beta2", ) self.fill_match("optimizer.params.eps", args.adam_epsilon, "adam_epsilon") - self.fill_match( - "optimizer.params.weight_decay", args.weight_decay, "weight_decay" - ) + self.fill_match("optimizer.params.weight_decay", args.weight_decay, "weight_decay") self.fill_only("scheduler.params.warmup_min_lr", 0) # not a trainer arg - self.fill_match( - "scheduler.params.warmup_max_lr", args.learning_rate, "learning_rate" - ) + self.fill_match("scheduler.params.warmup_max_lr", args.learning_rate, "learning_rate") # total_num_steps - will get set in trainer_config_finalize # fp16 @@ -200,14 +190,10 @@ def trainer_config_process(self, args, auto_find_batch_size=False): # apex: delegates amp work to apex (which needs to be available), but it cannot be used with any # ZeRO features - self.fill_match( - "amp.enabled", fp16_backend == "apex", "fp16+fp16_backend(apex)" - ) + self.fill_match("amp.enabled", fp16_backend == "apex", "fp16+fp16_backend(apex)") self.fill_match("amp.opt_level", args.fp16_opt_level, "fp16_opt_level") - self.fill_match( - "bf16.enabled", (args.bf16 or args.bf16_full_eval), "bf16|bf16_full_eval" - ) + self.fill_match("bf16.enabled", (args.bf16 or args.bf16_full_eval), "bf16|bf16_full_eval") # deepspeed's default mode is fp16 unless there is a config that says differently if self.is_true("bf16.enabled"): @@ -247,9 +233,7 @@ def trainer_config_finalize(self, args, model, num_training_steps): "`auto` values for these keys with an integer value of your choice." ) - self.fill_only( - "zero_optimization.reduce_bucket_size", hidden_size * hidden_size - ) + self.fill_only("zero_optimization.reduce_bucket_size", hidden_size * hidden_size) if self.is_zero3(): # automatically assign the optimal config values based on model config self.fill_only( @@ -300,28 +284,20 @@ def unset_hf_deepspeed_config(): def is_deepspeed_zero3_enabled(): - if ( - _hf_deepspeed_config_weak_ref is not None - and _hf_deepspeed_config_weak_ref() is not None - ): + if _hf_deepspeed_config_weak_ref is not None and _hf_deepspeed_config_weak_ref() is not None: return _hf_deepspeed_config_weak_ref().is_zero3() else: return False def deepspeed_config(): - if ( - _hf_deepspeed_config_weak_ref is not None - and _hf_deepspeed_config_weak_ref() is not None - ): + if _hf_deepspeed_config_weak_ref is not None and _hf_deepspeed_config_weak_ref() is not None: return _hf_deepspeed_config_weak_ref().config else: return None -def deepspeed_optim_sched( - trainer, hf_deepspeed_config, args, num_training_steps, model_parameters -): +def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps, model_parameters): """ A convenience wrapper that deals with optimizer and lr scheduler configuration. """ @@ -372,13 +348,9 @@ def _lr_scheduler_callable(optimizer): num_training_steps=num_training_steps, ) - lr_scheduler = DummyScheduler( - optimizer, lr_scheduler_callable=_lr_scheduler_callable - ) + lr_scheduler = DummyScheduler(optimizer, lr_scheduler_callable=_lr_scheduler_callable) else: - lr_scheduler = trainer.create_scheduler( - num_training_steps=num_training_steps, optimizer=optimizer - ) + lr_scheduler = trainer.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer) return optimizer, lr_scheduler @@ -420,9 +392,7 @@ def deepspeed_init(trainer, num_training_steps, inference=False): if inference: # only Z3 makes sense for the inference if not hf_deepspeed_config.is_zero3(): - raise ValueError( - "ZeRO inference only makes sense with ZeRO Stage 3 - please adjust your config" - ) + raise ValueError("ZeRO inference only makes sense with ZeRO Stage 3 - please adjust your config") # in case the training config is re-used for inference hf_deepspeed_config.del_config_sub_tree("optimizer") @@ -442,9 +412,7 @@ def deepspeed_init(trainer, num_training_steps, inference=False): return optimizer, lr_scheduler -def deepspeed_load_checkpoint( - deepspeed_engine, checkpoint_path, load_module_strict = True -): +def deepspeed_load_checkpoint(deepspeed_engine, checkpoint_path, load_module_strict=True): # it's possible that the user is trying to resume from model_path, which doesn't necessarily # contain a deepspeed checkpoint. e.g. examples just check if the dir exists and assume it's # a resume from a checkpoint and not just a local pretrained weight. So we check here if the @@ -463,8 +431,6 @@ def deepspeed_load_checkpoint( load_lr_scheduler_states=True, ) if load_path is None: - raise ValueError( - f"[deepspeed] failed to resume from checkpoint {checkpoint_path}" - ) + raise ValueError(f"[deepspeed] failed to resume from checkpoint {checkpoint_path}") else: raise ValueError(f"Can't find a valid checkpoint at {checkpoint_path}") diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 28d1f7b84b6c1c..f04ec27b217067 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2182,7 +2182,9 @@ def _load_best_model(self): model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model if self.is_deepspeed_enabled: deepspeed_load_checkpoint( - self.model_wrapped, self.state.best_model_checkpoint, load_module_strict=not _is_peft_model(self.model), + self.model_wrapped, + self.state.best_model_checkpoint, + load_module_strict=not _is_peft_model(self.model), ) elif self.is_fsdp_enabled: load_result = load_fsdp_model(