You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Describe the bug
I am using Zero3 and want to set different learning rates for different parameters in the model. I have rewritten the optimizers class to achieve this, and it works when Zero3 is not used. However, when Zero3 is enabled, it doesn't work, and all parameters end up with the same learning rate. To Reproduce
I am using the MS-Swift training framework and have rewritten the optimizers class to set a higher learning rate for the parameters in the activation function I designed.
def create_optimizers(args, model, dataset):
args = args.training_args
optimizer_grouped_parameters = None
if hasattr(model, 'create_optimizer_param_groups'):
# Lora+ parameter groups
optimizer_grouped_parameters = model.create_optimizer_param_groups(
lr=args.learning_rate, weight_decay=args.weight_decay)
if optimizer_grouped_parameters is None:
# Default parameter groups
decay_parameters = Trainer.get_decay_parameter_names(None, model)
optimizer_grouped_parameters = [
{
'params': [p for n, p in model.named_parameters() if (n in decay_parameters and p.requires_grad and
'act_fn' not in n)],
'weight_decay': args.weight_decay,
},
{
'params': [p for n, p in model.named_parameters() if (n not in decay_parameters and p.requires_grad)],
'weight_decay': 0.0,
},
{
"params": [
p for n, p in model.named_parameters() if (n in decay_parameters and p.requires_grad and
'act_fn' in n)
],
"weight_decay": 0.0,
'lr': 0.5
},
]
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(args)
opt = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
return opt, None
In the initialization, I checked optimizers.param_groups and it was:
At this point, the parameters I specified were correctly set to a learning rate of 0.5. However, after training and during backpropagation, when I checked optimizers.param_groups again, its value was:
It is completely different from what I set. It seems that only one param_groups group was retained, which is causing the inability to set different learning rates for different parameters. Expected behavior
I hope to set different learning rates for different parameters. ds_report output
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
runtime if needed. Op compatibility means that your system
meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
[WARNING] async_io requires the dev libaio .so object and headers but these were not found.
[WARNING] async_io: please install the libaio-dev package with apt
[WARNING] If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
async_io ............... [NO] ....... [NO]
fused_adam ............. [NO] ....... [OKAY]
cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
cpu_lion ............... [NO] ....... [OKAY]
[WARNING] Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
evoformer_attn ......... [NO] ....... [NO]
[WARNING] FP Quantizer is using an untested triton version (3.1.0), only 2.3.0 and 2.3.1 are known to be compatible with these kernels
fp_quantizer ........... [NO] ....... [NO]
fused_lamb ............. [NO] ....... [OKAY]
fused_lion ............. [NO] ....... [OKAY]
inference_core_ops ..... [NO] ....... [OKAY]
cutlass_ops ............ [NO] ....... [OKAY]
transformer_inference .. [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
ragged_device_ops ...... [NO] ....... [OKAY]
ragged_ops ............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
[WARNING] sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.5
[WARNING] using untested triton version (3.1.0), only 1.0.0 is known to be compatible
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/home/longchen/miniconda3/envs/swift/lib/python3.11/site-packages/torch']
torch version .................... 2.5.1+cu124
deepspeed install path ........... ['/home/longchen/miniconda3/envs/swift/lib/python3.11/site-packages/deepspeed']
deepspeed info ................... 0.14.5, unknown, unknown
torch cuda version ............... 12.4
torch hip version ................ None
nvcc version ..................... 12.4
deepspeed wheel compiled w. ...... torch 2.5, cuda 12.4
shared memory (/dev/shm) size .... 125.75 GB
System info (please complete the following information):
OS: [e.g. Ubuntu 18.04]
GPU count and types [e.g. two machines with x8 A100s each] 3090*8
Interconnects (if applicable) [e.g., two machines connected with 100 Gbps IB] no
Python version :3.11.0
Any other relevant info about your setup
The text was updated successfully, but these errors were encountered:
Describe the bug
I am using Zero3 and want to set different learning rates for different parameters in the model. I have rewritten the optimizers class to achieve this, and it works when Zero3 is not used. However, when Zero3 is enabled, it doesn't work, and all parameters end up with the same learning rate.
To Reproduce
I am using the MS-Swift training framework and have rewritten the optimizers class to set a higher learning rate for the parameters in the activation function I designed.
In the initialization, I checked
optimizers.param_groups
and it was:At this point, the parameters I specified were correctly set to a learning rate of
0.5
. However, after training and during backpropagation, when I checkedoptimizers.param_groups
again, its value was:It is completely different from what I set. It seems that only one
param_groups
group was retained, which is causing the inability to set different learning rates for different parameters.Expected behavior
I hope to set different learning rates for different parameters.
ds_report output
System info (please complete the following information):
The text was updated successfully, but these errors were encountered: