-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Tutorial for neural adaptive smc #11
Comments
Hi @fehiepsi, I'd like to take a stab at this, but could use a little help getting started. Given the model from 5.1 in the paper def ssm(xs = None, T_max = 1000):
z_0 = numpyro.sample("z_0", dist.Normal(0, 5))
z_t_m1 = z_0
for t in range(1, T_max):
z_t_loc = z_t_m1 / 2 + 25 * z_t_m1 / (1 + z_t_m1 ** 2) + 8 * jnp.cos(1.2 * t)
z_t = numpyro.sample(f"z_{t}", dist.Normal(z_t_loc, jnp.sqrt(10)))
x_t = numpyro.sample(f"x_{t}", dist.Normal(z_t ** 2 / 20, 1), obs=xs[t - 1] if xs is not None else None)
z_t_m1 = z_t
return x_t I figure what needs to be implemented is an LSTM-based mixture density network which parametrizes q(z_t | z_1:t-1, x_1:t) (or q(v_t | z_1:t-1, x_1:t, f(z_t-1, t)), since that works better according to the paper). Then make a list of proposals, one for each z_t, each of which is sampled and used to update the full dimensional variable using zs.at[t].set(sample) ? Would the targets simply be the model above, conditioned on x_1:t ? I will try to code something up, but some guidance would be very helpful! |
Great to hear that you are interested in this issue, @deoxyribose! The main theme of using coix is to define subprograms, then combine them together. Each subprogram is modelled by using a PPL, e.g. numpyro. Your model is already in the form of a "combined" one. You can factor it out by creating subprograms: init_proposal, proposal_t, target_t. Here target_t is your body function of your for loop. proposal_t is your lstm-based model. target_t defines the joint distribution of p(z_t,x_t|...) while proposal_t is q(z_t|...). Let's walk through this step first. Please let me know if you have any question. The next step is to combine those programs together. You can use the algorithm in coix.algo.nasmc or even better, combine them in your own way. But let's discuss this later. |
Thanks @fehiepsi! I've tried to do what you suggest here: https://github.com/deoxyribose/nasmc/blob/master/nasmc.ipynb, but I don't have a good handle on how it's supposed to look like yet. I'm not sure whether init_proposal and z_0 sampling should be separate from ssm_proposal and ssm_target respectively. In any case, I don't know how to progress from this current error message, but I figure I probably have some misconceptions apparent from the code which you could clear up :) Edit 10/10/24: At present, I can run training if jit compilation is turned off, but judging by the metrics, it's not very stable and eventually crashes. I think I need to break the problem down to smaller tests than just running training, but I'm not sure what that could be. |
Sorry for the late response, @deoxyribose! I'll look into your notebook tomorrow. |
Hi @deoxyribose I addressed several issues in your notebook to match the paper; it seems to work: https://gist.github.com/fehiepsi/b7def6a77bf9ca150cf2f17f2ba1a2b5
Let me know if something is unclear. |
Thanks so much @fehiepsi! I got busy, but will have plenty of time to finish this in 2 weeks, if not before. |
Neural Adaptive SMC, Gu etc. is a nice framework that allows us to train proposals for non-linear state space models. We can use forward KL in a nested variational inference scheme because both derivations provide similar grad estimations.
For state space models, we typically don't have reverse kernel because the state dimension grows over time. This example will greatly illustrate how to deal with growing-dimensional variables in JAX. The trick will be to prepare a full dimensional variable and perform index update in each smc step.
The text was updated successfully, but these errors were encountered: