-
Notifications
You must be signed in to change notification settings - Fork 47
Getting Started
MatlabStan has the following dependencies:
- CmdStan: 2.0.1 or greater
- MatlabProcessManager: 0.4.0 or greater
Installing CmdStan is covered in detail for different platforms in the CmdStan Manual. Note that CmdStan must be built, so it requires make and a C++ compiler.
MatlabProcessManager is two Matlab files; install by simply adding them to your Matlab path.
To install MatlabStan:
- Obtain a copy here or clone the repo.
- Add the resulting folder to your Matlab path.
+mstan
is a package folder that does not need to be added to the path, although its parent folder does. - Edit the file
stan_home.m
in the+mstan
directory to point to the parent folder of your CmdStan installation.
Aki Vehtari's package for Pareto smoothed importance sampling leave-one-out (PSIS-LOO) cross-validation is included in the psis
directory. Add this to your Matlab path as well if you want to use it.
Installing Steve Eddins's linewrap function is useful for dealing with unwrapped messages. His xUnit test framework is required if you want to run the unit tests.
This is a classic example from Section 5.5 of Gelman et al (2003). The following can be compared to the Rstan and Pystan versions.
schools_code = {
'data {'
' int<lower=0> J; // number of schools '
' real y[J]; // estimated treatment effects'
' real<lower=0> sigma[J]; // s.e. of effect estimates '
'}'
'parameters {'
' real mu; '
' real<lower=0> tau;'
' real eta[J];'
'}'
'transformed parameters {'
' real theta[J];'
' for (j in 1:J)'
' theta[j] <- mu + tau * eta[j];'
'}'
'model {'
' eta ~ normal(0, 1);'
' y ~ normal(theta, sigma);'
'}'
};
schools_dat = struct('J',8,...
'y',[28 8 -3 7 -1 1 18 12],...
'sigma',[15 10 16 11 9 11 10 18]);
fit = stan('model_code',schools_code,'data',schools_dat);
print(fit);
eta = fit.extract('permuted',true).eta;
mean(eta)
Stan models can also be defined using a file. For example, download the file eight_schools.stan
into your working directory and use the following call:
fit1 = stan('file','eight_schools.stan','data',schools_dat,'iter',1000,'chains',4);
Once a model is fitted, we can reuse the result as an input to stan
with other data or settings. This saves us from having to compile the C++ code again (see also here). For example, if we want to sample more iterations:
fit2 = stan('fit',fit1,'data',schools_dat,'iter',10000,'chains',4);
The stan
function returns a StanFit
object, which contains samples from the posterior distribution. StanFit
objects possess a number of methods, including print
, traceplot
and extract
. For example, a summary of the posterior samples as well as the log-posterior (which has the name lp__
) is obtained using
print(fit2);
which should look something like this:
Inference for Stan model: eight_schools_model
4 chains: each with iter=(5000,5000,5000,5000); warmup=(0,0,0,0); thin=(1,1,1,1); 20000 iterations saved.
Warmup took (0.16, 0.18, 0.22, 0.20) seconds, 0.77 seconds total
Sampling took (0.29, 0.30, 0.30, 0.25) seconds, 1.1 seconds total
Mean MCSE StdDev 5% 50% 95% N_Eff N_Eff/s R_hat
lp__ -4.8e+00 4.0e-02 2.6 -9.4e+00 -4.6e+00 -0.92 4364 3828 1.0e+00
accept_stat__ 7.2e-01 7.0e-02 0.30 3.4e-02 8.5e-01 1.0 19 16 1.1e+00
stepsize__ 4.7e-01 5.3e-02 0.075 3.7e-01 4.9e-01 0.58 2.0 1.8 1.5e+13
treedepth__ 1.7e+00 1.8e-01 0.61 0.0e+00 2.0e+00 2.0 11 10.0 1.1e+00
n_divergent__ 8.7e-03 1.7e-03 0.093 0.0e+00 0.0e+00 0.00 3095 2715 1.0e+00
mu 8.0e+00 9.4e-02 5.1 -9.5e-02 7.9e+00 16 2970 2605 1.0e+00
tau 6.7e+00 1.0e-01 5.5 5.3e-01 5.5e+00 17 2873 2520 1.0e+00
eta[1] 4.0e-01 9.1e-03 0.93 -1.2e+00 4.1e-01 1.9 10496 9206 1.0e+00
eta[2] -2.7e-03 8.9e-03 0.87 -1.4e+00 -7.3e-03 1.4 9545 8372 1.0e+00
eta[3] -2.0e-01 9.0e-03 0.93 -1.7e+00 -2.2e-01 1.4 10578 9279 1.0e+00
eta[4] -3.2e-02 8.4e-03 0.89 -1.5e+00 -3.3e-02 1.4 11090 9727 1.0e+00
eta[5] -3.5e-01 8.8e-03 0.87 -1.8e+00 -3.7e-01 1.1 9694 8503 1.0e+00
eta[6] -2.2e-01 9.1e-03 0.90 -1.6e+00 -2.4e-01 1.3 9764 8564 1.0e+00
eta[7] 3.5e-01 8.7e-03 0.87 -1.1e+00 3.7e-01 1.7 10018 8787 1.0e+00
eta[8] 5.5e-02 8.9e-03 0.93 -1.5e+00 5.6e-02 1.6 10972 9624 1.0e+00
theta[1] 1.1e+01 1.0e-01 8.3 2.2e-01 1.0e+01 27 6320 5544 1.0e+00
theta[2] 7.9e+00 6.3e-02 6.3 -2.2e+00 7.8e+00 18 10076 8838 1.0e+00
theta[3] 6.1e+00 8.9e-02 7.8 -7.3e+00 6.6e+00 18 7582 6651 1.0e+00
theta[4] 7.7e+00 6.6e-02 6.6 -3.2e+00 7.7e+00 19 9987 8760 1.0e+00
theta[5] 5.1e+00 6.4e-02 6.4 -6.2e+00 5.6e+00 15 10122 8878 1.0e+00
theta[6] 6.1e+00 6.9e-02 6.8 -5.7e+00 6.4e+00 17 9765 8566 1.0e+00
theta[7] 1.1e+01 7.7e-02 6.8 8.4e-01 1.0e+01 23 7782 6826 1.0e+00
theta[8] 8.6e+00 1.0e-01 8.2 -3.9e+00 8.3e+00 22 6297 5523 1.0e+00
Samples were drawn using hmc with nuts.
For each parameter, N_Eff is a crude measure of effective sample size,
and R_hat is the potential scale reduction factor on split chains (at
convergence, R_hat=1).
The extract
method returns a struct or struct array for parameters of interest
% return a struct with all parameters when none specifically requested
la = fit.extract('permuted',true);
mu = la.mu
% return an array with requested parameter
mu2 = fit.extract('pars','mu').mu;
% returns individual chains (each array element is a chain)
a = fit.extract('permuted',false);
Plotting traces is pretty basic at the moment
fit.traceplot;
Classic hierarchical normal model; a description and corresponding BUGS model can be found here. You can find the Stan model rats.stan. Fitting the model
y = [151, 145, 147, 155, 135, 159, 141, 159, 177, 134, ...
160, 143, 154, 171, 163, 160, 142, 156, 157, 152, 154, 139, 146, ...
157, 132, 160, 169, 157, 137, 153, 199, 199, 214, 200, 188, 210, ...
189, 201, 236, 182, 208, 188, 200, 221, 216, 207, 187, 203, 212, ...
203, 205, 190, 191, 211, 185, 207, 216, 205, 180, 200, 246, 249, ...
263, 237, 230, 252, 231, 248, 285, 220, 261, 220, 244, 270, 242, ...
248, 234, 243, 259, 246, 253, 225, 229, 250, 237, 257, 261, 248, ...
219, 244, 283, 293, 312, 272, 280, 298, 275, 297, 350, 260, 313, ...
273, 289, 326, 281, 288, 280, 283, 307, 286, 298, 267, 272, 285, ...
286, 303, 295, 289, 258, 286, 320, 354, 328, 297, 323, 331, 305, ...
338, 376, 296, 352, 314, 325, 358, 312, 324, 316, 317, 336, 321, ...
334, 302, 302, 323, 331, 345, 333, 316, 291, 324];
y = reshape(y,30,5);
x = [8 15 22 29 36];
rats_dat = struct('N',size(y,1),'TT',size(y,2),'x',x,'y',y,'xbar',mean(x));
rats_fit = stan('file','rats.stan','data',rats_dat,'verbose',true);
print(rats_fit);
should produce output that looks something like this (compare to Rstan run here):
Inference for Stan model: rats_model
4 chains: each with iter=(1000,1000,1000,1000); warmup=(0,0,0,0); thin=(1,1,1,1); 4000 iterations saved.
Warmup took (5.0, 6.6, 0.86, 0.67) seconds, 13 seconds total
Sampling took (0.28, 1.3, 0.38, 0.41) seconds, 2.4 seconds total
Mean MCSE StdDev 5% 50% 95% N_Eff N_Eff/s R_hat
lp__ -439 2.3e-01 7.0 -451 -438 -428 930 389 1.0e+00
accept_stat__ 0.81 4.1e-02 0.18 0.45 0.86 1.0 19 8.0 1.1e+00
stepsize__ 0.43 1.4e-01 0.20 0.078 0.54 0.58 2.0 0.84 4.7e+14
treedepth__ 2.5 1.6e-02 0.99 2.0 2.0 5.0 4000 1672 4.8e+00
n_divergent__ 0.00 0.0e+00 0.00 0.00 0.00 0.00 4000 1672 nan
alpha[1] 240 4.3e-02 2.7 235 240 244 4000 1672 1.0e+00
alpha[2] 248 4.3e-02 2.7 243 248 252 4000 1672 1.0e+00
alpha[3] 252 4.2e-02 2.7 248 252 257 4000 1672 1.0e+00
alpha[4] 233 4.3e-02 2.7 228 233 237 4000 1672 1.0e+00
alpha[5] 232 4.2e-02 2.6 227 232 236 4000 1672 1.0e+00
alpha[6] 250 4.2e-02 2.6 245 250 254 4000 1672 1.0e+00
alpha[7] 229 4.2e-02 2.7 224 229 233 4000 1672 1.0e+00
alpha[8] 248 4.2e-02 2.7 244 248 253 4000 1672 1.0e+00
alpha[9] 283 4.4e-02 2.8 279 283 288 4000 1672 1.0e+00
alpha[10] 219 4.5e-02 2.8 215 219 224 4000 1672 1.0e+00
alpha[11] 258 4.1e-02 2.6 254 258 263 4000 1672 1.0e+00
alpha[12] 228 4.3e-02 2.7 224 228 233 4000 1672 1.0e+00
alpha[13] 242 4.4e-02 2.8 238 242 247 4000 1672 1.0e+00
alpha[14] 268 4.3e-02 2.7 264 268 273 4000 1672 1.0e+00
alpha[15] 243 4.2e-02 2.7 238 243 247 4000 1672 1.0e+00
alpha[16] 245 4.2e-02 2.7 241 245 250 4000 1672 1.0e+00
alpha[17] 232 4.2e-02 2.7 228 232 236 4000 1672 1.0e+00
alpha[18] 240 4.2e-02 2.6 236 240 245 4000 1672 1.0e+00
alpha[19] 254 4.2e-02 2.6 249 254 258 4000 1672 1.0e+00
alpha[20] 242 4.2e-02 2.7 237 242 246 4000 1672 1.0e+00
alpha[21] 249 4.2e-02 2.6 244 249 253 4000 1672 1.0e+00
alpha[22] 225 4.2e-02 2.6 221 225 230 4000 1672 1.0e+00
alpha[23] 228 4.3e-02 2.7 224 229 233 4000 1672 1.0e+00
alpha[24] 245 4.2e-02 2.6 241 245 249 4000 1672 1.0e+00
alpha[25] 235 4.3e-02 2.7 230 235 239 4000 1672 1.0e+00
alpha[26] 254 4.3e-02 2.7 249 254 258 4000 1672 1.0e+00
alpha[27] 254 4.2e-02 2.7 250 254 259 4000 1672 1.0e+00
alpha[28] 243 4.2e-02 2.7 238 243 247 4000 1672 1.0e+00
alpha[29] 218 4.4e-02 2.8 213 218 222 4000 1672 1.0e+00
alpha[30] 241 4.2e-02 2.7 237 241 246 4000 1672 1.0e+00
beta[1] 6.1 3.8e-03 0.24 5.7 6.1 6.5 4000 1672 1.0e+00
beta[2] 7.1 4.1e-03 0.26 6.6 7.1 7.5 4000 1672 1.0e+00
beta[3] 6.5 3.9e-03 0.25 6.1 6.5 6.9 4000 1672 1.0e+00
beta[4] 5.3 4.2e-03 0.26 4.9 5.3 5.8 4000 1672 1.0e+00
beta[5] 6.6 3.8e-03 0.24 6.2 6.6 7.0 4000 1672 1.0e+00
beta[6] 6.2 3.7e-03 0.23 5.8 6.2 6.6 4000 1672 1.0e+00
beta[7] 6.0 3.9e-03 0.25 5.6 6.0 6.4 4000 1672 1.0e+00
beta[8] 6.4 3.9e-03 0.24 6.0 6.4 6.8 4000 1672 1.0e+00
beta[9] 7.1 4.1e-03 0.26 6.6 7.1 7.5 4000 1672 1.0e+00
beta[10] 5.8 3.8e-03 0.24 5.5 5.8 6.2 4000 1672 1.0e+00
beta[11] 6.8 4.0e-03 0.25 6.4 6.8 7.2 4000 1672 1.0e+00
beta[12] 6.1 3.9e-03 0.25 5.7 6.1 6.5 4000 1672 1.0e+00
beta[13] 6.2 3.8e-03 0.24 5.8 6.2 6.6 4000 1672 1.0e+00
beta[14] 6.7 3.9e-03 0.25 6.3 6.7 7.1 4000 1672 1.0e+00
beta[15] 5.4 4.0e-03 0.25 5.0 5.4 5.8 4000 1672 1.0e+00
beta[16] 5.9 4.0e-03 0.26 5.5 5.9 6.3 4000 1672 1.0e+00
beta[17] 6.3 3.8e-03 0.24 5.9 6.3 6.7 4000 1672 1.0e+00
beta[18] 5.8 3.9e-03 0.25 5.4 5.8 6.2 4000 1672 1.0e+00
beta[19] 6.4 3.9e-03 0.25 6.0 6.4 6.8 4000 1672 1.0e+00
beta[20] 6.1 3.9e-03 0.25 5.6 6.1 6.5 4000 1672 1.0e+00
beta[21] 6.4 3.8e-03 0.24 6.0 6.4 6.8 4000 1672 1.0e+00
beta[22] 5.9 3.9e-03 0.25 5.5 5.9 6.3 4000 1672 1.0e+00
beta[23] 5.7 4.0e-03 0.25 5.3 5.7 6.2 4000 1672 1.0e+00
beta[24] 5.9 3.9e-03 0.25 5.5 5.9 6.3 4000 1672 1.0e+00
beta[25] 6.9 4.1e-03 0.26 6.5 6.9 7.3 4000 1672 1.0e+00
beta[26] 6.5 3.8e-03 0.24 6.2 6.5 6.9 4000 1672 1.0e+00
beta[27] 5.9 3.9e-03 0.25 5.5 5.9 6.3 4000 1672 1.0e+00
beta[28] 5.8 3.9e-03 0.25 5.4 5.8 6.3 4000 1672 1.0e+00
beta[29] 5.7 3.8e-03 0.24 5.3 5.7 6.1 4000 1672 1.0e+00
beta[30] 6.1 3.9e-03 0.25 5.7 6.1 6.5 4000 1672 1.0e+00
mu_alpha 242 4.4e-02 2.8 238 242 247 4000 1672 1.0e+00
mu_beta 6.2 1.7e-03 0.11 6.0 6.2 6.4 4000 1672 1.0e+00
sigmasq_y 37 1.3e-01 5.8 29 37 48 2125 888 1.0e+00
sigmasq_alpha 219 1.0e+00 65 137 207 342 4000 1672 1.0e+00
sigmasq_beta 0.28 1.9e-03 0.10 0.15 0.26 0.47 2995 1252 1.0e+00
sigma_y 6.1 1.0e-02 0.47 5.4 6.1 6.9 2124 888 1.0e+00
sigma_alpha 15 3.3e-02 2.1 12 14 18 4000 1672 1.0e+00
sigma_beta 0.52 1.7e-03 0.093 0.38 0.51 0.68 2838 1187 1.0e+00
alpha0 106 5.8e-02 3.6 100 106 112 4000 1672 1.0e+00
Samples were drawn using hmc with nuts.
For each parameter, N_Eff is a crude measure of effective sample size,
and R_hat is the potential scale reduction factor on split chains (at
convergence, R_hat=1).