Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SFTTrainer not using both GPUs #1303

Closed
johnowhitaker opened this issue Feb 1, 2024 · 9 comments · Fixed by #1308
Closed

SFTTrainer not using both GPUs #1303

johnowhitaker opened this issue Feb 1, 2024 · 9 comments · Fixed by #1308

Comments

@johnowhitaker
Copy link
Contributor

I am trying to fine-tune Llama 2 7B with QLoRA on 2 GPUs. From what I've read SFTTrainer should support multiple GPUs just fine, but when I run this I see one GPU with high utilization and one with almost none:

Screenshot 2024-01-31 182908

Expected behaviour would be that both get used during training and it would be about 2x as fast as single-GPU training. I'm running this with python train.py, which I think means Trainer uses DP? I get an error launching with python -m torch.distributed.launch train.py (RuntimeError: Expected to mark a variable ready only once...) which makes me think DDP would need a bit more work...

This is an older machine without any fast interconnect, but I saw similar usage on a cloud machine with 2xA5000s so I don't think it's that. Anyway, maybe someone can help by explaining why DP might be so slow in this case and/or how to test DDP instead :)

Script:

from datasets import load_dataset
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoTokenizer
from peft import LoraConfig
from trl import SFTTrainer
from transformers import TrainingArguments

# Load the dataset
dataset_name = "timdettmers/openassistant-guanaco"
dataset = load_dataset(dataset_name, split="train")


# Load the model + tokenizer
model_name = "meta-llama/Llama-2-7b-hf"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    trust_remote_code=True,
    use_cache = False
)

# PEFT config
lora_alpha = 16
lora_dropout = 0.1
lora_r = 64
peft_config = LoraConfig(
    lora_alpha=lora_alpha,
    lora_dropout=lora_dropout,
    r=lora_r,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["k_proj", "q_proj", "v_proj", "up_proj", "down_proj", "gate_proj"]
)


# Args 
max_seq_length = 512
output_dir = "./results"
per_device_train_batch_size = 8
gradient_accumulation_steps = 2
optim = "adamw_hf"
save_steps = 10
logging_steps = 1
learning_rate = 2e-4
max_grad_norm = 0.3
max_steps = 300 # Approx the size of guanaco at bs 8, ga 2, 2 GPUs. 
warmup_ratio = 0.1
lr_scheduler_type = "cosine"
training_arguments = TrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=per_device_train_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    optim=optim,
    save_steps=save_steps,
    logging_steps=logging_steps,
    learning_rate=learning_rate,
    fp16=True,
    max_grad_norm=max_grad_norm,
    max_steps=max_steps,
    warmup_ratio=warmup_ratio,
    group_by_length=True,
    lr_scheduler_type=lr_scheduler_type,
    gradient_checkpointing=True,
    report_to="wandb",
)

# Trainer 
trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    peft_config=peft_config,
    dataset_text_field="text",
    max_seq_length=max_seq_length,
    tokenizer=tokenizer,
    args=training_arguments,
)

# Not sure if needed but noticed this in https://colab.research.google.com/drive/1t3exfAVLQo4oKIopQT1SKxK4UcYg7rC1#scrollTo=7OyIvEx7b1GT
for name, module in trainer.model.named_modules():
    if "norm" in name:
        module = module.to(torch.float32)

# Train :)
trainer.train()
@johnowhitaker
Copy link
Contributor Author

Ah, belatedly I found #921 which looks promising

@johnowhitaker
Copy link
Contributor Author

Fix:

  1. When loading model, make sure it's on the right device
