diff --git a/src/hssm/distribution_utils/onnx/onnx.py b/src/hssm/distribution_utils/onnx/onnx.py index 6bacfe07..6a9cc3b7 100644 --- a/src/hssm/distribution_utils/onnx/onnx.py +++ b/src/hssm/distribution_utils/onnx/onnx.py @@ -57,6 +57,7 @@ def make_jax_logp_funcs_from_onnx( ) scalars_only = all(not is_reg for is_reg in params_is_reg) + print("scalars only: ", scalars_only) def logp(*inputs) -> jnp.ndarray: """Compute the log-likelihood. @@ -76,11 +77,14 @@ def logp(*inputs) -> jnp.ndarray: The element-wise log-likelihoods. """ # Makes a matrix to feed to the LAN model + print("scalars only: ", scalars_only) + print("params only: ", params_only) if params_only: input_vector = jnp.array(inputs) else: data = inputs[0] dist_params = inputs[1:] + print([inp.shape for inp in dist_params]) param_vector = jnp.array([inp.squeeze() for inp in dist_params]) if param_vector.shape[-1] == 1: param_vector = param_vector.squeeze(axis=-1) @@ -89,6 +93,7 @@ def logp(*inputs) -> jnp.ndarray: return interpret_onnx(loaded_model.graph, input_vector)[0].squeeze() if params_only and scalars_only: + print("passing scalars only case") logp_vec = lambda *inputs: logp(*inputs).reshape((1,)) return jit(logp_vec), jit(grad(logp)), logp_vec diff --git a/src/hssm/hssm.py b/src/hssm/hssm.py index 13db4ea6..58294c3c 100644 --- a/src/hssm/hssm.py +++ b/src/hssm/hssm.py @@ -569,9 +569,11 @@ def sample( ) self._inference_obj = self.model.fit( - inference_method="mcmc" - if sampler in ["mcmc", "nuts_numpyro", "nuts_blackjax"] - else sampler, + inference_method=( + "mcmc" + if sampler in ["mcmc", "nuts_numpyro", "nuts_blackjax"] + else sampler + ), init=init, include_response_params=include_response_params, **kwargs, @@ -1549,14 +1551,8 @@ def _get_deterministic_var_names(self, idata) -> list[str]: if param.is_regression ] - # Handle specific case where parent is not explictly in traces - if ("~" + self._parent in var_names) and ( - self._parent not in idata.posterior.data_vars - ): - var_names.remove("~" + self._parent) - - if f"{self.response_str}_mean" in idata["posterior"].data_vars: - var_names.append(f"~{self.response_str}_mean") + if f"{self._parent}_mean" in idata["posterior"].data_vars: + var_names.append(f"~{self._parent}_mean") return var_names diff --git a/src/hssm/param.py b/src/hssm/param.py index 931a7dac..07062e53 100644 --- a/src/hssm/param.py +++ b/src/hssm/param.py @@ -144,7 +144,10 @@ def override_default_priors(self, data: pd.DataFrame, eval_env: dict[str, Any]): """ self._ensure_not_converted(context="prior") - if not self.is_regression: + # If no regression, or the parameter is the parent and does not have a + # formula attached (in which case it still gets a trial wise deterministic) + # do nothing + if not self.is_regression or (self.is_parent and self.formula is None): return override_priors = {} @@ -213,7 +216,10 @@ def override_default_priors_ddm(self, data: pd.DataFrame, eval_env: dict[str, An self._ensure_not_converted(context="prior") assert self.name is not None - if not self.is_regression: + # If no regression, or the parameter is the parent and does not have a + # formula attached (in which case it still gets a trial wise deterministic) + # do nothing + if not self.is_regression or (self.is_parent and self.formula is None): return override_priors = {} @@ -380,7 +386,7 @@ def is_regression(self) -> bool: bool A boolean that indicates if a regression is specified. """ - return self.formula is not None + return self.formula is not None or self._is_parent @property def is_parent(self) -> bool: diff --git a/tests/slow/test_mcmc.py b/tests/slow/test_mcmc.py index 2ddc6274..9eceb887 100644 --- a/tests/slow/test_mcmc.py +++ b/tests/slow/test_mcmc.py @@ -74,17 +74,17 @@ def opn(fixture_path): ("analytical", None, "mcmc", None, True), ("analytical", None, "mcmc", "slice", True), ("analytical", None, "nuts_numpyro", None, True), - ("analytical", None, "nuts_numpyro", "slice", TypeError), + ("analytical", None, "nuts_numpyro", "slice", ValueError), ("approx_differentiable", "pytensor", None, None, True), # Defaults should work ("approx_differentiable", "pytensor", "mcmc", None, True), ("approx_differentiable", "pytensor", "mcmc", "slice", True), ("approx_differentiable", "pytensor", "nuts_numpyro", None, True), - ("approx_differentiable", "pytensor", "nuts_numpyro", "slice", TypeError), + ("approx_differentiable", "pytensor", "nuts_numpyro", "slice", ValueError), ("approx_differentiable", "jax", None, None, True), # Defaults should work ("approx_differentiable", "jax", "mcmc", None, True), ("approx_differentiable", "jax", "mcmc", "slice", True), ("approx_differentiable", "jax", "nuts_numpyro", None, True), - ("approx_differentiable", "jax", "nuts_numpyro", "slice", TypeError), + ("approx_differentiable", "jax", "nuts_numpyro", "slice", ValueError), ("blackbox", None, None, None, True), # Defaults should work ("blackbox", None, "mcmc", None, True), ("blackbox", None, "mcmc", "slice", True), @@ -137,7 +137,9 @@ def test_simple_models(data_ddm, loglik_kind, backend, sampler, step, expected): # Only runs once if loglik_kind == "analytical" and sampler is None: - assert not model._get_deterministic_var_names(model.traces) + assert f"~{model._parent}_mean" in model._get_deterministic_var_names( + model.traces + ) # test summary: summary = model.summary() assert summary.shape[0] == 4