Skip to content

Commit

Permalink
Merge pull request #484 from lnccbrown/479-allow-pymc-=-515
Browse files Browse the repository at this point in the history
Fix compatibility with PyMC>=5.16
  • Loading branch information
digicosmos86 authored Jul 12, 2024
2 parents 3558223 + 815ad34 commit 2b8fadb
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 40 deletions.
8 changes: 8 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@

## 0.2.x

### 0.2.3

This is a maintenance release of HSSM, mainly to add a version constraint on `bambi` in light of the many breaking changes that version `0.1.4` introduces. This version also improved compatibility with `PyMC>=5.15` and incorporated minor bug fixes:

1. We incorporated a temporary fix to graphing which broke after `PyMC>=5.15`.
2. We deprecated `ndim` and `ndim_supp` definition in `SSMRandomVariable` in `PyMC>-5.16`.
3. We fixed a bug that prevents new traces from being returned if `model.sample()` is called again.

### 0.2.2

HSSM is now on Conda! We now recommend installing HSSM through `conda install -c conda-forge hssm`. For advanced users, we also support installing the GPU version of JAX through `pip install hssm[cuda12]`.
Expand Down
2 changes: 1 addition & 1 deletion docs/overrides/main.html
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
</span>
Navigate the site here!
</span>
<span class="right-margin"> v0.2.2 is released! </span>
<span class="right-margin"> v0.2.3 is released! </span>
<span>
<span class="twemoji">
{% include ".icons/material/head-question.svg" %}
Expand Down
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "HSSM"
version = "0.2.2"
version = "0.2.3"
description = "Bayesian inference for hierarchical sequential sampling models."
authors = [
"Alexander Fengler <[email protected]>",
Expand All @@ -16,12 +16,12 @@ keywords = ["HSSM", "sequential sampling models", "bayesian", "bayes", "mcmc"]

[tool.poetry.dependencies]
python = ">=3.10,<3.12"
pymc = ">=5.14.0,<5.15.0"
pymc = "~5.16.0"
arviz = "^0.18.0"
onnx = "^1.16.0"
ssm-simulators = "^0.7.2"
huggingface-hub = "^0.23.0"
bambi = "^0.13.0"
bambi = "~0.13.0"
numpyro = "^0.15.0"
hddm-wfpt = "^0.1.4"
seaborn = "^0.13.2"
Expand Down
54 changes: 24 additions & 30 deletions src/hssm/distribution_utils/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,9 @@ def ensure_positive_ndt(data, logp, list_params, dist_params):


def make_ssm_rv(
model_name: str, list_params: list[str], lapse: bmb.Prior | None = None
model_name: str,
list_params: list[str],
lapse: bmb.Prior | None = None,
) -> Type[RandomVariable]:
"""Build a RandomVariable Op according to the list of parameters.
Expand Down Expand Up @@ -163,23 +165,18 @@ class SSMRandomVariable(RandomVariable):
"""SSM random variable."""

name: str = "SSM_RV"
# to get around the support checking in PyMC that would result in error
ndim_supp: int = 1

ndims_params: list[int] = [0 for _ in list_params]
# New in PyMC 5.16+: instead of using `ndims_supp`, we use `signature` to define
# the signature of the random variable. The string to the left of the `->` sign
# describes the input signature, which is `()` for each parameter, meaning each
# parameter is a scalar. The string to the right of the
# `->` sign describes the output signature, which is `(2)`, which means the
# random variable is a length-2 array.
signature: str = f"{','.join(['()']*len(list_params))}->(2)"
dtype: str = "floatX"
_print_name: tuple[str, str] = ("SSM", "\\operatorname{SSM}")
_list_params = list_params
_lapse = lapse

# PyTensor, as of version 2.12, enforces a check to ensure that
# at least one parameter has the same ndims as the support.
# This overrides that check and ensures that the dimension checks are correct.
# For more information, see this issue
# https://github.com/lnccbrown/HSSM/issues/36
def _supp_shape_from_params(*args, **kwargs):
return (2,)

# pylint: disable=arguments-renamed,bad-option-value,W0221
# NOTE: `rng` now is a np.random.Generator instead of RandomState
# since the latter is now deprecated from numpy
Expand Down Expand Up @@ -282,6 +279,8 @@ def rng_fn(
# All parameters are scalars

theta = np.stack(arg_arrays)
if theta.ndim > 1:
theta = theta.squeeze(axis=-1)
n_samples = size
else:
# Preprocess all parameters, reshape them into a matrix of dimension
Expand Down Expand Up @@ -400,6 +399,7 @@ def make_distribution(
A pymc.Distribution that uses the log-likelihood function.
"""
random_variable = make_ssm_rv(rv, list_params, lapse) if isinstance(rv, str) else rv
extra_fields = [] if extra_fields is None else extra_fields

if lapse is not None:
if list_params[-1] != "p_outlier":
Expand All @@ -423,29 +423,18 @@ class SSMDistribution(pm.Distribution):

# NOTE: rv_op is an INSTANCE of RandomVariable
rv_op = random_variable()
params = list_params
_extra_fields = extra_fields
_params = list_params

@classmethod
def dist(cls, **kwargs): # pylint: disable=arguments-renamed
dist_params = [
pt.as_tensor_variable(pm.floatX(kwargs[param])) for param in cls.params
pt.as_tensor_variable(pm.floatX(kwargs[param])) for param in cls._params
]
if cls._extra_fields:
dist_params += [pm.floatX(field) for field in cls._extra_fields]
other_kwargs = {k: v for k, v in kwargs.items() if k not in cls.params}
other_kwargs = {k: v for k, v in kwargs.items() if k not in cls._params}
return super().dist(dist_params, **other_kwargs)

def logp(data, *dist_params): # pylint: disable=E0213
# AF-TODO: Apply clipping here

num_params = len(list_params)
extra_fields = []

if num_params < len(dist_params):
extra_fields = dist_params[num_params:]
dist_params = dist_params[:num_params]

if list_params[-1] == "p_outlier":
p_outlier = dist_params[-1]
dist_params = dist_params[:-1]
Expand Down Expand Up @@ -657,7 +646,12 @@ def likelihood_callable(data, *dist_params):
"""Compute the log-likelihoood of the model."""
# Assuming the first column of the data is always rt
data = pt.as_tensor_variable(data)
dist_params = [pt.as_tensor_variable(param) for param in dist_params]

# New in PyMC 5.16+: PyMC uses the signature of the RandomVariable to determine
# the dimensions of the inputs to the likelihood function. It automatically adds
# one additional dimension to our input variable if it is a scalar. We need to
# squeeze this dimension out.
dist_params = [pt.squeeze(param) for param in dist_params]

n_missing = pt.sum(pt.eq(data[:, 0], -999.0)).astype(int)
if n_missing == 0:
Expand All @@ -666,7 +660,7 @@ def likelihood_callable(data, *dist_params):
observed_data = data[n_missing:, :]

dist_params_observed = [
param if param.ndim == 0 else param[n_missing:] for param in dist_params
param[n_missing:] if param.ndim >= 1 else param for param in dist_params
]

if has_deadline:
Expand All @@ -675,7 +669,7 @@ def likelihood_callable(data, *dist_params):
logp_observed = callable(observed_data, *dist_params_observed)

dist_params_missing = [
param if param.ndim == 0 else param[:n_missing] for param in dist_params
param[:n_missing] if param.ndim >= 1 else param for param in dist_params
]

if params_only:
Expand Down
5 changes: 4 additions & 1 deletion src/hssm/distribution_utils/onnx/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,10 @@ def logp(*inputs) -> jnp.ndarray:
else:
data = inputs[0]
dist_params = inputs[1:]
input_vector = jnp.concatenate((jnp.array(dist_params), data))
param_vector = jnp.array([inp.squeeze() for inp in dist_params])
if param_vector.shape[-1] == 1:
param_vector = param_vector.squeeze(axis=-1)
input_vector = jnp.concatenate((param_vector, data))

return interpret_onnx(loaded_model.graph, input_vector)[0].squeeze()

Expand Down
4 changes: 3 additions & 1 deletion src/hssm/hssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1505,7 +1505,9 @@ def _make_model_distribution(self) -> type[pm.Distribution]:
bounds=self.bounds,
lapse=self.lapse,
extra_fields=(
None if not self.extra_fields else deepcopy(self.extra_fields)
None
if not self.extra_fields
else [deepcopy(self.data[field].values) for field in self.extra_fields]
),
)

