Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewGhazi committed Jul 6, 2024
1 parent 843aba1 commit 10685f5
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 21 deletions.
2 changes: 1 addition & 1 deletion R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,5 @@ run_gp_model = function(X, y, X_pred, ..., verbose) {
...)

fit$draws(format = "data.frame",
variables = c("alpha", "rho", "sigma", "f_star"))
variables = c("alpha", "rho", "sigma", "f_mean", "f_star"))
}
25 changes: 16 additions & 9 deletions R/run.R
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,11 @@ run_gp = function(dat, ..., max_grid_size = 2000,

check_df(dat)

cr_res = check_ranges(dat, param_ranges)
dat = cr_res[[1]]
param_ranges = cr_res[[2]]
if (!is.null(param_grid)) {
cr_res = check_ranges(dat, param_ranges)
dat = cr_res[[1]]
param_ranges = cr_res[[2]]
}

x_grid = get_x_grid(max_grid_size, param_ranges, param_grid)
x_grid_cent = center_grid(x_grid, param_ranges)
Expand All @@ -139,7 +141,7 @@ run_gp = function(dat, ..., max_grid_size = 2000,
#' improvement at grid values.
#' @details The acquisition function is \code{lambda*f_star_var + (1-lambda)*exp_imp}.
#' Higher values of lambda up-weight posterior predictive variance, leading to more
#' exploration over exploitation. Lower lambda values up-weight expected improvement over \code{max(dat$rating) - off}
#' exploration over exploitation. Lower lambda values up-weight expected improvement over \code{max(dat$rating) - offset}.
#' @returns a list with elements:
#' \itemize{
#' \item{draws_df}{a draws data frame of model parameters and grid point predictive draws f_star}
Expand All @@ -153,14 +155,19 @@ suggest_next = function(dat, ..., max_grid_size = 2000,
offset = .25,
lambda = .01) {

run_res = run_gp(dat, ...)
run_res = run_gp(dat,
max_grid_size = max_grid_size,
param_ranges = param_ranges,
param_grid = param_grid,
...)

gp_res = run_res[[1]]
x_grid = run_res[[2]]

obs_max = max(dat$rating)

f_star_mat = qM(gp_res |> get_vars("f_star", regex = TRUE))
f_mean_mat = qM(gp_res |> get_vars("f_mean", regex = TRUE))

# expected improvement ----
minus_max = f_star_mat - (obs_max - offset)
Expand All @@ -171,7 +178,7 @@ suggest_next = function(dat, ..., max_grid_size = 2000,

max_pred_dens = exp_imp |> which.max()

if (max_pred_dens == 1) cli::cli_warn("Selected the first grid point as maximum of the acquisition function. You may need to run the chains for longer or lower {.var offset}.")
if (all(exp_imp < .Machine$double.eps^0.5)) cli::cli_warn("All expected improvement values near zero. You may need to run the chains for longer or raise {.var offset}.")

# pred_g = x_grid[max_pred_dens,,drop=FALSE][,"gc"]

Expand All @@ -189,13 +196,13 @@ suggest_next = function(dat, ..., max_grid_size = 2000,
# scale_fill_viridis_c(limits = post_range)

# posterior uncertainty ----
f_star_var = f_star_mat |> fvar()
f_mean_sd = f_mean_mat |> fsd()

# combined expected improvement and posterior uncertainty ----

combined_acq = lambda*f_star_var + (1-lambda)*exp_imp
combined_acq = lambda*f_mean_sd + (1-lambda)*exp_imp

acq_df = data.table(post_var = f_star_var,
acq_df = data.table(post_sd = f_mean_sd,
exp_imp = exp_imp,
acq = combined_acq) |>
cbind(x_grid)
Expand Down
3 changes: 2 additions & 1 deletion README.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ Easy:
Medium:

* Non-normal outcome
* Fast GP approximations for 1D/2D datasets
* Fast GP approximations for 1D/2D datasets with [`gptools`](https://github.com/onnela-lab/gptools/tree/main)

Hard:

Expand All @@ -81,3 +81,4 @@ Nightmare:

* Fast GP approximations for 3D+
* I think this would require writing my own ND FFT function?
* Refactor to use INLA (preferably from scratch over `R-INLA`)
2 changes: 1 addition & 1 deletion man/suggest_next.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 10 additions & 9 deletions src/stan/gp_mod.stan
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ transformed data {

parameters {
real<lower=0> rho; // length scale
real<lower=0> alpha; // wiggliness
real<lower=0> alpha; // amplitude
real<lower=0> sigma; // I'm giving ratings with a known variance of 1
}

Expand All @@ -33,23 +33,24 @@ transformed parameters {
}

model {
rho ~ inv_gamma(3,3);
// rho ~ inv_gamma(3,3);
rho ~ std_normal();
// alpha ~ std_normal();
alpha ~ inv_gamma(3,3);
sigma ~ student_t(6, 0, .2);
alpha ~ inv_gamma(1.5,1.5);
sigma ~ student_t(3, 0, .2);

y ~ multi_normal_cholesky(mu, L_K);
}

generated quantities {
vector[N_pred] f_mean;
vector[N_pred] f_star;
{
matrix[N, N_pred] K_x_x_pred = gp_exp_quad_cov(x, x_pred, alpha, rho);
vector[N] K_div_y = mdivide_right_tri_low(mdivide_left_tri_low(L_K, y)', L_K)';
f_star = K_x_x_pred' * K_div_y;
}

for (i in 1:N_pred) {
f_star[i] += normal_rng(0, sigma);
f_mean = K_x_x_pred' * K_div_y;
for (i in 1:N_pred) {
f_star[i] = normal_rng(f_mean[i], sigma);
}
}
}

0 comments on commit 10685f5

Please sign in to comment.