local_rank = os.getenv("LOCAL_RANK")
device_string = "cuda:" + str(local_rank)
...
model = AutoModelForCausalLM.from_pretrained(
     ...
    device_map={'':device_string}
)
  1. If using gradient_checkpointing, add the following to the TrainingArguments: gradient_checkpointing_kwargs={'use_reentrant':False}, (otherwise DDP won't work) (see Need to explicitly set use_reentrant when calling checkpoint  transformers#26969)

  2. Run with accelerate launch hf_train.py --num_processes 2

I'm happy to close the issue but maybe it's useful for others in the same boat.

@younesbelkada
Copy link
Contributor

Yes this is exactly the way to load the model with multi-GPU and train it,
more specifically you should do:

from accelerate import PartialState
device_string = PartialState().process_index

I also feel this is a common scenario, would you be happy to update the documentation of SFTTrainer by adding a new section in this file: https://github.com/huggingface/trl/blob/main/docs/source/sft_trainer.mdx ?

@johnowhitaker
Copy link
Contributor Author

@younesbelkada something like this OK? #1308

@younesbelkada
Copy link
Contributor

Amazing, thank you @johnowhitaker !

@RonanKMcGovern
Copy link

RonanKMcGovern commented Apr 8, 2024

Thanks for the clear issue and resolution - very helpful in getting DDP to work.

@younesbelkada, I noticed that using DDP (for this case) seems to take up more VRAM (more easily runs into CUDA OOM) than running with PP (just setting device_map='auto'). Although, DDP does seem to be faster than PP (less time for the same number of steps).

Is that to be expected? It's not entirely intuitive to me from the docs.

I have run a similar script to above using PP (with python script.py) and DDP (accelerate launch script.py):

# Script from: https://github.com/huggingface/trl/issues/1303
# Run this with DDP with "accelerate launch test_sft.py"
from datasets import load_dataset
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoTokenizer
from peft import LoraConfig
from trl import SFTTrainer
from transformers import TrainingArguments
from accelerate import PartialState

device_map="DDP" # for DDP and running with `accelerate launch test_sft.py`
# device_map='auto' # for PP and running with `python test_sft.py`

if device_map == "DDP":
    device_string = PartialState().process_index
    device_map={'':device_string}

# Load the dataset
dataset_name = "timdettmers/openassistant-guanaco"
dataset = load_dataset(dataset_name, split="train")

# Load the model + tokenizer
model_name = "meta-llama/Llama-2-7b-hf"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    trust_remote_code=True,
    use_cache = False,
    device_map = device_map,
)

# PEFT config
lora_alpha = 8
lora_dropout = 0.1
lora_r = 32
peft_config = LoraConfig(
    lora_alpha=lora_alpha,
    lora_dropout=lora_dropout,
    r=lora_r,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["k_proj", "q_proj", "v_proj", "up_proj", "down_proj", "gate_proj"],
    modules_to_save=["embed_tokens", "input_layernorm", "post_attention_layernorm", "norm"],
)

# Args 
max_seq_length = 512
output_dir = "./results"
per_device_train_batch_size = 8
gradient_accumulation_steps = 2
optim = "adamw_torch"
save_steps = 10
logging_steps = 1
learning_rate = 2e-4
max_grad_norm = 0.3
max_steps = 1 # Approx the size of guanaco at bs 8, ga 2, 2 GPUs. 
warmup_ratio = 0.1
lr_scheduler_type = "cosine"
training_arguments = TrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=per_device_train_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    optim=optim,
    save_steps=save_steps,
    logging_steps=logging_steps,
    learning_rate=learning_rate,
    fp16=True,
    max_grad_norm=max_grad_norm,
    max_steps=max_steps,
    warmup_ratio=warmup_ratio,
    group_by_length=True,
    lr_scheduler_type=lr_scheduler_type,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs = {"use_reentrant": False}, #must be false for DDP
    report_to="wandb",
)

# Trainer 
trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    peft_config=peft_config,
    dataset_text_field="text",
    max_seq_length=max_seq_length,
    tokenizer=tokenizer,
    args=training_arguments,
)

# Train :)
trainer.train()

The script above runs fine in PP even when I train/save other modules in the LoRA config. But, for DDP, that results in OOM.

For comparison, when I ran the script above without other modules being saved, but varying the batch size up to 16, I got OOM with both the PP and DDP approaches.

For another comparison, I was able to run DDP with trainable/saveable added modules on TinyLlama with no OOM issues (obviously that's a much smaller model, but it tests whether the added modules pose an issue).

So, I'm a bit puzzled why DDP seems to take more VRAM than PP (especially when adding trainable modules). Why is this?

EDIT: I'm unclear on whether setting device_map = 'auto' and running 'python script.py' defaults to pipeline parallel or DP see issue. I'm referring to PP above but maybe I really mean DP.

@aveerago
Copy link

I'm unclear as well! I'm guessing setting device_map = 'auto' and running 'python script.py' defaults to naive pipeline parallel.
Btw, @younesbelkada are all GPUs utilized when fine-tuning LLMs like Mistral-7B with PP? I'm using sagemaker HF estimator and I dont think I can use DDP approach on a multi-gpu node like g5.24xlarge. Thanks.

@manliu1225
Copy link

manliu1225 commented Apr 29, 2024

This works to me:

from accelerate import PartialState
device_string = PartialState().process_index
model = AutoModelForCausalLM.from_pretrained(
     ...
    device_map={'':device_string}
)

gradient_checkpointing = False,
gradient_checkpointing_kwargs = {"use_reentrant": False}, #must be false for DDP
ddp_find_unused_parameters=False # if use DDP is false, otherwise true

3. Run with accelerate launch script.py

@sunilswain
Copy link

I see,

can we only run scirpts through accelerate or is there a way we can run through python codes. For example,

I'm using a notebook and there I have all the configurations. All I want is to run trainer.train() function with accelerate, how can I do that?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants