Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose some mcmc states for sequential sampling strategy #861

Merged
merged 4 commits into from
Jan 8, 2021

Conversation

fehiepsi
Copy link
Member

@fehiepsi fehiepsi commented Jan 5, 2021

Resolves #534. This is also requested by @rexdouglass in #539

This is just a solution, mainly for discussion. I'm not sure if this is a good API so if reviewers have other ideas, please let me know.

numpyro/infer/mcmc.py Outdated Show resolved Hide resolved
martinjankowiak
martinjankowiak previously approved these changes Jan 6, 2021
numpyro/infer/mcmc.py Outdated Show resolved Hide resolved
numpyro/infer/mcmc.py Outdated Show resolved Hide resolved
numpyro/infer/mcmc.py Outdated Show resolved Hide resolved
numpyro/infer/mcmc.py Outdated Show resolved Hide resolved
@fehiepsi
Copy link
Member Author

fehiepsi commented Jan 7, 2021

Thanks for your suggestions, Martin! I will need to discuss with Neeraj about the API. Especially, I am not sure if #781 is related and can be resolved here.

def init_state(self):
"""
The initial state of the MCMC chain. If this attribute is None,
:meth:`run` will call `self.sampler.init(...)` method for initialization.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My concern with exposing this is that it might give the impression that users can use this to specify initial value of parameters or mass matrix (which isn't true). However, there are other auxiliary variables like z_grad and those mentioned in the warning that are non-obvious and would result in leakage of internal details. Could you elaborate on the use case for this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that this warning is tricky to deliver to users. The motivation here is to start the chain on new data with the previously adapted mass matrix, step_size, and the last sample, as requested by @martinjankowiak and in #534.

For those who want such a warm start feature, they can still achieve that by modifying the warmup state by state._replace(i=0, pe=1e6) and taking care of other related diagnostics information (wrong z_grad is fine because z_grad on the old data only drives momentum on the first leapfrog step - but a wrong pe might lead to rejecting all proposals). So I think we should not expose it.

Copy link
Member

@neerajprad neerajprad Jan 7, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The motivation here is to start the chain on new data with the previously adapted mass matrix, step_size, and the last sample

Do you mean starting from these values and running adaptation from there? IIUC there are three things that we can specify for initialization - initial values (which is possible using init_to_value?), step size and mass matrix. This method is to be able to expose the latter two. Is that correct? Even then, I suppose the adaptation will re-learn fresh values for the latter two.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it will re-learn them but started from the previous values. This might be helpful for tricky posteriors, which require spending much time on initial steps...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm..how about things like the adaptation window index, wouldn't that need to be reset?

Another choice might be to have something like MCMCKernel.get_init_state(self, *args, **kwargs) method. For HMC, it could be get_init_state(self, init_params, mass_matrix=None, step_size=None) which can internally do all this book-keeping (e.g. setting i to 0). That might be a bit more verbose but easier to explain since we can tell the users to use the get_init_state method of the corresponding kernel without other caveats.

kernel = HMC(...)
mcmc = MCMC(kernel, ...)
mcmc.initial_state = kernel.get_init_state(init_params, mass_matrix)
mcmc.run(...)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense to me. We also need args, kwargs so it is more like kernel.init method (except that we won't try to find valid initial params). We have discussed how to expose the inverse mass matrix to the API in #536 but found that it is a bit ambiguous. I think we should skip this feature unless there are more requests for it. What do you think?

how about things like the adaptation window index

We use hmc_state.i for the update so it should be fine.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, I think it is fine to discuss a bit more before finalizing this. Power users can already do this if they need to, but we should provide an intuitive interface if we are going to advertise this.

self._init_state = state

@property
def warmup_state(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are to expose this, what do you think about calling this post_warmup_state for clarity?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean by post? Does it mean posterior?

Copy link
Member

@neerajprad neerajprad Jan 7, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant post as in after as opposed to pre. e.g. post-warmup draws from pystan. Other suggestions are welcome too, I just think warmup_state is fine for internal usage but isn't descriptive enough to be exposed in the API.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have no preference so I'll make the change. Thanks for your suggestion!

@neerajprad
Copy link
Member

I am not sure if #781 is related and can be resolved here.

To resolve this, I think we'll need to separately collect warmup samples and have a return_warmup argument to get_samples(return_warmup=True) to return these when requested. I am not sure if it is worth looking into in this PR (or at all since I seem to be the only one with that issue 😄 ).

@fehiepsi
Copy link
Member Author

fehiepsi commented Jan 7, 2021

seem to be the only one with that issue

I think I am better at understanding that issue now. It is not about how to collect those warmup samples, but about having a more intuitive API to do that job. We can do it later if it is needed. :) Your suggestions there make sense to me.

@neerajprad neerajprad merged commit 87258a2 into pyro-ppl:master Jan 8, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Sequential Sampling Strategy
3 participants