Skip to content

Commit

Permalink
Merge pull request #519 from lnccbrown/bambi-014-fix-plotting
Browse files Browse the repository at this point in the history
Fix plotting
  • Loading branch information
digicosmos86 authored Jul 31, 2024
2 parents 8b721e0 + cf2e083 commit e52ac63
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/hssm/plotting/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def _xarray_to_df(
We make the following assumptions:
1. The inference data always has a posterior predictive group with a `rt,response`
variable.
2. This variable always has four dimensions: `chain`, `draw`, `rt,response_obs`,
2. This variable always has four dimensions: `chain`, `draw`, `__obs__`,
and `rt,response_dim`.
Parameters
Expand All @@ -46,10 +46,10 @@ def _xarray_to_df(

# Convert the posterior samples to a dataframe
stacked = (
sampled_posterior.stack({"obs": ["chain", "draw", f"{response_str}_obs"]})
sampled_posterior.stack({"obs": ["chain", "draw", "__obs__"]})
.transpose()
.to_pandas()
.rename_axis(index={f"{response_str}_obs": "obs_n"})
.rename_axis(index={"__obs__": "obs_n"})
.sort_index(axis=0, level=["chain", "draw", "obs_n"])
)

Expand Down Expand Up @@ -141,7 +141,7 @@ def _get_plotting_df(
posterior.insert(0, "observed", "predicted")
return posterior

if extra_dims and idata_posterior[f"{response_str}_obs"].size != data.shape[0]:
if extra_dims and idata_posterior["__obs__"].size != data.shape[0]:
raise ValueError(
"The number of observations in the data and the number of posterior "
+ "samples are not equal."
Expand Down
Binary file modified tests/fixtures/cavanagh_idata.nc
Binary file not shown.
Binary file modified tests/fixtures/cavanagh_idata_pps.nc
Binary file not shown.

0 comments on commit e52ac63

Please sign in to comment.