Expand Down
18 changes: 14 additions & 4 deletions src/hssm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import xarray as xr
from bambi.terms import CommonTerm, GroupSpecificTerm, HSGPTerm, OffsetTerm
from huggingface_hub import hf_hub_download
from pymc.model_graph import ModelGraph
from pymc.model_graph import DEFAULT_NODE_FORMATTERS, ModelGraph
from pytensor import function

from .param import Param
Expand Down Expand Up @@ -207,7 +207,12 @@ def make_graph(
# must be preceded by 'cluster' to get a box around it
with graph.subgraph(name="cluster" + plate_label) as sub:
for var_name in all_var_names:
self._make_node(var_name, sub, formatting=formatting)
self._make_node(
var_name,
sub,
formatting=formatting,
node_formatters=DEFAULT_NODE_FORMATTERS,
)
# plate label goes bottom right
sub.attr(
label=plate_label,
Expand All @@ -218,7 +223,12 @@ def make_graph(

else:
for var_name in all_var_names:
self._make_node(var_name, graph, formatting=formatting)
self._make_node(
var_name,
graph,
formatting=formatting,
node_formatters=DEFAULT_NODE_FORMATTERS,
)

if self.parent.is_regression:
# Insert the parent parameter that's not included in the graph
Expand All @@ -244,7 +254,7 @@ def make_graph(
if (
self.parent.is_regression
and parent.startswith(f"{self.parent.name}_")
and child == self.get_parent_names
and child == response_str
):
# Modify the edges so that they point to the
# parent parameter
Expand Down

0 comments on commit 2b8fadb

Please sign in to comment.