From a6637e56b9209acea80adf489cfbb5b72dadaa97 Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Tue, 2 Jul 2024 12:37:14 -0400 Subject: [PATCH 1/8] Set PyMC version to be >=5.16.0,<5.17.0 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 61e0afe0..8dc830f4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,7 @@ keywords = ["HSSM", "sequential sampling models", "bayesian", "bayes", "mcmc"] [tool.poetry.dependencies] python = ">=3.10,<3.12" -pymc = ">=5.14.0,<5.15.0" +pymc = "~5.16.0" arviz = "^0.18.0" onnx = "^1.16.0" ssm-simulators = "^0.7.2" From 17bc74b8e63f68dd7b0526ed8c7d2540c7e3e963 Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Tue, 2 Jul 2024 12:39:31 -0400 Subject: [PATCH 2/8] Fix one dimension issue in onnx.py --- src/hssm/distribution_utils/onnx/onnx.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/hssm/distribution_utils/onnx/onnx.py b/src/hssm/distribution_utils/onnx/onnx.py index efd59b59..568a5bc4 100644 --- a/src/hssm/distribution_utils/onnx/onnx.py +++ b/src/hssm/distribution_utils/onnx/onnx.py @@ -81,7 +81,10 @@ def logp(*inputs) -> jnp.ndarray: else: data = inputs[0] dist_params = inputs[1:] - input_vector = jnp.concatenate((jnp.array(dist_params), data)) + param_vector = jnp.array(dist_params) + if param_vector.shape[-1] == 1: + param_vector = param_vector.squeeze(axis=-1) + input_vector = jnp.concatenate((param_vector, data)) return interpret_onnx(loaded_model.graph, input_vector)[0].squeeze() From fb5485ca698c6b086d249e2dfe383f617cbd3745 Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Tue, 2 Jul 2024 12:42:11 -0400 Subject: [PATCH 3/8] Use signature argument for RandomVariable class and fix dimension issues --- src/hssm/distribution_utils/dist.py | 32 +++++++---------------------- 1 file changed, 7 insertions(+), 25 deletions(-) diff --git a/src/hssm/distribution_utils/dist.py b/src/hssm/distribution_utils/dist.py index d3ed7740..befe123d 100644 --- a/src/hssm/distribution_utils/dist.py +++ b/src/hssm/distribution_utils/dist.py @@ -164,22 +164,12 @@ class SSMRandomVariable(RandomVariable): name: str = "SSM_RV" # to get around the support checking in PyMC that would result in error - ndim_supp: int = 1 - - ndims_params: list[int] = [0 for _ in list_params] + signature: str = f"{','.join(['()']*len(list_params))}->(2)" dtype: str = "floatX" _print_name: tuple[str, str] = ("SSM", "\\operatorname{SSM}") _list_params = list_params _lapse = lapse - # PyTensor, as of version 2.12, enforces a check to ensure that - # at least one parameter has the same ndims as the support. - # This overrides that check and ensures that the dimension checks are correct. - # For more information, see this issue - # https://github.com/lnccbrown/HSSM/issues/36 - def _supp_shape_from_params(*args, **kwargs): - return (2,) - # pylint: disable=arguments-renamed,bad-option-value,W0221 # NOTE: `rng` now is a np.random.Generator instead of RandomState # since the latter is now deprecated from numpy @@ -282,6 +272,8 @@ def rng_fn( # All parameters are scalars theta = np.stack(arg_arrays) + if theta.ndim > 1: + theta = theta.squeeze(axis=-1) n_samples = size else: # Preprocess all parameters, reshape them into a matrix of dimension @@ -400,6 +392,7 @@ def make_distribution( A pymc.Distribution that uses the log-likelihood function. """ random_variable = make_ssm_rv(rv, list_params, lapse) if isinstance(rv, str) else rv + extra_fields = [] if extra_fields is None else extra_fields if lapse is not None: if list_params[-1] != "p_outlier": @@ -423,29 +416,18 @@ class SSMDistribution(pm.Distribution): # NOTE: rv_op is an INSTANCE of RandomVariable rv_op = random_variable() - params = list_params - _extra_fields = extra_fields + _params = list_params @classmethod def dist(cls, **kwargs): # pylint: disable=arguments-renamed dist_params = [ - pt.as_tensor_variable(pm.floatX(kwargs[param])) for param in cls.params + pt.as_tensor_variable(pm.floatX(kwargs[param])) for param in cls._params ] - if cls._extra_fields: - dist_params += [pm.floatX(field) for field in cls._extra_fields] - other_kwargs = {k: v for k, v in kwargs.items() if k not in cls.params} + other_kwargs = {k: v for k, v in kwargs.items() if k not in cls._params} return super().dist(dist_params, **other_kwargs) def logp(data, *dist_params): # pylint: disable=E0213 # AF-TODO: Apply clipping here - - num_params = len(list_params) - extra_fields = [] - - if num_params < len(dist_params): - extra_fields = dist_params[num_params:] - dist_params = dist_params[:num_params] - if list_params[-1] == "p_outlier": p_outlier = dist_params[-1] dist_params = dist_params[:-1] From 43abae59c9ad2ce821292ec8d076025e30497904 Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Mon, 8 Jul 2024 10:58:38 -0400 Subject: [PATCH 4/8] fix dimension issues, add comments --- src/hssm/distribution_utils/dist.py | 22 +++++++++++++++++----- src/hssm/distribution_utils/onnx/onnx.py | 2 +- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/src/hssm/distribution_utils/dist.py b/src/hssm/distribution_utils/dist.py index befe123d..35aec58a 100644 --- a/src/hssm/distribution_utils/dist.py +++ b/src/hssm/distribution_utils/dist.py @@ -126,7 +126,9 @@ def ensure_positive_ndt(data, logp, list_params, dist_params): def make_ssm_rv( - model_name: str, list_params: list[str], lapse: bmb.Prior | None = None + model_name: str, + list_params: list[str], + lapse: bmb.Prior | None = None, ) -> Type[RandomVariable]: """Build a RandomVariable Op according to the list of parameters. @@ -163,7 +165,12 @@ class SSMRandomVariable(RandomVariable): """SSM random variable.""" name: str = "SSM_RV" - # to get around the support checking in PyMC that would result in error + # New in PyMC 5.16+: instead of using `ndims_supp`, we use `signature` to define + # the signature of the random variable. The string to the left of the `->` sign + # describes the input signature, which is `()` for each parameter, meaning each + # parameter is a scalar. The string to the right of the + # `->` sign describes the output signature, which is `(2)`, which means the + # random variable is a length-2 array. signature: str = f"{','.join(['()']*len(list_params))}->(2)" dtype: str = "floatX" _print_name: tuple[str, str] = ("SSM", "\\operatorname{SSM}") @@ -639,7 +646,12 @@ def likelihood_callable(data, *dist_params): """Compute the log-likelihoood of the model.""" # Assuming the first column of the data is always rt data = pt.as_tensor_variable(data) - dist_params = [pt.as_tensor_variable(param) for param in dist_params] + + # New in PyMC 5.16+: PyMC uses the signature of the RandomVariable to determine + # the dimensions of the inputs to the likelihood function. It automatically adds + # one additional dimension to our input variable if it is a scalar. We need to + # squeeze this dimension out. + dist_params = [pt.squeeze(param) for param in dist_params] n_missing = pt.sum(pt.eq(data[:, 0], -999.0)).astype(int) if n_missing == 0: @@ -648,7 +660,7 @@ def likelihood_callable(data, *dist_params): observed_data = data[n_missing:, :] dist_params_observed = [ - param if param.ndim == 0 else param[n_missing:] for param in dist_params + param[n_missing:] if param.ndim >= 1 else param for param in dist_params ] if has_deadline: @@ -657,7 +669,7 @@ def likelihood_callable(data, *dist_params): logp_observed = callable(observed_data, *dist_params_observed) dist_params_missing = [ - param if param.ndim == 0 else param[:n_missing] for param in dist_params + param[:n_missing] if param.ndim >= 1 else param for param in dist_params ] if params_only: diff --git a/src/hssm/distribution_utils/onnx/onnx.py b/src/hssm/distribution_utils/onnx/onnx.py index 568a5bc4..6bacfe07 100644 --- a/src/hssm/distribution_utils/onnx/onnx.py +++ b/src/hssm/distribution_utils/onnx/onnx.py @@ -81,7 +81,7 @@ def logp(*inputs) -> jnp.ndarray: else: data = inputs[0] dist_params = inputs[1:] - param_vector = jnp.array(dist_params) + param_vector = jnp.array([inp.squeeze() for inp in dist_params]) if param_vector.shape[-1] == 1: param_vector = param_vector.squeeze(axis=-1) input_vector = jnp.concatenate((param_vector, data)) From 694c83a4ae7c3784b02dc2d6601a67f996332e5f Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Wed, 10 Jul 2024 08:46:38 -0400 Subject: [PATCH 5/8] fix bambi version for now --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 8dc830f4..29aac326 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ arviz = "^0.18.0" onnx = "^1.16.0" ssm-simulators = "^0.7.2" huggingface-hub = "^0.23.0" -bambi = "^0.13.0" +bambi = "~0.13.0" numpyro = "^0.15.0" hddm-wfpt = "^0.1.4" seaborn = "^0.13.2" From aa798c71d0bba72133dde2cc25300b027b0af92f Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Wed, 10 Jul 2024 11:09:20 -0400 Subject: [PATCH 6/8] implemented a temp fix to graphing --- src/hssm/utils.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/src/hssm/utils.py b/src/hssm/utils.py index b6c15e54..95183581 100644 --- a/src/hssm/utils.py +++ b/src/hssm/utils.py @@ -20,7 +20,7 @@ import xarray as xr from bambi.terms import CommonTerm, GroupSpecificTerm, HSGPTerm, OffsetTerm from huggingface_hub import hf_hub_download -from pymc.model_graph import ModelGraph +from pymc.model_graph import DEFAULT_NODE_FORMATTERS, ModelGraph from pytensor import function from .param import Param @@ -207,7 +207,12 @@ def make_graph( # must be preceded by 'cluster' to get a box around it with graph.subgraph(name="cluster" + plate_label) as sub: for var_name in all_var_names: - self._make_node(var_name, sub, formatting=formatting) + self._make_node( + var_name, + sub, + formatting=formatting, + node_formatters=DEFAULT_NODE_FORMATTERS, + ) # plate label goes bottom right sub.attr( label=plate_label, @@ -218,7 +223,12 @@ def make_graph( else: for var_name in all_var_names: - self._make_node(var_name, graph, formatting=formatting) + self._make_node( + var_name, + graph, + formatting=formatting, + node_formatters=DEFAULT_NODE_FORMATTERS, + ) if self.parent.is_regression: # Insert the parent parameter that's not included in the graph @@ -244,7 +254,7 @@ def make_graph( if ( self.parent.is_regression and parent.startswith(f"{self.parent.name}_") - and child == self.get_parent_names + and child == response_str ): # Modify the edges so that they point to the # parent parameter From dde2e7d496e6b5c7885d8d7a51dad21b471b92e7 Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Fri, 12 Jul 2024 09:06:07 -0400 Subject: [PATCH 7/8] getting ready for HSSM 0.2.3 --- docs/changelog.md | 8 ++++++++ pyproject.toml | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/docs/changelog.md b/docs/changelog.md index 10311a75..87ff63f2 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -2,6 +2,14 @@ ## 0.2.x +### 0.2.3 + +This is a maintenance release of HSSM, mainly to add a version constraint on `bambi` in light of the many breaking changes that version `0.1.4` introduces. This version also improved compatibility with `PyMC>=5.15` and incorporated minor bug fixes: + +1. We incorporated a temporary fix to graphing which broke after `PyMC>=5.15`. +2. We deprecated `ndim` and `ndim_supp` definition in `SSMRandomVariable` in `PyMC>-5.16`. +3. We fixed a bug that prevents new traces from being returned if `model.sample()` is called again. + ### 0.2.2 HSSM is now on Conda! We now recommend installing HSSM through `conda install -c conda-forge hssm`. For advanced users, we also support installing the GPU version of JAX through `pip install hssm[cuda12]`. diff --git a/pyproject.toml b/pyproject.toml index 29aac326..f1340fd3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "HSSM" -version = "0.2.2" +version = "0.2.3" description = "Bayesian inference for hierarchical sequential sampling models." authors = [ "Alexander Fengler ", From 815ad34cce3c42e09a0e5291274c4789266b43c2 Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Fri, 12 Jul 2024 09:20:25 -0400 Subject: [PATCH 8/8] get ready for version 0.2.3 --- docs/overrides/main.html | 2 +- src/hssm/hssm.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/docs/overrides/main.html b/docs/overrides/main.html index 69fa1145..fffdca2e 100644 --- a/docs/overrides/main.html +++ b/docs/overrides/main.html @@ -5,7 +5,7 @@ Navigate the site here! - v0.2.2 is released! + v0.2.3 is released! {% include ".icons/material/head-question.svg" %} diff --git a/src/hssm/hssm.py b/src/hssm/hssm.py index 6c51756c..5934cbe6 100644 --- a/src/hssm/hssm.py +++ b/src/hssm/hssm.py @@ -1505,7 +1505,9 @@ def _make_model_distribution(self) -> type[pm.Distribution]: bounds=self.bounds, lapse=self.lapse, extra_fields=( - None if not self.extra_fields else deepcopy(self.extra_fields) + None + if not self.extra_fields + else [deepcopy(self.data[field].values) for field in self.extra_fields] ), )