Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Probaly mistake in DPOTrainer when compute/log grad_norm #2456

Open
7 of 9 tasks
AIR-hl opened this issue Dec 10, 2024 · 10 comments · Fixed by huggingface/transformers#35207
Open
7 of 9 tasks

Probaly mistake in DPOTrainer when compute/log grad_norm #2456

AIR-hl opened this issue Dec 10, 2024 · 10 comments · Fixed by huggingface/transformers#35207
Assignees
Labels
🏋 DPO Related to DPO ❓ question Seeking clarification or more information

Comments

@AIR-hl
Copy link
Contributor

AIR-hl commented Dec 10, 2024

System Info

  • Platform: Linux-5.4.0-155-generic-x86_64-with-glibc2.35
  • Python version: 3.10.8
  • PyTorch version: 2.4.0
  • Transformers version: 4.46.2
  • Accelerate version: 1.1.1
  • Accelerate config: not found
  • Datasets version: 3.1.0
  • HF Hub version: 0.24.6
  • TRL version: 0.12.1
  • bitsandbytes version: 0.44.1
  • DeepSpeed version: 0.15.4
  • Diffusers version: not installed
  • Liger-Kernel version: not installed
  • LLM-Blender version: not installed
  • OpenAI version: 1.54.3
  • PEFT version: 0.13.2

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

dpo scripte

import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
import multiprocessing
from trl import (
    DPOConfig,
    DPOTrainer,
    ModelConfig,
    ScriptArguments,
    TrlParser,
    get_kbit_device_map,
    get_peft_config,
    get_quantization_config,
)
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE

