Skip to content

Parameter-efficient Fine Tuning for Clinical LLMs

Notifications You must be signed in to change notification settings

aryopg/clinical_peft

Repository files navigation

Parameter-Efficient Fine-Tuning LLaMA for the Clinical Domain

This repository contains the code for domain adaptation fine-tuning, downstream fine-tuning, and evaluation for "Parameter-Efficient Fine-Tuning LLaMA for the Clinical Domain" (in submission)

Table of Contents

🛠️ Setup

Cloning the codebase

As we rely on a specific version of PEFT, we opt into using git submodule. Consequently, you have to clone this repo with a --recurse-submodules flag:

git clone --recurse-submodules https://github.com/aryopg/clinical_peft.git

Python packages

This codebase requires multiple dependencies.

Dependencies
- pip
- numpy
- pandas
- pytorch
- transformers
- datasets
- huggingface-hub
- evaluate
- pydantic
- scikit-learn
- python-dotenv
- black
- isort
- PyYAML
- tqdm
- wandb
- jupyterlab
- matplotlib
- peft

We opted in to using conda as our package manager. The following will install the necessary dependencies for a GPU training:

conda env create -f environment.yml
conda activate clinical_peft

Environment variables

There are multiple environment variables required to run the training:

  • WANDB_API_KEY: The authorisation key to access your WandB projects
  • WANDB_PROJECT_NAME: The name that you like for this project
  • WANDB_ENTITY: The WandB entity that will host the project
  • HF_DOWNLOAD_TOKEN: Download token for Huggingface
  • HF_UPLOAD_TOKEN: Upload token for Huggingface
  • HF_USERNAME: Your HuggingFace username

We use the python-dotenv package to load these environment variables. To set them:

mkdir env
nano env/.env

Write down all of the mentioned environment variables with the appropriate values inside that file. Certainly, you don't have to use nano, as long as the file name (env/.env) remain the same.

💾 Dataset

Domain adaptation fine-tuning

A combination of MIMIC-IV de-identified discharge summaries (331,794) and radiology reports (2,321,355), resulting in a collection of 2,653,149 individual clinical notes.

Downstream fine-tuning

  • Length of stay (LOS): a multiclass classification task to predict the length of a patient’s hospital stay, categorised into four time-bins: less than three days, three to seven days, one to two weeks, and more than two weeks (van Aken et al., 2021).
  • In-hospital mortality (MOR): a binary classification task to predict whether a patient will survive during their hospital stay (van Aken et al., 2021).
  • Prolonged mechanical ventilation (PMV): a binary classification task to predict whether a 282 patient will require mechanical ventilation for 283 more than seven days (Huang et al., 2020).
  • Diagnoses (DIAG): an extreme multilabel classification task to predict the differential diagnoses associated with a patient, represented by simplified ICD-9 diagnosis codes (van Aken et al., 2021).
  • Procedures (PROC): an extreme multilabel classification task to predict the diagnostics or treatments administered to a patient, represented by simplified ICD-9 procedure codes (van Aken et al., 2021).

🤖 Training

Prepare the MIMIC-IV dataset

  1. Obtain and extract the MIMIC-IV clinical notes. The files of interest are discharge.csv.gz and radiology.csv.gz
  2. Run python scripts/prepare_mimic_iv.py --dataset_dir PATH/TO/MIMIC-IV-DIR
  3. The previous script will create the training data with extension .txt.gz in the same directory as the raw datasets

Prepare the downstream datasets

  1. For all datasets, you need to obtain the MIMIC-III clinical notes
  2. Once you've obtained and extracted them:
    1. LOS, MOR, and PMV: we use the instruction provided by BEEP repository
    2. DIAG, PROC: we use the instruction provided by CORe repository

Domain adaptation fine-tuning

We use Accelerate to run the model training. To launch a training, you need to specify the right config file which can be found in the configs/mimic_pretrain_hpo_configs/ directory. For example:

accelerate launch --mixed_precision=fp16 scripts/train.py --config_filepath configs/mimic_pretrain_hpo_configs/llama_lora.yaml

Downstream fine-tuning

Baseline models

accelerate launch --mixed_precision=fp16 scripts/train.py --config_filepath configs/context_length_512/downstream_hpo_configs/***<TASK_OF_INTEREST>***/bioclinicalbert_baseline.yaml
accelerate launch --mixed_precision=fp16 scripts/train.py --config_filepath configs/context_length_512/downstream_hpo_configs/***<TASK_OF_INTEREST>***/bluebert_baseline.yaml
accelerate launch --mixed_precision=fp16 scripts/train.py --config_filepath configs/context_length_512/downstream_hpo_configs/***<TASK_OF_INTEREST>***/core_baseline.yaml
accelerate launch --mixed_precision=fp16 scripts/train.py --config_filepath configs/context_length_512/downstream_hpo_configs/***<TASK_OF_INTEREST>***/umlsbert_baseline.yaml

Clinical LLaMA-LoRA

accelerate launch --mixed_precision=fp16 scripts/train.py --config_filepath configs/context_length_512/downstream_hpo_configs/***<TASK_OF_INTEREST>***/clinical_llama_finetune_downstream_peft.yaml

About

Parameter-efficient Fine Tuning for Clinical LLMs

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published