Skip to content

Commit

Permalink
allow variable brew parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewGhazi committed Jun 30, 2024
1 parent 55d2860 commit dd0f0d1
Show file tree
Hide file tree
Showing 9 changed files with 255 additions and 52 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Generated by roxygen2: do not edit by hand

export(create_ranges)
export(run_gp)
export(suggest_next)
import(collapse)
Expand Down
38 changes: 37 additions & 1 deletion R/check_input.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
get_params = function(dat) {
grep("rating", names(dat), value = TRUE, invert = TRUE)
}

all_within = function(param_vec, range_vec) {
all((param_vec >= range_vec[1]) & (param_vec <= range_vec[2]))
}


#' Check input data.frame
#' @description
#' The input data frame should have a limited number of columns and at least two rows
#'
#' @param dat data frame input
#' @param call calling environment
check_df = function(dat, call = rlang::caller_env()) {
cn = colnames(dat)

Expand All @@ -11,3 +21,29 @@ check_df = function(dat, call = rlang::caller_env()) {

# if (nrow(dat) < 2) cli::cli_abort("Input needs at least two existing observations.", call = call)
}

check_param_olap = function(dat, param_ranges) {
dat_params = get_params(dat)
rng_params = get_params(param_ranges)

param_int = intersect(dat_params, rng_params)

all_in_int = all(dat_params %in% param_int) & all(rng_params %in% param_int)

if (!all_in_int) cli::cli_warn("Non-overlapping columns between data and parameter ranges will be dropped.")

}

check_ranges = function(dat, param_ranges, call = rlang::caller_env()) {
params = get_params(dat)

check_param_olap(dat, param_ranges)

within_ranges = mapply(all_within,
dat |> get_vars(params), param_ranges |> get_vars(params))

if (!all(within_ranges)) cli::cli_abort("Provided parameter values fall outside the specified ranges. Those with values outside the provided ranges are {.val {names(within_ranges[!within_ranges])}}",
call = call)

list(dat |> get_vars(c(params, "rating")), param_ranges |> get_vars(params))
}
178 changes: 134 additions & 44 deletions R/run.R
Original file line number Diff line number Diff line change
@@ -1,41 +1,143 @@
get_grid_vec = function(param_id, param_range) {
if (param_id == "grinder_setting") {
res = seq(param_range[1], param_range[2], by = .5)
} else if (param_id == "temp") {
res = seq(param_range[1], param_range[2], by = 5)
} else if (param_id == "bloom_time") {
res = seq(param_range[1], param_range[2], by = 10)
} else {
res = seq(param_range[1], param_range[2], length.out = 6)
}

res
}

form_x_grid = function(max_grid_size,
param_ranges) {

params = get_params(param_ranges)

vec_list = mapply(get_grid_vec,
params, param_ranges,
SIMPLIFY = FALSE)

res = expand.grid(vec_list) |> qDT()

if (nrow(res) > max_grid_size) cli::cli_abort("Automated grid exceeded the specified {.var max_grid_size}. Either provide your own grid or increase {.var max_grid_size}")

res
}

get_x_grid = function(max_grid_size,
param_ranges, param_grid) {

if (!is.null(param_grid)) {
x_grid = param_grid
} else {
x_grid = form_x_grid(max_grid_size,
param_ranges)
}

x_grid
}


#' Create a range data frame
#' @description This function creates an example data frame of mins and maxs for brew
#' parameter settings. That is, the range of grinder settings I want to search is from 4
#' to 14, temperatures from 170 to 210F, and bloom times from 0 to 60s.
#'
#' @export
run_gp = function(dat, ...) {
create_ranges = function() {
data.frame(grinder_setting = c(4,14),
temp = c(170, 210),
bloom_time = c(0, 60))
}

get_centers_and_widths = function(param_ranges) {
centers = param_ranges |> sapply(fmean)
widths = (param_ranges |> sapply(diff)) / 2
list(centers, widths)
}

center_grid = function(x_grid, param_ranges) {
cents_widths = get_centers_and_widths(param_ranges)

check_df(dat)
centers = cents_widths[[1]]
widths = cents_widths[[2]]

res = x_grid |>
TRA(centers) |>
TRA(widths, FUN = "/") |>
TRA(rep(3, ncol(x_grid)), FUN = "*")

names(res) = paste0(names(res), "_cent")
res
}

center_dat = function(dat, param_ranges) {
cents_widths = get_centers_and_widths(param_ranges)

# TODO adapt centering/scaling, generalize to arbitrary # of parameters
dat = dat |>
mtt(gs_cent = (grinder_setting - 9) / 5 * 3,
temp_cent = (temp - 190) / (20) * 3,
bloom_cent = (bloom_time - 30) / 30 * 3) |>
qDT()
centers = cents_widths[[1]]
widths = cents_widths[[2]]

g_map = data.table(g = seq(4,14, by = .5)) |>
mtt(gc = (g - 9) / 5 * 3)
params = get_params(dat)

res = dat |>
get_vars(params) |>
TRA(centers) |>
TRA(widths, FUN = "/") |>
TRA(rep(3, length(params)), FUN = "*")

names(res) = paste0(names(res), "_cent")

res |>
add_vars(dat$rating)
}

#' Run the GP
#' @param dat data frame input of brew parameters and rating
#' @param ... arguments passed to cmdstanr's sample method
#' @param max_grid_size maximum number of grid points to evaluate
#' @param param_ranges upper and lower limits of parameter ranges to evaluate
#' @details
#' The function \code{\link{create_ranges()}} will create an example range df.
#'
#' @export
run_gp = function(dat, ..., max_grid_size = 2000,
param_ranges = create_ranges(), param_grid = NULL) {

t_map = data.table(t = seq(170, 210, by = 5),
tc = (seq(170, 210, by = 5) - 190) / 20 * 3)
check_df(dat)
cr_res = check_ranges(dat, param_ranges)
dat = cr_res[[1]]; param_ranges = cr_res[[2]]

b_map = data.table(b = seq(0, 60, by = 10),
bc = ((seq(0, 60, by = 10) - 30) / 30) * 3 )
x_grid = get_x_grid(max_grid_size, param_ranges, param_grid)
x_grid_cent = center_grid(x_grid, param_ranges)

x_grid = expand.grid(gc = g_map$gc,
tc = t_map$tc,
bc = b_map$bc) |>
qM()
centered_dat = center_dat(dat, param_ranges)

X = dat |> slt(gs_cent, temp_cent, bloom_cent) |> qM()
X = centered_dat |> get_vars("_cent", regex=TRUE) |> qM()

list(run_gp_model(X, dat$rating, x_grid, ...),
x_grid)
list(run_gp_model(X = X, y = dat$rating, X_pred = x_grid_cent, ...),
x_grid,
x_grid_cent)
}

#' Suggest the next point to try
#' @inheritParams run_gp
#' @param ... arguments passed to cmdstanr's sample method
#' @param offset expected improvement hyperparameter. Higher values encourage more
#' exploration. Interpreted on the same scale as ratings.
#' @export
suggest_next = function(dat, x_grid, ...) {
suggest_next = function(dat, ..., max_grid_size = 2000,
param_ranges = create_ranges(), param_grid = NULL,
offset = .25) {

run_res = run_gp(dat, ...)

gp_res = run_res[[1]]
x_grid = run_res[[2]]
x_grid_cent = run_res[[3]]

obs_max = max(dat$rating)

Expand All @@ -48,15 +150,15 @@ suggest_next = function(dat, x_grid, ...) {

max_pred_dens = fsum(acq) |> which.max()

if (max_pred_dens == 1) cli::cli_warn("Selected the first grid point as maximum of the acquisition function. You may need to run the chains for longer.")

pred_g = x_grid[max_pred_dens,,drop=FALSE][,"gc"]
if (max_pred_dens == 1) cli::cli_warn("Selected the first grid point as maximum of the acquisition function. You may need to run the chains for longer or lower {.var offset}.")

acq_post = data.table(variable = colnames(acq),
mean = acq |> colMeans(),
i = 1:ncol(acq))
# pred_g = x_grid[max_pred_dens,,drop=FALSE][,"gc"]

post_range = acq_post$mean |> range()
# acq_post = data.table(variable = colnames(acq),
# mean = acq |> colMeans(),
# i = 1:ncol(acq))
#
# post_range = acq_post$mean |> range()

# qDT(x_grid) |> mtt(i = 1:nrow(x_grid)) |>
# sbt(dplyr::near(gc, pred_g)) |>
Expand All @@ -65,19 +167,7 @@ suggest_next = function(dat, x_grid, ...) {
# geom_tile(aes(fill = mean)) +
# scale_fill_viridis_c(limits = post_range)

g_map = data.table(g = seq(4,14, by = .5)) |>
mtt(gc = (g - 9) / 5 * 3)

t_map = data.table(t = seq(170, 210, by = 5),
tc = (seq(170, 210, by = 5) - 190) / 20 * 3)

b_map = data.table(b = seq(0, 60, by = 10),
bc = ((seq(0, 60, by = 10) - 30) / 30) * 3 )

x_grid[max_pred_dens,,drop=FALSE] |>
qDT() |>
join(g_map, verbose = FALSE) |>
join(t_map, verbose = FALSE) |>
join(b_map, verbose = FALSE)

list(draws_df = gp_res,
x_grid = x_grid,
suggested = x_grid[max_pred_dens,] )
}
12 changes: 6 additions & 6 deletions README.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,16 @@ Give the `suggest_next()` function a data frame of brew parameters with ratings

```{r eval=FALSE}
library(dyingforacup)
options(mc.cores = 4)
dat = data.frame(grinder_setting = c(8, 193, 25),
temp = c(7, 195, 20),
bloom_time = c(9, 179, 45),
rating = c(1.1, -.7, -1))
dat = data.frame(grinder_setting = c( 8, 7, 9),
temp = c(193, 195, 179),
bloom_time = c( 25, 20, 45),
rating = c(1.1, -0.7, -1))
suggest_next(dat,
iter_sampling = 4000,
refresh = 1250,
refresh = 0,
show_exceptions = FALSE,
adapt_delta = .95,
parallel_chains = 4)
Expand Down
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ probability of improving the rating.
``` r
library(dyingforacup)


dat = data.frame(grinder_setting = c(8, 193, 25),
temp = c(7, 195, 20),
bloom_time = c(9, 179, 45),
Expand Down
5 changes: 5 additions & 0 deletions man/check_df.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 13 additions & 0 deletions man/create_ranges.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

29 changes: 29 additions & 0 deletions man/run_gp.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

30 changes: 30 additions & 0 deletions man/suggest_next.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit dd0f0d1

Please sign in to comment.