if __name__ == "__main__":
    parser = TrlParser((ScriptArguments, DPOConfig, ModelConfig))
    script_args, training_args, model_config = parser.parse_args_and_config()

    torch_dtype = (
        model_config.torch_dtype
        if model_config.torch_dtype in ["auto", None]
        else getattr(torch, model_config.torch_dtype)
    )

    quantization_config = get_quantization_config(model_config)

    model_kwargs = dict(
        revision=model_config.model_revision,
        attn_implementation=model_config.attn_implementation,
        torch_dtype=torch_dtype,
        use_cache=False if training_args.gradient_checkpointing else True,
        device_map=get_kbit_device_map() if quantization_config is not None else None,
        quantization_config=quantization_config,
    )

    model = AutoModelForCausalLM.from_pretrained(
        model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs
    )

    peft_config = get_peft_config(model_config)
    if peft_config is None:
        ref_model = AutoModelForCausalLM.from_pretrained(
            model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs
        )
    else:
        ref_model = None

    tokenizer = AutoTokenizer.from_pretrained(
        model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    if tokenizer.chat_template is None:
        tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE
    if script_args.ignore_bias_buffers:
        model._ddp_params_and_buffers_to_ignore = [
            name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
        ]

    dataset = load_dataset(script_args.dataset_name,
                           split=script_args.dataset_train_split)
    dataset=dataset.select_columns(['prompt', 'chosen', 'rejected'])
    
    trainer = DPOTrainer(
        model,
        ref_model,
        args=training_args,
        train_dataset=dataset,
        processing_class=tokenizer,
        peft_config=peft_config,
    )

    trainer.train()

    trainer.save_model(training_args.output_dir)

bash script

python dpo.py --model_name_or_path AIR/Llama-3.2-1B-ultrachat200k \
    --dataset_name HuggingFaceH4/ultrafeedback_binarized \
    --output_dir test \
    --attn_implementation flash_attention_2 \
    --beta 0.05 \
    --bf16 \
    --dataset_train_split train_prefs \
    --do_train \
    --gradient_checkpointing \
    --gradient_accumulation_steps 16 \
    --learning_rate 0.00001 \
    --lr_scheduler_type cosine \
    --logging_steps 5 \
    --loss_type sigmoid \
    --max_prompt_length 512 \
    --max_length 1024 \
    --max_steps -1 \
    --num_train_epochs 1 \
    --per_device_train_batch_size 2 \
    --report_to tensorboard \
    --save_strategy epoch \
    --save_total_limit 1 \
    --save_only_model \
    --torch_dtype bfloat16 \
    --warmup_ratio 0.05

Expected behavior

When using DPOTrainer I found a unexpected behavior of grad_norm.
Specificlly, I keep the global_batch_size=32 and adjust different per_device_train_batch_size and gradient_accumulation_steps, the grad_norm is positively correlated with gradient_accumulation_steps, but it wont appear in SFTTrainer. As i know, the grad_norm shouldnt changed so dramatically under the same global_batch_size

batch_size=4, accmulation=8
08e3df269403f2f9e0fa765eb966edda

batch_size=2, accmulation=16
cd6aab369f9f33b75b33eb5ff25538e5

batch_size=1, accmulation=32
b4e479c005efe92d48e19ca33c559c73

Checklist

  • I have checked that my issue isn't already filed (see open issues)
  • I have included my system information
  • Any code provided is minimal, complete, and reproducible (more on MREs)
  • Any code provided is properly formatted in code blocks, (no screenshot, more on code blocks)
  • Any traceback provided is complete
@qgallouedec qgallouedec added ❓ question Seeking clarification or more information 🏋 DPO Related to DPO labels Dec 10, 2024
@qgallouedec qgallouedec self-assigned this Dec 10, 2024
@qgallouedec
Copy link
Member

qgallouedec commented Dec 10, 2024

This is an interesting finding! I suspect it's related to #2175. I'm investigating.

@qgallouedec
Copy link
Member

qgallouedec commented Dec 10, 2024

The issue arises from how the accelerator is configured in create_accelerator_and_postprocess.

To set the number of gradient accumulation steps, users can either:

  1. Specify num_steps in AcceleratorConfig, or
  2. Use TrainingArguments.gradient_accumulation_steps when initializing the transformers.Trainer.

However, in both cases, the gradient norm (grad_norm) is computed using the accelerator here. When using TrainingArguments.gradient_accumulation_steps to define the accumulation steps, the accelerator does not account for the specified value when calculating the gradient norm.

Adding a gradient_accumulation_steps argument to the Accelerator initialization here resolves the issue (as shown in the curves below). However, I'm pretty sure it's not what we want to do.

- self.accelerator = Accelerator(**args)
+ self.accelerator = Accelerator(**args, gradient_accumulation_steps=self.args.gradient_accumulation_steps)

@muellerzr, could you review this and share your thoughts?

--

--gradient_accumulation_steps 8 --per_device_train_batch_size 4

Screenshot 2024-12-10 at 19 50 52

--gradient_accumulation_steps 32 --per_device_train_batch_size 1

Screenshot 2024-12-10 at 19 51 11

Before the fix : red/pink ; after the fix blues
Screenshot 2024-12-10 at 20 01 10

@muellerzr
Copy link
Contributor

Correct, that's not what we want to do because with the fix to how we calculate the number of items in the batch, the losses will not align and things will be off, so we don't divide the loss by accumulation steps if we know that value. I'd need to play with this a bit as I'm not 100% sure if we can just modify the grads for clipping without modifying the overall loss we just calculated 🤔

@AIR-hl
Copy link
Contributor Author

AIR-hl commented Dec 11, 2024

The issue arises from how the accelerator is configured in create_accelerator_and_postprocess.

@qgallouedec I have a new question that if the problem arises from create_accelerator_and_postprocess in transformers.Trainer, why trl.SFTTrainer's behavior is normal, but trl.DPOTrainer isnt, they both inherit from transformers.Trainer

sft, batch_size=4, accumulation=8
7cf799b818cdced95fc4632de02a8fba

sft, batch_size=2, accumulation=16
1eba3468eab71db9185de3a1ab0120b9

sft, batch_size=1, accumulation=32
c6e2266b5eb3ff8736fe652a85124a41

@qgallouedec
Copy link
Member

@qgallouedec I have a new question that if the problem arises from create_accelerator_and_postprocess in transformers.Trainer, why trl.SFTTrainer's behavior is normal, but trl.DPOTrainer isnt, they both inherit from transformers.Trainer

I can't explain it right now. Any idea?

@qgallouedec
Copy link
Member

I may have found the solution: huggingface/transformers#35207

Running some experiments...

@qgallouedec
Copy link
Member

qgallouedec commented Dec 11, 2024

Does it solve the issue?

Before the fix

same effective batch size (32)

  • grad accumulation = 32 / batch_size = 1
  • grad accumulation = 8 / batch_size = 4

Screenshot 2024-12-11 at 12 04 50

We can see here that the grad_norm is different while it should be the same.

After the fix

same effective batch size (32)

  • grad accumulation = 32 / batch_size = 1
  • grad accumulation = 8 / batch_size = 4

Screenshot 2024-12-11 at 12 04 40

Now the grad_norm is the same.

Does it impact the results?

Config 1

grad accumulation = 32 / batch_size = 1 (effective batch size = 32). Curves are before the fix and after the fix

Screenshot 2024-12-11 at 12 04 14

The only value impacted is the grad_norm, no impact on loss

Config 2

grad accumulation = 8 / batch_size = 4 (effective batch size = 32). Curves are before the fix and after the fix

Screenshot 2024-12-11 at 12 03 13

The only value impacted is the grad_norm, no impact on loss

@AIR-hl
Copy link
Contributor Author

AIR-hl commented Dec 11, 2024

@qgallouedec Thanks for ur work! So this bug actually only affects the reported logs and not the training results, right? :)

@qgallouedec
Copy link
Member

That's what the results suggest yes

@AIR-hl AIR-hl closed this as completed Dec 11, 2024
@qgallouedec qgallouedec reopened this Dec 11, 2024
@qgallouedec
Copy link
Member

Leaving the issue open until huggingface/transformers#35207 is properly merged

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
🏋 DPO Related to DPO ❓ question Seeking clarification or more information
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants