diff --git a/src/transformers/integrations/deepspeed.py b/src/transformers/integrations/deepspeed.py index 92cc1a4b0e5947..07d0a5b5e37a57 100644 --- a/src/transformers/integrations/deepspeed.py +++ b/src/transformers/integrations/deepspeed.py @@ -143,14 +143,25 @@ def trainer_config_process(self, args, auto_find_batch_size=False): "per_device_train_batch_size", not auto_find_batch_size, ) - self.fill_match("gradient_accumulation_steps", args.gradient_accumulation_steps, "gradient_accumulation_steps") self.fill_match( - "train_batch_size", train_batch_size, "train_batch_size (calculated)", not auto_find_batch_size + "gradient_accumulation_steps", + args.gradient_accumulation_steps, + "gradient_accumulation_steps", + ) + self.fill_match( + "train_batch_size", + train_batch_size, + "train_batch_size (calculated)", + not auto_find_batch_size, ) self.fill_match("gradient_clipping", args.max_grad_norm, "max_grad_norm") self.fill_match("optimizer.params.lr", args.learning_rate, "learning_rate") - self.fill_match("optimizer.params.betas", [args.adam_beta1, args.adam_beta2], "adam_beta1+adam_beta2") + self.fill_match( + "optimizer.params.betas", + [args.adam_beta1, args.adam_beta2], + "adam_beta1+adam_beta2", + ) self.fill_match("optimizer.params.eps", args.adam_epsilon, "adam_epsilon") self.fill_match("optimizer.params.weight_decay", args.weight_decay, "weight_decay") @@ -225,12 +236,26 @@ def trainer_config_finalize(self, args, model, num_training_steps): self.fill_only("zero_optimization.reduce_bucket_size", hidden_size * hidden_size) if self.is_zero3(): # automatically assign the optimal config values based on model config - self.fill_only("zero_optimization.stage3_prefetch_bucket_size", 0.9 * hidden_size * hidden_size) - self.fill_only("zero_optimization.stage3_param_persistence_threshold", 10 * hidden_size) + self.fill_only( + "zero_optimization.stage3_prefetch_bucket_size", + 0.9 * hidden_size * hidden_size, + ) + self.fill_only( + "zero_optimization.stage3_param_persistence_threshold", + 10 * hidden_size, + ) # scheduler - self.fill_match("scheduler.params.total_num_steps", num_training_steps, "num_training_steps (calculated)") - self.fill_match("scheduler.params.warmup_num_steps", args.get_warmup_steps(num_training_steps), "warmup_steps") + self.fill_match( + "scheduler.params.total_num_steps", + num_training_steps, + "num_training_steps (calculated)", + ) + self.fill_match( + "scheduler.params.warmup_num_steps", + args.get_warmup_steps(num_training_steps), + "warmup_steps", + ) if len(self.mismatches) > 0: mismatches = "\n".join(self.mismatches) @@ -387,7 +412,7 @@ def deepspeed_init(trainer, num_training_steps, inference=False): return optimizer, lr_scheduler -def deepspeed_load_checkpoint(deepspeed_engine, checkpoint_path): +def deepspeed_load_checkpoint(deepspeed_engine, checkpoint_path, load_module_strict=True): # it's possible that the user is trying to resume from model_path, which doesn't necessarily # contain a deepspeed checkpoint. e.g. examples just check if the dir exists and assume it's # a resume from a checkpoint and not just a local pretrained weight. So we check here if the @@ -400,7 +425,10 @@ def deepspeed_load_checkpoint(deepspeed_engine, checkpoint_path): logger.info(f"Attempting to resume from {checkpoint_path}") # this magically updates self.optimizer and self.lr_scheduler load_path, _ = deepspeed_engine.load_checkpoint( - checkpoint_path, load_optimizer_states=True, load_lr_scheduler_states=True + checkpoint_path, + load_module_strict=load_module_strict, + load_optimizer_states=True, + load_lr_scheduler_states=True, ) if load_path is None: raise ValueError(f"[deepspeed] failed to resume from checkpoint {checkpoint_path}") diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 92c97dc065da43..de7a736293f4d2 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1727,7 +1727,9 @@ def _inner_training_loop( # ckpt loading if resume_from_checkpoint is not None: if self.is_deepspeed_enabled: - deepspeed_load_checkpoint(self.model_wrapped, resume_from_checkpoint) + deepspeed_load_checkpoint( + self.model_wrapped, resume_from_checkpoint, load_module_strict=not _is_peft_model(self.model) + ) elif is_sagemaker_mp_enabled() or self.is_fsdp_enabled: self._load_from_checkpoint(resume_from_checkpoint, self.model_wrapped) @@ -2193,7 +2195,11 @@ def _load_best_model(self): model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model if self.is_deepspeed_enabled: - deepspeed_load_checkpoint(self.model_wrapped, self.state.best_model_checkpoint) + deepspeed_load_checkpoint( + self.model_wrapped, + self.state.best_model_checkpoint, + load_module_strict=not _is_peft_model(self.model), + ) elif self.is_fsdp_enabled: load_result = load_fsdp_model( self.accelerator.state.fsdp_plugin,