Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Latent exclude deterministic #1901

Merged
merged 3 commits into from
Nov 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 11 additions & 12 deletions numpyro/infer/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,32 +790,31 @@ def _predictive(

def single_prediction(val):
rng_key, samples = val

def _samples_wo_deterministic(msg):
return samples.get(msg["name"]) if msg["type"] != "deterministic" else None

substituted_model = (
substitute(masked_model, substitute_fn=_samples_wo_deterministic)
if exclude_deterministic
else substitute(masked_model, samples)
)

if infer_discrete:
from numpyro.contrib.funsor import config_enumerate
from numpyro.contrib.funsor.discrete import _sample_posterior

model_trace = prototype_trace
temperature = 1
pred_samples = _sample_posterior(
config_enumerate(condition(model, samples)),
config_enumerate(substituted_model),
first_available_dim,
temperature,
rng_key,
*model_args,
**model_kwargs,
)
else:

def _samples_wo_deterministic(msg):
return (
samples.get(msg["name"]) if msg["type"] != "deterministic" else None
)

substituted_model = (
substitute(masked_model, substitute_fn=_samples_wo_deterministic)
if exclude_deterministic
else substitute(masked_model, samples)
)
model_trace = trace(seed(substituted_model, rng_key)).get_trace(
*model_args, **model_kwargs
)
Expand Down
41 changes: 41 additions & 0 deletions test/infer/test_infer_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,26 @@ def model(X, y=None):
return model, X, y


def categorical_probs():
probs0 = 0.5
nbatch0, nbatch1 = 2, 1
probs = jnp.ones((nbatch0, nbatch1, 3)) * probs0

def model(probs):
probs = numpyro.deterministic("probs", probs)

plate = numpyro.plate("plate", size=probs.shape[-1], dim=-1)

with plate:
numpyro.sample(
"counts_categorical",
dist.Categorical(probs=probs),
infer={"enumerate": "parallel"},
)

return model, probs


@pytest.mark.parametrize("parallel", [True, False])
def test_predictive(parallel):
model, data, true_probs = beta_bernoulli()
Expand Down Expand Up @@ -113,6 +133,27 @@ def test_predictive_with_deterministic(parallel):
assert predictive_samples["obs"].shape == (100,) + X[:n_preds].shape


@pytest.mark.parametrize(
argnames="parallel", argvalues=[True, False], ids=["parallel", "sequential"]
)
def test_discrete_predictive_with_deterministic(parallel):
"""Tests that the predictive samples include deterministic sites for discrete models."""
model, probs = categorical_probs()

predictive = Predictive(
model=model,
posterior_samples=dict(probs=probs),
infer_discrete=True,
batch_ndims=2,
parallel=parallel,
exclude_deterministic=False,
)

predictive_samples = predictive(random.PRNGKey(1), probs=probs)
assert predictive_samples.keys() == {"counts_categorical"}
assert predictive_samples["counts_categorical"].shape == probs.shape


def test_predictive_with_guide():
data = jnp.array([1] * 8 + [0] * 2)

Expand Down
Loading