You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
First of all thank you for publicly sharing your work.
I am a senior CS bachelor student and I am using the SWAG estimate of the posterior P(theta | dataset) as part of a theoretical framework supporting the empirical results I reached so far during my thesis. I have a couple of questions concerning the SWAG class.
So if I understood correctly, the SWAG class provides a way to sample from the posterior and compute the log probability of the samples. After digging deeper into the code, I see that the output of the "compute_logprob" method depends on the values of "mean_list, var_list, covar_mat_root_list" generated by the "generate_mean_var_covar()" method as indicated in the code snippet below.
Going through "generate_mean_var_covar()" method, these values are extracted from the "mean" and "sq_mean" attributes of each sub-module as indicated in the code snippet below:
So in order to get different outputs for the "compute_logprob()" method, the values of "mean" and "sq_mean" in the different sub modules need to change. However, the only method that changes these values is the "collect_model()" method. Hence, I conjunctured that I should proceed as follows:
Define a base model, then define a swag model with the same base class.
when training the base model I should call the swag_model.collect_model() at the end of each epoch as this will update the parameters (the mean, and covariance matrix).
After the training, the swag_model can be used to sample from the Posterior distribution as follows
swag_model.sample(): sample a set of parameters from the posterior
swag_model.compute_log_prob(): to compute the log probability of the current set of parameters (sampled with the call above)
I would greatly appreciate it if you can confirm / correct my understanding of your implementation. Thanks a lot in advance
The text was updated successfully, but these errors were encountered:
Thank you for prompt response. To clarify, I would like to build upon your work and sample from the posterior distribution P(parameter | data). I proceed as follows:
create a base model and a swag model separately. When training the base model, at the end of each epoch, I call swag_model.collect_model(base_model)
after training, I proceed as follows:
call: base_model.sample() // if I am not mistaken, it sets the parameters of the swag model to 'w' where w ~ P(param | data)
call: base_model.compute_log_prob() // returns an estimate of log(P(w | data))
Would you please confirm / correct my understanding of the code.
Thanks a lot in advance.
First of all thank you for publicly sharing your work.
I am a senior CS bachelor student and I am using the SWAG estimate of the posterior P(theta | dataset) as part of a theoretical framework supporting the empirical results I reached so far during my thesis. I have a couple of questions concerning the SWAG class.
So if I understood correctly, the SWAG class provides a way to sample from the posterior and compute the log probability of the samples. After digging deeper into the code, I see that the output of the "compute_logprob" method depends on the values of "mean_list, var_list, covar_mat_root_list" generated by the "generate_mean_var_covar()" method as indicated in the code snippet below.
Going through "generate_mean_var_covar()" method, these values are extracted from the "mean" and "sq_mean" attributes of each sub-module as indicated in the code snippet below:
So in order to get different outputs for the "compute_logprob()" method, the values of "mean" and "sq_mean" in the different sub modules need to change. However, the only method that changes these values is the "collect_model()" method. Hence, I conjunctured that I should proceed as follows:
I would greatly appreciate it if you can confirm / correct my understanding of your implementation. Thanks a lot in advance
The text was updated successfully, but these errors were encountered: