Skip to content

Commit

Permalink
clear up the parameters of supervised_finetuning.py (#1126)
Browse files Browse the repository at this point in the history
no_gradient_checkpointing is always false

Signed-off-by: Wang, Yi A <[email protected]>
  • Loading branch information
sywangyi authored Dec 22, 2023
1 parent c1bb1f3 commit 950ee21
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion examples/research_projects/stack_llama/scripts/README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# RLHF pipeline for the creation of StackLLaMa: a Stack exchange llama-7b model.
There were three main steps to the training process:
1. Supervised fine-tuning of the base llama-7b model to create llama-7b-se:
- `torchrun --nnodes 1 --nproc_per_node 8 examples/research_projects/stack_llama/scripts/supervised_finetuning.py --model_path=<LLAMA_MODEL_PATH> --streaming --no_gradient_checkpointing --learning_rate 1e-5 --max_steps 5000 --output_dir ./llama-se`
- `torchrun --nnodes 1 --nproc_per_node 8 examples/research_projects/stack_llama/scripts/supervised_finetuning.py --model_path=<LLAMA_MODEL_PATH> --streaming --learning_rate 1e-5 --max_steps 5000 --output_dir ./llama-se`
2. Reward modeling using dialog pairs from the SE dataset using the llama-7b-se to create llama-7b-se-rm:
- `torchrun --nnodes 1 --nproc_per_node 8 examples/research_projects/stack_llama/scripts/reward_modeling.py --model_name=<LLAMA_SE_MODEL>`
3. RL fine-tuning of llama-7b-se with the llama-7b-se-rm reward model:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ def get_args():
parser.add_argument("--weight_decay", type=float, default=0.05)

parser.add_argument("--local_rank", type=int, default=0)
parser.add_argument("--no_fp16", action="store_false")
parser.add_argument("--fp16", action="store_true", default=False)
parser.add_argument("--bf16", action="store_true", default=False)
parser.add_argument("--no_gradient_checkpointing", action="store_false", default=False)
parser.add_argument("--gradient_checkpointing", action="store_true", default=False)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--num_workers", type=int, default=None)
parser.add_argument("--output_dir", type=str, default="./checkpoints")
Expand Down Expand Up @@ -159,8 +159,8 @@ def run_training(args, train_data, val_data):
lr_scheduler_type=args.lr_scheduler_type,
warmup_steps=args.num_warmup_steps,
gradient_accumulation_steps=args.gradient_accumulation_steps,
gradient_checkpointing=not args.no_gradient_checkpointing,
fp16=not args.no_fp16,
gradient_checkpointing=args.gradient_checkpointing,
fp16=args.fp16,
bf16=args.bf16,
weight_decay=args.weight_decay,
run_name="llama-7b-finetuned",
Expand Down

0 comments on commit 950ee21

Please sign in to comment.