This repository contains the official PyTorch implementation of the following paper:
Identifying Sepsis Subphenotypes via Time-Aware Multi-ModalAuto-Encoder (KDD2020)
Changchang yin, Ruoqi Liu, Dongdong Zhang, Ping Zhang
paperAbstract: Sepsis is a heterogeneous clinical syndrome that is the leading cause of mortality in hospital intensive care units (ICUs). Identification of sepsis subphenotypes may allow for more precise treatments and lead to more targeted clinical interventions. Recently, sepsis subtyping on electronic health records (EHRs) has attracted interest from healthcare researchers. However, most sepsis subtyping studies ignore the temporality of EHR data and suffer from missing values. In this paper, we propose a new sepsis subtyping framework to address the two issues. Our subtyping framework consists of a novel Time-Aware Multi-modal auto-Encoder (TAME) model which introduces time-aware attention mechanism and incorporates multi-modal inputs (e.g., demographics, diagnoses, medications, lab tests and vital signs) to impute missing values, a dynamic time wrapping (DTW) method to measure patients' temporal similarity based on the imputed EHR data, and a weighted k-means algorithm to cluster patients. Comprehensive experiments on real-world datasets show TAME outperforms the baselines on imputation accuracy. After analyzing TAME-imputed EHR data, we identify four novel subphenotypes of sepsis patients, paving the way for improved personalization of sepsis management.
TAME takes multi-modal data (including demographic, diagnoses, medication, lab tests and vital signs) as inputs to impute missing values (in lab tests and vital signs).
Given the imputed data, DTW is adopted to compute temporal patient similarity. Based on the similarity matrix, wk-means is proposed to subtype sepsis patients into phenotypes.
TAME
|
|--code
|
|--file * The preprocessing codes will write some files here.
|
|--data * Put the downloaded datasets here.
| |
| |--DACMI * Put the downloaded DACMI dataset here.
| | |
| | |--train_groundtruth
| | |
| | |--train_with_missing
| |
| |--MIMIC
| |
| |--initial_mimiciii * Put the downloaded MIMIC-III dataset here.
| | |
| | |--PATIENTS.csv
| | |
| | |--ICUSTAYS.csv
| | |
| | |--DIAGNOSES_ICD.csv
| | |
| | |--PRESCRIPTIONS.csv
| | |
| | |--LABEVENTS.csv
| | |
| | |--pivoted_lab.csv * Generate pivoted_*.csv with https://github.com/MIT-LCP/mimic-code.
| | |
| | |--pivoted_vital.csv
| | |
| | |--pivoted_sofa.csv
| |
| |--train_groundtruth
| |
| |--train_with_missing
|
|
|--result * The imputation results and clustering results are here.
|
|--DACMI
|
|--MIMIC
Ubuntu16.04, python2.7
Install pytorch 1.3.0
-
Download DACMI dataset and put the data in TAME/data/DACMI/.
-
Preprocess DACMI data.
cd code/preprocessing
python generate_value_distribution.py --dataset DACMI
-
Download MIMIC-III dataset and put the data in TAME/data/MIMIC/initial_mimiciii/.
-
Generate pivoted files (pivoted_lab.csv, pivoted_vital.csv, pivoted_sofa.csv) according to MIT-LCP/mimic-code, and put the data in TAME/data/MIMIC/initial_mimiciii/.
- SQL for pivoted file generation can be found here.
- Preprocess MIMIC-III data.
cd code/preprocessing
python preprocess_mimic_data.py --dataset MIMIC
python generate_sepsis_variables.py --dataset MIMIC
python generate_value_distribution.py --dataset MIMIC
- Train TAME model.
cd code/TAME
python main.py --dataset DACMI
cd code/TAME
python main.py --dataset MIMIC
- Generate the imputation results.
cd code/TAME
python main.py --dataset DACMI --phase test --resume ../../data/DACMI/models/best.ckpt
cd code/TAME
python main.py --dataset MIMIC --phase test --resume ../../data/MIMIC/models/best.ckpt
The imputation results on DACMI dataset.
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Method | PCL | PK | PLCO2 | PNA | HCT | HGB | MCV | PLT | WBC | RDW | PBUN | PCRE | PGLU | Mean
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Mean | 0.295 | 0.277 | 0.301 | 0.293 | 0.287 | 0.292 | 0.309 | 0.319 | 0.299 | 0.318 | 0.313 | 0.306 | 0.282 | 0.299
KNN | 0.220 | 0.249 | 0.241 | 0.228 | 0.220 | 0.221 | 0.267 | 0.248 | 0.251 | 0.253 | 0.238 | 0.244 | 0.264 | 0.242
3DMICE | 0.200 | 0.263 | 0.231 | 0.214 | 0.150 | 0.149 | 0.229 | 0.256 | 0.246 | 0.185 | 0.234 | 0.277 | 0.224 | 0.220
T-LGBM | 0.135 | 0.226 | 0.179 | 0.156 | 0.100 | 0.092 | 0.229 | 0.158 | 0.199 | 0.202 | 0.134 | 0.183 | 0.240 | 0.172
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
BRNN | 0.155 | 0.230 | 0.196 | 0.174 | 0.090 | 0.087 | 0.245 | 0.175 | 0.211 | 0.208 | 0.154 | 0.210 | 0.252 | 0.184
CATSI | 0.174 | 0.243 | 0.203 | 0.196 | 0.144 | 0.135 | 0.253 | 0.186 | 0.227 | 0.213 | 0.157 | 0.206 | 0.260 | 0.200
DETROIT | 0.138 | 0.219 | 0.172 | 0.155 | 0.093 | 0.087 | 0.234 | 0.152 | 0.199 | 0.201 | 0.137 | 0.181 | 0.262 | 0.172
BRITS | 0.142 | 0.208 | 0.176 | 0.154 | 0.121 | 0.115 | 0.244 | 0.165 | 0.206 | 0.218 | 0.167 | 0.192 | 0.268 | 0.183
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
TAME-T | 0.102 | 0.187 | 0.144 | 0.128 | 0.079 | 0.074 | 0.227 | 0.144 | 0.199 | 0.211 | 0.128 | 0.201 | 0.228 | 0.158
TAME-V | 0.121 | 0.193 | 0.164 | 0.145 | 0.081 | 0.076 | 0.238 | 0.168 | 0.203 | 0.201 | 0.141 | 0.210 | 0.222 | 0.166
TAME | 0.100 | 0.179 | 0.155 | 0.125 | 0.073 | 0.077 | 0.218 | 0.136 | 0.198 | 0.180 | 0.121 | 0.185 | 0.221 | 0.151
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
The imputation results on MIMIC-III dataset (all 27 variables).
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Model | aniongap |bicarbonate | creatinine | chloride | glucose | hemoglobin | lactate | platelet | ptt | inr | pt | sodium | bun | wbc | spo2 | C-reactive | heartrate | hematocrit | sysbp | tempc | diasbp | gcs | resprate | bands | meanbp | Magnesium |urineoutput | Mean
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Mean | 0.29 | 0.24 | 0.25 | 0.22 | 0.30 | 0.27 | 0.42 | 0.27 | 0.46 | 0.32 | 0.41 | 0.24 | 0.23 | 0.26 | 0.32 | 1.37 | 0.40 | 0.31 | 0.31 | 0.38 | 0.32 | 0.37 | 0.31 | 0.86 | 0.22 | 0.42 | 0.34 | 0.37
KNN | 0.28 | 0.22 | 0.22 | 0.22 | 0.30 | 0.25 | 0.44 | 0.26 | 0.38 | 0.31 | 0.29 | 0.24 | 0.22 | 0.25 | 0.30 | 1.31 | 0.37 | 0.28 | 0.24 | 0.37 | 0.24 | 0.38 | 0.27 | 0.80 | 0.17 | 0.41 | 0.33 | 0.34
3DMICe | 0.22 | 0.19 | 0.22 | 0.18 | 0.27 | 0.18 | 0.42 | 0.25 | 0.40 | 0.25 | 0.29 | 0.20 | 0.22 | 0.25 | 0.27 | 1.20 | 0.34 | 0.28 | 0.24 | 0.36 | 0.20 | 0.33 | 0.27 | 0.79 | 0.15 | 0.38 | 0.30 | 0.32
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
BRNN | 0.15 | 0.17 | 0.20 | 0.13 | 0.29 | 0.12 | 0.40 | 0.20 | 0.41 | 0.16 | 0.24 | 0.18 | 0.17 | 0.26 | 0.25 | 1.28 | 0.31 | 0.25 | 0.18 | 0.26 | 0.17 | 0.23 | 0.24 | 0.96 | 0.13 | 0.36 | 0.27 | 0.30
CATSI | 0.12 | 0.12 | 0.22 | 0.13 | 0.29 | 0.14 | 0.41 | 0.22 | 0.42 | 0.20 | 0.25 | 0.18 | 0.20 | 0.23 | 0.25 | 1.13 | 0.34 | 0.25 | 0.18 | 0.24 | 0.16 | 0.22 | 0.24 | 0.85 | 0.13 | 0.33 | 0.27 | 0.29
DETROIT | 0.11 | 0.09 | 0.28 | 0.09 | 0.27 | 0.13 | 0.38 | 0.22 | 0.46 | 0.17 | 0.24 | 0.10 | 0.17 | 0.22 | 0.26 | 1.10 | 0.31 | 0.26 | 0.18 | 0.24 | 0.16 | 0.23 | 0.24 | 0.78 | 0.13 | 0.33 | 0.25 | 0.27
BRITS | 0.12 | 0.08 | 0.23 | 0.12 | 0.27 | 0.12 | 0.39 | 0.20 | 0.41 | 0.18 | 0.24 | 0.16 | 0.20 | 0.20 | 0.26 | 1.22 | 0.32 | 0.23 | 0.19 | 0.23 | 0.15 | 0.20 | 0.24 | 0.84 | 0.13 | 0.34 | 0.26 | 0.28
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
TAME-T | 0.13 | 0.11 | 0.24 | 0.10 | 0.25 | 0.11 | 0.34 | 0.19 | 0.36 | 0.21 | 0.24 | 0.11 | 0.17 | 0.20 | 0.24 | 1.29 | 0.24 | 0.17 | 0.15 | 0.25 | 0.13 | 0.16 | 0.23 | 0.71 | 0.12 | 0.30 | 0.24 | 0.26
TAME-V | 0.16 | 0.13 | 0.23 | 0.12 | 0.26 | 0.11 | 0.36 | 0.20 | 0.38 | 0.19 | 0.22 | 0.14 | 0.17 | 0.20 | 0.23 | 1.29 | 0.23 | 0.18 | 0.13 | 0.21 | 0.12 | 0.15 | 0.22 | 0.70 | 0.11 | 0.31 | 0.24 | 0.26
TAME-M | 0.13 | 0.11 | 0.24 | 0.10 | 0.25 | 0.11 | 0.34 | 0.19 | 0.36 | 0.21 | 0.24 | 0.11 | 0.17 | 0.23 | 0.24 | 1.33 | 0.24 | 0.20 | 0.14 | 0.25 | 0.14 | 0.16 | 0.24 | 0.73 | 0.14 | 0.32 | 0.25 | 0.27
TAME | 0.11 | 0.09 | 0.19 | 0.08 | 0.26 | 0.09 | 0.35 | 0.18 | 0.38 | 0.15 | 0.20 | 0.10 | 0.14 | 0.21 | 0.22 | 1.16 | 0.23 | 0.19 | 0.13 | 0.24 | 0.12 | 0.16 | 0.20 | 0.73 | 0.12 | 0.31 | 0.23 | 0.25
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Compute temporal patient similarity with DTW (on MIMIC-III dataset).
cd code/DTW
python main.py
Cluster patients with weighted k-means (on MIMIC-III dataset).
cd code/wkmeans
python main.py
We conduct experiments with different K and compute the mean P-value for the variables. As is shown in the figure, K=4 is the best option (which is the elbow point).
The ICU mortality rates in four phenotypes.
The variables' severity heatmap across the four phenotypes.