Skip to content

Commit

Permalink
Merge pull request #522 from lnccbrown/bambi-014-fix-inference-data
Browse files Browse the repository at this point in the history
Fix InferenceData
  • Loading branch information
digicosmos86 authored Aug 5, 2024
2 parents 1a7cadc + b0301f8 commit 6ad7dc5
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 18 deletions.
5 changes: 5 additions & 0 deletions src/hssm/distribution_utils/onnx/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def make_jax_logp_funcs_from_onnx(
)

scalars_only = all(not is_reg for is_reg in params_is_reg)
print("scalars only: ", scalars_only)

def logp(*inputs) -> jnp.ndarray:
"""Compute the log-likelihood.
Expand All @@ -76,11 +77,14 @@ def logp(*inputs) -> jnp.ndarray:
The element-wise log-likelihoods.
"""
# Makes a matrix to feed to the LAN model
print("scalars only: ", scalars_only)
print("params only: ", params_only)
if params_only:
input_vector = jnp.array(inputs)
else:
data = inputs[0]
dist_params = inputs[1:]
print([inp.shape for inp in dist_params])
param_vector = jnp.array([inp.squeeze() for inp in dist_params])
if param_vector.shape[-1] == 1:
param_vector = param_vector.squeeze(axis=-1)
Expand All @@ -89,6 +93,7 @@ def logp(*inputs) -> jnp.ndarray:
return interpret_onnx(loaded_model.graph, input_vector)[0].squeeze()

if params_only and scalars_only:
print("passing scalars only case")
logp_vec = lambda *inputs: logp(*inputs).reshape((1,))
return jit(logp_vec), jit(grad(logp)), logp_vec

Expand Down
18 changes: 7 additions & 11 deletions src/hssm/hssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,9 +569,11 @@ def sample(
)

self._inference_obj = self.model.fit(
inference_method="mcmc"
if sampler in ["mcmc", "nuts_numpyro", "nuts_blackjax"]
else sampler,
inference_method=(
"mcmc"
if sampler in ["mcmc", "nuts_numpyro", "nuts_blackjax"]
else sampler
),
init=init,
include_response_params=include_response_params,
**kwargs,
Expand Down Expand Up @@ -1549,14 +1551,8 @@ def _get_deterministic_var_names(self, idata) -> list[str]:
if param.is_regression
]

# Handle specific case where parent is not explictly in traces
if ("~" + self._parent in var_names) and (
self._parent not in idata.posterior.data_vars
):
var_names.remove("~" + self._parent)

if f"{self.response_str}_mean" in idata["posterior"].data_vars:
var_names.append(f"~{self.response_str}_mean")
if f"{self._parent}_mean" in idata["posterior"].data_vars:
var_names.append(f"~{self._parent}_mean")

return var_names

Expand Down
12 changes: 9 additions & 3 deletions src/hssm/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,10 @@ def override_default_priors(self, data: pd.DataFrame, eval_env: dict[str, Any]):
"""
self._ensure_not_converted(context="prior")

if not self.is_regression:
# If no regression, or the parameter is the parent and does not have a
# formula attached (in which case it still gets a trial wise deterministic)
# do nothing
if not self.is_regression or (self.is_parent and self.formula is None):
return

override_priors = {}
Expand Down Expand Up @@ -213,7 +216,10 @@ def override_default_priors_ddm(self, data: pd.DataFrame, eval_env: dict[str, An
self._ensure_not_converted(context="prior")
assert self.name is not None

if not self.is_regression:
# If no regression, or the parameter is the parent and does not have a
# formula attached (in which case it still gets a trial wise deterministic)
# do nothing
if not self.is_regression or (self.is_parent and self.formula is None):
return

override_priors = {}
Expand Down Expand Up @@ -380,7 +386,7 @@ def is_regression(self) -> bool:
bool
A boolean that indicates if a regression is specified.
"""
return self.formula is not None
return self.formula is not None or self._is_parent

@property
def is_parent(self) -> bool:
Expand Down
10 changes: 6 additions & 4 deletions tests/slow/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,17 +74,17 @@ def opn(fixture_path):
("analytical", None, "mcmc", None, True),
("analytical", None, "mcmc", "slice", True),
("analytical", None, "nuts_numpyro", None, True),
("analytical", None, "nuts_numpyro", "slice", TypeError),
("analytical", None, "nuts_numpyro", "slice", ValueError),
("approx_differentiable", "pytensor", None, None, True), # Defaults should work
("approx_differentiable", "pytensor", "mcmc", None, True),
("approx_differentiable", "pytensor", "mcmc", "slice", True),
("approx_differentiable", "pytensor", "nuts_numpyro", None, True),
("approx_differentiable", "pytensor", "nuts_numpyro", "slice", TypeError),
("approx_differentiable", "pytensor", "nuts_numpyro", "slice", ValueError),
("approx_differentiable", "jax", None, None, True), # Defaults should work
("approx_differentiable", "jax", "mcmc", None, True),
("approx_differentiable", "jax", "mcmc", "slice", True),
("approx_differentiable", "jax", "nuts_numpyro", None, True),
("approx_differentiable", "jax", "nuts_numpyro", "slice", TypeError),
("approx_differentiable", "jax", "nuts_numpyro", "slice", ValueError),
("blackbox", None, None, None, True), # Defaults should work
("blackbox", None, "mcmc", None, True),
("blackbox", None, "mcmc", "slice", True),
Expand Down Expand Up @@ -137,7 +137,9 @@ def test_simple_models(data_ddm, loglik_kind, backend, sampler, step, expected):

# Only runs once
if loglik_kind == "analytical" and sampler is None:
assert not model._get_deterministic_var_names(model.traces)
assert f"~{model._parent}_mean" in model._get_deterministic_var_names(
model.traces
)
# test summary:
summary = model.summary()
assert summary.shape[0] == 4
Expand Down

0 comments on commit 6ad7dc5

Please sign in to comment.