This repository contains the code for domain adaptation fine-tuning, downstream fine-tuning, and evaluation for "Parameter-Efficient Fine-Tuning LLaMA for the Clinical Domain" (in submission)
As we rely on a specific version of PEFT, we opt into using git submodule.
Consequently, you have to clone this repo with a --recurse-submodules
flag:
git clone --recurse-submodules https://github.com/aryopg/clinical_peft.git
This codebase requires multiple dependencies.
Dependencies
- pip
- numpy
- pandas
- pytorch
- transformers
- datasets
- huggingface-hub
- evaluate
- pydantic
- scikit-learn
- python-dotenv
- black
- isort
- PyYAML
- tqdm
- wandb
- jupyterlab
- matplotlib
- peft
We opted in to using conda as our package manager. The following will install the necessary dependencies for a GPU training:
conda env create -f environment.yml
conda activate clinical_peft
There are multiple environment variables required to run the training:
- WANDB_API_KEY: The authorisation key to access your WandB projects
- WANDB_PROJECT_NAME: The name that you like for this project
- WANDB_ENTITY: The WandB entity that will host the project
- HF_DOWNLOAD_TOKEN: Download token for Huggingface
- HF_UPLOAD_TOKEN: Upload token for Huggingface
- HF_USERNAME: Your HuggingFace username
We use the python-dotenv
package to load these environment variables. To set them:
mkdir env
nano env/.env
Write down all of the mentioned environment variables with the appropriate values inside that file.
Certainly, you don't have to use nano
, as long as the file name (env/.env
) remain the same.
A combination of MIMIC-IV de-identified discharge summaries (331,794) and radiology reports (2,321,355), resulting in a collection of 2,653,149 individual clinical notes.
- Length of stay (LOS): a multiclass classification task to predict the length of a patient’s hospital stay, categorised into four time-bins: less than three days, three to seven days, one to two weeks, and more than two weeks (van Aken et al., 2021).
- In-hospital mortality (MOR): a binary classification task to predict whether a patient will survive during their hospital stay (van Aken et al., 2021).
- Prolonged mechanical ventilation (PMV): a binary classification task to predict whether a 282 patient will require mechanical ventilation for 283 more than seven days (Huang et al., 2020).
- Diagnoses (DIAG): an extreme multilabel classification task to predict the differential diagnoses associated with a patient, represented by simplified ICD-9 diagnosis codes (van Aken et al., 2021).
- Procedures (PROC): an extreme multilabel classification task to predict the diagnostics or treatments administered to a patient, represented by simplified ICD-9 procedure codes (van Aken et al., 2021).
- Obtain and extract the MIMIC-IV clinical notes. The files of interest are
discharge.csv.gz
andradiology.csv.gz
- Run
python scripts/prepare_mimic_iv.py --dataset_dir PATH/TO/MIMIC-IV-DIR
- The previous script will create the training data with extension
.txt.gz
in the same directory as the raw datasets
- For all datasets, you need to obtain the MIMIC-III clinical notes
- Once you've obtained and extracted them:
We use Accelerate to run the model training.
To launch a training, you need to specify the right config file which can be found in the configs/mimic_pretrain_hpo_configs/
directory. For example:
accelerate launch --mixed_precision=fp16 scripts/train.py --config_filepath configs/mimic_pretrain_hpo_configs/llama_lora.yaml
accelerate launch --mixed_precision=fp16 scripts/train.py --config_filepath configs/context_length_512/downstream_hpo_configs/***<TASK_OF_INTEREST>***/bioclinicalbert_baseline.yaml
accelerate launch --mixed_precision=fp16 scripts/train.py --config_filepath configs/context_length_512/downstream_hpo_configs/***<TASK_OF_INTEREST>***/bluebert_baseline.yaml
accelerate launch --mixed_precision=fp16 scripts/train.py --config_filepath configs/context_length_512/downstream_hpo_configs/***<TASK_OF_INTEREST>***/core_baseline.yaml
accelerate launch --mixed_precision=fp16 scripts/train.py --config_filepath configs/context_length_512/downstream_hpo_configs/***<TASK_OF_INTEREST>***/umlsbert_baseline.yaml
accelerate launch --mixed_precision=fp16 scripts/train.py --config_filepath configs/context_length_512/downstream_hpo_configs/***<TASK_OF_INTEREST>***/clinical_llama_finetune_downstream_peft.yaml