Skip to content

Commit

Permalink
Skip printing summary if empty.
Browse files Browse the repository at this point in the history
  • Loading branch information
hessammehr committed Nov 20, 2024
1 parent e0d02e5 commit 1aa7c5d
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 1 deletion.
4 changes: 4 additions & 0 deletions numpyro/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,8 @@ def summary(

summary_dict = {}
for name, value in samples.items():
if len(value) == 0:
continue
value = device_get(value)
value_flat = np.reshape(value, (-1,) + value.shape[2:])
mean = value_flat.mean(axis=0)
Expand Down Expand Up @@ -307,6 +309,8 @@ def print_summary(
"Param:{}".format(i): v for i, v in enumerate(jax.tree.flatten(samples)[0])
}
summary_dict = summary(samples, prob, group_by_chain=True)
if not summary_dict:
return

row_names = {
k: k + "[" + ",".join(map(lambda x: str(x - 1), v.shape[2:])) + "]"
Expand Down
11 changes: 10 additions & 1 deletion test/infer/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1222,4 +1222,13 @@ def model2():
for model, shape in shapes.items():
mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=num_samples)
mcmc.run(random.PRNGKey(0))
assert mcmc.get_samples()["x"].shape == (num_samples,) + shape
assert mcmc.get_samples()["x"].shape == (num_samples,) + shape

def test_empty_summary():
def model():
pass

mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
mcmc.run(random.PRNGKey(0))

mcmc.print_summary()

0 comments on commit 1aa7c5d

Please sign in to comment.