Skip to content

Commit

Permalink
fix: format code
Browse files Browse the repository at this point in the history
  • Loading branch information
lh0x00 committed Jan 28, 2024
1 parent 0aa3bc9 commit 8d8e7a1
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 50 deletions.
64 changes: 15 additions & 49 deletions src/transformers/integrations/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,7 @@ def fill_match(self, ds_key_long, hf_val, hf_key=None, must_match=True):

ds_val = config.get(ds_key)
if ds_val is not None and ds_val != hf_val:
self.mismatches.append(
f"- ds {ds_key_long}={ds_val} vs hf {hf_key}={hf_val}"
)
self.mismatches.append(f"- ds {ds_key_long}={ds_val} vs hf {hf_key}={hf_val}")

fill_only = partialmethod(fill_match, must_match=False)

Expand All @@ -138,11 +136,7 @@ def trainer_config_process(self, args, auto_find_batch_size=False):
"""
# DeepSpeed does:
# train_batch_size = world_size * train_micro_batch_size_per_gpu * gradient_accumulation_steps
train_batch_size = (
args.world_size
* args.per_device_train_batch_size
* args.gradient_accumulation_steps
)
train_batch_size = args.world_size * args.per_device_train_batch_size * args.gradient_accumulation_steps
self.fill_match(
"train_micro_batch_size_per_gpu",
args.per_device_train_batch_size,
Expand All @@ -169,14 +163,10 @@ def trainer_config_process(self, args, auto_find_batch_size=False):
"adam_beta1+adam_beta2",
)
self.fill_match("optimizer.params.eps", args.adam_epsilon, "adam_epsilon")
self.fill_match(
"optimizer.params.weight_decay", args.weight_decay, "weight_decay"
)
self.fill_match("optimizer.params.weight_decay", args.weight_decay, "weight_decay")

self.fill_only("scheduler.params.warmup_min_lr", 0) # not a trainer arg
self.fill_match(
"scheduler.params.warmup_max_lr", args.learning_rate, "learning_rate"
)
self.fill_match("scheduler.params.warmup_max_lr", args.learning_rate, "learning_rate")
# total_num_steps - will get set in trainer_config_finalize

# fp16
Expand All @@ -200,14 +190,10 @@ def trainer_config_process(self, args, auto_find_batch_size=False):

# apex: delegates amp work to apex (which needs to be available), but it cannot be used with any
# ZeRO features
self.fill_match(
"amp.enabled", fp16_backend == "apex", "fp16+fp16_backend(apex)"
)
self.fill_match("amp.enabled", fp16_backend == "apex", "fp16+fp16_backend(apex)")
self.fill_match("amp.opt_level", args.fp16_opt_level, "fp16_opt_level")

self.fill_match(
"bf16.enabled", (args.bf16 or args.bf16_full_eval), "bf16|bf16_full_eval"
)
self.fill_match("bf16.enabled", (args.bf16 or args.bf16_full_eval), "bf16|bf16_full_eval")

# deepspeed's default mode is fp16 unless there is a config that says differently
if self.is_true("bf16.enabled"):
Expand Down Expand Up @@ -247,9 +233,7 @@ def trainer_config_finalize(self, args, model, num_training_steps):
"`auto` values for these keys with an integer value of your choice."
)

self.fill_only(
"zero_optimization.reduce_bucket_size", hidden_size * hidden_size
)
self.fill_only("zero_optimization.reduce_bucket_size", hidden_size * hidden_size)
if self.is_zero3():
# automatically assign the optimal config values based on model config
self.fill_only(
Expand Down Expand Up @@ -300,28 +284,20 @@ def unset_hf_deepspeed_config():


def is_deepspeed_zero3_enabled():
if (
_hf_deepspeed_config_weak_ref is not None
and _hf_deepspeed_config_weak_ref() is not None
):
if _hf_deepspeed_config_weak_ref is not None and _hf_deepspeed_config_weak_ref() is not None:
return _hf_deepspeed_config_weak_ref().is_zero3()
else:
return False


def deepspeed_config():
if (
_hf_deepspeed_config_weak_ref is not None
and _hf_deepspeed_config_weak_ref() is not None
):
if _hf_deepspeed_config_weak_ref is not None and _hf_deepspeed_config_weak_ref() is not None:
return _hf_deepspeed_config_weak_ref().config
else:
return None


def deepspeed_optim_sched(
trainer, hf_deepspeed_config, args, num_training_steps, model_parameters
):
def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps, model_parameters):
"""
A convenience wrapper that deals with optimizer and lr scheduler configuration.
"""
Expand Down Expand Up @@ -372,13 +348,9 @@ def _lr_scheduler_callable(optimizer):
num_training_steps=num_training_steps,
)

lr_scheduler = DummyScheduler(
optimizer, lr_scheduler_callable=_lr_scheduler_callable
)
lr_scheduler = DummyScheduler(optimizer, lr_scheduler_callable=_lr_scheduler_callable)
else:
lr_scheduler = trainer.create_scheduler(
num_training_steps=num_training_steps, optimizer=optimizer
)
lr_scheduler = trainer.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer)

return optimizer, lr_scheduler

Expand Down Expand Up @@ -420,9 +392,7 @@ def deepspeed_init(trainer, num_training_steps, inference=False):
if inference:
# only Z3 makes sense for the inference
if not hf_deepspeed_config.is_zero3():
raise ValueError(
"ZeRO inference only makes sense with ZeRO Stage 3 - please adjust your config"
)
raise ValueError("ZeRO inference only makes sense with ZeRO Stage 3 - please adjust your config")

# in case the training config is re-used for inference
hf_deepspeed_config.del_config_sub_tree("optimizer")
Expand All @@ -442,9 +412,7 @@ def deepspeed_init(trainer, num_training_steps, inference=False):
return optimizer, lr_scheduler


def deepspeed_load_checkpoint(
deepspeed_engine, checkpoint_path, load_module_strict = True
):
def deepspeed_load_checkpoint(deepspeed_engine, checkpoint_path, load_module_strict=True):
# it's possible that the user is trying to resume from model_path, which doesn't necessarily
# contain a deepspeed checkpoint. e.g. examples just check if the dir exists and assume it's
# a resume from a checkpoint and not just a local pretrained weight. So we check here if the
Expand All @@ -463,8 +431,6 @@ def deepspeed_load_checkpoint(
load_lr_scheduler_states=True,
)
if load_path is None:
raise ValueError(
f"[deepspeed] failed to resume from checkpoint {checkpoint_path}"
)
raise ValueError(f"[deepspeed] failed to resume from checkpoint {checkpoint_path}")
else:
raise ValueError(f"Can't find a valid checkpoint at {checkpoint_path}")
4 changes: 3 additions & 1 deletion src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2182,7 +2182,9 @@ def _load_best_model(self):
model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.is_deepspeed_enabled:
deepspeed_load_checkpoint(
self.model_wrapped, self.state.best_model_checkpoint, load_module_strict=not _is_peft_model(self.model),
self.model_wrapped,
self.state.best_model_checkpoint,
load_module_strict=not _is_peft_model(self.model),
)
elif self.is_fsdp_enabled:
load_result = load_fsdp_model(
Expand Down

0 comments on commit 8d8e7a1

Please sign in to comment.