This is a MATLAB package to fit generalised additive models (GAM) to neural data. If your primary goal is to fit a generalised linear model (GLM), consider using dedicated packages available elsewhere (e.g. this or that).
Generalised additive model (GAM) is a framework to simultaneously estimate the relationship between a response variable and multiple predictor variables. Unlike generalised linear models (GLM), GAM can account for arbitary nonlinear relationships between predictors and response. Since the relationship between stimulus variables and spikes is often nonlinear and nonmonotonic, GAMs make for an attractive option for modeling neural response. In neuroscience experiments, the neural response may be influenced by external inputs --- which may either be continuous-valued signals (e.g. velocity, position) or binary events (e.g. stimulus onset, reward) --- as well as neuron's own response in the past (spike-history). In its most general form, the model can be written as:
where is the continuous-valued input variable, is any generic (nonlinear) function operating on , is the binary event, is the temporal filter operating on , is the spike-history kernel, is the link function, is the expectation value of the response, and & denote the total number of continuous-valued inputs and binary events respectively.
Given response and inputs & , the goal is to recover , , and under some assumed link function . This can be solved by computing the maximum a posteriori (MAP) estimates as:
where is the model likelihood and is the prior over the model parameters.
We begin by discretising the value of each input variable into bins. We recode the value of at each time point as a one-hot binary string where the bit is:
Let denote the value of when . If is known to vary smoothly, we can write as:
where we have assumed a factorisable gaussian prior for simplicity, and is a hyperparameter capturing the degree of smoothness of . We solve for by maximising:
Once we determine the best-fit model parameters , we can construct the marginal 'tuning' to each individual input variable by computing the conditional mean of given as:
where denotes the probability that .
To fit the model using your data, you need to use the function BuildGAM
. BuildGAM
takes in three inputs xt
, yt
, and prs
.
xt
must be an n x 1 cell array containing the values of input variables (continuous-valued and/or binary events) where n is the total number of input variables. Each cell in xt
corresponds to one input variable. If is a one dimensional variable, xt{i}
must be a T x 1 vector, T being the total number of observations. If is two-dimensional e.g. position, then the corresponding xt{i}
must be a T x 2 array, with the two columns corresponding to the two dimensions. If is a binary event, xt{i}
must be a T x 1 vector whose elements are equal to 1
at all the bins at which that particular event occurred and 0
everywhere else. If you wish to fit spike-history kernel, make one of the cells in xt
equal to the response yt
(see below) --- note that in this case n will be the total number of inputs variables plus 1.
yt
must be a T x 1 array of spike counts. It is advisable to record your observations using a sampling rate of at least 50Hz
so that yt
is mostly comprised of 0s and 1s.
prs
is a structure specifying analysis parameters. The fields of this structure must be created in the following order:
prs.varname
1 x n cell array of names of the input variables (only used for labeling plots)
prs.vartype
1 x n cell array of types ('1D'
,'1Dcirc'
,'2D'
or 'event'
) of the input variables
prs.nbins
1 x n cell array of number of bins to discretise input (for continuous-valued) or the number of bins to discretise the temporal filter (for binary events)
prs.binrange
1 x n cell array of 2 x 1 vectors specifying lower & upper bounds of the input (for continuous-valued) or lower & upper bounds of the desired time window around the input (for binary events)
prs.nfolds
Number of folds for cross-validation
prs.dt
Time (in secs) between consecutive observation samples (1/samplingfrequency)
prs.filtwidth
Width of gaussian filter (in samples) to smooth spike train
prs.linkfunc
Choice of link function ('log'
,'identity'
or 'logit'
)
prs.lambda
1 x n cell array of hyper-parameters for imposing smoothness prior on tuning functions
prs.alpha
Significance level for comparing likelihood values
prs.varchoose
1 x n array of ones and zeros indicating the inclusion status of each variable. Use 1
to forcibly include a variable in the bestmodel, 0
to let the method determine whether to include a variable
prs.method
Method for selecting the best model 'Forward'
, 'Backward'
, 'FastForward'
or 'FastBackward'
For more details about the role of these parameters, use help BuildGAM
in MATLAB. Once you have xt
, yt
, and prs
, you can fit the model by running the following command:
models = BuildGAM(xt,yt,prs); % the output is saved in the variable called models
And then use this command to plot the results:
PlotGAM(models,prs); % plot model likelihoods and marginal tuning functions
Although meant for neural data, you can use this code to model any point process . Checkout demo.m
for examples. Write to me if you have questions.
This implementation builds on the LNP model described in Hardcastle et al., 2017, available here for MEC neurons.