From ca90cba351c4b725e6724f0441d35801f7455335 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E7=90=86?= Date: Fri, 23 Feb 2024 10:58:43 +0800 Subject: [PATCH] fix 8-bit multi-gpu training bug (#1353) * fix 8-bit multi-gpu training bug see https://github.com/huggingface/trl/issues/1348 * Update dpo_llama2.py make gradient_checkpointing_kwargs configurable. * Update dpo_llama2.py remote unnecessary config of device_map * format with make precommit --------- Co-authored-by: ubuntu --- .../research_projects/stack_llama_2/scripts/dpo_llama2.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py b/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py index 5684a876ce..474b4c1b3a 100644 --- a/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py +++ b/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py @@ -4,6 +4,7 @@ from typing import Dict, Optional import torch +from accelerate import Accelerator from datasets import Dataset, load_dataset from peft import LoraConfig from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, TrainingArguments @@ -41,6 +42,10 @@ class ScriptArguments: default=True, metadata={"help": "whether to use gradient checkpointing"} ) + gradient_checkpointing_use_reentrant: Optional[bool] = field( + default=True, metadata={"help": "whether to use reentrant for gradient checkpointing"} + ) + lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"}) lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "the lora dropout parameter"}) lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"}) @@ -129,6 +134,7 @@ def return_prompt_and_responses(samples) -> Dict[str, str]: low_cpu_mem_usage=True, torch_dtype=torch.float16, load_in_4bit=True, + device_map={"": Accelerator().local_process_index}, ) model.config.use_cache = False @@ -175,6 +181,7 @@ def return_prompt_and_responses(samples) -> Dict[str, str]: bf16=True, remove_unused_columns=False, run_name="dpo_llama2", + gradient_checkpointing_kwargs=dict(use_reentrant=script_args.gradient_checkpointing_use_reentrant), ) peft_config = LoraConfig(