Skip to content

Commit

Permalink
Add SHAP explainer for LSTM in fusion model
Browse files Browse the repository at this point in the history
  • Loading branch information
sophmrtn committed Jul 3, 2024
1 parent 83267b9 commit 07e1bed
Showing 1 changed file with 94 additions and 72 deletions.
166 changes: 94 additions & 72 deletions src/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import argparse
import os
import sys
import time

import matplotlib.pyplot as plt
Expand All @@ -10,7 +9,6 @@
import toml
from datasets import CollateTimeSeries, MIMIC4Dataset
from fairlearn.metrics import (
MetricFrame,
count,
demographic_parity_ratio,
equalized_odds_difference,
Expand Down Expand Up @@ -108,7 +106,7 @@
y_test = np.array(y_test)

model = load_pickle(args.model_path)
print("Evaluating on validation data...")
print("Evaluating on test data...")
y_hat = model.predict(x_test)
prob = model.predict_proba(x_test)[:, 1]

Expand All @@ -119,8 +117,11 @@
collate_fn=CollateTimeSeries(),
)

model = MMModel.load_from_checkpoint(checkpoint_path=args.model_path)
print("Evaluating on validation data...")
st_first = False if "mag-ts" in args.model_path else True
model = MMModel.load_from_checkpoint(
checkpoint_path=args.model_path, st_first=st_first
)
print("Evaluating on test data...")

trainer = Trainer(accelerator="gpu")
output = trainer.predict(model, dataloaders=test_dataloader)
Expand All @@ -138,69 +139,90 @@

### Explain
if args.explain:
if model_type != "rf":
print("Explanations for fusion models have not been implemented.")
sys.exit()

# Visualise important features
features = test_set.get_feature_list()
importances = model.feature_importances_
indices = np.argsort(importances)

plt.figure(figsize=(20, 10))
plt.title("Feature Importances")
plt.barh(range(len(indices)), importances[indices], color="b", align="center")
plt.yticks(range(len(indices)), [features[i] for i in indices])
plt.xlabel("Relative Importance")
plt.show()
if model_type == "rf":
# Visualise important features
features = test_set.get_feature_list()
importances = model.feature_importances_
indices = np.argsort(importances)

plt.figure(figsize=(20, 10))
plt.title("Feature Importances")
plt.barh(
range(len(indices)), importances[indices], color="b", align="center"
)
plt.yticks(range(len(indices)), [features[i] for i in indices])
plt.xlabel("Relative Importance")
plt.show()

# Create shap plot
explainer = shap.TreeExplainer(model)
# Create shap plot
explainer = shap.TreeExplainer(model, x_test)

# Plot single waterfall plot
# Plot single waterfall plot
# Correct classification (TP)
tps = np.argwhere(np.logical_and(y_hat == 1, y_test == 1))

# Correct classification (TP)
tps = np.argwhere(np.logical_and(y_hat == 1, y_test == 1))
if len(tps) > 0:
tp = tps[0][0]

if len(tps) > 0:
tp = tps[0][0]
plt.figure(figsize=(12, 4))
plt.title(
f"Truth: {int(y_test[tp])}, Predict: {int(y_hat[tp])}, Prob: {round(prob[tp], 2)}"
)
shap.bar_plot(
explainer(x_test[tp])[:, 1].values,
feature_names=features,
max_display=20,
)
plt.show()

plt.figure(figsize=(12, 4))
plt.title(
f"Truth: {int(y_test[tp])}, Predict: {int(model.predict(x_test[tp].reshape(1,-1)))}, Prob: {round(model.predict_proba(x_test[tp].reshape(1,-1))[:,1][0], 2)}"
)
shap.bar_plot(
explainer(x_test[tp])[:, 1].values,
feature_names=features,
max_display=20,
)
plt.show()
# Incorrect (FN)
fns = np.argwhere(np.logical_and(y_hat == 0, y_test == 1))

# Incorrect (FN)
fns = np.argwhere(np.logical_and(y_hat == 0, y_test == 1))
if len(fns) > 0:
fn = fns[0][0]

if len(fns) > 0:
fn = fns[0][0]
plt.figure(figsize=(12, 4))
plt.title(
f"Truth: {int(y_test[fn])}, Predict: {int(y_hat[fn])}, Prob: {round(prob[fn], 2)}"
)
shap.bar_plot(
explainer(x_test[fn])[:, 1].values,
feature_names=features,
max_display=20,
)
plt.show()

plt.figure(figsize=(12, 4))
plt.title(
f"Truth: {int(y_test[fn])}, Predict: {int(model.predict(x_test[fn].reshape(1,-1)))}, Prob: {round(model.predict_proba(x_test[fn].reshape(1,-1))[:,1][0], 2)}"
)
shap.bar_plot(
explainer(x_test[fn])[:, 1].values,
feature_names=features,
max_display=20,
# Plot summary over all test subjects
start = time.time()
shap_values = explainer(x_test, check_additivity=False)
print(time.time() - start)

plt.figure()
shap.summary_plot(
shap_values[:, :, 1], feature_names=features, max_display=20
)
plt.show()

# Plot summary over all test
start = time.time()
shap_values = explainer(x_test)
print(time.time() - start)
elif model_type == "fusion":
# get first collated batch (fixed size and num of samples)
batch = next(iter(test_dataloader))

plt.figure()
shap.summary_plot(shap_values[:, :, 1], feature_names=features, max_display=20)
plt.show()
for i in range(2):
features = test_set.get_feature_list(f"dynamic_{i}")

x_test = batch[2][i]
explainer = shap.DeepExplainer(model.embed_timeseries[i], x_test)

# Plot summary over all test subjects for single timepoint (t=0)
shap_values = explainer.shap_values(x_test, check_additivity=False)

plt.figure()
shap.summary_plot(
shap_values.mean(axis=3)[:, 5, :],
feature_names=features,
features=x_test[:, 5, :],
)
plt.show()

if args.fairness:
#### Fairness evaluation using fairlearn API
Expand Down Expand Up @@ -241,22 +263,22 @@
"count": count,
}

metric_frame = MetricFrame(
metrics=metrics,
y_true=y_test,
y_pred=y_hat,
sensitive_features=metadata,
)

metric_frame.by_group.plot.bar(
subplots=True,
layout=[3, 2],
colormap="Pastel2",
legend=False,
figsize=[12, 8],
title="Fairness evaluation",
xlabel=pf,
)
# metric_frame = MetricFrame(
# metrics=metrics,
# y_true=y_test,
# y_pred=y_hat,
# sensitive_features=metadata,
# )

# metric_frame.by_group.plot.bar(
# subplots=True,
# layout=[3, 2],
# colormap="Pastel2",
# legend=False,
# figsize=[12, 8],
# title="Fairness evaluation",
# xlabel=pf,
# )

# fairness
eor = equalized_odds_ratio(y_test, y_hat, sensitive_features=metadata)
Expand Down

0 comments on commit 07e1bed

Please sign in to comment.