diff --git a/src/instructlab/training/config.py b/src/instructlab/training/config.py index bf43f2eb..e75d0c7c 100644 --- a/src/instructlab/training/config.py +++ b/src/instructlab/training/config.py @@ -30,8 +30,8 @@ class DeepSpeedOffloadStrategy(Enum): # public API class DistributedBackend(Enum): - FSDP: str = "fsdp" - DEEPSPEED: str = "deepspeed" + FSDP = "fsdp" + DEEPSPEED = "deepspeed" # public API @@ -121,6 +121,17 @@ class DeepSpeedOptions(BaseModel): save_samples: int | None = None +# public API +class DistillationConfig(BaseModel): + """ + Config to use when performing knowledge distillation during training. + """ + + temperature: float = Field(1.0, gt=0.0) + alpha: float = Field(1.0, le=1.0, ge=0.0) + teacher_path: str + + # public API class ShardingStrategies(Enum): FULL_SHARD = "FULL_SHARD" @@ -179,6 +190,11 @@ class TrainingArgs(BaseModel): is_padding_free: bool = False # TODO: deprecate checkpoint_at_epoch: bool = True accelerate_full_state_at_epoch: bool = True + weight_decay: float = Field(0.0, ge=0.0) + + # settings for knowledge distillation + distillation_options: Optional[DistillationConfig] = None + use_distillation: bool = False mock_data: Optional[bool] = False mock_data_len: int = 0 diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index 0ad54c5f..189fdbc5 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -3,6 +3,7 @@ # Standard from copy import deepcopy from pathlib import Path +from typing import TYPE_CHECKING, Optional import argparse import json import math @@ -40,9 +41,11 @@ from instructlab.dolomite.hf_models import GPTDolomiteForCausalLM from torch.utils.data import DataLoader from tqdm import tqdm -from transformers import AutoModelForCausalLM, get_scheduler +from transformers import AutoModelForCausalLM, PreTrainedTokenizer, get_scheduler +from transformers.modeling_outputs import CausalLMOutput import torch -import torch.distributed +import torch.distributed as dist +import torch.nn.functional as F # First Party from instructlab.training import config @@ -71,6 +74,7 @@ create_lora_config, ensure_loadable_dolomite_checkpoint, load_latest_full_state, + log_rank_0, prepare_peft_model, prepare_universal_checkpoint_from_latest, retrieve_chat_template, @@ -88,7 +92,7 @@ def setup_optimizer(args, model): model.parameters(), lr=args.learning_rate, betas=(0.9, 0.95), - weight_decay=0.0, + weight_decay=args.weight_decay, ) elif args.distributed_training_framework == DistributedBackend.DEEPSPEED.value: # need to use this only when the CPU offload optimizer is enabled @@ -110,42 +114,7 @@ def setup_optimizer(args, model): return optimizer -def setup_model(args, tokenizer, train_loader, grad_accum, flash_enabled): - bnb_config = None - if args.lora_r > 0 and args.lora_quant_bits == 4: - # Third Party - from transformers import BitsAndBytesConfig - - bnb_config = BitsAndBytesConfig( - load_in_4bit=True, - bnb_4bit_quant_type="nf4", - bnb_4bit_use_double_quant=True, - bnb_4bit_compute_dtype=torch.float16, # if not set will throw a warning about slow speeds when training - ) - - base_model_args = { - "pretrained_model_name_or_path": args.model_name_or_path, - "torch_dtype": torch.bfloat16, - "quantization_config": bnb_config, - } - if flash_enabled: - base_model_args["attn_implementation"] = "flash_attention_2" - - if args.use_dolomite: - with ensure_loadable_dolomite_checkpoint( - args.model_name_or_path, args.output_dir - ) as path: - base_model_args["pretrained_model_name_or_path"] = path - base_model_args["use_padding_free_transformer"] = True - model = GPTDolomiteForCausalLM.from_pretrained( - **base_model_args, - ) - else: - model = AutoModelForCausalLM.from_pretrained(**base_model_args) - - # store the base model args so we can recall them later if saving a LoRA model - args.base_model_args = base_model_args - +def extend_model_tokenizer(model, tokenizer): if len(tokenizer) > model.config.vocab_size: print( f"WARNING: tokenizer has {len(tokenizer)} tokens but model has {model.config.vocab_size} vocab size" @@ -183,6 +152,45 @@ def setup_model(args, tokenizer, train_loader, grad_accum, flash_enabled): ) model.config.eos_token_id = tokenizer.eos_token_id + +def setup_model(args, tokenizer, train_loader, grad_accum, flash_enabled): + bnb_config = None + if args.lora_r > 0 and args.lora_quant_bits == 4: + # Third Party + from transformers import BitsAndBytesConfig + + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_use_double_quant=True, + bnb_4bit_compute_dtype=torch.float16, # if not set will throw a warning about slow speeds when training + ) + + base_model_args = { + "pretrained_model_name_or_path": args.model_name_or_path, + "torch_dtype": torch.bfloat16, + "quantization_config": bnb_config, + } + if flash_enabled: + base_model_args["attn_implementation"] = "flash_attention_2" + + if args.use_dolomite: + with ensure_loadable_dolomite_checkpoint( + args.model_name_or_path, args.output_dir + ) as path: + base_model_args["pretrained_model_name_or_path"] = path + base_model_args["use_padding_free_transformer"] = True + model = GPTDolomiteForCausalLM.from_pretrained( + **base_model_args, + ) + else: + model = AutoModelForCausalLM.from_pretrained(**base_model_args) + + # store the base model args so we can recall them later if saving a LoRA model + args.base_model_args = base_model_args + + extend_model_tokenizer(model, tokenizer) + assert model.__class__.__name__ in [ "MistralForCausalLM", "GPTDolomiteForCausalLM", @@ -252,6 +260,38 @@ def make_inputs_require_grad(module, input, output): return model, lr_scheduler, optimizer, accelerator +def setup_teacher_model( + model_name_or_path: str, device: torch.device, tokenizer: PreTrainedTokenizer +): + """ + Instantiates a teacher model to be used for distillation training. + """ + teacher_model = AutoModelForCausalLM.from_pretrained( + model_name_or_path, torch_dtype=torch.bfloat16 + ).to(device) + model_dev = next(teacher_model.parameters()).device + + # teacher model needs to live on the same GPU as where it's being trained + if ( + torch.cuda.is_available() + and dist.is_initialized() + and model_dev is torch.device("cpu") + ): + raise RuntimeError( + "error: torch.distributed is initialized but the teacher model was found to be on the CPU" + ) + + # need to make sure we've extended the tokenizer so our logits match the student's + extend_model_tokenizer(teacher_model, tokenizer) + + # disable gradient for all the parameters + for p in teacher_model.parameters(): + p.requires_grad = False + teacher_model.eval() + + return teacher_model + + # this function is to check if the checkpoint provided can be resumed def maybe_resume_training(args, model): local_rank = int(os.environ["LOCAL_RANK"]) @@ -312,6 +352,37 @@ def maybe_resume_training(args, model): return model +def distillation_loss( + student_output: CausalLMOutput, + teacher_output: CausalLMOutput, + alpha: float, + temp: float, +) -> torch.Tensor: + """ + Given a student and teacher model output, compute the KL divergence and return it + as a ratio with the existing loss. + + For reference: https://intellabs.github.io/distiller/knowledge_distillation.html + """ + # 1) get the standard loss + student_loss = student_output.loss + + # 2) Convert student and teacher logits into log-probabilities (log_softmax) at temperature T + teacher_probs = F.softmax(teacher_output.logits / temp, dim=-1).detach() + student_log_probs = F.log_softmax(student_output.logits / temp, dim=-1) + + # 3) Compute KL divergence + # 'reduction="batchmean"' will produce the average KL over the batch + kl = F.kl_div(student_log_probs, teacher_probs, reduction="batchmean") + + # Often people also multiply by temperature^2 to stabilize the gradient scale + # especially in the typical "T>1" scenario. So we do: + distillation_loss = kl * (temp**2) + loss = distillation_loss * alpha + (1 - alpha) * student_loss + + return loss + + def train( args, model, @@ -322,6 +393,7 @@ def train( train_loader: DataLoader, grad_accum, metric_logger, + teacher_model: Optional[AutoModelForCausalLM] = None, ): model.train() @@ -382,11 +454,36 @@ def train( if not args.use_dolomite: for k in batch: batch[k] = batch[k].to(local_rank) - output = model( + + # get the training loss from running on data + output: CausalLMOutput = model( **batch, use_cache=False, ) - loss = output.loss + + loss = None + if args.distill: + # teacher_model should always be provided when `args.distill` is enabled + if TYPE_CHECKING: + assert ( + teacher_model is not None + ), "teacher model cannot be None when `distill` is enabled" + + with torch.no_grad(): + teacher_output: CausalLMOutput = teacher_model( + **batch, use_cache=False + ) + loss = distillation_loss( + student_output=output, + teacher_output=teacher_output, + alpha=args.distill_alpha, + temp=args.distill_temp, + ) + + else: + loss = output.loss + + assert loss is not None, "loss cannot be equal to None!" log_loss = loss.detach().item() num_loss_counted_tokens, micro_batch_size, log_loss = map( @@ -407,8 +504,9 @@ def train( loss / num_loss_counted_tokens * world_size ) # dividing by the total number of non-padding tokens and multiplying by the number of GPUs so when accelerate averages by world_size, it will be the correct loss. print( - f"Epoch: {epoch}, Step: {global_step}, Rank: {torch.distributed.get_rank()}, loss = {loss}" + f"Epoch: {epoch}, Step: {global_step}, Rank: {dist.get_rank()}, loss = {loss}" ) + accelerator.backward(loss) if global_step % grad_accum == 0: @@ -441,7 +539,7 @@ def train( { "epoch": epoch, "step": global_step, - "rank": torch.distributed.get_rank(), + "rank": dist.get_rank(), "overall_throughput": overall_throughput, "lr": current_lr, "cuda_mem_allocated": cuda_mem_allocated, @@ -511,6 +609,9 @@ def main(args): # Third Party import yaml + if args.distill and not args.teacher_model_name_or_path: + raise ValueError("distillation was enabled but no teacher model is provided") + if args.distributed_training_framework == "deepspeed" and not FusedAdam: raise ImportError( "DeepSpeed was selected but we cannot import the `FusedAdam` optimizer" @@ -545,11 +646,11 @@ def main(args): #### distributed init ##### torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) args.local_rank = int(os.environ["LOCAL_RANK"]) - torch.distributed.init_process_group("nccl") - args.global_rank = torch.distributed.get_rank() + dist.init_process_group("nccl") + args.global_rank = dist.get_rank() tensor = torch.ByteTensor([False]).cuda() - torch.distributed.all_reduce(tensor) - torch.distributed.barrier() + dist.all_reduce(tensor) + dist.barrier() flash_enabled = check_flash_attn_enabled(args.disable_flash_attn, args.use_dolomite) @@ -561,7 +662,7 @@ def main(args): try: packing_max_batch_len, grad_accum = find_packing_max_batch_len_and_grad_accum( - num_gpus=torch.distributed.get_world_size(), + num_gpus=dist.get_world_size(), avg_sample_len=dataset.get_lengths().mean(), effective_batch_size=args.effective_batch_size, max_batch_len_per_gpu=args.max_batch_len, @@ -581,7 +682,7 @@ def main(args): args.sampler = "distributed" args.samples_per_gpu = ( - args.effective_batch_size // grad_accum // torch.distributed.get_world_size() + args.effective_batch_size // grad_accum // dist.get_world_size() ) train_loader = setup_dataloader( @@ -620,7 +721,7 @@ def main(args): if args.local_rank == 0: metric_logger.log_sync( { - "num_gpus": torch.distributed.get_world_size(), + "num_gpus": dist.get_world_size(), "avg_sample_len": dataset.get_lengths().mean(), "effective_batch_size": args.effective_batch_size, "max_batch_len_per_gpu": args.max_batch_len, @@ -637,6 +738,15 @@ def main(args): args, tokenizer, train_loader, grad_accum, flash_enabled ) + # bring in the teacher model + teacher_model = None + if args.distill: + log_rank_0("distillation was enabled, instantiating teacher model") + teacher_model = setup_teacher_model( + args.teacher_model_name_or_path, accelerator.device, tokenizer + ) + dist.barrier() + load_latest_full_state(args=args, accelerator=accelerator) train( @@ -649,10 +759,11 @@ def main(args): train_loader, grad_accum, metric_logger, + teacher_model=teacher_model, ) - torch.distributed.barrier() - torch.distributed.destroy_process_group() + dist.barrier() + dist.destroy_process_group() # public API @@ -707,6 +818,7 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: f"--max_batch_len={train_args.max_batch_len}", f"--seed={train_args.random_seed}", f"--chat-tmpl-path={train_args.chat_tmpl_path}", + f"--weight_decay={train_args.weight_decay}", ] if train_args.checkpoint_at_epoch: @@ -787,6 +899,21 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: f"--fsdp_sharding_strategy={train_args.fsdp_options.sharding_strategy.value}" ) + # knowledge distillation settings + if train_args.use_distillation: + if not train_args.distillation_options: + raise ValueError( + "`use_distillation` was enabled but `distillation_options` was not set" + ) + command.extend( + [ + "--distill", + f"--distill_temp={train_args.distillation_options.temperature}", + f"--teacher_model_name_or_path={train_args.distillation_options.teacher_path}", + f"--distill_alpha={train_args.distillation_options.alpha}", + ] + ) + print(f"\033[92mRunning training command as subprocess: {' '.join(command)}\033[0m") process = None interrupt: KeyboardInterrupt | Exception | None = None @@ -837,6 +964,33 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: # TODO(osilkin): Configure a type that these args must adhere to for the sake of type checking # Maybe switch out from argparse to something smarter parser = argparse.ArgumentParser() + parser.add_argument( + "--teacher_model_name_or_path", + type=str, + default=None, + help="Path or reference to a HuggingFace repo of the knowledge model.", + ) + parser.add_argument( + "--distill", + default=False, + action="store_true", + help="Train with knowledge distillation from a teacher model.", + ) + parser.add_argument( + "--distill_temp", + type=float, + default=1.0, + help="Floating-point value used to 'soften' the target distribution. Values greater than 1.0 help with knowledge transfer.", + ) + parser.add_argument( + "--distill_alpha", + type=float, + default=1.0, + help=( + "Proportion of information to be distilled from the teacher model vs. the raw cross-entropy loss. " + "Use 1.0 for complete distillation, and 0.0 for complete cross-entropy loss." + ), + ) parser.add_argument("--model_name_or_path", type=str) parser.add_argument("--data_path", type=str) parser.add_argument("--output_dir", type=str) @@ -929,6 +1083,12 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: help="Which modules we should target for injecting LoRA layers. Defaults to selecting all projection layers when no values are provided.", ) parser.add_argument("--max_batch_len", type=int, default=60000) + parser.add_argument( + "--weight_decay", + type=float, + default=0, + help="Weight decay rate for optimizers that support it.", + ) parser.add_argument( "--cpu_offload_optimizer", action="store_true",