diff --git a/k8s/peft/karpathy_speedups_example.yml b/k8s/peft/karpathy_speedups_example.yml new file mode 100644 index 0000000..f486b43 --- /dev/null +++ b/k8s/peft/karpathy_speedups_example.yml @@ -0,0 +1,95 @@ +apiVersion: batch/v1 +kind: Job +# This is used for naming the job and pod, and letting other cluster/namespace users know I created it +metadata: + generateName: bking2--hf-libraries-demo- + labels: + user: bking2 + k8s-app: bking2-hf-libraries-demo +spec: + template: + spec: + # Here we additionally specify that we need our pod (created by the job) to attach to a node with an A100 + affinity: + nodeAffinity: + requiredDuringSchedulingIgnoredDuringExecution: + nodeSelectorTerms: + - matchExpressions: + - key: nvidia.com/gpu.product + operator: In + values: + - NVIDIA-A100-SXM4-80GB + # Here is where we define the core parts of the job. We need 1) the Docker image 2) it's environment requirements + # (CPU/Memory/GPU) and 3) the command that gets run + containers: + - name: bking2-hf-libraries-demo + image: kingb12/hf_libraries_demo:latest + # Here I've added a secret for my weights and biases API key, so the job + # can create logs, and my huggingface API key, so I can download weights + envFrom: + - secretRef: + name: bking2-wandb-api-key-71a5 + - secretRef: + name: bking2-hf-api-token + resources: + limits: + memory: 64Gi + cpu: 32 + nvidia.com/gpu: "1" + requests: + memory: 32Gi + cpu: 16 + nvidia.com/gpu: "1" + command: [ "/bin/sh" ] + # This includes further setup to 1) cache transformers and datasets on my volume so weights don't need to be + # re-downloaded on each run and 2) log in to huggingface since Starcoder is agreement protected. + # everything after 'job ready to start' is the script we want to run. Using + # conda run --no-capture-output -p ./venv runs things with the correct conda environment + + # Note: rather than clone this job over different arguments to batch size, I just modified them here as I created + # things. + args: + - -c + - >- + cd /home/bking2/hf_libraries_demo && + export TRANSFORMERS_CACHE=/data/users/bking2/.cache/huggingface && + export HF_HOME=/data/users/bking2/.cache/huggingface && + pip install huggingface_hub && + python -c "from huggingface_hub.hf_api import HfFolder; HfFolder.save_token('${HF_API_TOKEN}')" && + echo "job ready to start" && + echo "import hf_libraries_demo.package_demo.addition_module as mymod\nprint(f'4 + 5 is {mymod.add_five_to_x(4)}')" > demo.py && + conda run --no-capture-output -p ./venv python src/hf_libraries_demo/experiments/peft/karpathy_speedups_example.py --batch_size 1 && + echo "job complete!" + # some arguments needed by kubernetes, plus some useful defaults + volumeMounts: + - mountPath: /data/users/bking2 + name: bking2-data-volume + restartPolicy: Never + schedulerName: default-scheduler + securityContext: {} + serviceAccount: default + serviceAccountName: default + terminationGracePeriodSeconds: 30 + # tolerations are used to define what to do if the cluster isn't ready, can't be reached, etc. Other tolerations + # can be used to define what to do when resources are inadequate for our requests/limits + tolerations: + - effect: NoExecute + key: node.kubernetes.io/not-ready + operator: Exists + tolerationSeconds: 300 + - effect: NoExecute + key: node.kubernetes.io/unreachable + operator: Exists + tolerationSeconds: 300 + # We add a toleration telling k8s not to schedule our job if no A100s are available yet + - effect: PreferNoSchedule + key: nvidia.com/gpu + operator: Exists + # here we specify the data volume as well. So far, I just use this for caching transformer/dataset weights + # See https://ucsd-prp.gitlab.io/userdocs/tutorial/storage/ for info on creating a data volume to mount to like + # this (pre-requisite to mounting as in this job, not shown in repo) + volumes: + - name: bking2-data-volume + persistentVolumeClaim: + claimName: bking2-data-volume + backoffLimit: 0 diff --git a/src/hf_libraries_demo/experiments/peft/README.md b/src/hf_libraries_demo/experiments/peft/README.md index ae68b5a..24b0ca1 100644 --- a/src/hf_libraries_demo/experiments/peft/README.md +++ b/src/hf_libraries_demo/experiments/peft/README.md @@ -24,3 +24,26 @@ from achieving this. See [`./flops_counter.py`](flops_counter.py) for an example counter that will work with the Huggingface Trainer, its use in [`./base_with_tflops.py`](base_with_tflops.py), and recorded logs in W&B at [kingb12/hf_libraries_demo_peft_example](https://wandb.ai/kingb12/hf_libraries_demo_peft_example) + +## Aside: Running using a single A100 w/ Kubernetes + +If you're compute environment is like mine, your use of a full A100 is mediated by a kubernetes cluster +(I use [Nautilus](https://portal.nrp-nautilus.io/)). In [/k8s](../../../../k8s/README.md), I have information and +templates used to run the `./base_with_tflops.py` in Docker as a kubernetes job on the Nautilus cluster. + +## Increasing batch size and dataloader num_workers + +We'll start with some of the easiest to implement tricks from this [@karpathy tweet](https://twitter.com/karpathy/status/1299921324333170689?s=20), and then figure out what batch size works best: + +![Karpathy Tweet](https://pbs.twimg.com/media/Ego_hTIUwAARnS6?format=png&name=900x900) + +Specifically, we'll: +- set `num_workers > 0` and default to `pin_memory=True` +- use `torch.backends.cudnn.benchmark = True` +- try to max out batch size on our GPU + +For different batch sizes, we'll plot TFLOPS and training samples per second. The code for +this example is in [./karpathy_speedups_example.py](./karpathy_speedups_example.py), and only differs +from the base example in that it 1) parses command line arguments for above 2) supplies them to +the appropriate [TrainerArguments](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments). + diff --git a/src/hf_libraries_demo/experiments/peft/karpathy_speedups_example.py b/src/hf_libraries_demo/experiments/peft/karpathy_speedups_example.py new file mode 100644 index 0000000..11136e7 --- /dev/null +++ b/src/hf_libraries_demo/experiments/peft/karpathy_speedups_example.py @@ -0,0 +1,121 @@ +""" +Adding @karpathys speed-ups, and taking batch size as an argument for fine-tuning StarCoder with the Peft Library +""" +import argparse +import os + +import wandb +from datasets import load_dataset, Dataset, DatasetDict +from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training +from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer +from transformers import TrainingArguments + +from hf_libraries_demo.experiments.peft.flops_counter import TFLOPSCallback +from hf_libraries_demo.experiments.peft.utils import SavePeftModelCallback, LoadBestPeftModelCallback, \ + print_trainable_parameters + +if __name__ == "__main__": + # parse arguments + parser = argparse.ArgumentParser(description="Training arguments parser") + parser.add_argument('--batch_size', type=int, default=1, help='Batch size for training (default: 1)') + parser.add_argument('--num_workers', type=int, default=8, help='Number of workers for data loading (default: 8)') + parser.add_argument('--pin_memory', action='store_true', default=True, + help='Use pinned (page-locked) memory. If not set, defaults to True.') + args = parser.parse_args() + + # Load the and process dataset. Added more training data points to get a more complete test. + full_dataset: Dataset = load_dataset("HuggingFaceH4/CodeAlpaca_20K", split=f"train[0:{128*10}]", use_auth_token=True) + split_dataset: DatasetDict = full_dataset.train_test_split(test_size=0.1) + + # take each prompt and completion and form a single text with a 'Question' and 'Answer', drop existing columns + split_dataset = split_dataset.map( + lambda item: {'text': f"Question: {item['prompt']}\n\nAnswer: {item['completion']}"}, + remove_columns=split_dataset['train'].column_names + ) + + # setup the tokenizer and tokenizer, ignore padding/truncation for now since we're using batch size 1 + tokenizer = AutoTokenizer.from_pretrained("bigcode/starcoder", use_auth_token=True) + tokenized_dataset = split_dataset.map(lambda batch: tokenizer(batch['text']), batched=True) + + # set the labels to the inputs. In this case, the MODEL will know to do appropriate shifting for Causal LM + tokenized_dataset = tokenized_dataset.map(lambda batch: {'labels': batch['input_ids']}, batched=True) + + model = AutoModelForCausalLM.from_pretrained( + "bigcode/starcoder", + use_auth_token=True, + use_cache=True, + # note this argument for loading the in 8-bit mode + load_in_8bit=True, + device_map="auto", + ) + + # some model preparation work done by `peft` + model = prepare_model_for_kbit_training(model) + + # For our parameter efficient tuning method, we'll use LoRA + lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=.05, + bias="none", + task_type="CAUSAL_LM", + target_modules=["c_proj", "c_attn", "q_attn"] + ) + + # get a peft model based on our config and base model + model = get_peft_model(model, lora_config) + + # for information, we'll log the total number of parameters and those that are trainable (requires_grad=True) + print_trainable_parameters(model) + + # wandb init for logging (log as this file name, no hyperparameters) + run = wandb.init(project="hf_libraries_demo_peft_example", name=os.path.basename(__file__)) + + wandb.log(vars(args)) + + # Finally, set up a Trainer and train as in typical fine-tuning. Taking very few steps again + output_dir: str = "./outputs" + os.makedirs(output_dir, exist_ok=True) + training_args = TrainingArguments( + output_dir=output_dir, + evaluation_strategy="steps", + save_strategy="steps", + load_best_model_at_end=True, + max_steps=32, + eval_steps=16, + save_steps=16, + logging_steps=1, + # We're optimizing training speed but in a real setup you can increase eval batch size beyond train batch size + per_device_train_batch_size=args.batch_size, + per_device_eval_batch_size=args.batch_size, + learning_rate=5e-6, + lr_scheduler_type="cosine", + warmup_steps=100, + gradient_accumulation_steps=4, # our effective batch size will be 4 as a result + fp16=True, + weight_decay=0.05, + report_to="wandb", + # implementing @karpathy's simple speed-ups for the dataloader. If using k8s, make sure cpu requests > this val + dataloader_num_workers=args.num_workers, + dataloader_pin_memory=args.pin_memory + ) + + + # Create a TFLOPs Callback which logs to wandb + tflops_callback: TFLOPSCallback = TFLOPSCallback(logging_callback=wandb.log) + + # setup the trainer and initiate training + trainer = Trainer( + model=model, + args=training_args, + train_dataset=tokenized_dataset['train'], + eval_dataset=tokenized_dataset['test'], + # these are defined in utils.py, and are convenience methods for saving and loading peft models without + # saving/loading the large model over again + callbacks=[ + SavePeftModelCallback(checkpoint_dir=output_dir), + LoadBestPeftModelCallback(), + tflops_callback + ] + ) + trainer.train()