From 0a2f6fe97c6d3becc055260ed5bea04eab88fcfa Mon Sep 17 00:00:00 2001 From: Du Phan Date: Thu, 14 Nov 2024 06:14:15 -0500 Subject: [PATCH] Allow to collect nested dict keys in mcmc (#1905) * allow to collect nested dict key * add docstring to nested attrgetter * make sure that the behavior is consistent for attrgetter --- numpyro/infer/mcmc.py | 11 ++++++++--- numpyro/util.py | 25 +++++++++++++++++++++++++ test/infer/test_mcmc.py | 9 +++++++++ 3 files changed, 42 insertions(+), 3 deletions(-) diff --git a/numpyro/infer/mcmc.py b/numpyro/infer/mcmc.py index 81efb1cbf..d27310931 100644 --- a/numpyro/infer/mcmc.py +++ b/numpyro/infer/mcmc.py @@ -20,6 +20,7 @@ fori_collect, identity, is_prng_key, + nested_attrgetter, ) __all__ = [ @@ -192,7 +193,7 @@ def _collect_fn(collect_fields, remove_sites): @cached_by(_collect_fn, collect_fields, remove_sites) def collect(x): if collect_fields: - fields = attrgetter(*collect_fields)(x[0]) + fields = nested_attrgetter(*collect_fields)(x[0]) if remove_sites != (): fields = [fields] if len(collect_fields) == 1 else list(fields) @@ -585,7 +586,10 @@ def warmup( :param extra_fields: Extra fields (aside from :meth:`~numpyro.infer.mcmc.MCMCKernel.default_fields`) from the state object (e.g. :data:`numpyro.infer.hmc.HMCState` for HMC) to collect during the MCMC run. Exclude sample sites from collection with "~`sampler.sample_field`.`sample_site`". - e.g. "~z.a" will prevent site "a" from being collected if you're using the NUTS sampler. + e.g. "~z.a" will prevent site "a" from being collected if you're using the NUTS sampler. To + collect samples of a site "a" in the unconstrained space, we can specify the variable here, e.g. + `extra_fields=("z.a",)`. + :type extra_fields: tuple or list :param bool collect_warmup: Whether to collect samples from the warmup phase. Defaults to `False`. @@ -622,7 +626,8 @@ def run(self, rng_key, *args, extra_fields=(), init_params=None, **kwargs): during the MCMC run. Note that subfields can be accessed using dots, e.g. `"adapt_state.step_size"` can be used to collect step sizes at each step. Exclude sample sites from collection with "~`sampler.sample_field`.`sample_site`". e.g. "~z.a" will prevent site "a" from - being collected if you're using the NUTS sampler. + being collected if you're using the NUTS sampler. To collect samples of a site "a" in the + unconstrained space, we can specify the variable here, e.g. `extra_fields=("z.a",)`. :type extra_fields: tuple or list of str :param init_params: Initial parameters to begin sampling. The type must be consistent with the input type to `potential_fn` provided to the kernel. If the kernel is diff --git a/numpyro/util.py b/numpyro/util.py index be8b46b0d..2c5b2a0b6 100644 --- a/numpyro/util.py +++ b/numpyro/util.py @@ -778,3 +778,28 @@ def find_stack_level() -> int: else: break return n + + +def nested_attrgetter(*collect_fields): + """ + Like attrgetter, but allows for accessing dictionary keys + using the dot notation (e.g., 'x.c.d'). + """ + + def getter(obj): + results = tuple(_get_nested_attr(obj, field) for field in collect_fields) + return results if len(collect_fields) > 1 else results[0] + + return getter + + +def _get_nested_attr(obj, field): + """ + Helper function to recursively access attributes and dictionary keys. + """ + for attr in field.split("."): + try: + obj = getattr(obj, attr) + except AttributeError: + obj = obj[attr] + return obj diff --git a/test/infer/test_mcmc.py b/test/infer/test_mcmc.py index 249cc4627..c938781e8 100644 --- a/test/infer/test_mcmc.py +++ b/test/infer/test_mcmc.py @@ -1199,3 +1199,12 @@ def model(): samps = mcmc.get_samples() assert all([site[3:] not in samps for site in remove_sites]) + + +def test_extra_fields_include_unconstrained_samples(): + def model(): + numpyro.sample("x", dist.HalfNormal(1)) + + mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10) + mcmc.run(random.PRNGKey(0), extra_fields=("z.x",)) + assert_allclose(mcmc.get_samples()["x"], jnp.exp(mcmc.get_extra_fields()["z.x"]))