Skip to content

Commit

Permalink
Allow to collect nested dict keys in mcmc (#1905)
Browse files Browse the repository at this point in the history
* allow to collect nested dict key

* add docstring to nested attrgetter

* make sure that the behavior is consistent for attrgetter
  • Loading branch information
fehiepsi authored Nov 14, 2024
1 parent d55d209 commit 0a2f6fe
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 3 deletions.
11 changes: 8 additions & 3 deletions numpyro/infer/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
fori_collect,
identity,
is_prng_key,
nested_attrgetter,
)

__all__ = [
Expand Down Expand Up @@ -192,7 +193,7 @@ def _collect_fn(collect_fields, remove_sites):
@cached_by(_collect_fn, collect_fields, remove_sites)
def collect(x):
if collect_fields:
fields = attrgetter(*collect_fields)(x[0])
fields = nested_attrgetter(*collect_fields)(x[0])

if remove_sites != ():
fields = [fields] if len(collect_fields) == 1 else list(fields)
Expand Down Expand Up @@ -585,7 +586,10 @@ def warmup(
:param extra_fields: Extra fields (aside from :meth:`~numpyro.infer.mcmc.MCMCKernel.default_fields`)
from the state object (e.g. :data:`numpyro.infer.hmc.HMCState` for HMC) to collect during
the MCMC run. Exclude sample sites from collection with "~`sampler.sample_field`.`sample_site`".
e.g. "~z.a" will prevent site "a" from being collected if you're using the NUTS sampler.
e.g. "~z.a" will prevent site "a" from being collected if you're using the NUTS sampler. To
collect samples of a site "a" in the unconstrained space, we can specify the variable here, e.g.
`extra_fields=("z.a",)`.
:type extra_fields: tuple or list
:param bool collect_warmup: Whether to collect samples from the warmup phase. Defaults
to `False`.
Expand Down Expand Up @@ -622,7 +626,8 @@ def run(self, rng_key, *args, extra_fields=(), init_params=None, **kwargs):
during the MCMC run. Note that subfields can be accessed using dots, e.g.
`"adapt_state.step_size"` can be used to collect step sizes at each step. Exclude sample sites from
collection with "~`sampler.sample_field`.`sample_site`". e.g. "~z.a" will prevent site "a" from
being collected if you're using the NUTS sampler.
being collected if you're using the NUTS sampler. To collect samples of a site "a" in the
unconstrained space, we can specify the variable here, e.g. `extra_fields=("z.a",)`.
:type extra_fields: tuple or list of str
:param init_params: Initial parameters to begin sampling. The type must be consistent
with the input type to `potential_fn` provided to the kernel. If the kernel is
Expand Down
25 changes: 25 additions & 0 deletions numpyro/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,3 +778,28 @@ def find_stack_level() -> int:
else:
break
return n


def nested_attrgetter(*collect_fields):
"""
Like attrgetter, but allows for accessing dictionary keys
using the dot notation (e.g., 'x.c.d').
"""

def getter(obj):
results = tuple(_get_nested_attr(obj, field) for field in collect_fields)
return results if len(collect_fields) > 1 else results[0]

return getter


def _get_nested_attr(obj, field):
"""
Helper function to recursively access attributes and dictionary keys.
"""
for attr in field.split("."):
try:
obj = getattr(obj, attr)
except AttributeError:
obj = obj[attr]
return obj
9 changes: 9 additions & 0 deletions test/infer/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1199,3 +1199,12 @@ def model():
samps = mcmc.get_samples()

assert all([site[3:] not in samps for site in remove_sites])


def test_extra_fields_include_unconstrained_samples():
def model():
numpyro.sample("x", dist.HalfNormal(1))

mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
mcmc.run(random.PRNGKey(0), extra_fields=("z.x",))
assert_allclose(mcmc.get_samples()["x"], jnp.exp(mcmc.get_extra_fields()["z.x"]))

0 comments on commit 0a2f6fe

Please sign in to comment.