Skip to content

Commit

Permalink
Merge pull request #408 from lnccbrown/406-plot_posterior_predictive-…
Browse files Browse the repository at this point in the history
…fails-for-full_ddm-model

change order between sz and sv in list_params
  • Loading branch information
AlexanderFengler authored May 8, 2024
2 parents e8f16b1 + 6134905 commit 8d56fe9
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 19 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ jobs:
- name: install pymc and poetry
run: |
mamba info
mamba install -c conda-forge pymc=5.12 poetry
mamba install -c conda-forge pymc poetry
- name: install hssm
run: |
Expand Down
2 changes: 1 addition & 1 deletion src/hssm/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ class DefaultConfig(TypedDict):
},
"full_ddm": {
"response": ["rt", "response"],
"list_params": ["v", "a", "z", "t", "sv", "sz", "st"],
"list_params": ["v", "a", "z", "t", "sz", "sv", "st"],
"description": "The full Drift Diffusion Model (DDM)",
"likelihoods": {
"blackbox": {
Expand Down
26 changes: 22 additions & 4 deletions src/hssm/hssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@


class HSSM:
"""The Hierarchical Sequential Sampling Model (HSSM) class.
"""The basic Hierarchical Sequential Sampling Model (HSSM) class.
Parameters
----------
Expand Down Expand Up @@ -497,6 +497,17 @@ def sample(
inference_method=sampler, init=init, **kwargs
)

# The parent was previously not part of deterministics --> compute it via
# posterior_predictive (works because it acts as the 'mu' parameter
# in the GLM as far as bambi is concerned)
if self._inference_obj is not None:
if self._parent not in self._inference_obj.posterior.data_vars.keys():
self.model.predict(self._inference_obj, kind="mean", inplace=True)
# rename 'rt,response_mean' to 'v' so in the traces everything
# looks the way it should
self._inference_obj.rename_vars(
{"rt,response_mean": self._parent}, inplace=True
)
return self.traces

def sample_posterior_predictive(
Expand Down Expand Up @@ -526,13 +537,13 @@ def sample_posterior_predictive(
If `True` will make predictions including the group specific effects.
Otherwise, predictions are made with common effects only (i.e. group-
specific are set to zero), by default True.
kind
kind: optional
Indicates the type of prediction required. Can be `"mean"` or `"pps"`. The
first returns draws from the posterior distribution of the mean, while the
latter returns the draws from the posterior predictive distribution
(i.e. the posterior probability distribution for a new observation).
Defaults to `"pps"`.
n_samples
n_samples: optional
The number of samples to draw from the posterior predictive distribution
from each chain.
When it's an integer >= 1, the number of samples to be extracted from the
Expand Down Expand Up @@ -1308,11 +1319,18 @@ def _get_deterministic_var_names(self, idata) -> list[str]:
var_names = [
f"~{param_name}"
for param_name, param in self.params.items()
if param.is_regression and not param.is_parent
if param.is_regression
]

# Handle specific case where parent is not explictly in traces
if ("~" + self._parent in var_names) and (
self._parent not in idata.posterior.data_vars
):
var_names.remove("~" + self._parent)

if f"{self.response_str}_mean" in idata["posterior"].data_vars:
var_names.append(f"~{self.response_str}_mean")

return var_names

def _handle_missing_data_and_deadline(self):
Expand Down
2 changes: 1 addition & 1 deletion src/hssm/likelihoods/blackbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def logp_ddm_sdv_bbox(data: np.ndarray, v, a, z, t, sv) -> np.ndarray:


@hddm_to_hssm
def logp_full_ddm(data: np.ndarray, v, a, z, t, sv, sz, st):
def logp_full_ddm(data: np.ndarray, v, a, z, t, sz, sv, st):
"""Compute blackbox log-likelihoods for full_ddm models."""
return wfpt.wiener_logp_array(
x=data,
Expand Down
8 changes: 0 additions & 8 deletions src/hssm/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,11 +351,3 @@ def get_hddm_default_prior(
"sz": {"dist": "Gamma", "mu": HDDM_MU["sz"], "sigma": HDDM_SIGMA["sz"]},
"st": {"dist": "Gamma", "mu": HDDM_MU["st"], "sigma": HDDM_SIGMA["st"]},
}

# INITVAL_SETTINGS_LOGIT: dict[Any, Any] = {
# "t" : {"initval": -4.0},
# }

# INITVAL_SETTINGS_NOLINK: dict[Any, Any] = {
# "t" : {"initval": 0.05},
# }
25 changes: 22 additions & 3 deletions tests/slow/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from hssm.utils import _rearrange_data

hssm.set_floatX("float32")
hssm.set_floatX("float32", jax=True)

# AF-TODO: Include more tests that use different link functions!

Expand Down Expand Up @@ -124,6 +124,8 @@ def run_sample(model, sampler, step, expected):

@pytest.mark.parametrize(parameter_names, parameter_grid)
def test_simple_models(data_ddm, loglik_kind, backend, sampler, step, expected):
print("PYMC VERSION: ")
print(pm.__version__)
print("TEST INPUTS WERE: ")
print("REPORTING FROM SIMPLE MODELS TEST")
print(loglik_kind, backend, sampler, step, expected)
Expand All @@ -147,6 +149,8 @@ def test_simple_models(data_ddm, loglik_kind, backend, sampler, step, expected):

@pytest.mark.parametrize(parameter_names, parameter_grid)
def test_reg_models(data_ddm_reg, loglik_kind, backend, sampler, step, expected):
print("PYMC VERSION: ")
print(pm.__version__)
print("TEST INPUTS WERE: ")
print("REPORTING FROM REG MODELS TEST")
print(loglik_kind, backend, sampler, step, expected)
Expand All @@ -169,7 +173,7 @@ def test_reg_models(data_ddm_reg, loglik_kind, backend, sampler, step, expected)

# Only runs once
if loglik_kind == "analytical" and sampler is None:
assert not model._get_deterministic_var_names(model.traces)
assert model._get_deterministic_var_names(model.traces) == ["~v"]
# test summary:
summary = model.summary()
assert summary.shape[0] == 6
Expand All @@ -181,6 +185,8 @@ def test_reg_models(data_ddm_reg, loglik_kind, backend, sampler, step, expected)

@pytest.mark.parametrize(parameter_names, parameter_grid)
def test_reg_models_v_a(data_ddm_reg, loglik_kind, backend, sampler, step, expected):
print("PYMC VERSION: ")
print(pm.__version__)
print("TEST INPUTS WERE: ")
print("REPORTING FROM REG MODELS V_A TEST")
print(loglik_kind, backend, sampler, step, expected)
Expand Down Expand Up @@ -218,7 +224,12 @@ def test_reg_models_v_a(data_ddm_reg, loglik_kind, backend, sampler, step, expec

# Only runs once
if loglik_kind == "analytical" and sampler is None:
assert model._get_deterministic_var_names(model.traces) == ["~a"]
assert len(model._get_deterministic_var_names(model.traces)) == len(
["~a", "~v"]
)
assert set(model._get_deterministic_var_names(model.traces)) == set(
["~a", "~v"]
)
# test summary:
summary = model.summary()
assert summary.shape[0] == 8
Expand Down Expand Up @@ -253,6 +264,8 @@ def test_reg_models_v_a(data_ddm_reg, loglik_kind, backend, sampler, step, expec
def test_simple_models_missing_data(
data_ddm_missing, loglik_kind, backend, sampler, step, expected, cpn
):
print("PYMC VERSION: ")
print(pm.__version__)
print("TEST INPUTS WERE: ")
print("REPORTING FROM SIMPLE MODELS MISSING DATA TEST")
print(loglik_kind, backend, sampler, step, expected)
Expand All @@ -271,6 +284,8 @@ def test_simple_models_missing_data(
def test_reg_models_missing_data(
data_ddm_reg_missing, loglik_kind, backend, sampler, step, expected, cpn
):
print("PYMC VERSION: ")
print(pm.__version__)
print("TEST INPUTS WERE: ")
print("REPORTING FROM REG MODELS MISSING DATA TEST")
print(loglik_kind, backend, sampler, step, expected)
Expand Down Expand Up @@ -298,6 +313,8 @@ def test_reg_models_missing_data(
def test_simple_models_deadline(
data_ddm_deadline, loglik_kind, backend, sampler, step, expected, opn
):
print("PYMC VERSION: ")
print(pm.__version__)
print("TEST INPUTS WERE: ")
print("REPORTING FROM SIMPLE MODELS DEADLINE TEST")
print(loglik_kind, backend, sampler, step, expected)
Expand All @@ -315,6 +332,8 @@ def test_simple_models_deadline(
def test_reg_models_deadline(
data_ddm_reg_deadline, loglik_kind, backend, sampler, step, expected, opn
):
print("PYMC VERSION: ")
print(pm.__version__)
print("TEST INPUTS WERE: ")
print("REPORTING FROM REG MODELS DEADLINE TEST")
print(loglik_kind, backend, sampler, step, expected)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_hssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from hssm.utils import download_hf
from hssm.likelihoods import DDM, logp_ddm

hssm.set_floatX("float32")
hssm.set_floatX("float32", jax=True)

param_v = {
"name": "v",
Expand Down

0 comments on commit 8d56fe9

Please sign in to comment.