Skip to content

Commit

Permalink
Merge postprocess_fn into the fori_collect loop (#1910)
Browse files Browse the repository at this point in the history
* merge postprocess_fn into the loop

* cover the case of an empty model
  • Loading branch information
fehiepsi authored Nov 20, 2024
1 parent 5c2eafa commit 0e7bd20
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 25 deletions.
41 changes: 17 additions & 24 deletions numpyro/infer/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,27 +189,30 @@ def _sample_fn_nojit_args(state, sampler, args, kwargs):
return (sampler.sample(state[0], args, kwargs),)


def _collect_fn(collect_fields, remove_sites):
@cached_by(_collect_fn, collect_fields, remove_sites)
def collect(x):
def _collect_and_postprocess(postprocess_fn, collect_fields, remove_sites):
@cached_by(_collect_and_postprocess, postprocess_fn, collect_fields, remove_sites)
def collect_and_postprocess(x):
if collect_fields:
fields = nested_attrgetter(*collect_fields)(x[0])
fields = [fields] if len(collect_fields) == 1 else list(fields)
site_values = jax.tree.flatten(fields[0])[0]
if len(site_values) > 0:
fields[0] = postprocess_fn(fields[0], *x[1:])

if remove_sites != ():
fields = [fields] if len(collect_fields) == 1 else list(fields)
assert isinstance(fields[0], dict)

sample_sites = fields[0].copy()
for site in remove_sites:
sample_sites.pop(site)
fields[0] = sample_sites
fields = fields[0] if len(collect_fields) == 1 else fields

fields = fields[0] if len(collect_fields) == 1 else fields
return fields
else:
return x[0]

return collect
return collect_and_postprocess


# XXX: Is there a better hash key that we can use?
Expand Down Expand Up @@ -397,28 +400,28 @@ def _get_cached_fns(self):
fns, key = None, None
if fns is None:

def laxmap_postprocess_fn(states, args, kwargs):
def _postprocess_fn(state, args, kwargs):
if self.postprocess_fn is None:
body_fn = self.sampler.postprocess_fn(args, kwargs)
else:
body_fn = self.postprocess_fn
if self.chain_method == "vectorized" and self.num_chains > 1:
body_fn = vmap(body_fn)

return lax.map(body_fn, states)
return body_fn(state)

if self._jit_model_args:
sample_fn = partial(_sample_fn_jit_args, sampler=self.sampler)
postprocess_fn = jit(laxmap_postprocess_fn)
postprocess_fn = _postprocess_fn
else:
sample_fn = partial(
_sample_fn_nojit_args,
sampler=self.sampler,
args=self._args,
kwargs=self._kwargs,
)
postprocess_fn = jit(
partial(laxmap_postprocess_fn, args=self._args, kwargs=self._kwargs)
postprocess_fn = partial(
_postprocess_fn, args=self._args, kwargs=self._kwargs
)

fns = sample_fn, postprocess_fn
Expand Down Expand Up @@ -470,7 +473,9 @@ def _single_chain_mcmc(self, init, args, kwargs, collect_fields, remove_sites):
upper_idx,
sample_fn,
init_val,
transform=_collect_fn(collect_fields, remove_sites),
transform=_collect_and_postprocess(
postprocess_fn, collect_fields, remove_sites
),
progbar=self.progress_bar,
return_last_val=True,
thinning=self.thinning,
Expand All @@ -487,18 +492,6 @@ def _single_chain_mcmc(self, init, args, kwargs, collect_fields, remove_sites):
if len(collect_fields) == 1:
states = (states,)
states = dict(zip(collect_fields, states))
# Apply constraints if number of samples is non-zero
site_values = jax.tree.flatten(states[self._sample_field])[0]
# XXX: lax.map still works if some arrays have 0 size
# so we only need to filter out the case site_value.shape[0] == 0
# (which happens when lower_idx==upper_idx)
if len(site_values) > 0 and jnp.shape(site_values[0])[0] > 0:
if self._jit_model_args:
states[self._sample_field] = postprocess_fn(
states[self._sample_field], args, kwargs
)
else:
states[self._sample_field] = postprocess_fn(states[self._sample_field])
return states, last_state

def _set_collection_params(
Expand Down
5 changes: 4 additions & 1 deletion test/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ def test_mcmc_one_chain(deterministic, find_heuristic_step_size):

num_traces_for_heuristic = 2 if find_heuristic_step_size else 0
if deterministic:
assert GLOBAL["count"] == 4 + num_traces_for_heuristic
# We have two extra calls to the model to get deterministic values:
# 1. transform the init state
# 2. transform state during the loop
assert GLOBAL["count"] == 5 + num_traces_for_heuristic
else:
assert GLOBAL["count"] == 3 + num_traces_for_heuristic

Expand Down

0 comments on commit 0e7bd20

Please sign in to comment.