diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..0ee611d --- /dev/null +++ b/.gitignore @@ -0,0 +1,42 @@ +# History files +.Rhistory +.Rapp.history + +# Session Data files +.RData + +# User-specific files +.Ruserdata + +# Example code in package build process +*-Ex.R + +# Output files from R CMD build +/*.tar.gz + +# Output files from R CMD check +/*.Rcheck/ + +# RStudio files +.Rproj.user/ + +# produced vignettes +vignettes/*.html +vignettes/*.pdf + +# OAuth2 token, see https://github.com/hadley/httr/releases/tag/v0.3 +.httr-oauth + +# knitr and R markdown default cache directories +*_cache/ +/cache/ + +# Temporary files created by R markdown +*.utf8.md +*.knit.md + +# Results +Models/*.rds +Results/* +Plots/* +log.txt \ No newline at end of file diff --git a/Models/AR.stan b/Models/AR.stan new file mode 100644 index 0000000..2d5ebb2 --- /dev/null +++ b/Models/AR.stan @@ -0,0 +1,111 @@ +// Autoregressive model of order 1 + +functions { + real normal_lb_ub_rng(real mu, real sigma, real lb, real ub) { + // Sample from truncated normal of mean mu, standard deviation sigma, lower bound lb and upper bound ub + real u; + real p1; + real p2; + if (is_nan(mu) || is_inf(mu)) { + reject("normal_lb_ub_rng: mu must be finite; ", "found mu = ", mu); + } else if (is_nan(sigma) || is_inf(sigma) || sigma < 0) { + reject("normal_lb_ub_rng: sigma must be finite and non-negative; ", "found sigma = ", sigma); + } else if (lb >= ub) { + reject("normal_lb_ub_rng: lb must be less than ub; ", "found lb = ", lb, "and ub = ", ub); + } else { + p1 = normal_cdf(lb, mu, sigma); + p2 = normal_cdf(ub, mu, sigma); + if (p1 >= p2) { + reject("normal_lb_ub_rng: p1 >= p2. p1 = ", p1, " and p2 = ", p2, ". mu = ", mu, "and sigma = ", sigma); + } else { + u = uniform_rng(p1, p2); + } + } + return (sigma * inv_Phi(u)) + mu; // inverse cdf + } +} + +data { + int N_obs; // Number of non-missing observations + int N_mis; // Number of missing observations + int N_pt; // Number of patients + real max_score; // Maximum value that the score can take + int idx_obs[N_obs]; // Index of non-missing observations + int idx_mis[N_mis]; // Index of missing observations + real S_obs[N_obs]; // Observed score + + int N_test; // Number of predictions to evaluate (predictions are set to missing in the data) + int idx_test[N_test]; // Index of test observations + real S_test[N_test]; // Observed score for predictions + + int run; // Switch to evaluate the likelihood + int rep; // Switch to generate replications +} + +transformed data { + int N = N_obs + N_mis; // Total number of observations + int mult = N / N_pt; // Number of timepoints per patient (same by construction) + int start[N_pt]; // index of first observation for patient each patient + int end[N_pt]; // index of last observation for patient each patient + + for (k in 1:N_pt) { + start[k] = (k - 1) * mult + 1; + end[k] = k * mult; + } +} + +parameters { + real S_mis[N_mis]; // Missing S (useful for predictions (cf. bounds)) + real sigma; // Standard deviation + real alpha; // Autocorrelation parameter + real S_inf; // Autoregression mean +} + +transformed parameters { + real S_meas[N]; // Measured S + real b = S_inf * (1 - alpha); // Intercept + + S_meas[idx_obs] = S_obs; + S_meas[idx_mis] = S_mis; +} + +model { + // implicit (bounded) uniform prior for S_mis + // implicit uniform prior for alpha + S_inf / max_score ~ normal(0.5, 0.25); + sigma / max_score ~ lognormal(-log(20), 0.5 * log(5)); // 95% CI is [.01, 0.25] * max_score + + for (k in 1:N_pt) { + to_vector(S_meas[(start[k] + 1):end[k]]) ~ normal(alpha * to_vector(S_meas[start[k]:(end[k] - 1)]) + b, sigma); // Vectorise for efficiency + } + +} + +generated quantities { + real S_rep[N]; // Replications of S_meas[t + 1] given S_lat[t] + real S_pred[N_test]; // Predictive sample of S_test + real lpd[N_test]; // Log predictive density + + // Replications + if (rep == 1) { + for (k in 1:N_pt) { + S_rep[start[k]] = S_meas[start[k]]; + for (t in (start[k] + 1):end[k]) { + S_rep[t] = normal_lb_ub_rng(alpha * S_meas[t - 1] + b, sigma, 0, max_score); + } + } + S_pred = S_rep[idx_test]; + } + + // Log predictive density + { + real linpred; // Linear predictor + real Z; // Normalisation constant + for (i in 1:N_test) { + linpred = alpha * S_meas[idx_test[i]] + b; + Z = normal_cdf(max_score, linpred, sigma) - normal_cdf(0, linpred, sigma); + lpd[i] = normal_lpdf(S_test[i] | linpred, sigma) - log(Z); + } + } + +} diff --git a/Models/MixedAR.stan b/Models/MixedAR.stan new file mode 100644 index 0000000..2a10d2b --- /dev/null +++ b/Models/MixedAR.stan @@ -0,0 +1,133 @@ +// Autoregressive model of order 1 with patient-dependent parameters + +functions { + real normal_lb_ub_rng(real mu, real sigma, real lb, real ub) { + // Sample from truncated normal of mean mu, standard deviation sigma, lower bound lb and upper bound ub + real u; + real p1; + real p2; + if (is_nan(mu) || is_inf(mu)) { + reject("normal_lb_ub_rng: mu must be finite; ", "found mu = ", mu); + } else if (is_nan(sigma) || is_inf(sigma) || sigma < 0) { + reject("normal_lb_ub_rng: sigma must be finite and non-negative; ", "found sigma = ", sigma); + } else if (lb >= ub) { + reject("normal_lb_ub_rng: lb must be less than ub; ", "found lb = ", lb, "and ub = ", ub); + } else { + p1 = normal_cdf(lb, mu, sigma); + p2 = normal_cdf(ub, mu, sigma); + if (p1 >= p2) { + reject("normal_lb_ub_rng: p1 >= p2. p1 = ", p1, " and p2 = ", p2, ". mu = ", mu, "and sigma = ", sigma); + } else { + u = uniform_rng(p1, p2); + } + } + return (sigma * inv_Phi(u)) + mu; // inverse cdf + } +} + +data { + int N_obs; // Number of non-missing observations + int N_mis; // Number of missing observations + int N_pt; // Number of patients + real max_score; // Maximum value that the score can take + int idx_obs[N_obs]; // Index of non-missing observations + int idx_mis[N_mis]; // Index of missing observations + real S_obs[N_obs]; // Observed score + + int N_test; // Number of predictions to evaluate (predictions are set to missing in the data) + int idx_test[N_test]; // Index of test observations + real S_test[N_test]; // Observed score for predictions + + int run; // Switch to evaluate the likelihood + int rep; // Switch to generate replications +} + +transformed data { + int N = N_obs + N_mis; // Total number of observations + int mult = N / N_pt; // Number of timepoints per patient (same by construction) + int start[N_pt]; // index of first observation for patient each patient + int end[N_pt]; // index of last observation for patient each patient + + for (k in 1:N_pt) { + start[k] = (k - 1) * mult + 1; + end[k] = k * mult; + } +} + +parameters { + real S_mis[N_mis]; // Missing S (useful for predictions (cf. bounds)) + real sigma; // Standard deviation + + // Autocorrelation parameter + real alpha[N_pt]; // Autocorrelation parameter + real mu_alpha; // Population mean of alpha + real phi_alpha; // Population pseudo "sample size" of alpha + + // Population autoregression mean (for S_lat) + real mu_inf; // Population mean + real sigma_inf; // Population std + real eta_inf[N_pt]; // Error term +} + +transformed parameters { + real S_meas[N]; // Measured S + real S_inf[N_pt]; + real b[N_pt]; + + for (k in 1:N_pt) { + S_inf[k] = mu_inf + sigma_inf * eta_inf[k]; + b[k] = S_inf[k] * (1 - alpha[k]); + } + + S_meas[idx_obs] = S_obs; + S_meas[idx_mis] = S_mis; +} + +model { + // implicit (bounded) uniform prior for S_mis + sigma / max_score ~ lognormal(-log(20), 0.5 * log(5)); // 95% CI is [.01, 0.25] * max_score + + eta_inf ~ std_normal(); + mu_inf / max_score ~ normal(0.5, 0.25); + sigma_inf / max_score ~ normal(0, 0.125); + + mu_alpha ~ beta(2, 2); + phi_alpha ~ lognormal(1 * log(10), 0.5 * log(10)); // mass between 1 and 100 + alpha ~ beta(mu_alpha * phi_alpha, (1 - mu_alpha) * phi_alpha); + + for (k in 1:N_pt) { + to_vector(S_meas[(start[k] + 1):end[k]]) ~ normal(alpha[k] * to_vector(S_meas[start[k]:(end[k] - 1)]) + b[k], sigma); // Vectorise for efficiency + } + +} + +generated quantities { + real S_rep[N]; // Replications of S_meas[t + 1] given S_lat[t] + real S_pred[N_test]; // Predictive sample of S_test + real lpd[N_test]; // Log predictive density + + // Replications + if (rep == 1) { + for (k in 1:N_pt) { + S_rep[start[k]] = S_meas[start[k]]; + for (t in (start[k] + 1):end[k]) { + S_rep[t] = normal_lb_ub_rng(alpha[k] * S_meas[t - 1] + b[k], sigma, 0, max_score); + } + } + S_pred = S_rep[idx_test]; + } + + // Log predictive density + { + int k_test; // Patient index + real linpred; // Linear predictor + real Z; // Normalisation constant + for (i in 1:N_test) { + k_test = ((idx_test[i] - 1) % mult) + 1; + linpred = alpha[k_test] * S_meas[idx_test[i]] + b[k_test]; + Z = normal_cdf(max_score, linpred, sigma) - normal_cdf(0, linpred, sigma); + lpd[i] = normal_lpdf(S_test[i] | linpred, sigma) - log(Z); + } + } + +} diff --git a/Models/RW.stan b/Models/RW.stan new file mode 100644 index 0000000..5746982 --- /dev/null +++ b/Models/RW.stan @@ -0,0 +1,103 @@ +// Random Walk model + +functions { + real normal_lb_ub_rng(real mu, real sigma, real lb, real ub) { + // Sample from truncated normal of mean mu, standard deviation sigma, lower bound lb and upper bound ub + real u; + real p1; + real p2; + if (is_nan(mu) || is_inf(mu)) { + reject("normal_lb_ub_rng: mu must be finite; ", "found mu = ", mu); + } else if (is_nan(sigma) || is_inf(sigma) || sigma < 0) { + reject("normal_lb_ub_rng: sigma must be finite and non-negative; ", "found sigma = ", sigma); + } else if (lb >= ub) { + reject("normal_lb_ub_rng: lb must be less than ub; ", "found lb = ", lb, "and ub = ", ub); + } else { + p1 = normal_cdf(lb, mu, sigma); + p2 = normal_cdf(ub, mu, sigma); + if (p1 >= p2) { + reject("normal_lb_ub_rng: p1 >= p2. p1 = ", p1, " and p2 = ", p2, ". mu = ", mu, "and sigma = ", sigma); + } else { + u = uniform_rng(p1, p2); + } + } + return (sigma * inv_Phi(u)) + mu; // inverse cdf + } +} + +data { + int N_obs; // Number of non-missing observations + int N_mis; // Number of missing observations + int N_pt; // Number of patients + real max_score; // Maximum value that the score can take + int idx_obs[N_obs]; // Index of non-missing observations + int idx_mis[N_mis]; // Index of missing observations + real S_obs[N_obs]; // Observed score + + int N_test; // Number of predictions to evaluate (predictions are set to missing in the data) + int idx_test[N_test]; // Index of test observations + real S_test[N_test]; // Observed score for predictions + + int run; // Switch to evaluate the likelihood + int rep; // Switch to generate replications +} + +transformed data { + int N = N_obs + N_mis; // Total number of observations + int mult = N / N_pt; // Number of timepoints per patient (same by construction) + int start[N_pt]; // index of first observation for patient each patient + int end[N_pt]; // index of last observation for patient each patient + + for (k in 1:N_pt) { + start[k] = (k - 1) * mult + 1; + end[k] = k * mult; + } +} + +parameters { + real S_mis[N_mis]; // Missing S (useful for predictions (cf. bounds)) + real sigma; // Standard deviation +} + +transformed parameters { + real S_meas[N]; // Measured S + S_meas[idx_obs] = S_obs; + S_meas[idx_mis] = S_mis; +} + +model { + // implicit (bounded) uniform prior for S_mis + sigma / max_score ~ lognormal(-log(20), 0.5 * log(5)); // 95% CI is [.01, 0.25] * max_score + + for (k in 1:N_pt) { + to_vector(S_meas[(start[k] + 1):end[k]]) ~ normal(to_vector(S_meas[start[k]:(end[k] - 1)]), sigma); // Vectorise for efficiency + } + +} + +generated quantities { + real S_rep[N]; // Replications of S_meas[t + 1] given S_lat[t] + real S_pred[N_test]; // Predictive sample of S_test + real lpd[N_test]; // Log predictive density + + // Replications + if (rep == 1) { + for (k in 1:N_pt) { + S_rep[start[k]] = S_meas[start[k]]; + for (t in (start[k] + 1):end[k]) { + S_rep[t] = normal_lb_ub_rng(S_meas[t - 1], sigma, 0, max_score); + } + } + S_pred = S_rep[idx_test]; + } + + // Log predictive density + { + real Z; // Normalisation constant + for (i in 1:N_test) { + Z = normal_cdf(max_score, S_meas[idx_test[i]], sigma) - normal_cdf(0, S_meas[idx_test[i]], sigma); + lpd[i] = normal_lpdf(S_test[i] | S_meas[idx_test[i]], sigma) - log(Z); + } + } + +} diff --git a/Models/SSM.stan b/Models/SSM.stan new file mode 100644 index 0000000..3b106f4 --- /dev/null +++ b/Models/SSM.stan @@ -0,0 +1,146 @@ +functions { + real normal_lb_ub_rng(real mu, real sigma, real lb, real ub) { + // Sample from truncated normal of mean mu, standard deviation sigma, lower bound lb and upper bound ub + real u; + real p1; + real p2; + if (is_nan(mu) || is_inf(mu)) { + reject("normal_lb_ub_rng: mu must be finite; ", "found mu = ", mu); + } else if (is_nan(sigma) || is_inf(sigma) || sigma < 0) { + reject("normal_lb_ub_rng: sigma must be finite and non-negative; ", "found sigma = ", sigma); + } else if (lb >= ub) { + reject("normal_lb_ub_rng: lb must be less than ub; ", "found lb = ", lb, "and ub = ", ub); + } else { + p1 = normal_cdf(lb, mu, sigma); + p2 = normal_cdf(ub, mu, sigma); + if (p1 >= p2) { + reject("normal_lb_ub_rng: p1 >= p2. p1 = ", p1, " and p2 = ", p2, ". mu = ", mu, "and sigma = ", sigma); + } else { + u = uniform_rng(p1, p2); + } + } + return (sigma * inv_Phi(u)) + mu; // inverse cdf + } + + real soft_uniform_lpdf(real x, real lb, real ub) { + return(log(inv_logit(x - lb) - inv_logit(x - ub)) - log(ub - lb)); + } + +} + +data { + int N_obs; // Number of non-missing observations + int N_mis; // Number of missing observations + int N_pt; // Number of patients + real max_score; // Maximum value that the score can take + int idx_obs[N_obs]; // Index of non-missing observations + int idx_mis[N_mis]; // Index of missing observations + real S_obs[N_obs]; // Observed score + + int N_test; // Number of predictions to evaluate (predictions are set to missing in the data) + int idx_test[N_test]; // Index of test observations + real S_test[N_test]; // Observed score for predictions + + int run; // Switch to evaluate the likelihood + int rep; // Switch to generate replications +} + +transformed data { + int N = N_obs + N_mis; // Total number of observations + int mult = N / N_pt; // Number of timepoints per patient (same by construction) + int start[N_pt]; // index of first observation for patient each patient + int end[N_pt]; // index of last observation for patient each patient + + for (k in 1:N_pt) { + start[k] = (k - 1) * mult + 1; + end[k] = k * mult; + } +} + +parameters { + real S_lat_eta[N]; // Error term for S_lat, cf. non-centered parametrisation + + // Autocorrelation parameter + real alpha[N_pt]; // Autocorrelation parameter + real mu_alpha; // Population mean of alpha + real phi_alpha; // Population pseudo "sample size" of alpha + + // Population autoregression mean (for S_lat) + real mu_inf; // Population mean + real sigma_inf; // Population std + real eta_inf[N_pt]; // Error term + + real sigma_tot; // total noise std + real rho2; // proportion of the stochastic noise variance in the total noise variance +} + +transformed parameters { + real S_lat[N]; // Latent S + real S_inf[N_pt]; + real b[N_pt]; + real sigma_meas = sqrt(rho2) * sigma_tot; // measurement noise std + real sigma_lat = sqrt(1 - rho2) * sigma_tot; // stochastic noise std + real MDC = 1.96 * sigma_meas; // minimum detectable change (95% level) + + for (k in 1:N_pt) { + S_inf[k] = mu_inf + sigma_inf * eta_inf[k]; + b[k] = S_inf[k] * (1 - alpha[k]); + } + + // Non-centered parametrisation + for (k in 1:N_pt) { + S_lat[start[k]] = max_score * (0.5 + 0.25 * S_lat_eta[start[k]]); // prior covering the full range of the score + for (t in (start[k] + 1):end[k]) { + S_lat[t] = alpha[k] * S_lat[t - 1] + b[k] + sigma_lat * S_lat_eta[t]; + } + } + +} + +model { + S_lat_eta ~ std_normal(); + eta_inf ~ std_normal(); + mu_inf / max_score ~ normal(0.5, 0.25); + sigma_inf / max_score ~ normal(0, 0.125); + sigma_tot / max_score ~ lognormal(-log(20), 0.5 * log(5)); // 95% CI is [.01, 0.25] * max_score + rho2 ~ beta(4, 2); // process noise expected to be small compared to the measurement noise + + mu_alpha ~ beta(2, 2); + phi_alpha ~ lognormal(1 * log(10), 0.5 * log(10)); // mass between 1 and 100 + alpha ~ beta(mu_alpha * phi_alpha, (1 - mu_alpha) * phi_alpha); + + for (i in 1:N) { + // S_lat[i] ~ soft_uniform(-1, max_score + 1); + S_lat[i] ~ soft_uniform(-.01 * max_score, 1.01 * max_score); + } + + if (run == 1) { + for (i in 1:N_obs) { + S_obs[i] ~ normal(S_lat[idx_obs[i]], sigma_meas) T[0, max_score]; + } + } +} + +generated quantities { + real S_rep[N]; // Replications of S_meas[t + 1] given S_lat[t] + real S_pred[N_test]; // Predictive sample of S_test + real lpd[N_test]; // Log predictive density + + // Replications + if (rep == 1) { + for (i in 1:N) { + S_rep[i] = normal_lb_ub_rng(S_lat[i], sigma_meas, 0, max_score); + } + S_pred = S_rep[idx_test]; + } + + // Log predictive density + { + real Z; // Normalisation constant + for (i in 1:N_test) { + Z = normal_cdf(max_score, S_lat[idx_test[i]], sigma_meas) - normal_cdf(0, S_lat[idx_test[i]], sigma_meas); + lpd[i] = normal_lpdf(S_test[i] | S_lat[idx_test[i]], sigma_meas) - log(Z); + } + } + +} diff --git a/Models/SSMX.stan b/Models/SSMX.stan new file mode 100644 index 0000000..aa3f8b5 --- /dev/null +++ b/Models/SSMX.stan @@ -0,0 +1,212 @@ +// State space model with eXogeneous variables (covariates) following a regularised horseshoe prior +// 2 parametrisations are implemented for the Horseshoe +// +// The prior for tau (global shrinkage) is a function of the expected number of "important" features... +// ... and a function of the number of observations (here N_pt, cf. scale_global) +// The slab prior (i.e. for non-zero coefficients) is a scaled-inverse chi-squared distribution... +// ... where the tail is similar to a Student t distribution with slab_df degree of freedom + +functions { + real normal_lb_ub_rng(real mu, real sigma, real lb, real ub) { + // Sample from truncated normal of mean mu, standard deviation sigma, lower bound lb and upper bound ub + real u; + real p1; + real p2; + if (is_nan(mu) || is_inf(mu)) { + reject("normal_lb_ub_rng: mu must be finite; ", "found mu = ", mu); + } else if (is_nan(sigma) || is_inf(sigma) || sigma < 0) { + reject("normal_lb_ub_rng: sigma must be finite and non-negative; ", "found sigma = ", sigma); + } else if (lb >= ub) { + reject("normal_lb_ub_rng: lb must be less than ub; ", "found lb = ", lb, "and ub = ", ub); + } else { + p1 = normal_cdf(lb, mu, sigma); + p2 = normal_cdf(ub, mu, sigma); + if (p1 >= p2) { + reject("normal_lb_ub_rng: p1 >= p2. p1 = ", p1, " and p2 = ", p2, ". mu = ", mu, "and sigma = ", sigma); + } else { + u = uniform_rng(p1, p2); + } + } + return (sigma * inv_Phi(u)) + mu; // inverse cdf + } + + real soft_uniform_lpdf(real x, real lb, real ub) { + return(log(inv_logit(x - lb) - inv_logit(x - ub)) - log(ub - lb)); + } + +} + +data { + int N_obs; // Number of non-missing observations + int N_mis; // Number of missing observations + int N_pt; // Number of patients + real max_score; // Maximum value that the score can take + int idx_obs[N_obs]; // Index of non-missing observations + int idx_mis[N_mis]; // Index of missing observations + real S_obs[N_obs]; // Observed score + + int N_test; // Number of predictions to evaluate (predictions are set to missing in the data) + int idx_test[N_test]; // Index of test observations + real S_test[N_test]; // Observed score for predictions + + real p0; // Horseshoe: guess on the number of non-zero parameters + real slab_scale; // Horseshoe: slab scale + real slab_df; // Horseshoe: slab degrees of freedom (1 for cauchy, more for closer to gaussian) + int N_cov; // Number of covariates at t0 + matrix[N_pt, N_cov] X_cov; // Matrix of covariates at t0 + + int parametrisation; // Switch to change parametrisation of the horseshoe + int run; // Switch to evaluate the likelihood + int rep; // Switch to generate replications +} + +transformed data { + int N = N_obs + N_mis; // Total number of observations + int mult = N / N_pt; // Number of timepoints per patient (same by construction) + int start[N_pt]; // index of first observation for patient each patient + int end[N_pt]; // index of last observation for patient each patient + + real scale_global = p0 / (N_cov - p0) / sqrt(N_pt); // Horseshoe: scale for tau + real nu_local = 1; // Horseshoe: degree of freedom for lambdas prior (for horseshoe it's 1 (cauchy)) + real nu_global = 1; // Horseshoe: degree of freedom for tau prior (1 is cauchy) + + for (k in 1:N_pt) { + start[k] = (k - 1) * mult + 1; + end[k] = k * mult; + } +} + +parameters { + real S_mis[N_mis]; // Missing S (useful for predictions (cf. bounds)) + real S_lat_eta[N]; // cf. non-centered parametrisation + + // Autocorrelation parameter + real alpha[N_pt]; // Autocorrelation parameter + real mu_alpha; // Population mean of alpha + real phi_alpha; // Population pseudo "sample size" of alpha + + // Population autoregression mean (for S_lat) + real mu_inf; // Population mean + real sigma_inf; // Population std + real eta_inf[N_pt]; // Error term + + real sigma_tot; // total noise std + real rho2; // proportion of the stochastic noise variance in the total noise variance + + // Horseshoe + vector[N_cov] z; // Horseshoe noise (non-centered) + real caux; // Horseshoe: cauchy noise for c + // Horseshoe parametrisation 0 + real tau0[1 - parametrisation]; // Horseshoe, parametrisation 0: global shrinkage parameter + vector[N_cov] lambda0[1 - parametrisation]; // Horseshoe, parametrisation 0: local shrinkage parameter + // Horseshoe parametrisation 1 + real aux1_global[parametrisation]; // Horseshoe: cf. parametrisation of tau + real aux2_global[parametrisation]; // Horseshoe: cf. parametrisation of tau + vector[N_cov] aux1_local[parametrisation]; // Horseshoe: cf. parametrisation of lambdas + vector[N_cov] aux2_local[parametrisation]; // Horseshoe: cf. parametrisation of lambdas + +} + +transformed parameters { + real S_lat[N]; // Latent S + real S_inf[N_pt]; + real b[N_pt]; + real sigma_meas = sqrt(rho2) * sigma_tot; // measurement noise std + real sigma_lat = sqrt(1 - rho2) * sigma_tot; // stochastic noise std + real MDC = 1.96 * sigma_meas; // minimum detectable change (95% level) + + // Horseshoe + vector[N_cov] lambda_tilde; // Horseshoe: truncated local shrinkage parameter + real c; // Horseshoe: slab scale + vector[N_cov] beta; // Horseshoe: regularised coefficients + vector[N_pt] f; // Horseshoe: latent function values (x*beta) + real tau; // Horseshoe: global shrinkage parameter + vector[N_cov] lambda; // Horseshoe: local shrinkage parameter + + for (k in 1:N_pt) { + S_inf[k] = mu_inf + sigma_inf * eta_inf[k]; + b[k] = S_inf[k] * (1 - alpha[k]); + } + + // Horseshoe + if (parametrisation == 0) { + tau = tau0[1]; + lambda = lambda0[1]; + } else { + tau = aux1_global[1] * sqrt(aux2_global[1]) * scale_global * sigma_lat; + lambda = aux1_local[1] .* sqrt(aux2_local[1]); + } + c = slab_scale * sqrt(caux); + lambda_tilde = sqrt(c^2 * square(lambda) ./ (c^2 + tau^2 * square(lambda))); + beta = z .* lambda_tilde * tau; + f = X_cov * beta; + + for (k in 1:N_pt) { + S_lat[start[k]] = max_score * (0.5 + 0.25 * S_lat_eta[start[k]]); // prior covering the full range of the score + for (t in (start[k] + 1):end[k]) { + S_lat[t] = alpha[k] * S_lat[t - 1] + b[k] + f[k] + sigma_lat * S_lat_eta[t]; + } + } +} + +model { + S_lat_eta ~ std_normal(); + eta_inf ~ std_normal(); + mu_inf / max_score ~ normal(0.5, 0.25); + sigma_inf / max_score ~ normal(0, 0.125); + sigma_tot / max_score ~ lognormal(-log(20), 0.5 * log(5)); // 95% CI is [.01, 0.25] * max_score + rho2 ~ beta(4, 2); // process noise expected to be small compared to the measurement noise + + mu_alpha ~ beta(2, 2); + phi_alpha ~ lognormal(1 * log(10), 0.5 * log(10)); // mass between 1 and 100 + alpha ~ beta(mu_alpha * phi_alpha, (1 - mu_alpha) * phi_alpha); + + for (i in 1:N) { + // S_lat[i] ~ soft_uniform(-1, max_score + 1); + S_lat[i] ~ soft_uniform(-.01 * max_score, 1.01 * max_score); + } + + if (run) { + for (i in 1:N_obs) { + S_obs[i] ~ normal(S_lat[idx_obs[i]], sigma_meas) T[0, max_score]; + } + } + + // + if (parametrisation == 0) { + lambda0[1] ~ student_t(nu_local, 0, 1); + tau0[1] ~ student_t(nu_global, 0, scale_global * sigma_lat); + } else { + aux1_local[1] ~ std_normal(); + aux2_local[1] ~ inv_gamma (0.5 * nu_local, 0.5 * nu_local); + aux1_global[1] ~ std_normal(); + aux2_global[1] ~ inv_gamma (0.5 * nu_global , 0.5 * nu_global ); + } + z ~ std_normal(); + caux ~ inv_gamma (0.5 * slab_df , 0.5 * slab_df); +} + +generated quantities { + real S_rep[N]; // Replications of S_meas[t + 1] given S_lat[t] + real S_pred[N_test]; // Predictive sample of S_test + real lpd[N_test]; // Log predictive density + + // Replications + if (rep == 1) { + for (i in 1:N) { + S_rep[i] = normal_lb_ub_rng(S_lat[i], sigma_meas, 0, max_score); + } + S_pred = S_rep[idx_test]; + } + + // Log predictive density + { + real Z; // Normalisation constant + for (i in 1:N_test) { + Z = normal_cdf(max_score, S_lat[idx_test[i]], sigma_meas) - normal_cdf(0, S_lat[idx_test[i]], sigma_meas); + lpd[i] = normal_lpdf(S_test[i] | S_lat[idx_test[i]], sigma_meas) - log(Z); + } + } + + +} diff --git a/README.md b/README.md new file mode 100644 index 0000000..3837d3e --- /dev/null +++ b/README.md @@ -0,0 +1,33 @@ +# Predicting eczema severity using serum biomarkers + +This repository contains the code for the article by [**Hurault et al. (preprint), "Can serum biomarkers predict the outcome of systemic therapy for atopic dermatitis?"**](). + +The code is written in the R language for statistical computing and the models using the probabilistic programming language [Stan](https://mc-stan.org/). + +## File structure + +The dataset used in this study is not available according to our data sharing agreement. +During the analysis, the dataset is loaded from a proprietary package `TanakaData` which includes the raw files as well as data processing functions. + +Utility functions used within the scripts are available in [`functions.R`](functions.R). +In addition, we used functions from Guillem Hurault's personal package, [HuraultMisc](https://github.com/ghurault/HuraultMisc). + +The `Models` folder contains the different Stan models developed in this project: + +- [`RW.stan`](Models/RW.stan): the random walk model, one of the reference model. +- [`AR.stan`](Models/AR.stan): the autoregressive model, one of the reference model. +- [`MixedAR.stan`](Models/MixedAR.stan): the mixed effect autoregressive model, one of the reference model. +- [`SSM.stan`](Models/SSM.stan): the Bayesian state space model without covariates. +- [`SSMX.stan`](Models/SSMX.stan): the Bayesian state space model with covariates (following a horseshoe prior). + +The modelling workflow in separated into different scripts: + +- [`check_models.R`](check_models.R): Conduct prior predictive checks and fake data check of the different models. +This script is notably useful to simulate data that resembles the one we used. +- [`fit_models.R`](fit_models.R): Fit the different models to real data, perform diagnostics and posterior predictive checks. +- [`run_validation.R`](run_validation.R): Run the validation process (K-fold cross-validation and forward chaining). +- [`check_performance.R`](check_performance.R): Analyse validation results. + +## License + +This open source version of this project is licensed under the GPLv3 license, which can be seen in the [LICENSE](LICENSE) file. diff --git a/check_models.R b/check_models.R new file mode 100644 index 0000000..fd9c6e0 --- /dev/null +++ b/check_models.R @@ -0,0 +1,358 @@ +# Notes ------------------------------------------------------------------- + +# Master file to do: +# - Prior predictive checks +# - Fake data simulation (from priors) +# - Fit fake data (to see if we can recover parameters) +# For different models: +# - RW: random walk model +# - AR: autoregressive model (order 1, fixed effects) +# - MixedAR: mixed effect autoregressive model (order 1) +# - SSM: Hidden Markov Model (Gaussian measurement error and mixed autoregressive model for the latent dynamic) +# - SSMX: Hidden Markov Model with eXogeneous variables, following a horseshoe prior (two parametrisations for the horseshoe) + +# NB: for AR/SSM, identifiability issue when alpha -> 1 (we can't determine S_inf); make sure draw is OK + +# Initialisation ---------------------------------------------------------- + +rm(list = ls()) # Clear Workspace (but better to restart session) + +set.seed(1744834965) # Reproducibility (Stan use a different seed) + +library(tidyverse) +library(cowplot) +library(HuraultMisc) # Functions shared across projects +library(rstan) +rstan_options(auto_write = TRUE) # Save compiled model +options(mc.cores = parallel::detectCores()) # Parallel computing +source("functions.R") # Additional functions + +#### OPTIONS +score <- "EASI" +model_name <- "SSM" +run_prior <- FALSE # prior distribution +run_fake <- FALSE # fit fake data +n_it <- 2000 +n_chains <- 4 +#### + +score <- match.arg(score, c("EASI", "oSCORAD", "SCORAD", "POEM")) +model_name <- match.arg(model_name, c("RW", "AR", "MixedAR", "SSM", "SSMX")) + +# Results files +stan_code <- file.path("Models", paste0(model_name, ".stan")) +prior_file <- file.path("Results", paste0("prior_", score, "_", model_name, ".rds")) +par0_file <- file.path("Results", paste0("par0_", score, "_", model_name, ".rds")) +fake_file <- file.path("Results", paste0("fake_", score, "_", model_name, ".rds")) + +# Model parameters +if (model_name == "RW") { + param_pop <- c("sigma") + param_ind <- c() +} +if (model_name == "AR") { + param_pop <- c("sigma", "alpha", "S_inf", "b") + param_ind <- c() +} +if (model_name == "MixedAR") { + param_pop <- c("sigma", + "mu_alpha", "phi_alpha", + "mu_inf", "sigma_inf") + param_ind <- c("alpha", "S_inf", "b") +} +if (model_name == "SSM") { + param_pop <- c("sigma_tot", "rho2", "sigma_lat", "sigma_meas", "MDC", + "mu_alpha", "phi_alpha", + "mu_inf", "sigma_inf") + param_ind <- c("alpha", "S_inf", "b") +} +if (model_name == "SSMX") { + param_pop <- c("sigma_tot", "rho2", "sigma_lat", "sigma_meas", "MDC", + "mu_alpha", "phi_alpha", + "mu_inf", "sigma_inf", + "beta") + param_ind <- c("alpha", "S_inf", "b", "f") +} + +if (any(run_prior, run_fake)) { + compiled_model <- stan_model(stan_code) +} + +score_char <- data.frame(Score = c("SCORAD", "oSCORAD", "EASI", "POEM"), + Range = c(103, 83, 72, 28), + MCID = c(8.7, 8.2, 6.6, 3.4)) %>% + filter(Score == score) + +# Data -------------------------------------------------------------------- +# Load data to generate fake data with similar characteristics +# Also to extract the matrix of biomarkers to avoid generating one + +l <- load_dataset() +dp <- l$patient_data +dt <- l$severity_data +pt <- unique(dt[["Patient"]]) +bio <- as.matrix(dp[, colnames(dp) != "Patient"]) # matrix of biomarkers (including treatment, age, sex...) + +n_pt <- length(pt) +n_dur <- 24 / 2 + 1 + +# Prior predictive check ------------------------------------------------------ + +param_obs <- c("S_rep") +param <- c(param_pop, param_ind, param_obs) + +data_prior <- list( + N_obs = n_pt, + N_mis = n_pt * (n_dur - 1), + N_pt = n_pt, + max_score = score_char$Range, + idx_obs = ((1:n_pt) - 1) * (n_dur) + 1, + idx_mis = setdiff(1:(n_pt * n_dur), ((1:n_pt) - 1) * (n_dur) + 1), + S_obs = runif(n_pt, 0, score_char$Range), + + N_test = 0, + idx_test = vector(), + S_test = vector(), + + # For horseshoe + p0 = 5, + slab_scale = 1, + slab_df = 5, + N_cov = ncol(bio), + X_cov = bio, + parametrisation = 1, + + run = 0, + rep = 1 +) + +if (run_prior) { + fit_prior <- sampling(compiled_model, + data = data_prior, + pars = param, + iter = n_it, + chains = n_chains, + control = list(adapt_delta = 0.9)) + saveRDS(fit_prior, file = prior_file) + par0 <- extract_parameters(fit_prior, param, param_ind, param_obs, pt, data_prior) # save for comparing prior to posterior + saveRDS(par0, file = par0_file) +} else { + fit_prior <- readRDS(prior_file) + par0 <- readRDS(par0_file) +} + +# If divergent transitions +# Probably caused by the fact that prior but probability on regions of high curvature (e.g. hierarchical sigma close to 0) +# In that case shouldn't be a problem +# But check anyway if chains are not getting stuck... + +if (FALSE) { + check_hmc_diagnostics(fit_prior) + # pairs(fit_prior, pars = param_pop) + # plot(fit_prior, pars = param_pop, plotfun = "trace") + + # Distribution of parameters + plot(fit_prior, pars = setdiff(param_pop, "beta"), plotfun = "hist") + if (length(param_ind) > 0) {plot(fit_prior, pars = paste0(param_ind, "[1]"), plotfun = "hist")} + if (model_name == "SSMX") {plot(fit_prior, pars = "beta[1]", plotfun = "hist")} + + # Predictive distribution + lapply(sample(pt, 4), + function(pid) { + library(ggplot2) + ggplot(data = subset(par0, Variable == "S_rep" & Patient == pid), + aes(x = Week, ymin = pmax(`5%`, 0), ymax = pmin(`95%`, score_char$Range))) + + geom_ribbon(alpha = 0.5) + + scale_y_continuous(limits = c(0, score_char$Range)) + + theme_bw(base_size = 15) + + theme(panel.grid.minor.y = element_blank()) + }) %>% + plot_grid(plotlist = ., ncol = 2) + + # score-50 (e.g. EASI-50) + apply(rstan::extract(fit_prior, pars = "S_rep")[[1]], + 1, + function(x) { + k <- with(data_prior, 1:N_pt) + mult <- with(data_prior, (N_obs + N_mis) / N_pt) + start_k <- (k - 1) * mult + 1 + end_k <- k * mult + mean(x[end_k] < 0.5 * x[start_k]) + }) %>% + hist(., breaks = with(data_prior, (0:N_pt) / N_pt), probability = TRUE, col = "#B2001D", main = "", xlab = "EASI-50") + +} + +# Generate fake data --------------------------------------------------------------- + +# Take one draw (different draws corresponds to different a priori pattern in the data) +draw <- 35 + +# True parameters +true_param <- extract_parameters_from_draw(fit_prior, c(param_pop, param_ind), draw) +true_param[["Patient"]] <- pt[true_param[["Index"]]] + +# Dataframe +sim <- rstan::extract(fit_prior, pars = "S_rep")[[1]][draw, ] +# sim <- round(sim) # Round for observed severity +fd <- data.frame(Patient = rep(pt, each = n_dur), + Week = 2 * rep(0:(n_dur - 1), n_pt), + S = sim) +fd$S[!(fd$Week %in% c(0, 2, 4, 8, 12, 24))] <- NA + +# Plot trajectories +lapply(sample(pt, 4), + function(pid) { + library(ggplot2) + ggplot(data = subset(fd, Patient == pid), + aes(x = Week)) + + geom_point(aes(y = S)) + + scale_y_continuous(limits = c(0, score_char$Range)) + + labs(title = paste("Patient", pid), + subtitle = paste("alpha = ", signif(true_param %>% filter(Parameter == "alpha" & Patient == pid) %>% pull(Value), 2))) + + theme_bw(base_size = 20) + + theme(panel.grid.minor.y = element_blank()) + }) %>% + plot_grid(plotlist = ., ncol = 2) + +if (model_name == "SSMX") { + true_param %>% + filter(Parameter == "beta") %>% + pull(Value) %>% + sort() %>% + barplot() +} + +# Fit fake data --------------------------------------------------------- + +data_fake <- with(fd, + list( + N_obs = sum(!is.na(S)), + N_mis = sum(is.na(S)), + N_pt = length(unique(Patient)), + max_score = score_char$Range, + idx_obs = which(!is.na(S)), + idx_mis = which(is.na(S)), + S_obs = na.omit(S), + + N_test = 0, + idx_test = vector(), + S_test = vector(), + + # For horseshoe + p0 = 5, + slab_scale = 1, + slab_df = 5, + N_cov = ncol(bio), + X_cov = bio, + parametrisation = 0, + + run = 1, + rep = 1 + )) + +if (run_fake) { + fit_fake <- sampling(compiled_model, + data = data_fake, + pars = param, + iter = n_it, + chains = n_chains, + control = list(adapt_delta = case_when(model_name %in% c("SSM", "SSMX") ~ 0.99, + TRUE ~ 0.9))) + saveRDS(fit_fake, file = fake_file) +} else { + fit_fake <- readRDS(fake_file) +} + +# Fake data check ---------------------------------------------------------- + +if (FALSE) { + + check_hmc_diagnostics(fit_fake) + + pairs(fit_fake, pars = setdiff(param_pop, "beta")) + # pairs(fit_fake, pars = paste0("beta[", 1:5, "]")) + # print(fit_fake, pars = param_pop) + + # Sensitivity to prior + par <- extract_parameters(fit_fake, param, param_ind, param_obs, pt, data_fake) + HuraultMisc::check_model_sensitivity(par0, par, param) + + ## Can we recover known parameters? + tmp <- HuraultMisc::summary_statistics(fit_fake, param) %>% + inner_join(true_param, by = c("Variable" = "Parameter", "Index")) %>% + rename(True = Value) + tmp$Patient <- factor(tmp$Patient, levels = pt) + + # Population parameters + tmp %>% + filter(Variable %in% setdiff(param_pop, "beta")) %>% + ggplot(aes(x = Variable)) + + geom_pointrange(aes(y = Mean, ymin = `5%`, ymax = `95%`)) + + geom_point(aes(y = True), col = "#E69F00", size = 2) + + coord_flip() + + labs(x = "", y = "Estimate") + + theme_bw(base_size = 20) + + # Patient parameters + lapply(param_ind, + function(par_name) { + tmp %>% + filter(Variable == par_name) %>% + mutate(Patient = fct_reorder(Patient, True)) %>% + ggplot(aes(x = Patient)) + + geom_pointrange(aes(y = Mean, ymin = `5%`, ymax = `95%`)) + + geom_point(aes(y = True), col = "#E69F00", size = 2) + + coord_flip() + + labs(y = par_name) + + theme_bw(base_size = 15) + }) %>% + plot_grid(plotlist = ., nrow = 1) + + if (model_name == "SSMX") { + # beta + tmp %>% + filter(Variable == "beta") %>% + mutate(Index = fct_reorder(factor(Index), True)) %>% + ggplot(aes(x = Index)) + + geom_pointrange(aes(y = Mean, ymin = `5%`, ymax = `95%`)) + + geom_point(aes(y = True), col = "#E69F00", size = 2) + + coord_flip() + + labs(x = "", y = "Estimate") + + theme_bw(base_size = 20) + + # Coverage of beta + HuraultMisc::plot_coverage(extract(fit_fake, pars = "beta")[[1]], + subset(true_param, Parameter == "beta")$Value) + } + + # Coverage of the posterior predictive distribution + yrep_fake <- rstan::extract(fit_fake, pars = "S_rep")[[1]] + HuraultMisc::plot_coverage(yrep_fake, fd[["S"]]) + + # Coverage of patient-dependent parameters + lapply(param_ind, + function(x) { + HuraultMisc::plot_coverage(rstan::extract(fit_fake, pars = x)[[1]], + subset(true_param, Parameter == x)$Value) + }) %>% + plot_grid(plotlist = .) + + ## Posterior predictive checks + ssi <- full_join(HuraultMisc::extract_distribution(fit_fake, "S_rep", type = "hdi", CI_level = seq(0.1, 0.9, 0.1)), + get_index(pt, data_fake), + by = "Index") + pl <- lapply(sample(pt, 4), + function(pid) { + PPC_fanchart(ssi, fd %>% rename(y = "S"), pid, score_char$Range) + + labs(y = score, title = paste("Patient", pid)) + }) + plot_grid(get_legend(pl[[1]] + theme(legend.position = "top")), + plot_grid(plotlist = lapply(pl, + function(p) { + p + theme(legend.position = "none") + }), + nrow = 2, labels = "AUTO"), + nrow = 2, rel_heights = c(.1, .9)) + +} diff --git a/check_performance.R b/check_performance.R new file mode 100644 index 0000000..99b76b6 --- /dev/null +++ b/check_performance.R @@ -0,0 +1,252 @@ +# Notes ------------------------------------------------------------------- + +# Analyse the predictive performance of the different models +# - Performance one-step-ahead +# - Raw performance estimates +# - Learning curve from meta-model (controlling for prediction horizon) + +# Initialisation ---------------------------------------------------------- + +rm(list = ls()) # Clear Workspace (but better to restart session) + +library(HuraultMisc) # Functions shared across projects +library(tidyverse) +library(cowplot) +source("functions.R") # Additional functions + +#### OPTIONS +score <- "EASI" +metric <- "lpd" +model_names <- c("Uniform", "RW", "AR", "MixedAR", "SSM") # "SSMX" +#### + +score <- match.arg(score, c("EASI", "SCORAD", "oSCORAD", "POEM")) +metric <- match.arg(metric, c("lpd", "CRPS", "ProbAccuracy", "Accuracy", "RMSE")) +stopifnot(all(model_names %in% c("Uniform", "RW", "AR", "MixedAR", "SSM", "SSMX"))) +res_files <- file.path("Results", paste0("val_", score, "_", model_names, ".rds")) +stopifnot(all(file.exists(res_files))) + +score_char <- data.frame(Score = c("SCORAD", "oSCORAD", "EASI", "POEM"), + Range = c(103, 83, 72, 28), + MCID = c(8.7, 8.2, 6.6, 3.4)) %>% + filter(Score == score) + +# Process results --------------------------------------------------------- + +perf <- do.call(bind_rows, + lapply(1:length(model_names), + function(i) { + res <- readRDS(res_files[i]) + + # Probabilistic accuracy + if (model_names[i] == "Uniform") { + ub <- pmin(score_char$Range, res$S + score_char$MCID) + lb <- pmax(0, res$S - score_char$MCID) + acc <- (ub - lb) / score_char$Range + } else { + acc <- sapply(1:nrow(res), function(j) { + mean(abs(res$S[j] - res$Samples[j][[1]]) < score_char$MCID) + }) + } + + res %>% + mutate(SquaredError = (S - Mean_pred)^2, + ProbAccuracy = acc, + Accuracy = as.numeric(abs(S - Mean_pred) < score_char$MCID)) %>% + mutate(Model = model_names[i]) %>% + select(-Samples) + })) %>% + mutate(Model = factor(Model, levels = rev(model_names))) + +# One-steap-ahead performance --------------------------------------------- +# Prediction for the next clinical visits +# Prediction horizon differ though + +# Select one-step-ahead prediction +cv_osa <- perf %>% + group_by(Model, TrainingWeek) %>% + filter(TestingWeek == min(TestingWeek)) %>% + ungroup() +# Compute performance for each fold +cv_osa <- cv_osa %>% + group_by(Model, Fold) %>% + summarise(lpd = mean(lpd), + CRPS = mean(CRPS), + ProbAccuracy = mean(ProbAccuracy), + Accuracy = mean(Accuracy), + RMSE = sqrt(mean(SquaredError))) %>% + ungroup() +# Average performance across fold +cv_osa <- cv_osa %>% + pivot_longer(cols = all_of(c("lpd", "CRPS", "ProbAccuracy", "Accuracy", "RMSE")), names_to = "Metric", values_to = "Value") %>% + group_by(Model, Metric) %>% + summarise(Mean = mean(Value), SE = sd(Value) / sqrt(n())) + +p1 <- cv_osa %>% + filter(Metric == metric) %>% + ggplot(aes(x = Model, y = Mean, ymin = Mean - SE, ymax = Mean + SE)) + + geom_pointrange() + + coord_flip() + + labs(x = "", y = metric) + + theme_bw(base_size = 15) +if (metric == "Accuracy") { + p1 <- p1 + scale_y_continuous(limits = c(0, 1)) +} +if (metric %in% c("CRPS", "RMSE")) { + p1 <- p1 + scale_y_continuous(limits = c(0, NA)) +} +p1 + +# Raw performance estimates -------------------------------------------------------- + +# Compute performance for each fold (and each condition) +cv <- perf %>% + group_by(Model, TrainingWeek, TestingWeek, Fold) %>% + summarise(lpd = mean(lpd), + CRPS = mean(CRPS), + ProbAccuracy = mean(ProbAccuracy), + Accuracy = mean(Accuracy), + RMSE = sqrt(mean(SquaredError))) %>% + ungroup() +# Average performance across fold +cv <- cv %>% + pivot_longer(cols = all_of(c("lpd", "CRPS", "ProbAccuracy", "Accuracy", "RMSE")), names_to = "Metric", values_to = "Value") %>% + group_by(Model, TrainingWeek, TestingWeek, Metric) %>% + summarise(Mean = mean(Value), SE = sd(Value) / sqrt(n())) %>% + ungroup() +# Compute prediction horizon +cv <- cv %>% + mutate(Horizon = TestingWeek - TrainingWeek) + +# Performance as a function of prediction horizon, for each training week (Model in colour) +# Alternatively plot as a function of prediction horizon, for each model (training week in colour); harder for model comparison +p2 <- cv %>% + filter(Metric == metric) %>% + ggplot(aes(x = Horizon, y = Mean, ymin = Mean - SE, ymax = Mean + SE, colour = factor(Model))) + + facet_grid(rows = vars(TrainingWeek)) + + # ggplot(aes(x = Horizon, y = Mean, ymin = Mean - SE, ymax = Mean + SE, colour = factor(TrainingWeek))) + + # facet_grid(cols = vars(Model)) + + geom_pointrange() + + geom_line() + + scale_color_manual(values = cbbPalette) + + labs(x = "Prediction Horizon (weeks)", + y = metric, + colour = "") + + scale_x_continuous(breaks = sort(unique(cv[["Horizon"]]))) + + theme_bw(base_size = 15) + + theme(panel.grid.minor.x = element_blank()) +if (metric == "Accuracy") { + p2 <- p2 + scale_y_continuous(limits = c(0, 1)) +} +if (metric %in% c("CRPS", "RMSE")) { + p2 <- p2 + scale_y_continuous(limits = c(0, NA)) +} +p2 + +if (FALSE) { + ggsave(file.path("Plots", paste0(score, "_", metric, "_rawperf.jpg")), + width = 10, height = 10, units = "cm", dpi = 300, scale = 2) +} + +# Learning curves from meta-model ----------------------------------------- + +cv <- perf %>% + group_by(Model, TrainingWeek, TestingWeek, Fold) %>% + summarise(lpd = mean(lpd), + CRPS = mean(CRPS), + ProbAccuracy = mean(ProbAccuracy), + Accuracy = mean(Accuracy), + RMSE = sqrt(mean(SquaredError))) %>% + ungroup() %>% + mutate(Horizon = TestingWeek - TrainingWeek) + +estimate_performance <- function(df, metric, adjust_horizon = TRUE) { + # Estimate learning curves with a meta-model (linear regression) + # + # Args: + # df: Dataframe of performance metric per fold + # metric: Metric name + # adjust_horizon: Whether to adjust for prediciton horizon in the model + # + # Returns: + # Dataframe with columns: TrainingWeek, Horizon, Mean, SE, Variable + + stopifnot(is.data.frame(df), + is.character(metric), + all(c("TrainingWeek", "Horizon", "Fold", metric) %in% colnames(df)), + is.logical(adjust_horizon)) + + f <- paste0(metric, " ~ factor(TrainingWeek) + 0") + if (adjust_horizon) { + f <- paste0(f, " + Horizon") + } + f <- formula(f) + + meta_model <- glm(f, + family = "gaussian", + data = df) + lm_fit <- data.frame(TrainingWeek = c(0, 2, 4, 8, 12), Horizon = 2) + pred <- predict(meta_model, newdata = lm_fit, se.fit = TRUE) + lm_fit <- lm_fit %>% + mutate(Mean = pred$fit, + SE = pred$se.fit, + Variable = "Fit") + + if (adjust_horizon) { + s <- summary(meta_model) + lm_horizon <- data.frame(Mean = s$coefficients["Horizon", "Estimate"], + SE = s$coefficients["Horizon", "Std. Error"], + Variable = "Horizon") + } else { + lm_horizon <- data.frame(Mean = 0, + SE = 0, + Variable = "Horizon") + } + + bind_rows(lm_fit, lm_horizon) %>% + mutate(Metric = metric) +} + +fit_perf <- do.call(rbind, + lapply(model_names, + function(x) { + cv %>% + filter(Model == x) %>% + estimate_performance(., metric, adjust_horizon = !((x == "Uniform") & (metric == "lpd"))) %>% + mutate(Model = x) + })) %>% + mutate(Model = factor(Model, levels = rev(model_names))) + +p3 <- fit_perf %>% + filter(Variable == "Fit") %>% + ggplot(aes(x = TrainingWeek, y = Mean, ymin = Mean - SE, ymax = Mean + SE, colour = Model, fill = Model)) + + geom_line() + + # geom_pointrange(position = position_dodge(width = 1)) + + geom_point() + + geom_ribbon(alpha = 0.5) + + scale_colour_manual(values = cbbPalette) + + scale_fill_manual(values = cbbPalette) + + scale_x_continuous(breaks = sort(unique(fit_perf[["TrainingWeek"]]))) + + labs(x = "Training week", y = metric, colour = "", fill = "") + + theme_bw(base_size = 15) + + theme(panel.grid.minor.x = element_blank()) + +p4 <- fit_perf %>% + filter(Variable == "Horizon") %>% + ggplot(aes(x = Model, y = Mean, ymin = Mean - SE, ymax = Mean + SE, colour = Model)) + + geom_pointrange(size = 1.5) + + scale_colour_manual(values = cbbPalette) + + labs(x = "", y = paste0(metric, " change with increasing \nprediction horizon of 2 weeks"), colour = "") + + theme_bw(base_size = 15) + + theme(legend.position = "bottom", + axis.text.x = element_blank(), + axis.ticks.x = element_blank()) + +plot_grid(p3 + theme(legend.position = "none"), + p4 + theme(legend.position = "none"), + get_legend(p3 + theme(legend.position = "right")), + nrow = 1, rel_widths = c(4, 3, 1), labels = c("A", "B", "")) +if (FALSE) { + ggsave(file.path("Plots", paste0(score, "_", "metric", "_metaperf.jpg")), + width = 13, height = 8, units = "cm", dpi = 300, scale = 2) +} diff --git a/fit_models.R b/fit_models.R new file mode 100644 index 0000000..b6a2e3e --- /dev/null +++ b/fit_models.R @@ -0,0 +1,227 @@ +# Notes ------------------------------------------------------------------- + +# Master file to fit models for the different scores: +# - RW: random walk model +# - AR: autoregressive model (order 1, fixed effects) +# - MixedAR: mixed effect autoregressive model (order 1) +# - SSM: Hidden Markov Model (Gaussian measurement error and mixed autoregressive model for the latent dynamic) +# - SSMX: Hidden Markov Model with eXogeneous variables, following a horseshoe prior (two parametrisations for the horseshoe) + +# Initialisation ---------------------------------------------------------- + +rm(list = ls()) # Clear Workspace (but better to restart session) + +library(HuraultMisc) # Functions shared across projects +library(tidyverse) +library(cowplot) +library(rstan) +rstan_options(auto_write = TRUE) # Save compiled model +options(mc.cores = parallel::detectCores()) # Parallel computing +source("functions.R") # Additional functions + +seed <- 462528635 # seed also used for stan +set.seed(seed) + +#### OPTIONS +score <- "EASI" +model_name <- "SSM" +run <- FALSE +n_it <- 2000 +n_chains <- 4 +#### + +score <- match.arg(score, c("EASI", "SCORAD", "oSCORAD", "POEM")) +model_name <- match.arg(model_name, c("RW", "AR", "MixedAR", "SSM", "SSMX")) + +stan_code <- file.path("Models", paste0(model_name, ".stan")) +res_file <- file.path("Results", paste0("fit_", score, "_", model_name, ".rds")) +par_file <- file.path("Results", paste0("par_", score, "_", model_name, ".rds")) +par0_file <- file.path("Results", paste0("par0_", score, "_", model_name, ".rds")) + +if (model_name == "RW") { + param_pop <- c("sigma") + param_ind <- c() + param_obs <- c("S_rep") # "S_mis" +} +if (model_name == "AR") { + param_pop <- c("sigma", "alpha", "S_inf", "b") + param_ind <- c() + param_obs <- c("S_rep") # "S_mis" +} +if (model_name == "MixedAR") { + param_pop <- c("sigma", + "mu_alpha", "phi_alpha", + "mu_inf", "sigma_inf") + param_ind <- c("alpha", "S_inf", "b") + param_obs <- c("S_rep") # "S_mis" +} +if (model_name == "SSM") { + param_pop <- c("sigma_tot", "rho2", "sigma_lat", "sigma_meas", "MDC", + "mu_alpha", "phi_alpha", + "mu_inf", "sigma_inf") + param_ind <- c("alpha", "S_inf", "b") + param_obs <- c("S_lat", "S_rep") +} +if (model_name == "SSMX") { + param_pop <- c("sigma_tot", "rho2", "sigma_lat", "sigma_meas", "MDC", + "mu_alpha", "phi_alpha", + "mu_inf", "sigma_inf", + "beta") + param_ind <- c("alpha", "S_inf", "b", "f") + param_obs <- c("S_lat", "S_rep") +} +param <- c(param_pop, param_ind, param_obs) + +score_char <- data.frame(Score = c("SCORAD", "oSCORAD", "EASI", "POEM"), + Range = c(103, 83, 72, 28), + MCID = c(8.7, 8.2, 6.6, 3.4)) %>% + filter(Score == score) + +# Data -------------------------------------------------------------------- + +l <- load_dataset() +dp <- l$patient_data +dt <- l$severity_data +pt <- unique(dt[["Patient"]]) +bio <- as.matrix(dp[, colnames(dp) != "Patient"]) # matrix of biomarkers (including treatment, age, sex...) + +# Model ------------------------------------------------------------------- + +format_data <- function(df, score) { + list( + N_obs = sum(!is.na(df[, score])), + N_mis = sum(is.na(df[, score])), + N_pt = length(unique(df$Patient)), + max_score = score_char$Range, + idx_obs = which(!is.na(df[, score])), + idx_mis = which(is.na(df[, score])), + S_obs = na.omit(df[, score]), + + N_test = 0, + idx_test = vector(), + S_test = vector(), + + # For horsehoe + p0 = 5, + slab_scale = 1, + slab_df = 5, + N_cov = ncol(bio), + X_cov = bio, + parametrisation = 0, + + run = 1, + rep = 1 + ) +} + +data_stan <- dt %>% + rename(y = score) %>% + # mutate(y = replace(y, Week > 12, NA)) %>% # cf. remove test set + format_data(., "y") + +if (run) { + fit <- stan(file = stan_code, + data = data_stan, + iter = n_it, + chains = n_chains, + pars = param, + seed = seed, + control = list(adapt_delta = case_when(model_name %in% c("SSM", "SSMX") ~ 0.99, + TRUE ~ 0.9))) + saveRDS(fit, file = res_file) + par <- extract_parameters(fit, param, param_ind, param_obs, pt, data_stan) + saveRDS(par, file = par_file) +} else { + fit <- readRDS(res_file) + par <- readRDS(par_file) +} + +par0 <- readRDS(par0_file) + +# Diagnostics and fit ---------------------------------------------------------------- + +if (FALSE) { + + # shinystan::launch_shinystan(fit) + check_hmc_diagnostics(fit) + # max(par[["Rhat"]], na.rm = TRUE) + + pairs(fit, pars = setdiff(param_pop, "beta")) + # pairs(fit, pars = paste0("beta[", 1:5, "]")) + plot(fit, pars = setdiff(param_pop, "beta"), plotfun = "trace") + plot(fit, pars = setdiff(param_pop, "beta"), plotfun = "hist") + + if (model_name %in% c("MixedAR", "SSM", "SSMX")) { + # plot(fit, pars = "alpha") + plot_grid( + plot_coef(fit, "alpha", pt, limits = c(0, 1), ylab = "Patient"), + plot_coef(fit, "b", pt, ylab = "Patient"), + plot_coef(fit, "S_inf", pt, ylab = "Patient"), + nrow = 1 + ) + } + if (model_name == "SSMX") { + plot_grid( + plot_coef(fit, "beta", colnames(bio), CI = c(.05, 0.95), limits = c(-1, 1)), + plot_coef(fit, "f", pt, CI = c(0.05, 0.95), limits = c(-1, 1)) + + labs(x = "Patient", y = "x * beta") + + theme(axis.text.y = element_blank()), + labels = "AUTO", nrow = 1, rel_widths = c(.55, .45) + ) + if (FALSE) { + ggsave(file.path("Plots", paste0(score, "_covariates.jpg")), + width = 10, height = 10, units = "cm", dpi = 300, scale = 2) + } + + } + # print(fit, pars = param_pop) + + # Check priors + param01 <- intersect(param_pop, c("alpha", "mu_alpha", "rho2")) # parameters in 0-1 + HuraultMisc::plot_prior_posterior(par0, par, setdiff(param_pop, param01)) + if (length(param01) > 0) { + # cf. 0-1 scale + HuraultMisc::plot_prior_posterior(par0, par, param01) + + coord_flip(ylim = c(0, 1)) + + theme(legend.position = "none") + } + plot_prior_influence(par0, par, c(param_pop, param_ind)) + # compute_prior_influence(par0, par, param_pop) + + lapply(param_ind, function(x) {PPC_group_distribution(fit, x, 100)}) %>% + plot_grid(plotlist = .) + +} + +# PPC Trajectories ------------------------------------------------------------ + +if (FALSE) { + + ssi <- full_join(HuraultMisc::extract_distribution(fit, "S_rep", type = "hdi", CI_level = seq(0.1, 0.9, 0.1)), + get_index(pt, data_stan), + by = "Index") + pl <- lapply(c(108, 119, 134, 137), # sort(sample(pt, 4, replace = FALSE)), + function(pid) { + tmp <- dt %>% + rename(y = score) + if (FALSE) { + # Identify training and testing data when the fit is not on the full dataset + tmp <- tmp %>% + mutate(Validation = case_when(Week <= 12 ~ "Training", + Week > 12 ~ "Testing"), + Validation = fct_relevel(Validation, "Training", "Testing")) + } + + PPC_fanchart(ssi, tmp, pid, score_char$Range) + + labs(y = score) # , title = paste("Patient", pid)) + }) + plot_grid(get_legend(pl[[1]] + theme(legend.position = "top")), + plot_grid(plotlist = lapply(pl, + function(p) { + p + theme(legend.position = "none") + }), + nrow = 2, labels = "AUTO"), + nrow = 2, rel_heights = c(.1, .9)) + # ggsave(file.path("Plots", paste0(score, "_PPC.jpg")), width = 30, height = 20, units = "cm", dpi = 300) + +} diff --git a/functions.R b/functions.R new file mode 100644 index 0000000..d55e5d7 --- /dev/null +++ b/functions.R @@ -0,0 +1,231 @@ +cbbPalette <- c("#000000", "#E69F00", "#56B4E9", "#009E73", "#F0E442", "#0072B2", "#D55E00", "#CC79A7") + +# Data -------------------------------------------------------------------- + +load_dataset <- function() { + # Load and prepare systemic therapy dataset + # + # - Select week 0 biomarkers + # - Remove missing biomarkers + # - Remove missing patients + # - Log and standardize features + # - Impute missing biomarkers and demographics + # - Add row for missing weeks in severity dataframe + # + # Returns: + # List containing the patient and the severity time-series dataframe + + library(TanakaData) # # Contains data and data processing functions + library(dplyr) + + # Process biomarkers + bio <- biomarkers_SystemicTherapy %>% + # Select biomarkers at week 0 + filter(Week == 0) %>% + select(-Week) %>% + # Remove almost missing biomarkers + select(-IL1a, -GCSF) %>% + # Log and standardize biomarkers + mutate(across(-matches("Patient"), ~scale(log10(.x)))) + + # Process demographics + demo <- patient_SystemicTherapy %>% + # Transform to numeric binary variables and standardize Age + mutate(FLG = as.numeric(FLG), + Sex = as.numeric(Sex == "M"), # 0 female, 1 male + Treatment = as.numeric(Treatment == "AZA"), # 0 MTX, 1 AZA + Age = scale(Age)) + + # Patient dataframe + dp <- full_join(demo, bio, by = "Patient") + # Severity dataframe + dt <- severity_SystemicTherapy + + pt <- intersect(unique(dp[["Patient"]]), unique(dt[["Patient"]])) + # Exclude patient 140 (mostly missing) + pt <- pt[pt != 140] + dp <- dp %>% filter(Patient %in% pt) + dt <- dt %>% filter(Patient %in% pt) + + # Impute missing in dp by 0 (mean/default value for binary data) + dp <- dp %>% + mutate(across(-matches("Patient"), ~tidyr::replace_na(., 0))) + + # Add rows for missing severity + dt <- bind_rows(dt, + setdiff(expand_grid(Patient = pt, + Week = seq(0, 24, 2)), + dt %>% select(Patient, Week))) + stopifnot(nrow(dt) == length(pt) * 13) + + # Reorder dataframes + dp <- dp %>% arrange(Patient) + dt <- dt %>% arrange(Patient, Week) + + return(list(patient_data = dp, severity_data = dt)) +} + +# Fitting -------------------------------------------------------------- + +plot_coef <- function(fit, parName, parLabel = NULL, CI = c(.05, .95), limits = NULL, ylab = "") { + # Plot patient coefficient estimates from stan model (custom function) + # + # Args: + # fit: stanfit object + # parName: name of the patient-dependent parameter in fit + # parLabel: vector of names for parName + # CI: vector of length two indicating the credible interval lower and upper bounds + # limits: vector of length two indicating the range of estimates to plot + # + # Returns: + # Ggplot of patient coefficient estimates + + library(ggplot2) + + tmp <- rstan::extract(fit, pars = parName)[[1]] + + if (is.null(parLabel)) {parLabel <- paste(parName, 1:ncol(tmp) ,sep= "_")} + + d <- data.frame(Parameter = factor(parLabel, levels = parLabel, labels = parLabel), + Mean = apply(tmp, 2, mean), + Lower = apply(tmp, 2, function(x) {quantile(x, probs = min(CI))}), + Upper = apply(tmp, 2, function(x) {quantile(x, probs = max(CI))})) + d$Parameter = factor(d$Parameter, levels = rev(parLabel)) + + p <- ggplot(data = d, aes(x = Parameter, y = Mean, ymin = Lower, ymax = Upper)) + + geom_pointrange() + + labs(y = parName, x = ylab) + + coord_flip(ylim = limits) + + theme_bw(base_size = 20) + theme(panel.grid.minor.x = element_blank()) + + return(p) +} + +get_index <- function(pt, data_stan) { + # Associate (Patient, Week) pairs to corresponding index in the model + # + # Args: + # pt: vector of patients ID (same order as the patient parameters in stanfit) + # data_stan: data input to the stan function + # + # Returns: + # Dataframe + + mult <- with(data_stan, (N_obs + N_mis)/N_pt) + max_week <- (mult - 1) * 2 + out <- data.frame(Patient = pt[rep(1:data_stan$N_pt, each = mult)], + Week = rep(seq(0, max_week, 2), length(pt))) + out$Patient <- as.character(out$Patient) + out$Index <- 1:nrow(out) + return(out) +} + +extract_parameters <- function(fit, param, param_ind, param_obs, pt, data_stan) { + # Extract parameters' summary + # + # Args: + # fit: stanfit object + # param: parameters to extract + # param_ind: individual parameters in param + # param_obs: observation parameters in param + # pt: vector of patients ID (same order as the patient parameters in stanfit) + # data_stan: data input to the stan function + # + # Returns: + # Dataframe containing posterior summary statistics of the parameters + + par <- HuraultMisc::summary_statistics(fit, param) + par$Patient <- NA + par$Week <- NA + + pt <- as.character(pt) + + ## Patient-dependent parameter + for (i in intersect(param_ind, param)) { + idx <- which(par$Variable == i) + par$Patient[idx] <- pt[par$Index[idx]] + } + + ## Patient and time-dependent parameter (observation parameters) + dict <- get_index(pt, data_stan) + for (i in intersect(param_obs, param)) { + idx <- sort(which(par$Variable == i)) + par[idx, c("Patient", "Week")] <- dict[, c("Patient", "Week")] + } + + ## Missing score + for (i in intersect("S_mis", param)) { + idx <- which(par$Variable == i) + id_mis <- data_stan$idx_mis[par$Index[idx]] + par[idx, c("Patient", "Week")] <- dict[id_mis, c("Patient", "Week")] + } + + # par$Index <- NULL + return(par) +} + +PPC_fanchart <- function(ssi, df = NULL, patient_id, max_score = NULL) { + # PPC plot with stacked prediction intervals (fan chart) centered around the median + # + # Args: + # ssi: Dataframe summarising predictive distribution as credible intervals (with columns: Lower, Upper, Level, Patient, Day) + # df: dataframe of observed trajectory (can be NULL, in that case the actual trajectory is not overlapped) + # patient_id: patient ID + # max_score: maximum value that the measure can take (for plotting) + # + # Returns: + # Ggplot + + library(ggplot2) + + stopifnot(is.data.frame(ssi), + all(c("Lower", "Upper", "Level", "Patient", "Week") %in% colnames(ssi)), + patient_id %in% unique(ssi[["Patient"]]), + is.null(df) || is.data.frame(df)) + + if (is.data.frame(df)) { + stopifnot(all(c("Patient", "Week", "y") %in% colnames(df)), + patient_id %in% unique(df[["Patient"]])) + } + if (!is.null(max_score)) { + stopifnot(is.numeric(max_score), + max_score > 0) + } + + lvl <- sort(unique(ssi[["Level"]]), decreasing = TRUE) + + p <- ggplot() + # Prediction intervals (cf. fill cannot be an aesthetic with a ribbon) + for (i in 1:length(lvl)) { + p <- p + geom_ribbon(data = subset(ssi, Patient == patient_id & Level == lvl[i]), + aes(x = Week, ymin = Lower, ymax = Upper, fill = Level)) + } + # Actual trajectory + if (is.data.frame(df)) { + sub_df <- subset(df, Patient == patient_id) + if ("Validation" %in% colnames(df)) { + p <- p + + geom_point(data = sub_df, aes(x = Week, y = y, colour = Validation), size = 1) + } else { + p <- p + + geom_point(data = sub_df, aes(x = Week, y = y), size = 1) + } + } + # Formatting + p <- p + + scale_x_continuous(expand = expansion(mult = .01)) + + scale_fill_gradientn(colours = rev(c("#FFFFFF", RColorBrewer::brewer.pal(n = 6, "Blues")))[-1], + limits = c(0, 1), breaks = c(.1, .5, .9)) + # seq(0, 1, 0.25) + scale_colour_manual(values = c("#000000", "#E69F00")) + + labs(fill = "Confidence level", colour = "") + + theme_classic(base_size = 15) + + if (!is.null(max_score)) { + p <- p + + scale_y_continuous(limits = c(0, max_score), + breaks = c(round(seq(0, max_score, length.out = 5), -1)[-5], max_score), + expand = c(0, 0.01 * max_score)) + } + + return(p) +} diff --git a/run_validation.R b/run_validation.R new file mode 100644 index 0000000..4a2e508 --- /dev/null +++ b/run_validation.R @@ -0,0 +1,202 @@ +# Notes ------------------------------------------------------------------- + +# We implement a mixture of K-fold cross-validation (leave N patients out) and forward chaining +# 1/k patients are used for testing and 1-1/k patients for training +# The model is trained on the complete data of training patients and the data up to week w = TrainingWeek (included) for testing patients +# The model is tested on the data after week w for testing patients +# This process is repeated for different w (i.e. we provide more data for the testing patients) +# And this process is repeated for different subsets of training and testing patients + +# If k = 1, there is no cross-validation, just forward chaining (predict N step ahead) +# In that case, w cannot be 0! +# If TestingWeek(i) = TrainingWeek(i + 1) then this is prediction one step ahead + +# If k > 0 and w = TrainingWeek = 0, this is "predict given initial point" + +# Initialisation ---------------------------------------------------------- + +rm(list = ls()) # Clear Workspace (but better to restart session) + +library(HuraultMisc) # Functions shared across projects +library(tidyverse) +library(cowplot) +library(rstan) +rstan_options(auto_write = TRUE) # Save compiled model +options(mc.cores = parallel::detectCores()) # Parallel computing +library(foreach) +library(doParallel) +source("functions.R") # Additional functions + +seed <- 462528635 # seed also used for stan +set.seed(seed) + +#### OPTIONS +score <- "EASI" +model_name <- "SSM" +k <- 7 # Number of folds for k-fold cross-validation, set to 1 if you don't want k-fold +run <- FALSE +n_it <- 2000 +n_chains <- 4 +n_cluster <- 7 +#### + +score <- match.arg(score, c("EASI", "SCORAD", "oSCORAD", "POEM")) +model_name <- match.arg(model_name, c("Uniform", "RW", "AR", "MixedAR", "SSM", "SSMX")) + +stan_code <- file.path("Models", paste0(model_name, ".stan")) +res_file <- file.path("Results", paste0("val_", score, "_", model_name, ".rds")) +res_dir <- file.path("Results", paste0("val_", score, "_", model_name)) # temporary directory + +param <- c("S_pred", "lpd") + +if (run & model_name != "Uniform") { + compiled_model <- rstan::stan_model(stan_code) +} + +score_char <- data.frame(Score = c("SCORAD", "oSCORAD", "EASI", "POEM"), + Range = c(103, 83, 72, 28), + MCID = c(8.7, 8.2, 6.6, 3.4)) %>% + filter(Score == score) + +# Data -------------------------------------------------------------------- + +l <- load_dataset() +dp <- l$patient_data +dt <- l$severity_data +pt <- unique(dt[["Patient"]]) +bio <- as.matrix(dp[, colnames(dp) != "Patient"]) # matrix of biomarkers (including treatment, age, sex...) + +# Validation ------------------------------------------------------------------- + +stopifnot(k == round(k), + k > 0 & k < length(pt)) +weeks <- c(0, 2, 4, 8, 12, 24) +if (k > 1) { + folds <- sample(cut(1:length(pt), breaks = k, labels = FALSE)) # K-fold + it <- expand.grid(Fold = 1:k, TrainingWeek = weeks[-length(weeks)]) # we can train from week 0 since we test only a subset of patients +} else { + folds <- rep(1, length(pt)) + it <- expand.grid(Fold = 1:k, TrainingWeek = weeks[c(-1, -length(weeks))]) +} + +format_data <- function(df, score, idx_test) { + list( + N_obs = sum(!is.na(df[, score])), + N_mis = sum(is.na(df[, score])), + N_pt = length(unique(df$Patient)), + max_score = score_char$Range, + idx_obs = which(!is.na(df[, score])), + idx_mis = which(is.na(df[, score])), + S_obs = na.omit(df[, score]), + + N_test = length(idx_test), + idx_test = idx_test, + S_test = df[idx_test, "S"], + + # For horseshoe + p0 = 5, + slab_scale = 1, + slab_df = 5, + N_cov = ncol(bio), + X_cov = bio, + parametrisation = 0, + + run = 1, + rep = 1 + ) +} + +if (run) { + + duration <- Sys.time() + cl <- makeCluster(n_cluster) + registerDoParallel(cl) + + writeLines(c(""), "log.txt") + dir.create(res_dir) + + out <- foreach(i = 1:nrow(it)) %dopar% { + w <- it$TrainingWeek[i] + f <- it$Fold[i] + + library(tidyverse) + library(rstan) + rstan_options(auto_write = TRUE) # Save compiled model + options(mc.cores = parallel::detectCores()) # Parallel computing + source("functions.R") + + sink("log.txt", append = TRUE) + cat(paste0("Starting training at week ", w, ", fold ", f, " \n")) + + ### + + ## Prepare data + dt_wf <- dt + dt_wf$S <- dt_wf[, score] + dt_wf[, c("SCORAD", "oSCORAD", "EASI", "POEM", "ITCH", "SLEEP")] <- NULL + + dt_wf$S_train <- dt_wf$S + idx_pred <- which((dt_wf$Week %in% weeks) & !is.na(dt_wf$S) & (dt_wf$Week > w) & (dt_wf$Patient %in% pt[folds == f])) + dt_wf$S_train[idx_pred] <- NA + + data_stan <- format_data(dt_wf, "S_train", idx_pred) + + perf <- data.frame(Patient = dt_wf$Patient[idx_pred], + TrainingWeek = w, + TestingWeek = dt_wf$Week[idx_pred], + Fold = f, + S = dt_wf$S[idx_pred]) + + if (model_name == "Uniform") { + + perf <- perf %>% + mutate(Mean_pred = score_char$Range / 2, + lpd = -log(score_char$Range), + CRPS = scoringRules::crps_unif(perf[["S"]], min = 0, max = score_char$Range), + Samples = NA) + + } else { + ## Fit + fit <- sampling(compiled_model, + data = data_stan, + pars = param, + iter = n_it, + chains = n_chains, + seed = seed, + control = list(adapt_delta = case_when(model_name %in% c("SSM", "SSMX") ~ 0.99, + TRUE ~ 0.9))) + + ## Prepare ouput + lpd <- extract(fit, pars = "lpd")[[1]] + pred <- extract(fit, pars = "S_pred")[[1]] + smp <- sapply(1:ncol(pred), function(i) {list(pred[, i])}) + + perf <- perf %>% + mutate(Mean_pred = apply(pred, 2, mean), # cf. point prediction (mean) + lpd = apply(lpd, 2, function(x) {log(mean(exp(x)))}), # marginalise lpd + CRPS = scoringRules::crps_sample(perf[["S"]], t(pred)), + Samples = smp) + } + + ## Save (intermediate results) + saveRDS(perf, file = file.path(res_dir, paste0("val_", i, ".rds"))) + + cat(paste0("Ending training at week ", w, ", fold ", f, " \n")) + } + stopCluster(cl) + (duration = Sys.time() - duration) + + # Recombine results + files <- list.files(res_dir) + if (length(files) < nrow(it)) { + warning("Number of files (", length(files), ") less than the number of iterations (", max_it + 1, "). Some runs may have failed.") + } + res <- do.call(rbind, + lapply(files, + function(f) { + readRDS(file.path(res_dir, f)) + })) + saveRDS(res, file = res_file) +} else { + res <- readRDS(res_file) +} diff --git a/ssm-eczema-biomarkers.Rproj b/ssm-eczema-biomarkers.Rproj new file mode 100644 index 0000000..8e3c2eb --- /dev/null +++ b/ssm-eczema-biomarkers.Rproj @@ -0,0 +1,13 @@ +Version: 1.0 + +RestoreWorkspace: Default +SaveWorkspace: Default +AlwaysSaveHistory: Default + +EnableCodeIndexing: Yes +UseSpacesForTab: Yes +NumSpacesForTab: 2 +Encoding: UTF-8 + +RnwWeave: Sweave +LaTeX: pdfLaTeX