Skip to content

Commit

Permalink
CHANGED Stein BNN example to use prediction with noise (i.e. y_bnn ->…
Browse files Browse the repository at this point in the history
… y) (#1900)
  • Loading branch information
OlaRonning authored Nov 8, 2024
1 parent 76c6d96 commit 7078260
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions examples/stein_bnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,15 @@ def model(x, y=None, hidden_dim=50, sub_size=100):

# precision prior on observations
prec_obs = sample("prec_obs", Gamma(1.0, 0.1))

with plate("data", x.shape[0], subsample_size=sub_size, dim=-1):
batch_x = subsample(x, event_dim=1)
if y is not None:
batch_y = subsample(y, event_dim=0)
else:
batch_y = y

loc_y = deterministic("y_pred", nn.relu(batch_x @ w1 + b1) @ w2 + b2)
loc_y = deterministic("y_bnn", nn.relu(batch_x @ w1 + b1) @ w2 + b2)

sample(
"y",
Expand Down Expand Up @@ -156,28 +157,27 @@ def main(args):
data.xte, xtr_mean, xtr_std
) # Use train data statistics when accessing generalization.
n = xte.shape[0]
y_preds = pred(pred_key, xte, sub_size=n, hidden_dim=args.hidden_dim)["y_pred"]

mean_pred = y_preds.mean(0)
rmse = jnp.sqrt(jnp.mean((mean_pred - data.yte) ** 2))
pred_y = pred(pred_key, xte, sub_size=n, hidden_dim=args.hidden_dim)["y"]
rmse = jnp.sqrt(jnp.mean((pred_y.mean(0) - data.yte) ** 2))

print(rf"Time taken: {datetime.timedelta(seconds=int(time_taken))}")
print(rf"RMSE: {rmse:.2f}")

# compute mean prediction and confidence interval around median
percentiles = jnp.percentile(y_preds, jnp.array([5.0, 95.0]), axis=0)
# Compute mean prediction and confidence interval around median
percentiles = jnp.percentile(pred_y, jnp.array([5.0, 95.0]), axis=0)

# make plots
# Make plots
fig, ax = plt.subplots(figsize=(8, 6), constrained_layout=True)
ran = np.arange(mean_pred.shape[0])
ran = np.arange(pred_y.shape[1])
ax.add_collection(
LineCollection(
zip(zip(ran, percentiles[0]), zip(ran, percentiles[1])), colors="lightblue"
)
)
ax.plot(data.yte, "kx", label="y true")
ax.plot(mean_pred, "ko", label="y pred")
ax.set(xlabel="example", ylabel="y", title="Mean predictions with 90% CI")
ax.plot(pred_y.mean(0), "ko", label="y pred")
ax.set(xlabel="example", ylabel="y", title="Mean Predictions with 90% CI")
ax.legend()
fig.savefig("stein_bnn.pdf")

Expand Down

0 comments on commit 7078260

Please sign in to comment.