-
Notifications
You must be signed in to change notification settings - Fork 0
/
streamlit.py
172 lines (151 loc) · 6.13 KB
/
streamlit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
# streamlit test, version1
import pandas as pd
import numpy as np
from pickle import load
import streamlit as st
from PIL import Image
import numpy as np
import pandas as pd
from xgboost import XGBClassifier
from matplotlib import pyplot as plt
from numpy import sqrt
from numpy import argmax
import joblib
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_curve
from sklearn.metrics import classification_report
from sklearn.metrics import roc_auc_score
import sklearn.metrics as metrics
from sklearn.metrics import plot_roc_curve
from sklearn.calibration import CalibratedClassifierCV
from sklearn.metrics import confusion_matrix
from sklearn.utils import resample
from pickle import load
import shap
# loading in the model to predict on the data
## xgb models
model60 = joblib.load('ua60_model')
## scaler
scaler60 = joblib.load('scaler60.pkl')
## explainer
expl60 = joblib.load('explainer60.pkl')
## im non
im_non = Image.open('non.jpg')
## im pro
im_pro = Image.open('pro.jpg')
## im normal
im_normal = Image.open('normal.jpg')
## im normal
im_abnormal = Image.open('abnormal.jpg')
# custom def : shap
def shap(
sample_case,
scaler = scaler60,
explainer = expl60
):
# standardization columns
std_cols=['age','he_uph','he_usg']
# feature extraction from input data UA
sample_case_features = sample_case.loc[:,['male', 'he_usg', 'he_uph', 'he_ubld', 'he_uglu', 'he_upro', 'age']]
sample_case_features[std_cols] = scaler.transform(sample_case_features[std_cols])
expl_test = expl60.shap_values(sample_case_features.iloc[0])
shap_bar = pd.DataFrame(
{'shap_value(probability)' : expl_test}, index = ['sex', 'urine specific gravity', 'urine pH', 'urine blood', 'urine glucose', 'urien protein', 'age'])
# clrs = ['blue' if x < 0 else 'red' for x in shap_var['shap']]
return shap_bar
# custom def : standardization and prediction
def model_prediction(
sample_case,
scaler = joblib.load('scaler60.pkl'),
model = joblib.load('ua60_model')
):
"""
UA5 type model
he_usg = Urine specific gravity
he_uph = Urine pH
he_ubld = Urine blood
he_uglu = Urine glucose
he_upro = Urine protein
"""
# standardization columns
std_cols=['age','he_uph','he_usg']
# feature extraction from input data UA
sample_case_features = sample_case.loc[:,['male', 'he_usg', 'he_uph', 'he_ubld', 'he_uglu', 'he_upro', 'age']]
sample_case_features[std_cols] = scaler.transform(sample_case_features[std_cols])
# predict probability by model
prob = model.predict_proba(sample_case_features)[:,1]
return np.float64(prob)
def data_mapping(df):
"""this function preprocess the user input
return type: pandas dataframe
"""
df.male = df.male.map({'female':0, 'male':1})
df.he_ubld = df.he_ubld.map({"-":0, "+/-":1, "1+":2, "2+":3, "3+":4, "4+":5})
df.he_upro = df.he_upro.map({"-":0, "+/-":1, "1+":2, "2+":3, "3+":4, "4+":5})
df.he_uglu = df.he_uglu.map({"-":0, "+/-":1, "1+":2, "2+":3, "3+":4, "4+":5})
return df
def main():
# giving the webpage a title
st.title("check your function of kidney")
# here we define some of the front end elements of the web page like
# the font and background color, the padding and the text to be displayed
html_temp = """
<div style ="background-color:grey;padding:13px">
<h1 style ="color:black;text-align:center;">eGFR60 classifier ML App </h1>
</div>
"""
# this line allows us to display the front end aspects we have
# defined in the above code
st.markdown(html_temp, unsafe_allow_html = True)
# the following lines create text boxes in which the user can enter
# the data required to make the prediction
age = st.sidebar.slider("age", 0, 100, 1)
male = st.sidebar.selectbox("sex", ("female", "male"))
he_usg = st.sidebar.selectbox("urine specific gravity", (1.000, 1.005, 1.010, 1.015, 1.020, 1.025, 1.030))
he_uph = st.sidebar.selectbox("urine pH", (5.0, 5.5, 6.0, 6.5, 7.0, 7.5, 8.0, 8.5, 9.0))
he_ubld = st.sidebar.selectbox("urine blood", ("-", "+/-", "1+", "2+", "3+", "4+"))
he_uglu = st.sidebar.selectbox("urine glucose", ("-", "+/-", "1+", "2+", "3+", "4+"))
he_upro = st.sidebar.selectbox("urine protein", ("-", "+/-", "1+", "2+", "3+", "4+"))
features = {"male" : male,
"he_usg" : he_usg,
"he_uph" : he_uph,
"he_ubld" : he_ubld,
"he_uglu" : he_uglu,
"he_upro" : he_upro,
"age" : age}
sample_case = pd.DataFrame(features, index=[0])
result = ""
prob = 0.0
# the below line ensures that when the button called 'Predict' is clicked,
# the prediction function defined above is called to make the prediction
# and store it in the variable result
if st.button("Predict"):
sample_case_map = data_mapping(sample_case)
result = model_prediction(sample_case_map)
prob = result
shap_bar = shap(sample_case_map)
st.success('probability : {}'.format(result))
if ((sample_case_map['he_upro'].item()<=1) and (prob > 0.44)) :
#st.success("prediction : eGFR<60, abnormal")
st.success("threshold : 0.44")
st.image(im_abnormal)
st.image(im_non, caption='reference')
st.bar_chart(data=shap_bar)
elif ((sample_case_map['he_upro'].item()<=1) and (prob <= 0.44)) :
#st.success("prediction : eGFR>=60, normal")
st.success("threshold : 0.44")
st.image(im_normal)
st.bar_chart(data=shap_bar)
elif((sample_case_map['he_upro'].item()>1) and (prob > 0.77)) :
#st.success("prediction : eGFR<60, abnormal")
st.success("threshold : 0.77")
st.image(im_abnormal)
st.bar_chart(data=shap_bar)
st.image(im_pro, caption='reference')
else :
#st.success("prediction : eGFR>=60, normal")
st.success("threshold : 0.77")
st.image(im_normal)
st.bar_chart(data=shap_bar)
if __name__=='__main__':
main()