Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewGhazi committed Jul 4, 2024
1 parent 5dc448f commit 3f0f168
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 14 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,5 @@ inst/stan/**/*.exe
inst/stan/**/*.EXE
*.dll
dyingforacup.Rproj
R/setup_env.R
src/stan/gp_mod
52 changes: 39 additions & 13 deletions R/run.R
Original file line number Diff line number Diff line change
Expand Up @@ -114,15 +114,17 @@ run_gp = function(dat, ..., max_grid_size = 2000,
param_ranges = create_ranges(), param_grid = NULL) {

check_df(dat)
cr_res = check_ranges(dat, param_ranges)
dat = cr_res[[1]]; param_ranges = cr_res[[2]]

cr_res = check_ranges(dat, param_ranges)
dat = cr_res[[1]]
param_ranges = cr_res[[2]]

x_grid = get_x_grid(max_grid_size, param_ranges, param_grid)
x_grid_cent = center_grid(x_grid, param_ranges)

centered_dat = center_dat(dat, param_ranges)

X = centered_dat |> get_vars("_cent", regex=TRUE) |> qM()
X = centered_dat |> get_vars("_cent", regex = TRUE) |> qM()

list(run_gp_model(X = X, y = dat$rating, X_pred = x_grid_cent, ...),
x_grid)
Expand All @@ -133,10 +135,17 @@ run_gp = function(dat, ..., max_grid_size = 2000,
#' @param ... arguments passed to cmdstanr's sample method
#' @param offset expected improvement hyperparameter. Higher values encourage more
#' exploration. Interpreted on the same scale as ratings.
#' @param lambda tradeoff between weighting posterior predictive variance and expected
#' improvement at grid values.
#' @details The acquisition function is \code{lambda*f_star_var + (1-lambda)*exp_imp}.
#' Higher values of lambda up-weight posterior predictive variance, leading to more
#' exploration over exploitation.
#'
#' @export
suggest_next = function(dat, ..., max_grid_size = 2000,
param_ranges = create_ranges(), param_grid = NULL,
offset = .25) {
offset = .25,
lambda = .01) {

run_res = run_gp(dat, ...)

Expand All @@ -145,32 +154,49 @@ suggest_next = function(dat, ..., max_grid_size = 2000,

obs_max = max(dat$rating)

minus_max = qM(gp_res |> get_vars("f_star", regex = TRUE)) - obs_max - offset
f_star_mat = qM(gp_res |> get_vars("f_star", regex = TRUE))

# expected improvement ----
minus_max = f_star_mat - obs_max - offset

w = 1*(minus_max > 0)

acq = minus_max * w
exp_imp = fmean(minus_max * w)

max_pred_dens = fmean(acq) |> which.max()
max_pred_dens = exp_imp |> which.max()

if (max_pred_dens == 1) cli::cli_warn("Selected the first grid point as maximum of the acquisition function. You may need to run the chains for longer or lower {.var offset}.")

# pred_g = x_grid[max_pred_dens,,drop=FALSE][,"gc"]

# acq_post = data.table(variable = colnames(acq),
# mean = acq |> colMeans(),
# i = 1:ncol(acq))
# exp_imp_post = data.table(variable = colnames(exp_imp),
# mean = exp_imp |> colMeans(),
# i = 1:ncol(exp_imp))
#
# post_range = acq_post$mean |> range()
# post_range = exp_imp_post$mean |> range()

# qDT(x_grid) |> mtt(i = 1:nrow(x_grid)) |>
# sbt(dplyr::near(gc, pred_g)) |>
# join(acq_post, on = "i", validate = "1:1") |>
# join(exp_imp_post, on = "i", validate = "1:1") |>
# ggplot(aes(tc, bc)) +
# geom_tile(aes(fill = mean)) +
# scale_fill_viridis_c(limits = post_range)

# posterior uncertainty ----
f_star_var = f_star_mat |> fvar()

# combined expected improvement and posterior uncertainty ----

combined_acq = lambda*f_star_var + (1-lambda)*exp_imp

acq_df = data.table(post_var = f_star_var,
exp_imp = exp_imp,
acq = combined_acq) |>
cbind(x_grid)

suggest = acq_df |> sbt(whichv(acq, fmax(acq)))

list(draws_df = gp_res,
x_grid = x_grid,
suggested = x_grid[max_pred_dens,] )
suggested = suggest )
}
11 changes: 10 additions & 1 deletion man/suggest_next.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 3f0f168

Please sign in to comment.