diff --git a/docs/changelog.md b/docs/changelog.md
index 10311a75..87ff63f2 100644
--- a/docs/changelog.md
+++ b/docs/changelog.md
@@ -2,6 +2,14 @@
## 0.2.x
+### 0.2.3
+
+This is a maintenance release of HSSM, mainly to add a version constraint on `bambi` in light of the many breaking changes that version `0.1.4` introduces. This version also improved compatibility with `PyMC>=5.15` and incorporated minor bug fixes:
+
+1. We incorporated a temporary fix to graphing which broke after `PyMC>=5.15`.
+2. We deprecated `ndim` and `ndim_supp` definition in `SSMRandomVariable` in `PyMC>-5.16`.
+3. We fixed a bug that prevents new traces from being returned if `model.sample()` is called again.
+
### 0.2.2
HSSM is now on Conda! We now recommend installing HSSM through `conda install -c conda-forge hssm`. For advanced users, we also support installing the GPU version of JAX through `pip install hssm[cuda12]`.
diff --git a/docs/overrides/main.html b/docs/overrides/main.html
index 69fa1145..fffdca2e 100644
--- a/docs/overrides/main.html
+++ b/docs/overrides/main.html
@@ -5,7 +5,7 @@
Navigate the site here!
- v0.2.2 is released!
+ v0.2.3 is released!
{% include ".icons/material/head-question.svg" %}
diff --git a/pyproject.toml b/pyproject.toml
index 61e0afe0..f1340fd3 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[tool.poetry]
name = "HSSM"
-version = "0.2.2"
+version = "0.2.3"
description = "Bayesian inference for hierarchical sequential sampling models."
authors = [
"Alexander Fengler ",
@@ -16,12 +16,12 @@ keywords = ["HSSM", "sequential sampling models", "bayesian", "bayes", "mcmc"]
[tool.poetry.dependencies]
python = ">=3.10,<3.12"
-pymc = ">=5.14.0,<5.15.0"
+pymc = "~5.16.0"
arviz = "^0.18.0"
onnx = "^1.16.0"
ssm-simulators = "^0.7.2"
huggingface-hub = "^0.23.0"
-bambi = "^0.13.0"
+bambi = "~0.13.0"
numpyro = "^0.15.0"
hddm-wfpt = "^0.1.4"
seaborn = "^0.13.2"
diff --git a/src/hssm/distribution_utils/dist.py b/src/hssm/distribution_utils/dist.py
index d3ed7740..35aec58a 100644
--- a/src/hssm/distribution_utils/dist.py
+++ b/src/hssm/distribution_utils/dist.py
@@ -126,7 +126,9 @@ def ensure_positive_ndt(data, logp, list_params, dist_params):
def make_ssm_rv(
- model_name: str, list_params: list[str], lapse: bmb.Prior | None = None
+ model_name: str,
+ list_params: list[str],
+ lapse: bmb.Prior | None = None,
) -> Type[RandomVariable]:
"""Build a RandomVariable Op according to the list of parameters.
@@ -163,23 +165,18 @@ class SSMRandomVariable(RandomVariable):
"""SSM random variable."""
name: str = "SSM_RV"
- # to get around the support checking in PyMC that would result in error
- ndim_supp: int = 1
-
- ndims_params: list[int] = [0 for _ in list_params]
+ # New in PyMC 5.16+: instead of using `ndims_supp`, we use `signature` to define
+ # the signature of the random variable. The string to the left of the `->` sign
+ # describes the input signature, which is `()` for each parameter, meaning each
+ # parameter is a scalar. The string to the right of the
+ # `->` sign describes the output signature, which is `(2)`, which means the
+ # random variable is a length-2 array.
+ signature: str = f"{','.join(['()']*len(list_params))}->(2)"
dtype: str = "floatX"
_print_name: tuple[str, str] = ("SSM", "\\operatorname{SSM}")
_list_params = list_params
_lapse = lapse
- # PyTensor, as of version 2.12, enforces a check to ensure that
- # at least one parameter has the same ndims as the support.
- # This overrides that check and ensures that the dimension checks are correct.
- # For more information, see this issue
- # https://github.com/lnccbrown/HSSM/issues/36
- def _supp_shape_from_params(*args, **kwargs):
- return (2,)
-
# pylint: disable=arguments-renamed,bad-option-value,W0221
# NOTE: `rng` now is a np.random.Generator instead of RandomState
# since the latter is now deprecated from numpy
@@ -282,6 +279,8 @@ def rng_fn(
# All parameters are scalars
theta = np.stack(arg_arrays)
+ if theta.ndim > 1:
+ theta = theta.squeeze(axis=-1)
n_samples = size
else:
# Preprocess all parameters, reshape them into a matrix of dimension
@@ -400,6 +399,7 @@ def make_distribution(
A pymc.Distribution that uses the log-likelihood function.
"""
random_variable = make_ssm_rv(rv, list_params, lapse) if isinstance(rv, str) else rv
+ extra_fields = [] if extra_fields is None else extra_fields
if lapse is not None:
if list_params[-1] != "p_outlier":
@@ -423,29 +423,18 @@ class SSMDistribution(pm.Distribution):
# NOTE: rv_op is an INSTANCE of RandomVariable
rv_op = random_variable()
- params = list_params
- _extra_fields = extra_fields
+ _params = list_params
@classmethod
def dist(cls, **kwargs): # pylint: disable=arguments-renamed
dist_params = [
- pt.as_tensor_variable(pm.floatX(kwargs[param])) for param in cls.params
+ pt.as_tensor_variable(pm.floatX(kwargs[param])) for param in cls._params
]
- if cls._extra_fields:
- dist_params += [pm.floatX(field) for field in cls._extra_fields]
- other_kwargs = {k: v for k, v in kwargs.items() if k not in cls.params}
+ other_kwargs = {k: v for k, v in kwargs.items() if k not in cls._params}
return super().dist(dist_params, **other_kwargs)
def logp(data, *dist_params): # pylint: disable=E0213
# AF-TODO: Apply clipping here
-
- num_params = len(list_params)
- extra_fields = []
-
- if num_params < len(dist_params):
- extra_fields = dist_params[num_params:]
- dist_params = dist_params[:num_params]
-
if list_params[-1] == "p_outlier":
p_outlier = dist_params[-1]
dist_params = dist_params[:-1]
@@ -657,7 +646,12 @@ def likelihood_callable(data, *dist_params):
"""Compute the log-likelihoood of the model."""
# Assuming the first column of the data is always rt
data = pt.as_tensor_variable(data)
- dist_params = [pt.as_tensor_variable(param) for param in dist_params]
+
+ # New in PyMC 5.16+: PyMC uses the signature of the RandomVariable to determine
+ # the dimensions of the inputs to the likelihood function. It automatically adds
+ # one additional dimension to our input variable if it is a scalar. We need to
+ # squeeze this dimension out.
+ dist_params = [pt.squeeze(param) for param in dist_params]
n_missing = pt.sum(pt.eq(data[:, 0], -999.0)).astype(int)
if n_missing == 0:
@@ -666,7 +660,7 @@ def likelihood_callable(data, *dist_params):
observed_data = data[n_missing:, :]
dist_params_observed = [
- param if param.ndim == 0 else param[n_missing:] for param in dist_params
+ param[n_missing:] if param.ndim >= 1 else param for param in dist_params
]
if has_deadline:
@@ -675,7 +669,7 @@ def likelihood_callable(data, *dist_params):
logp_observed = callable(observed_data, *dist_params_observed)
dist_params_missing = [
- param if param.ndim == 0 else param[:n_missing] for param in dist_params
+ param[:n_missing] if param.ndim >= 1 else param for param in dist_params
]
if params_only:
diff --git a/src/hssm/distribution_utils/onnx/onnx.py b/src/hssm/distribution_utils/onnx/onnx.py
index efd59b59..6bacfe07 100644
--- a/src/hssm/distribution_utils/onnx/onnx.py
+++ b/src/hssm/distribution_utils/onnx/onnx.py
@@ -81,7 +81,10 @@ def logp(*inputs) -> jnp.ndarray:
else:
data = inputs[0]
dist_params = inputs[1:]
- input_vector = jnp.concatenate((jnp.array(dist_params), data))
+ param_vector = jnp.array([inp.squeeze() for inp in dist_params])
+ if param_vector.shape[-1] == 1:
+ param_vector = param_vector.squeeze(axis=-1)
+ input_vector = jnp.concatenate((param_vector, data))
return interpret_onnx(loaded_model.graph, input_vector)[0].squeeze()
diff --git a/src/hssm/hssm.py b/src/hssm/hssm.py
index 6c51756c..5934cbe6 100644
--- a/src/hssm/hssm.py
+++ b/src/hssm/hssm.py
@@ -1505,7 +1505,9 @@ def _make_model_distribution(self) -> type[pm.Distribution]:
bounds=self.bounds,
lapse=self.lapse,
extra_fields=(
- None if not self.extra_fields else deepcopy(self.extra_fields)
+ None
+ if not self.extra_fields
+ else [deepcopy(self.data[field].values) for field in self.extra_fields]
),
)
diff --git a/src/hssm/utils.py b/src/hssm/utils.py
index b6c15e54..95183581 100644
--- a/src/hssm/utils.py
+++ b/src/hssm/utils.py
@@ -20,7 +20,7 @@
import xarray as xr
from bambi.terms import CommonTerm, GroupSpecificTerm, HSGPTerm, OffsetTerm
from huggingface_hub import hf_hub_download
-from pymc.model_graph import ModelGraph
+from pymc.model_graph import DEFAULT_NODE_FORMATTERS, ModelGraph
from pytensor import function
from .param import Param
@@ -207,7 +207,12 @@ def make_graph(
# must be preceded by 'cluster' to get a box around it
with graph.subgraph(name="cluster" + plate_label) as sub:
for var_name in all_var_names:
- self._make_node(var_name, sub, formatting=formatting)
+ self._make_node(
+ var_name,
+ sub,
+ formatting=formatting,
+ node_formatters=DEFAULT_NODE_FORMATTERS,
+ )
# plate label goes bottom right
sub.attr(
label=plate_label,
@@ -218,7 +223,12 @@ def make_graph(
else:
for var_name in all_var_names:
- self._make_node(var_name, graph, formatting=formatting)
+ self._make_node(
+ var_name,
+ graph,
+ formatting=formatting,
+ node_formatters=DEFAULT_NODE_FORMATTERS,
+ )
if self.parent.is_regression:
# Insert the parent parameter that's not included in the graph
@@ -244,7 +254,7 @@ def make_graph(
if (
self.parent.is_regression
and parent.startswith(f"{self.parent.name}_")
- and child == self.get_parent_names
+ and child == response_str
):
# Modify the edges so that they point to the
# parent parameter