-
Notifications
You must be signed in to change notification settings - Fork 0
/
extract.py
62 lines (48 loc) · 2.06 KB
/
extract.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import logging
import os
import pandas as pd
import numpy as np
import psycopg2
import psycopg2.extras
logger = logging.getLogger()
MIMIC_2 = 'mimic2'
MIMIC_3 = 'mimic3'
def mimic3_diagnosis_icd(conn):
logger.info("Pulling ICD9 Diagnosis codes for MIMIC3 Admissions")
columns = ['hadm_id', 'subject_id', 'sequence', 'icd9_code', 'short_title', 'long_title']
query = """
select d_icd.hadm_id, d_icd.subject_id, d_icd.seq_num,
icd.icd9_code, icd.short_title, icd.long_title
from DIAGNOSES_ICD as d_icd, D_ICD_DIAGNOSES as icd
where d_icd.icd9_code = icd.icd9_code
"""
return _pull_from_db(conn, columns, query, 'admission_codes.csv')
def mimic3_note_adult_only(conn):
query = """
select * from (
select EXTRACT(YEAR from a.admittime) - EXTRACT(YEAR from p.dob) AS age from admissions as a, patients as p where a.subject_id = p.subject_id
) as ages
where ages.age > 18;
"""
return _pull_from_db(conn, ['*'], query, MIMIC_3, 'notes.csv')
def mimic3_notes(conn):
logger.info("Pulling all Discharge Summaries for MIMIC3")
columns = ['subject_id', 'hadm_id', 'chartdate', 'category', 'description', 'text']
query = f"select {', '.join(columns)} from NOTEEVENTS where category='Discharge summary'"
return _pull_from_db(conn, columns, query, 'notes.csv')
def _pull_from_db(conn, columns, query, filename):
cursor = conn.cursor(cursor_factory=psycopg2.extras.DictCursor)
logger.info(f"Executing Query:\n{query}")
cursor.execute(query)
data = pd.DataFrame(np.array(cursor.fetchall()), columns=columns)
os.makedirs(f'data/', exist_ok=True)
data.to_csv(f'data/{filename}', index=False)
logger.info(f'Written {filename} to working dir .')
cursor.close()
return data
def append_addendums(df):
join_str = '\n\n'
aggs_by_col = {'subject_id': lambda s: list(s)[0], 'chartdate': lambda s: list(s)[0],
'text': lambda s: join_str.join(s)}
grouped = df.groupby(by=['hadm_id'], group_keys=False).agg(aggs_by_col).reset_index()
return grouped