Skip to content

Commit

Permalink
stage3: efficient compute of scaled_global_grad_norm (#5256)
Browse files Browse the repository at this point in the history
using torch.norm instead of inefficient for loop

---------

Co-authored-by: Olatunji Ruwase <[email protected]>
  • Loading branch information
nelyahu and tjruwase authored Apr 14, 2024
1 parent 7b5b066 commit 54c0687
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from deepspeed.utils import logger
from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler
from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced, all_to_all_quant_reduce
from deepspeed.runtime.utils import inf, get_global_norm, is_model_parallel_parameter, get_only_unique_item
from deepspeed.runtime.utils import inf, is_model_parallel_parameter, get_only_unique_item
from deepspeed.runtime.zero.partition_parameters import *
from deepspeed.runtime.zero.config import ZeroStageEnum
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum
Expand Down Expand Up @@ -2027,7 +2027,7 @@ def step(self, closure=None):
return

norm_groups = self._get_norm_groups()
scaled_global_grad_norm = get_global_norm(norm_list=norm_groups)
scaled_global_grad_norm = torch.norm(torch.stack(norm_groups))

# Stash unscaled gradient norm
self._global_grad_norm = scaled_global_grad_norm / self.loss_scale
Expand Down

3 comments on commit 54c0687

@ojijo
Copy link

@ojijo ojijo commented on 54c0687 Jun 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found this change will cause zero-3 cpu offload feature to fail.

Error message is :
Exception has occurred: RuntimeError
Expected all tensors to be on the same device, but found at least two devices, cuda:4 and cpu!
File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/zero/stage3.py", line 2117, in unscale_and_clip_grads self.fp32_partitioned_groups_flat[sub_group_id].grad.mul_(1. / combined_scale)
File "/usr/local/lib/python3.10/dist-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
ret_val = func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/zero/stage3.py", line 2047, in step
self.unscale_and_clip_grads(sub_group_id, scaled_global_grad_norm)
File "/usr/local/lib/python3.10/dist-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
ret_val = func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/engine.py", line 2075, in _take_model_step
self.optimizer.step()
File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/engine.py", line 2169, in step
self._take_model_step(lr_kwargs)
File "/usr/local/lib/python3.10/dist-packages/accelerate/utils/deepspeed.py", line 175, in backward
self.engine.step()
File "/usr/local/lib/python3.10/dist-packages/accelerate/accelerator.py", line 2117, in backward
self.deepspeed_engine_wrapped.backward(loss, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 3250, in training_step
self.accelerator.backward(loss)
File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 2216, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs)
File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 1885, in train
return inner_training_loop(
File "/app/src/llamafactory/train/sft/workflow.py", line 73, in run_sft
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
File "/app/src/llamafactory/train/tuner.py", line 33, in run_exp
run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
File "/app/src/train.py", line 5, in main
run_exp()
File "/app/src/train.py", line 14, in
main()
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:4 and cpu!

the call stack is:
unscale_and_clip_grads (\usr\local\lib\python3.10\dist-packages\deepspeed\runtime\zero\stage3.py:2117)
wrapped_fn (\usr\local\lib\python3.10\dist-packages\deepspeed\utils\nvtx.py:15)
step (\usr\local\lib\python3.10\dist-packages\deepspeed\runtime\zero\stage3.py:2047)
wrapped_fn (\usr\local\lib\python3.10\dist-packages\deepspeed\utils\nvtx.py:15)
_take_model_step (\usr\local\lib\python3.10\dist-packages\deepspeed\runtime\engine.py:2075)
step (\usr\local\lib\python3.10\dist-packages\deepspeed\runtime\engine.py:2169)
backward (\usr\local\lib\python3.10\dist-packages\accelerate\utils\deepspeed.py:175)
backward (\usr\local\lib\python3.10\dist-packages\accelerate\accelerator.py:2117)
training_step (\usr\local\lib\python3.10\dist-packages\transformers\trainer.py:3250)
_inner_training_loop (\usr\local\lib\python3.10\dist-packages\transformers\trainer.py:2216)
train (\usr\local\lib\python3.10\dist-packages\transformers\trainer.py:1885)
run_sft (\app\src\llamafactory\train\sft\workflow.py:73)
run_exp (\app\src\llamafactory\train\tuner.py:33)
main (\app\src\train.py:5)
(\app\src\train.py:14)

@ojijo
Copy link

@ojijo ojijo commented on 54c0687 Jun 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reverted the code to the previous version, and the error disappeared.

@tjruwase
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ojijo, can you please open a ticket for this so we can repro and fix properly. Thanks.

Please sign in to comment.