From 441d754772457cbe7aefc045d09a5a2091993ff9 Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Tue, 16 Apr 2024 14:20:18 +0200 Subject: [PATCH 01/34] feat: aggr function for inner tuning --- DESCRIPTION | 2 +- NAMESPACE | 2 ++ NEWS.md | 5 +++++ R/Domain.R | 18 ++++++++++++++---- R/ParamDbl.R | 4 ++-- R/ParamFct.R | 4 ++-- R/ParamInt.R | 4 ++-- R/ParamLgl.R | 4 ++-- R/ParamSet.R | 7 ++++++- R/ParamUty.R | 4 ++-- R/helper.R | 8 ++++++++ R/to_tune.R | 35 +++++++++++++++++++++++++++++++++++ man/Domain.Rd | 20 +++++++++++++++----- man/Sampler.Rd | 2 +- man/Sampler1D.Rd | 4 ++-- man/Sampler1DCateg.Rd | 6 +++--- man/Sampler1DNormal.Rd | 6 +++--- man/Sampler1DRfun.Rd | 6 +++--- man/Sampler1DUnif.Rd | 6 +++--- man/SamplerHierarchical.Rd | 6 +++--- man/SamplerJointIndep.Rd | 6 +++--- man/SamplerUnif.Rd | 6 +++--- man/in_tune.Rd | 19 +++++++++++++++++++ 23 files changed, 139 insertions(+), 45 deletions(-) create mode 100644 man/in_tune.Rd diff --git a/DESCRIPTION b/DESCRIPTION index 6c79cb3a..621fe625 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -61,7 +61,7 @@ Config/testthat/edition: 3 Config/testthat/parallel: false NeedsCompilation: no Roxygen: list(markdown = TRUE, r6 = TRUE) -RoxygenNote: 7.2.3 +RoxygenNote: 7.3.1 VignetteBuilder: knitr Collate: 'Condition.R' diff --git a/NAMESPACE b/NAMESPACE index b3eefa14..8863f291 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -49,6 +49,7 @@ S3method(format,Condition) S3method(print,Condition) S3method(print,Domain) S3method(print,FullTuneToken) +S3method(print,InnerTuneToken) S3method(print,ObjectTuneToken) S3method(print,RangeTuneToken) S3method(rd_info,ParamSet) @@ -85,6 +86,7 @@ export(generate_design_grid) export(generate_design_lhs) export(generate_design_random) export(generate_design_sobol) +export(in_tune) export(p_dbl) export(p_fct) export(p_int) diff --git a/NEWS.md b/NEWS.md index 79a119c3..1b2c7018 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,3 +1,8 @@ +# dev + +* feat: added `aggr`(egation function) to `Domain` which can be used for inner +tuning. + # paradox 0.12.0 * Removed `Param` objects. `ParamSet` now uses a `data.table` internally; individual parameters are more like `Domain` objects now. `ParamSets` should be constructed using the `ps()` shorthand and `Domain` objects. This entails the following major changes: * `ParamSet` now supports `extra_trafo` natively; it behaves like `.extra_trafo` of the `ps()` call. diff --git a/R/Domain.R b/R/Domain.R index 56b98a55..e9224a63 100644 --- a/R/Domain.R +++ b/R/Domain.R @@ -63,6 +63,10 @@ #' @param init (`any`)\cr #' Initial value. When this is given, then the corresponding entry in `ParamSet$values` is initialized with this #' value upon construction. +#' @param aggr (`function`)\cr +#' Function with one argument, which is a list of parameter values. +#' The function specifies how this list of parameter values is aggregated to form one parameter value. +#' This is used in the context of inner tuning. The default is to aggregate the values. #' #' @return A `Domain` object. #' @@ -134,7 +138,8 @@ Domain = function(cls, grouping, trafo = NULL, depends_expr = NULL, storage_type = "list", - init) { + init, + aggr = NULL) { assert_string(cls) assert_string(grouping) @@ -146,7 +151,11 @@ Domain = function(cls, grouping, if (length(special_vals) && !is.null(trafo)) stop("trafo and special_values can not both be given at the same time.") assert_character(tags, any.missing = FALSE, unique = TRUE) assert_function(trafo, null.ok = TRUE) + assert_function(aggr, null.ok = TRUE, nargs = 1L) + if (is.null(aggr) && "inner_tuning" %in% tags) { + aggr = default_aggr + } # depends may be an expression, but may also be quote() or expression() if (length(depends_expr) == 1) { @@ -168,9 +177,9 @@ Domain = function(cls, grouping, .tags = list(tags), .trafo = list(trafo), .requirements = list(parse_depends(depends_expr, parent.frame(2))), - .init_given = !missing(init), - .init = list(if (!missing(init)) init) + .init = list(if (!missing(init)) init), + .aggr = list(aggr) ) class(param) = c(cls, "Domain", class(param)) @@ -215,7 +224,8 @@ empty_domain = data.table(id = character(0), cls = character(0), grouping = char .trafo = list(), .requirements = list(), .init_given = logical(0), - .init = list() + .init = list(), + .aggr = list() ) domain_names = names(empty_domain) diff --git a/R/ParamDbl.R b/R/ParamDbl.R index 73000591..609ab14c 100644 --- a/R/ParamDbl.R +++ b/R/ParamDbl.R @@ -1,6 +1,6 @@ #' @rdname Domain #' @export -p_dbl = function(lower = -Inf, upper = Inf, special_vals = list(), default = NO_DEF, tags = character(), tolerance = sqrt(.Machine$double.eps), depends = NULL, trafo = NULL, logscale = FALSE, init) { +p_dbl = function(lower = -Inf, upper = Inf, special_vals = list(), default = NO_DEF, tags = character(), tolerance = sqrt(.Machine$double.eps), depends = NULL, trafo = NULL, logscale = FALSE, init, aggr = NULL) { assert_number(tolerance, lower = 0) assert_number(lower) assert_number(upper) @@ -18,7 +18,7 @@ p_dbl = function(lower = -Inf, upper = Inf, special_vals = list(), default = NO_ } Domain(cls = "ParamDbl", grouping = "ParamDbl", lower = real_lower, upper = real_upper, special_vals = special_vals, default = default, tags = tags, tolerance = tolerance, trafo = trafo, storage_type = "numeric", - depends_expr = substitute(depends), init = init, cargo = if (logscale) "logscale") + depends_expr = substitute(depends), init = init, cargo = if (logscale) "logscale", aggr = aggr) } #' @export diff --git a/R/ParamFct.R b/R/ParamFct.R index cad930db..38e0b3b6 100644 --- a/R/ParamFct.R +++ b/R/ParamFct.R @@ -1,6 +1,6 @@ #' @rdname Domain #' @export -p_fct = function(levels, special_vals = list(), default = NO_DEF, tags = character(), depends = NULL, trafo = NULL, init) { +p_fct = function(levels, special_vals = list(), default = NO_DEF, tags = character(), depends = NULL, trafo = NULL, init, aggr = NULL) { constargs = as.list(match.call()[-1]) levels = eval.parent(constargs$levels) if (!is.character(levels)) { @@ -22,7 +22,7 @@ p_fct = function(levels, special_vals = list(), default = NO_DEF, tags = charact # group p_fct by levels, so the group can be checked in a vectorized fashion. # We escape '"' and '\' to '\"' and '\\', respectively. grouping = str_collapse(gsub("([\\\\\"])", "\\\\\\1", sort(real_levels)), quote = '"', sep = ",") - Domain(cls = "ParamFct", grouping = grouping, levels = real_levels, special_vals = special_vals, default = default, tags = tags, trafo = trafo, storage_type = "character", depends_expr = substitute(depends), init = init) + Domain(cls = "ParamFct", grouping = grouping, levels = real_levels, special_vals = special_vals, default = default, tags = tags, trafo = trafo, storage_type = "character", depends_expr = substitute(depends), init = init, aggr = aggr) } #' @export diff --git a/R/ParamInt.R b/R/ParamInt.R index 495d4d86..4a273cac 100644 --- a/R/ParamInt.R +++ b/R/ParamInt.R @@ -1,7 +1,7 @@ #' @rdname Domain #' @export -p_int = function(lower = -Inf, upper = Inf, special_vals = list(), default = NO_DEF, tags = character(), tolerance = sqrt(.Machine$double.eps), depends = NULL, trafo = NULL, logscale = FALSE, init) { +p_int = function(lower = -Inf, upper = Inf, special_vals = list(), default = NO_DEF, tags = character(), tolerance = sqrt(.Machine$double.eps), depends = NULL, trafo = NULL, logscale = FALSE, init, aggr = NULL) { assert_number(tolerance, lower = 0, upper = 0.5) # assert_int will stop for `Inf` values, which we explicitly allow as lower / upper bound if (!isTRUE(is.infinite(lower))) assert_int(lower, tol = 1e-300) else assert_number(lower) @@ -25,7 +25,7 @@ p_int = function(lower = -Inf, upper = Inf, special_vals = list(), default = NO_ Domain(cls = cls, grouping = cls, lower = real_lower, upper = real_upper, special_vals = special_vals, default = default, tags = tags, tolerance = tolerance, trafo = trafo, storage_type = storage_type, - depends_expr = substitute(depends), init = init, cargo = if (logscale) "logscale") + depends_expr = substitute(depends), init = init, cargo = if (logscale) "logscale", aggr = aggr) } #' @export diff --git a/R/ParamLgl.R b/R/ParamLgl.R index 123fe73b..9a0522ed 100644 --- a/R/ParamLgl.R +++ b/R/ParamLgl.R @@ -1,8 +1,8 @@ #' @rdname Domain #' @export -p_lgl = function(special_vals = list(), default = NO_DEF, tags = character(), depends = NULL, trafo = NULL, init) { +p_lgl = function(special_vals = list(), default = NO_DEF, tags = character(), depends = NULL, trafo = NULL, init, aggr = NULL) { Domain(cls = "ParamLgl", grouping = "ParamLgl", levels = c(TRUE, FALSE), special_vals = special_vals, default = default, - tags = tags, trafo = trafo, storage_type = "logical", depends_expr = substitute(depends), init = init) + tags = tags, trafo = trafo, storage_type = "logical", depends_expr = substitute(depends), init = init, aggr = aggr) } #' @export diff --git a/R/ParamSet.R b/R/ParamSet.R index e389c211..3c289f48 100644 --- a/R/ParamSet.R +++ b/R/ParamSet.R @@ -90,6 +90,10 @@ ParamSet = R6Class("ParamSet", private$.trafos = setkeyv(paramtbl[!map_lgl(.trafo, is.null), .(id, trafo = .trafo)], "id") } + if (".aggr" %in% names(paramtbl)) { + private$.aggrs = setkeyv(paramtbl[!map_lgl(.aggr, is.null), .(id, aggr = .aggr)], "id") + } + if (".requirements" %in% names(paramtbl)) { requirements = paramtbl$.requirements private$.params = paramtbl # self$add_dep needs this @@ -645,7 +649,7 @@ ParamSet = R6Class("ParamSet", if (nrow(deps)) { # add a nice extra charvec-col to the tab, which lists all parents-ids on = NULL dd = deps[, list(parents = list(unlist(on))), by = "id"] - d = merge(d, dd, on = "id", all.x = TRUE) + d = merge(d, dd, by = "id", all.x = TRUE) } v = named_list(d$id) # add values to last col of print-dt as list col v = insert_named(v, self$values) @@ -872,6 +876,7 @@ ParamSet = R6Class("ParamSet", .tags = data.table(id = character(0L), tag = character(0), key = "id"), .deps = data.table(id = character(0L), on = character(0L), cond = list()), .trafos = data.table(id = character(0L), trafo = list(), key = "id"), + .aggrs = data.table(id = character(0L), aggr = list(), key = "id"), get_tune_ps = function(values) { values = keep(values, inherits, "TuneToken") diff --git a/R/ParamUty.R b/R/ParamUty.R index 61cbd143..b487d7c6 100644 --- a/R/ParamUty.R +++ b/R/ParamUty.R @@ -1,7 +1,7 @@ #' @rdname Domain #' @export -p_uty = function(custom_check = NULL, special_vals = list(), default = NO_DEF, tags = character(), depends = NULL, trafo = NULL, repr = substitute(default), init) { +p_uty = function(custom_check = NULL, special_vals = list(), default = NO_DEF, tags = character(), depends = NULL, trafo = NULL, repr = substitute(default), init, aggr = NULL) { assert_function(custom_check, null.ok = TRUE) if (!is.null(custom_check)) { custom_check_result = custom_check(1) @@ -12,7 +12,7 @@ p_uty = function(custom_check = NULL, special_vals = list(), default = NO_DEF, t } else { "NoDefault" } - Domain(cls = "ParamUty", grouping = "ParamUty", cargo = list(custom_check = custom_check, repr = repr), special_vals = special_vals, default = default, tags = tags, trafo = trafo, storage_type = "list", depends_expr = substitute(depends), init = init) + Domain(cls = "ParamUty", grouping = "ParamUty", cargo = list(custom_check = custom_check, repr = repr), special_vals = special_vals, default = default, tags = tags, trafo = trafo, storage_type = "list", depends_expr = substitute(depends), init = init, aggr = aggr) } #' @export diff --git a/R/helper.R b/R/helper.R index 53ca3f44..70268885 100644 --- a/R/helper.R +++ b/R/helper.R @@ -53,3 +53,11 @@ col_to_nl = function(dt, col = 1, idcol = 2) { names(data) = dt[[idcol]] data } + +default_aggr = function(x) { + if (!test_numeric(x[[1]], len = 1L)) { + stopf("Provide a custom aggregator for non-numeric and non-scalar parameters.") + } + ceiling(mean(unlist(x))) +} + diff --git a/R/to_tune.R b/R/to_tune.R index 9c41e6ea..a069d0fb 100644 --- a/R/to_tune.R +++ b/R/to_tune.R @@ -183,6 +183,27 @@ to_tune = function(...) { set_class(list(content = content, call = deparse1(call)), c(type, "TuneToken")) } +#' @title Create an Inner Tuning Token +#' @description +#' Works just like [`to_tune()`], but marks the parameter for inner tuning. +#' See [`mlr3::Learner`] for more information. +#' @inheritParams to_tune +#' @param aggr (`function`)\cr +#' The aggregator function that determines how to aggregate a list of parameter values into one value. +#' a single parameter value. The default is to average them. +#' @export +in_tune = function(..., aggr = NULL) { + if (is.null(aggr)) { + aggr = default_aggr + } else { + test_function(aggr, nargs = 1L) + } + tt = to_tune(...) + tt$aggr = aggr + tt = set_class(tt, classes = c("InnerTuneToken", class(tt))) + return(tt) +} + #' @export print.FullTuneToken = function(x, ...) { catf("Tuning over:\n\n", @@ -201,6 +222,12 @@ print.ObjectTuneToken = function(x, ...) { print(x$content) } +#' @export +print.InnerTuneToken = function(x, ...) { + cat("Inner ") + NextMethod() +} + # tunetoken_to_ps: Convert a `TuneToken` to a `ParamSet` that tunes over this. # Needs the corresponding `Domain` to which the `TuneToken` refers, both to # get the range (e.g. if `to_tune()` was used) and to verify that the `TuneToken` @@ -212,6 +239,13 @@ tunetoken_to_ps = function(tt, param) { UseMethod("tunetoken_to_ps") } +tunetoken_to_ps.InnerTuneToken = function(tt, params) { + ps = NextMethod() + browser() + ps$tags = map(ps$tags, function(tags) union(tags, "inner_tune")) + return(ps) +} + tunetoken_to_ps.FullTuneToken = function(tt, param) { if (!domain_is_bounded(param)) { stopf("%s must give a range for unbounded parameter %s.", tt$call, param$id) @@ -224,6 +258,7 @@ tunetoken_to_ps.FullTuneToken = function(tt, param) { } } + tunetoken_to_ps.RangeTuneToken = function(tt, param) { if (!domain_is_number(param)) { stopf("%s for non-numeric param must have zero or one argument.", tt$call) diff --git a/man/Domain.Rd b/man/Domain.Rd index d0c25a03..553d4cd1 100644 --- a/man/Domain.Rd +++ b/man/Domain.Rd @@ -20,7 +20,8 @@ p_dbl( depends = NULL, trafo = NULL, logscale = FALSE, - init + init, + aggr = NULL ) p_fct( @@ -30,7 +31,8 @@ p_fct( tags = character(), depends = NULL, trafo = NULL, - init + init, + aggr = NULL ) p_int( @@ -43,7 +45,8 @@ p_int( depends = NULL, trafo = NULL, logscale = FALSE, - init + init, + aggr = NULL ) p_lgl( @@ -52,7 +55,8 @@ p_lgl( tags = character(), depends = NULL, trafo = NULL, - init + init, + aggr = NULL ) p_uty( @@ -63,7 +67,8 @@ p_uty( depends = NULL, trafo = NULL, repr = substitute(default), - init + init, + aggr = NULL ) } \arguments{ @@ -139,6 +144,11 @@ defining domains or hyperparameter ranges of learning algorithms, because these Initial value. When this is given, then the corresponding entry in \code{ParamSet$values} is initialized with this value upon construction.} +\item{aggr}{(\code{function})\cr +Function with one argument, which is a list of parameter values. +The function specifies how this list of parameter values is aggregated to form one parameter value. +This is used in the context of inner tuning. The default is to aggregate the values.} + \item{levels}{(\code{character} | \code{atomic} | \code{list})\cr Allowed categorical values of the parameter. If this is not a \code{character}, then a \code{trafo} is generated that converts the names (if not given: \code{as.character()} of the values) of the \code{levels} argument to the values. diff --git a/man/Sampler.Rd b/man/Sampler.Rd index f66fc8d4..178c7a07 100644 --- a/man/Sampler.Rd +++ b/man/Sampler.Rd @@ -8,11 +8,11 @@ This is the abstract base class for sampling objects like \link{Sampler1D}, \lin } \seealso{ Other Sampler: +\code{\link{Sampler1D}}, \code{\link{Sampler1DCateg}}, \code{\link{Sampler1DNormal}}, \code{\link{Sampler1DRfun}}, \code{\link{Sampler1DUnif}}, -\code{\link{Sampler1D}}, \code{\link{SamplerHierarchical}}, \code{\link{SamplerJointIndep}}, \code{\link{SamplerUnif}} diff --git a/man/Sampler1D.Rd b/man/Sampler1D.Rd index b2e6a4d2..f1acc2bd 100644 --- a/man/Sampler1D.Rd +++ b/man/Sampler1D.Rd @@ -9,14 +9,14 @@ } \seealso{ Other Sampler: +\code{\link{Sampler}}, \code{\link{Sampler1DCateg}}, \code{\link{Sampler1DNormal}}, \code{\link{Sampler1DRfun}}, \code{\link{Sampler1DUnif}}, \code{\link{SamplerHierarchical}}, \code{\link{SamplerJointIndep}}, -\code{\link{SamplerUnif}}, -\code{\link{Sampler}} +\code{\link{SamplerUnif}} } \concept{Sampler} \section{Super class}{ diff --git a/man/Sampler1DCateg.Rd b/man/Sampler1DCateg.Rd index 8ebbf668..979ff08b 100644 --- a/man/Sampler1DCateg.Rd +++ b/man/Sampler1DCateg.Rd @@ -8,14 +8,14 @@ Sampling from a discrete distribution, for a \code{\link{ParamSet}} containing a } \seealso{ Other Sampler: +\code{\link{Sampler}}, +\code{\link{Sampler1D}}, \code{\link{Sampler1DNormal}}, \code{\link{Sampler1DRfun}}, \code{\link{Sampler1DUnif}}, -\code{\link{Sampler1D}}, \code{\link{SamplerHierarchical}}, \code{\link{SamplerJointIndep}}, -\code{\link{SamplerUnif}}, -\code{\link{Sampler}} +\code{\link{SamplerUnif}} } \concept{Sampler} \section{Super classes}{ diff --git a/man/Sampler1DNormal.Rd b/man/Sampler1DNormal.Rd index c043e228..cc5e2d1e 100644 --- a/man/Sampler1DNormal.Rd +++ b/man/Sampler1DNormal.Rd @@ -8,14 +8,14 @@ Normal sampling (potentially truncated) for \code{\link[=p_dbl]{p_dbl()}}. } \seealso{ Other Sampler: +\code{\link{Sampler}}, +\code{\link{Sampler1D}}, \code{\link{Sampler1DCateg}}, \code{\link{Sampler1DRfun}}, \code{\link{Sampler1DUnif}}, -\code{\link{Sampler1D}}, \code{\link{SamplerHierarchical}}, \code{\link{SamplerJointIndep}}, -\code{\link{SamplerUnif}}, -\code{\link{Sampler}} +\code{\link{SamplerUnif}} } \concept{Sampler} \section{Super classes}{ diff --git a/man/Sampler1DRfun.Rd b/man/Sampler1DRfun.Rd index dd479245..c2d65417 100644 --- a/man/Sampler1DRfun.Rd +++ b/man/Sampler1DRfun.Rd @@ -8,14 +8,14 @@ Arbitrary sampling from 1D RNG functions from R. } \seealso{ Other Sampler: +\code{\link{Sampler}}, +\code{\link{Sampler1D}}, \code{\link{Sampler1DCateg}}, \code{\link{Sampler1DNormal}}, \code{\link{Sampler1DUnif}}, -\code{\link{Sampler1D}}, \code{\link{SamplerHierarchical}}, \code{\link{SamplerJointIndep}}, -\code{\link{SamplerUnif}}, -\code{\link{Sampler}} +\code{\link{SamplerUnif}} } \concept{Sampler} \section{Super classes}{ diff --git a/man/Sampler1DUnif.Rd b/man/Sampler1DUnif.Rd index 0254bf28..5bb3b6c1 100644 --- a/man/Sampler1DUnif.Rd +++ b/man/Sampler1DUnif.Rd @@ -8,14 +8,14 @@ Uniform random sampler for arbitrary (bounded) parameters. } \seealso{ Other Sampler: +\code{\link{Sampler}}, +\code{\link{Sampler1D}}, \code{\link{Sampler1DCateg}}, \code{\link{Sampler1DNormal}}, \code{\link{Sampler1DRfun}}, -\code{\link{Sampler1D}}, \code{\link{SamplerHierarchical}}, \code{\link{SamplerJointIndep}}, -\code{\link{SamplerUnif}}, -\code{\link{Sampler}} +\code{\link{SamplerUnif}} } \concept{Sampler} \section{Super classes}{ diff --git a/man/SamplerHierarchical.Rd b/man/SamplerHierarchical.Rd index e0b6f9fd..9973825c 100644 --- a/man/SamplerHierarchical.Rd +++ b/man/SamplerHierarchical.Rd @@ -10,14 +10,14 @@ and if dependencies do not hold, values are set to \code{NA} in the resulting \c } \seealso{ Other Sampler: +\code{\link{Sampler}}, +\code{\link{Sampler1D}}, \code{\link{Sampler1DCateg}}, \code{\link{Sampler1DNormal}}, \code{\link{Sampler1DRfun}}, \code{\link{Sampler1DUnif}}, -\code{\link{Sampler1D}}, \code{\link{SamplerJointIndep}}, -\code{\link{SamplerUnif}}, -\code{\link{Sampler}} +\code{\link{SamplerUnif}} } \concept{Sampler} \section{Super class}{ diff --git a/man/SamplerJointIndep.Rd b/man/SamplerJointIndep.Rd index ee4a3349..b05a487f 100644 --- a/man/SamplerJointIndep.Rd +++ b/man/SamplerJointIndep.Rd @@ -8,14 +8,14 @@ Create joint, independent sampler out of multiple other samplers. } \seealso{ Other Sampler: +\code{\link{Sampler}}, +\code{\link{Sampler1D}}, \code{\link{Sampler1DCateg}}, \code{\link{Sampler1DNormal}}, \code{\link{Sampler1DRfun}}, \code{\link{Sampler1DUnif}}, -\code{\link{Sampler1D}}, \code{\link{SamplerHierarchical}}, -\code{\link{SamplerUnif}}, -\code{\link{Sampler}} +\code{\link{SamplerUnif}} } \concept{Sampler} \section{Super class}{ diff --git a/man/SamplerUnif.Rd b/man/SamplerUnif.Rd index 77a038b3..df6283d6 100644 --- a/man/SamplerUnif.Rd +++ b/man/SamplerUnif.Rd @@ -10,14 +10,14 @@ Hence, also works for \link{ParamSet}s sets with dependencies. } \seealso{ Other Sampler: +\code{\link{Sampler}}, +\code{\link{Sampler1D}}, \code{\link{Sampler1DCateg}}, \code{\link{Sampler1DNormal}}, \code{\link{Sampler1DRfun}}, \code{\link{Sampler1DUnif}}, -\code{\link{Sampler1D}}, \code{\link{SamplerHierarchical}}, -\code{\link{SamplerJointIndep}}, -\code{\link{Sampler}} +\code{\link{SamplerJointIndep}} } \concept{Sampler} \section{Super classes}{ diff --git a/man/in_tune.Rd b/man/in_tune.Rd new file mode 100644 index 00000000..080a8447 --- /dev/null +++ b/man/in_tune.Rd @@ -0,0 +1,19 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/to_tune.R +\name{in_tune} +\alias{in_tune} +\title{Create an Inner Tuning Token} +\usage{ +in_tune(..., aggr = NULL) +} +\arguments{ +\item{...}{if given, restricts the range to be tuning over, as described above.} + +\item{aggr}{(\code{function})\cr +The aggregator function that determines how to aggregate a list of parameter values into one value. +a single parameter value. The default is to average them.} +} +\description{ +Works just like \code{\link[=to_tune]{to_tune()}}, but marks the parameter for inner tuning. +See \code{\link[mlr3:Learner]{mlr3::Learner}} for more information. +} From 8fc9f5f79db805d0f8bc1dc4afec51d297d12814 Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Tue, 16 Apr 2024 15:11:26 +0200 Subject: [PATCH 02/34] add aggr function to ParamSet, tests --- R/Domain.R | 2 +- R/ParamSet.R | 22 ++++++++++++++++++++++ R/to_tune.R | 7 +++---- man/Domain.Rd | 2 +- man/ParamSet.Rd | 23 +++++++++++++++++++++++ man/ParamSetCollection.Rd | 1 + man/in_tune.Rd | 4 ++-- tests/testthat/test_ParamSet.R | 19 +++++++++++++++++++ tests/testthat/test_domain.R | 13 +++++++++++++ 9 files changed, 85 insertions(+), 8 deletions(-) diff --git a/R/Domain.R b/R/Domain.R index e9224a63..07feae10 100644 --- a/R/Domain.R +++ b/R/Domain.R @@ -65,7 +65,7 @@ #' value upon construction. #' @param aggr (`function`)\cr #' Function with one argument, which is a list of parameter values. -#' The function specifies how this list of parameter values is aggregated to form one parameter value. +#' The function specifies how a list of parameter values is aggregated to form one parameter value. #' This is used in the context of inner tuning. The default is to aggregate the values. #' #' @return A `Domain` object. diff --git a/R/ParamSet.R b/R/ParamSet.R index 3c289f48..588091c9 100644 --- a/R/ParamSet.R +++ b/R/ParamSet.R @@ -259,6 +259,28 @@ ParamSet = R6Class("ParamSet", x }, + #' @description + #' + #' Aggregate parameter values according to the aggregation rules. + #' + #' @param x (named `list()` of `list()`s)\cr + #' The value(s) to be aggregated. Names are parameter values. + #' The aggregation function is selected accordingly for each parameter. + #' @return (named `list()`) + aggr = function(x) { + assert_list(x, types = "list") + assert_permutation(names(x), private$.aggrs$id) + if (!(length(unique(lengths(x))) == 1L)) { + stopf("The same number of values are required for each parameter") + } + if (nrow(private$.aggrs) && !length(x[[1L]])) { + stopf("More than one value is required to aggregate them") + } + imap(x, function(value, .id) { + aggr = private$.aggrs[list(.id), "aggr", on = "id"][[1L]][[1L]](value) + }) + }, + #' @description #' \pkg{checkmate}-like test-function. Takes a named list. #' Return `FALSE` if the given `$constraint` is not satisfied, `TRUE` otherwise. diff --git a/R/to_tune.R b/R/to_tune.R index a069d0fb..4eb3955a 100644 --- a/R/to_tune.R +++ b/R/to_tune.R @@ -189,8 +189,8 @@ to_tune = function(...) { #' See [`mlr3::Learner`] for more information. #' @inheritParams to_tune #' @param aggr (`function`)\cr -#' The aggregator function that determines how to aggregate a list of parameter values into one value. -#' a single parameter value. The default is to average them. +#' The aggregator function that determines how to aggregate a list of parameter values into a single parameter value. +#' The default is to average the values and round them up. #' @export in_tune = function(..., aggr = NULL) { if (is.null(aggr)) { @@ -241,8 +241,7 @@ tunetoken_to_ps = function(tt, param) { tunetoken_to_ps.InnerTuneToken = function(tt, params) { ps = NextMethod() - browser() - ps$tags = map(ps$tags, function(tags) union(tags, "inner_tune")) + ps$tags = map(ps$tags, function(tags) union(tags, "inner_tuning")) return(ps) } diff --git a/man/Domain.Rd b/man/Domain.Rd index 553d4cd1..1af85d2d 100644 --- a/man/Domain.Rd +++ b/man/Domain.Rd @@ -146,7 +146,7 @@ value upon construction.} \item{aggr}{(\code{function})\cr Function with one argument, which is a list of parameter values. -The function specifies how this list of parameter values is aggregated to form one parameter value. +The function specifies how a list of parameter values is aggregated to form one parameter value. This is used in the context of inner tuning. The default is to aggregate the values.} \item{levels}{(\code{character} | \code{atomic} | \code{list})\cr diff --git a/man/ParamSet.Rd b/man/ParamSet.Rd index b74daebc..165905cd 100644 --- a/man/ParamSet.Rd +++ b/man/ParamSet.Rd @@ -172,6 +172,7 @@ Named with param IDs.} \item \href{#method-ParamSet-get_values}{\code{ParamSet$get_values()}} \item \href{#method-ParamSet-set_values}{\code{ParamSet$set_values()}} \item \href{#method-ParamSet-trafo}{\code{ParamSet$trafo()}} +\item \href{#method-ParamSet-aggr}{\code{ParamSet$aggr()}} \item \href{#method-ParamSet-test_constraint}{\code{ParamSet$test_constraint()}} \item \href{#method-ParamSet-test_constraint_dt}{\code{ParamSet$test_constraint_dt()}} \item \href{#method-ParamSet-check}{\code{ParamSet$check()}} @@ -340,6 +341,28 @@ In almost all cases, the default \code{param_set = self} should be used.} } } \if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ParamSet-aggr}{}}} +\subsection{Method \code{aggr()}}{ +Aggregate parameter values according to the aggregation rules. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ParamSet$aggr(x)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{x}}{(named \code{list()} of \code{list()}s)\cr +The value(s) to be aggregated. Names are parameter values. +The aggregation function is selected accordingly for each parameter.} +} +\if{html}{\out{
}} +} +\subsection{Returns}{ +(named \code{list()}) +} +} +\if{html}{\out{
}} \if{html}{\out{}} \if{latex}{\out{\hypertarget{method-ParamSet-test_constraint}{}}} \subsection{Method \code{test_constraint()}}{ diff --git a/man/ParamSetCollection.Rd b/man/ParamSetCollection.Rd index 93e4c93e..3a863f06 100644 --- a/man/ParamSetCollection.Rd +++ b/man/ParamSetCollection.Rd @@ -79,6 +79,7 @@ This field provides direct references to the \code{\link{ParamSet}} objects.}
Inherited methods
  • paradox::ParamSet$add_dep()
  • +
  • paradox::ParamSet$aggr()
  • paradox::ParamSet$assert()
  • paradox::ParamSet$assert_dt()
  • paradox::ParamSet$check()
  • diff --git a/man/in_tune.Rd b/man/in_tune.Rd index 080a8447..b0f5be51 100644 --- a/man/in_tune.Rd +++ b/man/in_tune.Rd @@ -10,8 +10,8 @@ in_tune(..., aggr = NULL) \item{...}{if given, restricts the range to be tuning over, as described above.} \item{aggr}{(\code{function})\cr -The aggregator function that determines how to aggregate a list of parameter values into one value. -a single parameter value. The default is to average them.} +The aggregator function that determines how to aggregate a list of parameter values into a single parameter value. +The default is to average the values and round them up.} } \description{ Works just like \code{\link[=to_tune]{to_tune()}}, but marks the parameter for inner tuning. diff --git a/tests/testthat/test_ParamSet.R b/tests/testthat/test_ParamSet.R index 260e5538..bb220235 100644 --- a/tests/testthat/test_ParamSet.R +++ b/tests/testthat/test_ParamSet.R @@ -431,3 +431,22 @@ test_that("set_values allows to unset parameters by setting them to NULL", { param_set$set_values(.values = list(a = NULL), .insert = FALSE) expect_identical(param_set$values, list(a = NULL)) }) + +test_that("aggr", { + param_set = ps( + a = p_uty(aggr = function(x) "a"), + b = p_fct(levels = c("a", "b"), aggr = function(x) "b"), + c = p_lgl(aggr = function(x) "c"), + d = p_int(aggr = function(x) "d"), + e = p_dbl(aggr = function(x) "e") + ) + expect_class(param_set, "ParamSet") + + vals = param_set$aggr(list(a = list(1), b = list(1), c = list(1), d = list(1), e = list(1))) + expect_equal(vals, list(a = "a", b = "b", c = "c", d = "d", e = "e")) + + expect_error(param_set$aggr(1), "list") + expect_error(param_set$aggr(list(1)), "list") + expect_error(param_set$aggr(list(a = list(), b = list(), c = list(), d = list())), "permutation") + expect_error(param_set$aggr(list(a = list(), b = list(), c = list(), d = list(), e = list())), "More than one") +}) diff --git a/tests/testthat/test_domain.R b/tests/testthat/test_domain.R index 1a29fe1d..f18d3c32 100644 --- a/tests/testthat/test_domain.R +++ b/tests/testthat/test_domain.R @@ -347,3 +347,16 @@ test_that("$extra_trafo flag works", { search_space = pps$search_space() expect_false(search_space$has_extra_trafo) }) + +test_that("in_tune", { + it = in_tune(1) + expect_class(it, "InnerTuneToken") + expect_function(it$aggr) + tt = to_tune(1) + expect_equal(it$content, tt$content) + expect_equal(it$aggr(list(1, 2)), 2L) + + it1 = in_tune(aggr = function(x) min(unlist(x))) + expect_equal(it1$aggr(list(1, 2)), 1) + expect_true("inner_tuning" %in% ps(a = p_dbl(1, 10))$set_values(a = in_tune())$search_space()$tags) +}) From cc5c828b6b6c5d1688b82e7b74d17d0dea789579 Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Tue, 16 Apr 2024 18:56:12 +0200 Subject: [PATCH 03/34] ... --- NEWS.md | 3 +-- R/Design.R | 3 ++- R/Domain.R | 6 +----- R/ParamSet.R | 10 ++++++++-- R/ParamSetCollection.R | 5 +++++ R/to_tune.R | 21 ++++++++++++--------- man/Domain.Rd | 2 +- tests/testthat/test_ParamSet.R | 11 +++++++++++ tests/testthat/test_domain.R | 13 ++++++++----- tests/testthat/test_to_tune.R | 2 +- 10 files changed, 50 insertions(+), 26 deletions(-) diff --git a/NEWS.md b/NEWS.md index 1b2c7018..f3b1f9ab 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,7 +1,6 @@ # dev -* feat: added `aggr`(egation function) to `Domain` which can be used for inner -tuning. +* feat: added support for `aggr`(egation function) which can be used for inner tuning. # paradox 0.12.0 * Removed `Param` objects. `ParamSet` now uses a `data.table` internally; individual parameters are more like `Domain` objects now. `ParamSets` should be constructed using the `ps()` shorthand and `Domain` objects. This entails the following major changes: diff --git a/R/Design.R b/R/Design.R index 41a6b425..71106fb1 100644 --- a/R/Design.R +++ b/R/Design.R @@ -34,7 +34,8 @@ Design = R6Class("Design", # set fixed param vals to their constant values # FIXME: this might also be problematic for LHS # do we still create an LHS like this? - imap(param_set$values, function(v, n) set(data, j = n, value = v)) + + imap(param_set$values, function(v, n) {set(data, j = n, value = list(v))}) self$data = data if (param_set$has_deps) { private$set_deps_to_na() diff --git a/R/Domain.R b/R/Domain.R index 07feae10..652fa754 100644 --- a/R/Domain.R +++ b/R/Domain.R @@ -66,7 +66,7 @@ #' @param aggr (`function`)\cr #' Function with one argument, which is a list of parameter values. #' The function specifies how a list of parameter values is aggregated to form one parameter value. -#' This is used in the context of inner tuning. The default is to aggregate the values. +#' This is used in the context of inner tuning. The default is to aggregate the values and round up. #' #' @return A `Domain` object. #' @@ -153,10 +153,6 @@ Domain = function(cls, grouping, assert_function(trafo, null.ok = TRUE) assert_function(aggr, null.ok = TRUE, nargs = 1L) - if (is.null(aggr) && "inner_tuning" %in% tags) { - aggr = default_aggr - } - # depends may be an expression, but may also be quote() or expression() if (length(depends_expr) == 1) { depends_expr = eval(depends_expr, envir = parent.frame(2)) diff --git a/R/ParamSet.R b/R/ParamSet.R index 588091c9..e3ea6c48 100644 --- a/R/ParamSet.R +++ b/R/ParamSet.R @@ -97,7 +97,7 @@ ParamSet = R6Class("ParamSet", if (".requirements" %in% names(paramtbl)) { requirements = paramtbl$.requirements private$.params = paramtbl # self$add_dep needs this - for (row in seq_len(nrow(paramtbl))) { + for (row in seq_len(nrow(paramtbl))) { for (req in requirements[[row]]) { invoke(self$add_dep, id = paramtbl$id[[row]], allow_dangling_dependencies = allow_dangling_dependencies, .args = req) @@ -276,6 +276,7 @@ ParamSet = R6Class("ParamSet", if (nrow(private$.aggrs) && !length(x[[1L]])) { stopf("More than one value is required to aggregate them") } + imap(x, function(value, .id) { aggr = private$.aggrs[list(.id), "aggr", on = "id"][[1L]][[1L]](value) }) @@ -528,6 +529,7 @@ ParamSet = R6Class("ParamSet", .trafo = private$.trafos[id, trafo], .requirements = list(if (nrow(depstbl)) transpose_list(depstbl)), # NULL if no deps .init_given = id %in% names(vals), + .aggr = private$.aggrs[id, get("aggr")], .init = unname(vals[id])) ] @@ -562,6 +564,7 @@ ParamSet = R6Class("ParamSet", result$.__enclos_env__$private$.params = setindexv(private$.params[ids, on = "id"], c("id", "cls", "grouping")) result$.__enclos_env__$private$.trafos = setkeyv(private$.trafos[ids, on = "id", nomatch = NULL], "id") + result$.__enclos_env__$private$.aggrs = setkeyv(private$.aggrs[ids, on = "id", nomatch = NULL], "id") result$.__enclos_env__$private$.tags = setkeyv(private$.tags[ids, on = "id", nomatch = NULL], "id") result$assert_values = FALSE result$deps = deps[ids, on = "id", nomatch = NULL] @@ -589,6 +592,7 @@ ParamSet = R6Class("ParamSet", result$.__enclos_env__$private$.params = setindexv(private$.params[get_id, on = "id"], c("id", "cls", "grouping")) # setkeyv not strictly necessary since get_id is scalar, but we do it for consistency result$.__enclos_env__$private$.trafos = setkeyv(private$.trafos[get_id, on = "id", nomatch = NULL], "id") + result$.__enclos_env__$private$.aggrs = setkeyv(private$.aggrs[get_id, on = "id", nomatch = NULL], "id") result$.__enclos_env__$private$.tags = setkeyv(private$.tags[get_id, on = "id", nomatch = NULL], "id") result$assert_values = FALSE result$values = values[match(get_id, names(values), nomatch = 0)] @@ -740,6 +744,7 @@ ParamSet = R6Class("ParamSet", result = copy(private$.params) result[, .tags := list(self$tags)] result[private$.trafos, .trafo := list(trafo), on = "id"] + result[private$.aggrs, .aggr := list(aggr), on = "id"] result[self$deps, .requirements := transpose_list(.(on, cond)), on = "id"] vals = self$values result[, `:=`( @@ -904,13 +909,14 @@ ParamSet = R6Class("ParamSet", values = keep(values, inherits, "TuneToken") if (!length(values)) return(ParamSet$new()) params = map(names(values), function(pn) { - domain = private$.params[pn, on = "id"] + domain = self$params[pn, on = "id"] set_class(domain, c(domain$cls, "Domain", class(domain))) }) names(params) = names(values) # package-internal S3 fails if we don't call the function indirectly here partsets = pmap(list(values, params), function(...) tunetoken_to_ps(...)) + pars = ps_union(partsets) # partsets does not have names here, wihch is what we want. names(partsets) = names(values) diff --git a/R/ParamSetCollection.R b/R/ParamSetCollection.R index 2d604245..fbbfe8b4 100644 --- a/R/ParamSetCollection.R +++ b/R/ParamSetCollection.R @@ -66,6 +66,7 @@ ParamSetCollection = R6Class("ParamSetCollection", inherit = ParamSet, private$.tags = paramtbl[, .(tag = unique(unlist(.tags))), keyby = "id"] private$.trafos = setkeyv(paramtbl[!map_lgl(.trafo, is.null), .(id, trafo = .trafo)], "id") + private$.aggrs = setkeyv(paramtbl[!map_lgl(.aggr, is.null), .(id, aggr = .aggr)], "id") private$.translation = paramtbl[, c("id", "original_id", "owner_ps_index", "owner_name"), with = FALSE] setkeyv(private$.translation, "id") @@ -125,6 +126,10 @@ ParamSetCollection = R6Class("ParamSetCollection", inherit = ParamSet, if (nrow(newtrafos)) { private$.trafos = setkeyv(rbind(private$.trafos, newtrafos), "id") } + newaggrs = paramtbl[!map_lgl(.aggr, is.null), .(id, trafo = .aggr)] + if (nrow(newaggrs)) { + private$.aggrs = setkeyv(rbind(private$.aggrs, newaggrs), "id") + } private$.translation = rbind(private$.translation, paramtbl[, c("id", "original_id", "owner_ps_index", "owner_name"), with = FALSE]) setkeyv(private$.translation, "id") diff --git a/R/to_tune.R b/R/to_tune.R index 4eb3955a..bf375d44 100644 --- a/R/to_tune.R +++ b/R/to_tune.R @@ -193,13 +193,9 @@ to_tune = function(...) { #' The default is to average the values and round them up. #' @export in_tune = function(..., aggr = NULL) { - if (is.null(aggr)) { - aggr = default_aggr - } else { - test_function(aggr, nargs = 1L) - } + test_function(aggr, nargs = 1L, null.ok = TRUE) tt = to_tune(...) - tt$aggr = aggr + if (!is.null(aggr)) tt$content$aggr = aggr tt = set_class(tt, classes = c("InnerTuneToken", class(tt))) return(tt) } @@ -239,7 +235,11 @@ tunetoken_to_ps = function(tt, param) { UseMethod("tunetoken_to_ps") } -tunetoken_to_ps.InnerTuneToken = function(tt, params) { +tunetoken_to_ps.InnerTuneToken = function(tt, param) { + tt$content$aggr = tt$content$aggr %??% param$.aggr + if (is.null(tt$content$aggr)) { + stopf("%s (%s): Provide an aggregation function for inner tuning.", tt$call, param$id) + } ps = NextMethod() ps$tags = map(ps$tags, function(tags) union(tags, "inner_tuning")) return(ps) @@ -251,7 +251,7 @@ tunetoken_to_ps.FullTuneToken = function(tt, param) { } if (isTRUE(tt$content$logscale)) { if (!domain_is_number(param)) stop("%s (%s): logscale only valid for numeric / integer parameters.", tt$call, param$id) - tunetoken_to_ps.RangeTuneToken(list(content = list(logscale = tt$content$logscale), tt$call), param) + tunetoken_to_ps.RangeTuneToken(list(content = list(logscale = tt$content$logscale, aggr = tt$content$aggr), tt$call), param) } else { pslike_to_ps(param, tt$call, param) } @@ -264,6 +264,7 @@ tunetoken_to_ps.RangeTuneToken = function(tt, param) { } invalidpoints = discard(tt$content, function(x) is.null(x) || domain_test(param, set_names(list(x), param$id))) invalidpoints$logscale = NULL + invalidpoints$aggr = NULL if (length(invalidpoints)) { stopf("%s range not compatible with param %s.\nBad value(s):\n%s\nParameter:\n%s", tt$call, param$id, repr(invalidpoints), repr(param)) @@ -279,7 +280,9 @@ tunetoken_to_ps.RangeTuneToken = function(tt, param) { # create p_int / p_dbl object. Doesn't work if there is a numeric param class that we don't know about :-/ constructor = switch(param$cls, ParamInt = p_int, ParamDbl = p_dbl, stopf("%s: logscale for parameter %s of class %s not supported", tt$call, param$id, param$class)) - content = constructor(lower = bound_lower, upper = bound_upper, logscale = tt$content$logscale) + content = constructor(lower = bound_lower, upper = bound_upper, logscale = tt$content$logscale, + aggr = tt$content$aggr) + pslike_to_ps(content, tt$call, param) } diff --git a/man/Domain.Rd b/man/Domain.Rd index 1af85d2d..d34e3f5a 100644 --- a/man/Domain.Rd +++ b/man/Domain.Rd @@ -147,7 +147,7 @@ value upon construction.} \item{aggr}{(\code{function})\cr Function with one argument, which is a list of parameter values. The function specifies how a list of parameter values is aggregated to form one parameter value. -This is used in the context of inner tuning. The default is to aggregate the values.} +This is used in the context of inner tuning. The default is to aggregate the values and round up.} \item{levels}{(\code{character} | \code{atomic} | \code{list})\cr Allowed categorical values of the parameter. If this is not a \code{character}, then a \code{trafo} is generated that diff --git a/tests/testthat/test_ParamSet.R b/tests/testthat/test_ParamSet.R index bb220235..a20ddea0 100644 --- a/tests/testthat/test_ParamSet.R +++ b/tests/testthat/test_ParamSet.R @@ -450,3 +450,14 @@ test_that("aggr", { expect_error(param_set$aggr(list(a = list(), b = list(), c = list(), d = list())), "permutation") expect_error(param_set$aggr(list(a = list(), b = list(), c = list(), d = list(), e = list())), "More than one") }) + +test_that("in_tune", { + param_set = ps(a = p_dbl(lower = 1, upper = 2)) + param_set$set_values( + a = in_tune(lower = 1, upper = 2, aggr = function(x) 1.5) + ) + + ss = param_set$search_space() + + ss$aggr(list(a = list(1, 2))) +}) diff --git a/tests/testthat/test_domain.R b/tests/testthat/test_domain.R index f18d3c32..78f53a39 100644 --- a/tests/testthat/test_domain.R +++ b/tests/testthat/test_domain.R @@ -351,12 +351,15 @@ test_that("$extra_trafo flag works", { test_that("in_tune", { it = in_tune(1) expect_class(it, "InnerTuneToken") - expect_function(it$aggr) - tt = to_tune(1) + expect_null(it$aggr) + tt = in_tune(1) expect_equal(it$content, tt$content) - expect_equal(it$aggr(list(1, 2)), 2L) it1 = in_tune(aggr = function(x) min(unlist(x))) - expect_equal(it1$aggr(list(1, 2)), 1) - expect_true("inner_tuning" %in% ps(a = p_dbl(1, 10))$set_values(a = in_tune())$search_space()$tags) + expect_equal(it1$content$aggr(list(1, 2)), 1) + param_set = ps( + a = p_dbl(1, 10, aggr = default_aggr) + ) + param_set$set_values(a = in_tune()) + expect_class(param_set$values$a, "InnerTuneToken") }) diff --git a/tests/testthat/test_to_tune.R b/tests/testthat/test_to_tune.R index 339b989c..bd8e6576 100644 --- a/tests/testthat/test_to_tune.R +++ b/tests/testthat/test_to_tune.R @@ -396,5 +396,5 @@ test_that("logscale in tunetoken", { expect_output(print(to_tune(lower = 1, logscale = TRUE)), "range \\[1, \\.\\.\\.] \\(log scale\\)") expect_output(print(to_tune(upper = 1, logscale = TRUE)), "range \\[\\.\\.\\., 1] \\(log scale\\)") expect_output(print(to_tune(lower = 0, upper = 1, logscale = TRUE)), "range \\[0, 1] \\(log scale\\)") - + expect_output(print(in_tune()), "Inner") }) From 874cd7f629af3e1e6df06115a9b10c2264c93b3b Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Tue, 16 Apr 2024 19:21:18 +0200 Subject: [PATCH 04/34] fix bug --- R/ParamSet.R | 6 ++++-- R/to_tune.R | 16 ++++++++-------- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/R/ParamSet.R b/R/ParamSet.R index e3ea6c48..5cb64b5d 100644 --- a/R/ParamSet.R +++ b/R/ParamSet.R @@ -909,13 +909,15 @@ ParamSet = R6Class("ParamSet", values = keep(values, inherits, "TuneToken") if (!length(values)) return(ParamSet$new()) params = map(names(values), function(pn) { - domain = self$params[pn, on = "id"] + domain = private$.params[pn, on = "id"] set_class(domain, c(domain$cls, "Domain", class(domain))) }) names(params) = names(values) # package-internal S3 fails if we don't call the function indirectly here - partsets = pmap(list(values, params), function(...) tunetoken_to_ps(...)) + partsets = pmap(list(values, params), function(tt, param) { + tunetoken_to_ps(tt, param, param_aggr = private$.aggrs[list(param$id), "aggr", on = "id"][[1L]][[1L]]) + }) pars = ps_union(partsets) # partsets does not have names here, wihch is what we want. diff --git a/R/to_tune.R b/R/to_tune.R index bf375d44..442d7b9a 100644 --- a/R/to_tune.R +++ b/R/to_tune.R @@ -231,12 +231,12 @@ print.InnerTuneToken = function(x, ...) { # # Makes liberal use to `pslike_to_ps` (converting Param, ParamSet, Domain to ParamSet) # param is a data.table that is potentially modified by reference using data.table set() methods. -tunetoken_to_ps = function(tt, param) { +tunetoken_to_ps = function(tt, param, param_aggr) { UseMethod("tunetoken_to_ps") } -tunetoken_to_ps.InnerTuneToken = function(tt, param) { - tt$content$aggr = tt$content$aggr %??% param$.aggr +tunetoken_to_ps.InnerTuneToken = function(tt, param, param_aggr) { + tt$content$aggr = tt$content$aggr %??% param_aggr if (is.null(tt$content$aggr)) { stopf("%s (%s): Provide an aggregation function for inner tuning.", tt$call, param$id) } @@ -245,20 +245,20 @@ tunetoken_to_ps.InnerTuneToken = function(tt, param) { return(ps) } -tunetoken_to_ps.FullTuneToken = function(tt, param) { +tunetoken_to_ps.FullTuneToken = function(tt, param, param_aggr) { if (!domain_is_bounded(param)) { stopf("%s must give a range for unbounded parameter %s.", tt$call, param$id) } if (isTRUE(tt$content$logscale)) { if (!domain_is_number(param)) stop("%s (%s): logscale only valid for numeric / integer parameters.", tt$call, param$id) - tunetoken_to_ps.RangeTuneToken(list(content = list(logscale = tt$content$logscale, aggr = tt$content$aggr), tt$call), param) + tunetoken_to_ps.RangeTuneToken(list(content = list(logscale = tt$content$logscale, aggr = tt$content$aggr), tt$call), param, param_aggr) } else { - pslike_to_ps(param, tt$call, param) + pslike_to_ps(param, tt$call, param, param_aggr) } } -tunetoken_to_ps.RangeTuneToken = function(tt, param) { +tunetoken_to_ps.RangeTuneToken = function(tt, param, param_aggr) { if (!domain_is_number(param)) { stopf("%s for non-numeric param must have zero or one argument.", tt$call) } @@ -286,7 +286,7 @@ tunetoken_to_ps.RangeTuneToken = function(tt, param) { pslike_to_ps(content, tt$call, param) } -tunetoken_to_ps.ObjectTuneToken = function(tt, param) { +tunetoken_to_ps.ObjectTuneToken = function(tt, param, param_aggr) { pslike_to_ps(tt$content, tt$call, param) } From d83abfb2a6567856adffb4d7b17fb458fa1d9fbe Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Tue, 16 Apr 2024 19:24:05 +0200 Subject: [PATCH 05/34] fix bug --- R/Design.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/Design.R b/R/Design.R index 71106fb1..af0348bf 100644 --- a/R/Design.R +++ b/R/Design.R @@ -35,7 +35,7 @@ Design = R6Class("Design", # FIXME: this might also be problematic for LHS # do we still create an LHS like this? - imap(param_set$values, function(v, n) {set(data, j = n, value = list(v))}) + imap(param_set$values, function(v, n) {set(data, j = n, value = v)}) self$data = data if (param_set$has_deps) { private$set_deps_to_na() From 5fd31c2f41c31ecc2a5cf51f28aae90d8a7505fd Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Wed, 17 Apr 2024 19:06:33 +0200 Subject: [PATCH 06/34] ... --- R/ParamSet.R | 4 +--- R/to_tune.R | 19 +++++++++++-------- tests/testthat/test_domain.R | 2 +- tests/testthat/test_to_tune.R | 5 +++++ 4 files changed, 18 insertions(+), 12 deletions(-) diff --git a/R/ParamSet.R b/R/ParamSet.R index 5cb64b5d..bd0b9889 100644 --- a/R/ParamSet.R +++ b/R/ParamSet.R @@ -915,9 +915,7 @@ ParamSet = R6Class("ParamSet", names(params) = names(values) # package-internal S3 fails if we don't call the function indirectly here - partsets = pmap(list(values, params), function(tt, param) { - tunetoken_to_ps(tt, param, param_aggr = private$.aggrs[list(param$id), "aggr", on = "id"][[1L]][[1L]]) - }) + partsets = pmap(list(values, params), function(...) tunetoken_to_ps(..., param_set = param_set)) pars = ps_union(partsets) # partsets does not have names here, wihch is what we want. diff --git a/R/to_tune.R b/R/to_tune.R index 442d7b9a..75ef24c9 100644 --- a/R/to_tune.R +++ b/R/to_tune.R @@ -231,12 +231,15 @@ print.InnerTuneToken = function(x, ...) { # # Makes liberal use to `pslike_to_ps` (converting Param, ParamSet, Domain to ParamSet) # param is a data.table that is potentially modified by reference using data.table set() methods. -tunetoken_to_ps = function(tt, param, param_aggr) { +tunetoken_to_ps = function(tt, param, param_set) { UseMethod("tunetoken_to_ps") } -tunetoken_to_ps.InnerTuneToken = function(tt, param, param_aggr) { - tt$content$aggr = tt$content$aggr %??% param_aggr +tunetoken_to_ps.InnerTuneToken = function(tt, param, param_set) { + tt$content$aggr = tt$content$aggr %??% get_private(param_set)$.aggrs[list(param$id), "aggr", on = "id"][[1L]][[1L]] + if ("inner_tuning" %nin% param_set$tags[[param$id]]) { + stopf("%s (%s): Parameter not eligible for inner tuning", tt$call, param$id) + } if (is.null(tt$content$aggr)) { stopf("%s (%s): Provide an aggregation function for inner tuning.", tt$call, param$id) } @@ -245,20 +248,20 @@ tunetoken_to_ps.InnerTuneToken = function(tt, param, param_aggr) { return(ps) } -tunetoken_to_ps.FullTuneToken = function(tt, param, param_aggr) { +tunetoken_to_ps.FullTuneToken = function(tt, param, param_set) { if (!domain_is_bounded(param)) { stopf("%s must give a range for unbounded parameter %s.", tt$call, param$id) } if (isTRUE(tt$content$logscale)) { if (!domain_is_number(param)) stop("%s (%s): logscale only valid for numeric / integer parameters.", tt$call, param$id) - tunetoken_to_ps.RangeTuneToken(list(content = list(logscale = tt$content$logscale, aggr = tt$content$aggr), tt$call), param, param_aggr) + tunetoken_to_ps.RangeTuneToken(list(content = list(logscale = tt$content$logscale, aggr = tt$content$aggr), tt$call), param, param_set) } else { - pslike_to_ps(param, tt$call, param, param_aggr) + pslike_to_ps(param, tt$call, param, param_set) } } -tunetoken_to_ps.RangeTuneToken = function(tt, param, param_aggr) { +tunetoken_to_ps.RangeTuneToken = function(tt, param, param_set) { if (!domain_is_number(param)) { stopf("%s for non-numeric param must have zero or one argument.", tt$call) } @@ -286,7 +289,7 @@ tunetoken_to_ps.RangeTuneToken = function(tt, param, param_aggr) { pslike_to_ps(content, tt$call, param) } -tunetoken_to_ps.ObjectTuneToken = function(tt, param, param_aggr) { +tunetoken_to_ps.ObjectTuneToken = function(tt, param, param_set) { pslike_to_ps(tt$content, tt$call, param) } diff --git a/tests/testthat/test_domain.R b/tests/testthat/test_domain.R index 78f53a39..08c9f20e 100644 --- a/tests/testthat/test_domain.R +++ b/tests/testthat/test_domain.R @@ -358,7 +358,7 @@ test_that("in_tune", { it1 = in_tune(aggr = function(x) min(unlist(x))) expect_equal(it1$content$aggr(list(1, 2)), 1) param_set = ps( - a = p_dbl(1, 10, aggr = default_aggr) + a = p_dbl(1, 10, aggr = default_aggr, tags = "inner_tuning") ) param_set$set_values(a = in_tune()) expect_class(param_set$values$a, "InnerTuneToken") diff --git a/tests/testthat/test_to_tune.R b/tests/testthat/test_to_tune.R index bd8e6576..0e8ae92f 100644 --- a/tests/testthat/test_to_tune.R +++ b/tests/testthat/test_to_tune.R @@ -398,3 +398,8 @@ test_that("logscale in tunetoken", { expect_output(print(to_tune(lower = 0, upper = 1, logscale = TRUE)), "range \\[0, 1] \\(log scale\\)") expect_output(print(in_tune()), "Inner") }) + +test_that("inner tune", { + param_set = ps(a = p_int(1, 10)) + expect_error(param_set$set_values(a = in_tune(aggr = default_aggr)), "inner tuning") +}) From b6fee528981e5c702b4053fc62df49e7d34672d0 Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Mon, 22 Apr 2024 15:46:55 +0200 Subject: [PATCH 07/34] aggr is now part of cargo --- NAMESPACE | 1 - R/Domain.R | 10 +++----- R/ParamDbl.R | 7 +++++- R/ParamFct.R | 3 ++- R/ParamInt.R | 7 +++++- R/ParamLgl.R | 3 ++- R/ParamSet.R | 26 +++++++------------- R/ParamSetCollection.R | 5 ---- R/ParamUty.R | 6 ++++- R/helper.R | 8 ------ R/to_tune.R | 45 ++++++++++++++++++---------------- man/ParamSet.Rd | 4 +-- man/in_tune.Rd | 19 -------------- man/to_tune.Rd | 14 ++++++++--- tests/testthat/test_ParamSet.R | 14 ++++++----- tests/testthat/test_domain.R | 12 ++++----- tests/testthat/test_to_tune.R | 7 +----- 17 files changed, 85 insertions(+), 106 deletions(-) delete mode 100644 man/in_tune.Rd diff --git a/NAMESPACE b/NAMESPACE index 8863f291..35a657b1 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -86,7 +86,6 @@ export(generate_design_grid) export(generate_design_lhs) export(generate_design_random) export(generate_design_sobol) -export(in_tune) export(p_dbl) export(p_fct) export(p_int) diff --git a/R/Domain.R b/R/Domain.R index 652fa754..580abc1e 100644 --- a/R/Domain.R +++ b/R/Domain.R @@ -138,8 +138,7 @@ Domain = function(cls, grouping, trafo = NULL, depends_expr = NULL, storage_type = "list", - init, - aggr = NULL) { + init) { assert_string(cls) assert_string(grouping) @@ -151,7 +150,6 @@ Domain = function(cls, grouping, if (length(special_vals) && !is.null(trafo)) stop("trafo and special_values can not both be given at the same time.") assert_character(tags, any.missing = FALSE, unique = TRUE) assert_function(trafo, null.ok = TRUE) - assert_function(aggr, null.ok = TRUE, nargs = 1L) # depends may be an expression, but may also be quote() or expression() if (length(depends_expr) == 1) { @@ -174,8 +172,7 @@ Domain = function(cls, grouping, .trafo = list(trafo), .requirements = list(parse_depends(depends_expr, parent.frame(2))), .init_given = !missing(init), - .init = list(if (!missing(init)) init), - .aggr = list(aggr) + .init = list(if (!missing(init)) init) ) class(param) = c(cls, "Domain", class(param)) @@ -220,8 +217,7 @@ empty_domain = data.table(id = character(0), cls = character(0), grouping = char .trafo = list(), .requirements = list(), .init_given = logical(0), - .init = list(), - .aggr = list() + .init = list() ) domain_names = names(empty_domain) diff --git a/R/ParamDbl.R b/R/ParamDbl.R index 609ab14c..495da692 100644 --- a/R/ParamDbl.R +++ b/R/ParamDbl.R @@ -1,6 +1,7 @@ #' @rdname Domain #' @export p_dbl = function(lower = -Inf, upper = Inf, special_vals = list(), default = NO_DEF, tags = character(), tolerance = sqrt(.Machine$double.eps), depends = NULL, trafo = NULL, logscale = FALSE, init, aggr = NULL) { + assert_function(aggr, null.ok = TRUE, nargs = 1L) assert_number(tolerance, lower = 0) assert_number(lower) assert_number(upper) @@ -17,8 +18,12 @@ p_dbl = function(lower = -Inf, upper = Inf, special_vals = list(), default = NO_ real_upper = upper } + cargo = list() + if (logscale) cargo$logscale = TRUE + cargo$aggr = aggr + Domain(cls = "ParamDbl", grouping = "ParamDbl", lower = real_lower, upper = real_upper, special_vals = special_vals, default = default, tags = tags, tolerance = tolerance, trafo = trafo, storage_type = "numeric", - depends_expr = substitute(depends), init = init, cargo = if (logscale) "logscale", aggr = aggr) + depends_expr = substitute(depends), init = init, cargo = if (length(cargo)) cargo) } #' @export diff --git a/R/ParamFct.R b/R/ParamFct.R index 38e0b3b6..13e986c4 100644 --- a/R/ParamFct.R +++ b/R/ParamFct.R @@ -1,6 +1,7 @@ #' @rdname Domain #' @export p_fct = function(levels, special_vals = list(), default = NO_DEF, tags = character(), depends = NULL, trafo = NULL, init, aggr = NULL) { + assert_function(aggr, null.ok = TRUE, nargs = 1L) constargs = as.list(match.call()[-1]) levels = eval.parent(constargs$levels) if (!is.character(levels)) { @@ -22,7 +23,7 @@ p_fct = function(levels, special_vals = list(), default = NO_DEF, tags = charact # group p_fct by levels, so the group can be checked in a vectorized fashion. # We escape '"' and '\' to '\"' and '\\', respectively. grouping = str_collapse(gsub("([\\\\\"])", "\\\\\\1", sort(real_levels)), quote = '"', sep = ",") - Domain(cls = "ParamFct", grouping = grouping, levels = real_levels, special_vals = special_vals, default = default, tags = tags, trafo = trafo, storage_type = "character", depends_expr = substitute(depends), init = init, aggr = aggr) + Domain(cls = "ParamFct", grouping = grouping, levels = real_levels, special_vals = special_vals, default = default, tags = tags, trafo = trafo, storage_type = "character", depends_expr = substitute(depends), init = init, cargo = if (!is.null(aggr)) list(aggr = aggr)) } #' @export diff --git a/R/ParamInt.R b/R/ParamInt.R index 4a273cac..41d06d8c 100644 --- a/R/ParamInt.R +++ b/R/ParamInt.R @@ -2,6 +2,7 @@ #' @rdname Domain #' @export p_int = function(lower = -Inf, upper = Inf, special_vals = list(), default = NO_DEF, tags = character(), tolerance = sqrt(.Machine$double.eps), depends = NULL, trafo = NULL, logscale = FALSE, init, aggr = NULL) { + assert_function(aggr, null.ok = TRUE, nargs = 1L) assert_number(tolerance, lower = 0, upper = 0.5) # assert_int will stop for `Inf` values, which we explicitly allow as lower / upper bound if (!isTRUE(is.infinite(lower))) assert_int(lower, tol = 1e-300) else assert_number(lower) @@ -23,9 +24,13 @@ p_int = function(lower = -Inf, upper = Inf, special_vals = list(), default = NO_ real_upper = upper } + cargo = list() + if (logscale) cargo$logscale = TRUE + cargo$aggr = aggr + Domain(cls = cls, grouping = cls, lower = real_lower, upper = real_upper, special_vals = special_vals, default = default, tags = tags, tolerance = tolerance, trafo = trafo, storage_type = storage_type, - depends_expr = substitute(depends), init = init, cargo = if (logscale) "logscale", aggr = aggr) + depends_expr = substitute(depends), init = init, cargo = if (length(cargo)) cargo) } #' @export diff --git a/R/ParamLgl.R b/R/ParamLgl.R index 9a0522ed..b46d7cb5 100644 --- a/R/ParamLgl.R +++ b/R/ParamLgl.R @@ -1,8 +1,9 @@ #' @rdname Domain #' @export p_lgl = function(special_vals = list(), default = NO_DEF, tags = character(), depends = NULL, trafo = NULL, init, aggr = NULL) { + assert_function(aggr, null.ok = TRUE, nargs = 1L) Domain(cls = "ParamLgl", grouping = "ParamLgl", levels = c(TRUE, FALSE), special_vals = special_vals, default = default, - tags = tags, trafo = trafo, storage_type = "logical", depends_expr = substitute(depends), init = init, aggr = aggr) + tags = tags, trafo = trafo, storage_type = "logical", depends_expr = substitute(depends), init = init, cargo = if (!is.null(aggr)) list(aggr = aggr)) } #' @export diff --git a/R/ParamSet.R b/R/ParamSet.R index bd0b9889..d3204894 100644 --- a/R/ParamSet.R +++ b/R/ParamSet.R @@ -28,7 +28,7 @@ #' - special_vals: list col of list #' - default: list col #' - storage_type: character -#' - tags: list col of character vectors +#' - tags: list col of character vectorssearch #' @examples #' pset = ParamSet$new( #' params = list( @@ -90,10 +90,6 @@ ParamSet = R6Class("ParamSet", private$.trafos = setkeyv(paramtbl[!map_lgl(.trafo, is.null), .(id, trafo = .trafo)], "id") } - if (".aggr" %in% names(paramtbl)) { - private$.aggrs = setkeyv(paramtbl[!map_lgl(.aggr, is.null), .(id, aggr = .aggr)], "id") - } - if (".requirements" %in% names(paramtbl)) { requirements = paramtbl$.requirements private$.params = paramtbl # self$add_dep needs this @@ -265,20 +261,21 @@ ParamSet = R6Class("ParamSet", #' #' @param x (named `list()` of `list()`s)\cr #' The value(s) to be aggregated. Names are parameter values. - #' The aggregation function is selected accordingly for each parameter. + #' The aggregation function is selected based on the parameter. #' @return (named `list()`) aggr = function(x) { assert_list(x, types = "list") - assert_permutation(names(x), private$.aggrs$id) + aggrs = private$.params[map_lgl(get("cargo"), function(cargo) is.function(cargo$aggr)), list(id = get("id"), aggr = map(get("cargo"), "aggr"))] + assert_permutation(names(x), aggrs$id) if (!(length(unique(lengths(x))) == 1L)) { stopf("The same number of values are required for each parameter") } - if (nrow(private$.aggrs) && !length(x[[1L]])) { - stopf("More than one value is required to aggregate them") + if (nrow(aggrs) && !length(x[[1L]])) { + stopf("At least one value is required to aggregate them") } imap(x, function(value, .id) { - aggr = private$.aggrs[list(.id), "aggr", on = "id"][[1L]][[1L]](value) + aggr = aggrs[list(.id), "aggr", on = "id"][[1L]][[1L]](value) }) }, @@ -529,7 +526,6 @@ ParamSet = R6Class("ParamSet", .trafo = private$.trafos[id, trafo], .requirements = list(if (nrow(depstbl)) transpose_list(depstbl)), # NULL if no deps .init_given = id %in% names(vals), - .aggr = private$.aggrs[id, get("aggr")], .init = unname(vals[id])) ] @@ -564,7 +560,6 @@ ParamSet = R6Class("ParamSet", result$.__enclos_env__$private$.params = setindexv(private$.params[ids, on = "id"], c("id", "cls", "grouping")) result$.__enclos_env__$private$.trafos = setkeyv(private$.trafos[ids, on = "id", nomatch = NULL], "id") - result$.__enclos_env__$private$.aggrs = setkeyv(private$.aggrs[ids, on = "id", nomatch = NULL], "id") result$.__enclos_env__$private$.tags = setkeyv(private$.tags[ids, on = "id", nomatch = NULL], "id") result$assert_values = FALSE result$deps = deps[ids, on = "id", nomatch = NULL] @@ -592,7 +587,6 @@ ParamSet = R6Class("ParamSet", result$.__enclos_env__$private$.params = setindexv(private$.params[get_id, on = "id"], c("id", "cls", "grouping")) # setkeyv not strictly necessary since get_id is scalar, but we do it for consistency result$.__enclos_env__$private$.trafos = setkeyv(private$.trafos[get_id, on = "id", nomatch = NULL], "id") - result$.__enclos_env__$private$.aggrs = setkeyv(private$.aggrs[get_id, on = "id", nomatch = NULL], "id") result$.__enclos_env__$private$.tags = setkeyv(private$.tags[get_id, on = "id", nomatch = NULL], "id") result$assert_values = FALSE result$values = values[match(get_id, names(values), nomatch = 0)] @@ -744,7 +738,6 @@ ParamSet = R6Class("ParamSet", result = copy(private$.params) result[, .tags := list(self$tags)] result[private$.trafos, .trafo := list(trafo), on = "id"] - result[private$.aggrs, .aggr := list(aggr), on = "id"] result[self$deps, .requirements := transpose_list(.(on, cond)), on = "id"] vals = self$values result[, `:=`( @@ -852,7 +845,7 @@ ParamSet = R6Class("ParamSet", #' Note that this only refers to the `logscale` flag set during construction, e.g. `p_dbl(logscale = TRUE)`. #' If the parameter was set to logscale manually, e.g. through `p_dbl(trafo = exp)`, #' this `is_logscale` will be `FALSE`. - is_logscale = function() with(private$.params, set_names(cls %in% c("ParamDbl", "ParamInt") & cargo == "logscale", id)), + is_logscale = function() with(private$.params, set_names(cls %in% c("ParamDbl", "ParamInt") & map_lgl(cargo, function(x) isTRUE(x$logscale)), id)), ############################ # Per-Parameter class properties (S3 method call) @@ -903,7 +896,6 @@ ParamSet = R6Class("ParamSet", .tags = data.table(id = character(0L), tag = character(0), key = "id"), .deps = data.table(id = character(0L), on = character(0L), cond = list()), .trafos = data.table(id = character(0L), trafo = list(), key = "id"), - .aggrs = data.table(id = character(0L), aggr = list(), key = "id"), get_tune_ps = function(values) { values = keep(values, inherits, "TuneToken") @@ -915,7 +907,7 @@ ParamSet = R6Class("ParamSet", names(params) = names(values) # package-internal S3 fails if we don't call the function indirectly here - partsets = pmap(list(values, params), function(...) tunetoken_to_ps(..., param_set = param_set)) + partsets = pmap(list(values, params), function(...) tunetoken_to_ps(..., param_set = self)) pars = ps_union(partsets) # partsets does not have names here, wihch is what we want. diff --git a/R/ParamSetCollection.R b/R/ParamSetCollection.R index fbbfe8b4..2d604245 100644 --- a/R/ParamSetCollection.R +++ b/R/ParamSetCollection.R @@ -66,7 +66,6 @@ ParamSetCollection = R6Class("ParamSetCollection", inherit = ParamSet, private$.tags = paramtbl[, .(tag = unique(unlist(.tags))), keyby = "id"] private$.trafos = setkeyv(paramtbl[!map_lgl(.trafo, is.null), .(id, trafo = .trafo)], "id") - private$.aggrs = setkeyv(paramtbl[!map_lgl(.aggr, is.null), .(id, aggr = .aggr)], "id") private$.translation = paramtbl[, c("id", "original_id", "owner_ps_index", "owner_name"), with = FALSE] setkeyv(private$.translation, "id") @@ -126,10 +125,6 @@ ParamSetCollection = R6Class("ParamSetCollection", inherit = ParamSet, if (nrow(newtrafos)) { private$.trafos = setkeyv(rbind(private$.trafos, newtrafos), "id") } - newaggrs = paramtbl[!map_lgl(.aggr, is.null), .(id, trafo = .aggr)] - if (nrow(newaggrs)) { - private$.aggrs = setkeyv(rbind(private$.aggrs, newaggrs), "id") - } private$.translation = rbind(private$.translation, paramtbl[, c("id", "original_id", "owner_ps_index", "owner_name"), with = FALSE]) setkeyv(private$.translation, "id") diff --git a/R/ParamUty.R b/R/ParamUty.R index b487d7c6..fbf48ed0 100644 --- a/R/ParamUty.R +++ b/R/ParamUty.R @@ -3,6 +3,7 @@ #' @export p_uty = function(custom_check = NULL, special_vals = list(), default = NO_DEF, tags = character(), depends = NULL, trafo = NULL, repr = substitute(default), init, aggr = NULL) { assert_function(custom_check, null.ok = TRUE) + assert_function(aggr, null.ok = TRUE, nargs = 1L) if (!is.null(custom_check)) { custom_check_result = custom_check(1) assert(check_true(custom_check_result), check_string(custom_check_result), .var.name = "The result of 'custom_check()'") @@ -12,7 +13,10 @@ p_uty = function(custom_check = NULL, special_vals = list(), default = NO_DEF, t } else { "NoDefault" } - Domain(cls = "ParamUty", grouping = "ParamUty", cargo = list(custom_check = custom_check, repr = repr), special_vals = special_vals, default = default, tags = tags, trafo = trafo, storage_type = "list", depends_expr = substitute(depends), init = init, aggr = aggr) + cargo = list(custom_check = custom_check, repr = repr) + cargo$aggr = aggr + + Domain(cls = "ParamUty", grouping = "ParamUty", cargo = cargo, special_vals = special_vals, default = default, tags = tags, trafo = trafo, storage_type = "list", depends_expr = substitute(depends), init = init) } #' @export diff --git a/R/helper.R b/R/helper.R index 70268885..53ca3f44 100644 --- a/R/helper.R +++ b/R/helper.R @@ -53,11 +53,3 @@ col_to_nl = function(dt, col = 1, idcol = 2) { names(data) = dt[[idcol]] data } - -default_aggr = function(x) { - if (!test_numeric(x[[1]], len = 1L)) { - stopf("Provide a custom aggregator for non-numeric and non-scalar parameters.") - } - ceiling(mean(unlist(x))) -} - diff --git a/R/to_tune.R b/R/to_tune.R index 75ef24c9..46959604 100644 --- a/R/to_tune.R +++ b/R/to_tune.R @@ -41,6 +41,12 @@ #' The `TuneToken` object's internals are subject to change and should not be relied upon. `TuneToken` objects should #' only be constructed via `to_tune()`, and should only be used by giving them to `$values` of a [`ParamSet`]. #' @param ... if given, restricts the range to be tuning over, as described above. +#' @param aggr (`function`)\cr +#' The aggregator function that determines how to aggregate a list of parameter values into a single parameter value. +#' If `NULL`, the default aggregation function of the parameter will be used.\ +#' @param inner (`logical(1)`)\cr +#' Whether to create an inner tuning token, i.e. the value will be optimized using the `Learner`-internal tuning +#' mechanism, such as early stopping for XGBoost. #' @return A `TuneToken` object. #' @examples #' params = ps( @@ -54,7 +60,8 @@ #' uty2 = p_uty(), #' uty3 = p_uty(), #' uty4 = p_uty(), -#' uty5 = p_uty() +#' uty5 = p_uty(), +#' p_inner = p_int(tags = "inner_tuning", aggr = function(x) round(mean(unlist(x)))) #' ) #' #' params$values = list( @@ -101,7 +108,10 @@ #' )), #' #' # not all values need to be tuned! -#' uty5 = 100 +#' uty5 = 100, +#' +#' # Fix value to 100, but use learner-internal tuning +#' p_inner = to_tune(p_fct(100), inner = TRUE)) #' ) #' #' print(params$values) @@ -132,7 +142,12 @@ #' @family ParamSet construction helpers #' @aliases TuneToken #' @export -to_tune = function(...) { +to_tune = function(..., inner = !is.null(aggr), aggr = NULL) { + test_function(aggr, nargs = 1L, null.ok = TRUE) + assert_flag(inner) + if (!is.null(aggr)) { + assert_true(inner) + } call = sys.call() if (...length() > 3) { stop("to_tune() must have zero arguments (tune entire parameter range), one argument (a Domain/Param, or a vector/list of values to tune over), or up to three arguments (any of `lower`, `upper`, `logscale`).") @@ -180,24 +195,12 @@ to_tune = function(...) { content = list(logscale = FALSE) } - set_class(list(content = content, call = deparse1(call)), c(type, "TuneToken")) -} + if (inner) { + type = c("InnerTuneToken", type) + } + if (!is.null(aggr)) content$aggr = aggr -#' @title Create an Inner Tuning Token -#' @description -#' Works just like [`to_tune()`], but marks the parameter for inner tuning. -#' See [`mlr3::Learner`] for more information. -#' @inheritParams to_tune -#' @param aggr (`function`)\cr -#' The aggregator function that determines how to aggregate a list of parameter values into a single parameter value. -#' The default is to average the values and round them up. -#' @export -in_tune = function(..., aggr = NULL) { - test_function(aggr, nargs = 1L, null.ok = TRUE) - tt = to_tune(...) - if (!is.null(aggr)) tt$content$aggr = aggr - tt = set_class(tt, classes = c("InnerTuneToken", class(tt))) - return(tt) + set_class(list(content = content, call = deparse1(call)), c(type, "TuneToken")) } #' @export @@ -236,7 +239,7 @@ tunetoken_to_ps = function(tt, param, param_set) { } tunetoken_to_ps.InnerTuneToken = function(tt, param, param_set) { - tt$content$aggr = tt$content$aggr %??% get_private(param_set)$.aggrs[list(param$id), "aggr", on = "id"][[1L]][[1L]] + tt$content$aggr = tt$content$aggr %??% param_set$params[list(param$id), "cargo", on = "id"][[1L]][[1L]]$aggr if ("inner_tuning" %nin% param_set$tags[[param$id]]) { stopf("%s (%s): Parameter not eligible for inner tuning", tt$call, param$id) } diff --git a/man/ParamSet.Rd b/man/ParamSet.Rd index 165905cd..de796aa2 100644 --- a/man/ParamSet.Rd +++ b/man/ParamSet.Rd @@ -34,7 +34,7 @@ Compact representation as datatable. Col types are:\cr \item special_vals: list col of list \item default: list col \item storage_type: character -\item tags: list col of character vectors +\item tags: list col of character vectorssearch } } } @@ -354,7 +354,7 @@ Aggregate parameter values according to the aggregation rules. \describe{ \item{\code{x}}{(named \code{list()} of \code{list()}s)\cr The value(s) to be aggregated. Names are parameter values. -The aggregation function is selected accordingly for each parameter.} +The aggregation function is selected based on the parameter.} } \if{html}{\out{}} } diff --git a/man/in_tune.Rd b/man/in_tune.Rd deleted file mode 100644 index b0f5be51..00000000 --- a/man/in_tune.Rd +++ /dev/null @@ -1,19 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/to_tune.R -\name{in_tune} -\alias{in_tune} -\title{Create an Inner Tuning Token} -\usage{ -in_tune(..., aggr = NULL) -} -\arguments{ -\item{...}{if given, restricts the range to be tuning over, as described above.} - -\item{aggr}{(\code{function})\cr -The aggregator function that determines how to aggregate a list of parameter values into a single parameter value. -The default is to average the values and round them up.} -} -\description{ -Works just like \code{\link[=to_tune]{to_tune()}}, but marks the parameter for inner tuning. -See \code{\link[mlr3:Learner]{mlr3::Learner}} for more information. -} diff --git a/man/to_tune.Rd b/man/to_tune.Rd index 0b74e7d0..73de0d87 100644 --- a/man/to_tune.Rd +++ b/man/to_tune.Rd @@ -5,10 +5,14 @@ \alias{TuneToken} \title{Indicate that a Parameter Value should be Tuned} \usage{ -to_tune(...) +to_tune(..., inner = !is.null(aggr), aggr = NULL) } \arguments{ \item{...}{if given, restricts the range to be tuning over, as described above.} + +\item{inner}{(\code{logical(1)})\cr +Whether to create an inner tuning token, i.e. the value will be optimized using the \code{Learner}-internal tuning +mechanism, such as early stopping for XGBoost.} } \value{ A \code{TuneToken} object. @@ -65,7 +69,8 @@ params = ps( uty2 = p_uty(), uty3 = p_uty(), uty4 = p_uty(), - uty5 = p_uty() + uty5 = p_uty(), + p_inner = p_int(tags = "inner_tuning", aggr = function(x) round(mean(unlist(x)))) ) params$values = list( @@ -112,7 +117,10 @@ params$values = list( )), # not all values need to be tuned! - uty5 = 100 + uty5 = 100, + + # Fix value to 100, but use learner-internal tuning + p_inner = to_tune(p_fct(100), inner = TRUE)) ) print(params$values) diff --git a/tests/testthat/test_ParamSet.R b/tests/testthat/test_ParamSet.R index a20ddea0..cf299bbd 100644 --- a/tests/testthat/test_ParamSet.R +++ b/tests/testthat/test_ParamSet.R @@ -448,16 +448,18 @@ test_that("aggr", { expect_error(param_set$aggr(1), "list") expect_error(param_set$aggr(list(1)), "list") expect_error(param_set$aggr(list(a = list(), b = list(), c = list(), d = list())), "permutation") - expect_error(param_set$aggr(list(a = list(), b = list(), c = list(), d = list(), e = list())), "More than one") + expect_error(param_set$aggr(list(a = list(), b = list(), c = list(), d = list(), e = list())), "At least one") }) -test_that("in_tune", { - param_set = ps(a = p_dbl(lower = 1, upper = 2)) +test_that("inner", { + param_set = ps(a = p_dbl(lower = 1, upper = 2, tags = "inner_tuning")) param_set$set_values( - a = in_tune(lower = 1, upper = 2, aggr = function(x) 1.5) + a = to_tune(lower = 1, upper = 2, aggr = function(x) 1.5) ) - ss = param_set$search_space() - ss$aggr(list(a = list(1, 2))) + expect_equal(ss$aggr(list(a = list(1, 2))), list(a = 1.5)) + + param_set1 = ps(a = p_dbl(lower = 1, upper = 2)) + expect_error(param_set1$set_values(a = to_tune(inner = TRUE)), "not eligible") }) diff --git a/tests/testthat/test_domain.R b/tests/testthat/test_domain.R index 08c9f20e..48c1b5f0 100644 --- a/tests/testthat/test_domain.R +++ b/tests/testthat/test_domain.R @@ -348,18 +348,18 @@ test_that("$extra_trafo flag works", { expect_false(search_space$has_extra_trafo) }) -test_that("in_tune", { - it = in_tune(1) +test_that("inner", { + it = to_tune(1, inner = TRUE) expect_class(it, "InnerTuneToken") expect_null(it$aggr) - tt = in_tune(1) + tt = to_tune(1, inner = TRUE) expect_equal(it$content, tt$content) - it1 = in_tune(aggr = function(x) min(unlist(x))) + it1 = to_tune(aggr = function(x) min(unlist(x))) expect_equal(it1$content$aggr(list(1, 2)), 1) param_set = ps( - a = p_dbl(1, 10, aggr = default_aggr, tags = "inner_tuning") + a = p_dbl(1, 10, aggr = function(x) mean(unlist(x)), tags = "inner_tuning") ) - param_set$set_values(a = in_tune()) + param_set$set_values(a = to_tune(inner = TRUE)) expect_class(param_set$values$a, "InnerTuneToken") }) diff --git a/tests/testthat/test_to_tune.R b/tests/testthat/test_to_tune.R index 0e8ae92f..d69cbb74 100644 --- a/tests/testthat/test_to_tune.R +++ b/tests/testthat/test_to_tune.R @@ -396,10 +396,5 @@ test_that("logscale in tunetoken", { expect_output(print(to_tune(lower = 1, logscale = TRUE)), "range \\[1, \\.\\.\\.] \\(log scale\\)") expect_output(print(to_tune(upper = 1, logscale = TRUE)), "range \\[\\.\\.\\., 1] \\(log scale\\)") expect_output(print(to_tune(lower = 0, upper = 1, logscale = TRUE)), "range \\[0, 1] \\(log scale\\)") - expect_output(print(in_tune()), "Inner") -}) - -test_that("inner tune", { - param_set = ps(a = p_int(1, 10)) - expect_error(param_set$set_values(a = in_tune(aggr = default_aggr)), "inner tuning") + expect_output(print(to_tune(inner = TRUE)), "Inner") }) From 4fa12b1d140d0796e0fa86fa7685c9efb3757918 Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Mon, 22 Apr 2024 15:52:07 +0200 Subject: [PATCH 08/34] cleanup --- R/Design.R | 2 +- R/Domain.R | 2 +- R/ParamSet.R | 2 +- R/to_tune.R | 4 ++-- man/Domain.Rd | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/R/Design.R b/R/Design.R index af0348bf..3ca3b465 100644 --- a/R/Design.R +++ b/R/Design.R @@ -35,7 +35,7 @@ Design = R6Class("Design", # FIXME: this might also be problematic for LHS # do we still create an LHS like this? - imap(param_set$values, function(v, n) {set(data, j = n, value = v)}) + imap(param_set$values, function(v, n) set(data, j = n, value = v)) self$data = data if (param_set$has_deps) { private$set_deps_to_na() diff --git a/R/Domain.R b/R/Domain.R index 580abc1e..d9c4a31a 100644 --- a/R/Domain.R +++ b/R/Domain.R @@ -66,7 +66,7 @@ #' @param aggr (`function`)\cr #' Function with one argument, which is a list of parameter values. #' The function specifies how a list of parameter values is aggregated to form one parameter value. -#' This is used in the context of inner tuning. The default is to aggregate the values and round up. +#' This is used in the context of inner tuning, where the inner tuned values on the different resampling iterations might differ. #' #' @return A `Domain` object. #' diff --git a/R/ParamSet.R b/R/ParamSet.R index d3204894..beb8fc38 100644 --- a/R/ParamSet.R +++ b/R/ParamSet.R @@ -28,7 +28,7 @@ #' - special_vals: list col of list #' - default: list col #' - storage_type: character -#' - tags: list col of character vectorssearch +#' - tags: list col of character vectors #' @examples #' pset = ParamSet$new( #' params = list( diff --git a/R/to_tune.R b/R/to_tune.R index 46959604..c76ab33e 100644 --- a/R/to_tune.R +++ b/R/to_tune.R @@ -43,7 +43,7 @@ #' @param ... if given, restricts the range to be tuning over, as described above. #' @param aggr (`function`)\cr #' The aggregator function that determines how to aggregate a list of parameter values into a single parameter value. -#' If `NULL`, the default aggregation function of the parameter will be used.\ +#' If `NULL`, the default aggregation function of the parameter (if available) will be used. #' @param inner (`logical(1)`)\cr #' Whether to create an inner tuning token, i.e. the value will be optimized using the `Learner`-internal tuning #' mechanism, such as early stopping for XGBoost. @@ -110,7 +110,7 @@ #' # not all values need to be tuned! #' uty5 = 100, #' -#' # Fix value to 100, but use learner-internal tuning +#' # Fix value to 100, but use learner-internal tuning and default aggregation rule #' p_inner = to_tune(p_fct(100), inner = TRUE)) #' ) #' diff --git a/man/Domain.Rd b/man/Domain.Rd index d34e3f5a..c5ba2e17 100644 --- a/man/Domain.Rd +++ b/man/Domain.Rd @@ -147,7 +147,7 @@ value upon construction.} \item{aggr}{(\code{function})\cr Function with one argument, which is a list of parameter values. The function specifies how a list of parameter values is aggregated to form one parameter value. -This is used in the context of inner tuning. The default is to aggregate the values and round up.} +This is used in the context of inner tuning, where the inner tuned values on the different resampling iterations might differ.} \item{levels}{(\code{character} | \code{atomic} | \code{list})\cr Allowed categorical values of the parameter. If this is not a \code{character}, then a \code{trafo} is generated that From 8f7cd3682d03ae0e697414a514baef9840a4f1e9 Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Mon, 22 Apr 2024 15:57:43 +0200 Subject: [PATCH 09/34] fix example --- R/to_tune.R | 2 +- man/ParamSet.Rd | 2 +- man/to_tune.Rd | 8 ++++++-- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/R/to_tune.R b/R/to_tune.R index c76ab33e..465fa32b 100644 --- a/R/to_tune.R +++ b/R/to_tune.R @@ -111,7 +111,7 @@ #' uty5 = 100, #' #' # Fix value to 100, but use learner-internal tuning and default aggregation rule -#' p_inner = to_tune(p_fct(100), inner = TRUE)) +#' p_inner = to_tune(p_fct(100), inner = TRUE) #' ) #' #' print(params$values) diff --git a/man/ParamSet.Rd b/man/ParamSet.Rd index de796aa2..cfc3a121 100644 --- a/man/ParamSet.Rd +++ b/man/ParamSet.Rd @@ -34,7 +34,7 @@ Compact representation as datatable. Col types are:\cr \item special_vals: list col of list \item default: list col \item storage_type: character -\item tags: list col of character vectorssearch +\item tags: list col of character vectors } } } diff --git a/man/to_tune.Rd b/man/to_tune.Rd index 73de0d87..e447c05a 100644 --- a/man/to_tune.Rd +++ b/man/to_tune.Rd @@ -13,6 +13,10 @@ to_tune(..., inner = !is.null(aggr), aggr = NULL) \item{inner}{(\code{logical(1)})\cr Whether to create an inner tuning token, i.e. the value will be optimized using the \code{Learner}-internal tuning mechanism, such as early stopping for XGBoost.} + +\item{aggr}{(\code{function})\cr +The aggregator function that determines how to aggregate a list of parameter values into a single parameter value. +If \code{NULL}, the default aggregation function of the parameter (if available) will be used.} } \value{ A \code{TuneToken} object. @@ -119,8 +123,8 @@ params$values = list( # not all values need to be tuned! uty5 = 100, - # Fix value to 100, but use learner-internal tuning - p_inner = to_tune(p_fct(100), inner = TRUE)) + # Fix value to 100, but use learner-internal tuning and default aggregation rule + p_inner = to_tune(p_fct(100), inner = TRUE) ) print(params$values) From f9c9ff1d6a664d23218ed98085683f0a1773bd19 Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Mon, 22 Apr 2024 17:51:18 +0200 Subject: [PATCH 10/34] fix bug in objecttunetoken --- R/to_tune.R | 33 +++++++++++++++++++++++---------- tests/testthat/test_domain.R | 1 + 2 files changed, 24 insertions(+), 10 deletions(-) diff --git a/R/to_tune.R b/R/to_tune.R index 465fa32b..27ae7aaa 100644 --- a/R/to_tune.R +++ b/R/to_tune.R @@ -150,7 +150,7 @@ to_tune = function(..., inner = !is.null(aggr), aggr = NULL) { } call = sys.call() if (...length() > 3) { - stop("to_tune() must have zero arguments (tune entire parameter range), one argument (a Domain/Param, or a vector/list of values to tune over), or up to three arguments (any of `lower`, `upper`, `logscale`).") + stop("to_tune() must have zero ... arguments (tune entire parameter range), one argument (a Domain/Param, or a vector/list of values to tune over), or up to three arguments (any of `lower`, `upper`, `logscale`) in addition to the inner and aggr arguments.") } args = list(...) if (...length() > 1 || any(names(args) %in% c("lower", "upper"))) { @@ -176,7 +176,12 @@ to_tune = function(..., inner = !is.null(aggr), aggr = NULL) { check_list(content, names = "unique"), check_list(content, names = "unnamed") ) - content = p_fct(levels = content) + # for the printer + content = if (!is.null(aggr)) { + p_fct(levels = content, aggr = aggr) + } else { + p_fct(levels = content) + } } else { if (inherits(content, "Domain")) { bounded = domain_is_bounded(content) @@ -195,10 +200,10 @@ to_tune = function(..., inner = !is.null(aggr), aggr = NULL) { content = list(logscale = FALSE) } + if (!is.null(aggr) && type != "ObjectTuneToken") content$aggr = aggr if (inner) { type = c("InnerTuneToken", type) } - if (!is.null(aggr)) content$aggr = aggr set_class(list(content = content, call = deparse1(call)), c(type, "TuneToken")) } @@ -239,12 +244,19 @@ tunetoken_to_ps = function(tt, param, param_set) { } tunetoken_to_ps.InnerTuneToken = function(tt, param, param_set) { - tt$content$aggr = tt$content$aggr %??% param_set$params[list(param$id), "cargo", on = "id"][[1L]][[1L]]$aggr - if ("inner_tuning" %nin% param_set$tags[[param$id]]) { - stopf("%s (%s): Parameter not eligible for inner tuning", tt$call, param$id) - } - if (is.null(tt$content$aggr)) { - stopf("%s (%s): Provide an aggregation function for inner tuning.", tt$call, param$id) + if (!test_class(tt, "ObjectTuneToken")) { + tt$content$aggr = tt$content$aggr %??% param_set$params[list(param$id), "cargo", on = "id"][[1L]][[1L]]$aggr + if ("inner_tuning" %nin% param_set$tags[[param$id]]) { + stopf("%s (%s): Parameter not eligible for inner tuning", tt$call, param$id) + } + if (is.null(tt$content$aggr)) { + stopf("%s (%s): Provide an aggregation function for inner tuning.", tt$call, param$id) + } + } else { + if ("inner_tuning" %in% tt$content$.tags && "inner_tuning" %nin% param_set$tags[[param$id]]) { + stopf("%s (%s): Parameter not eligible for inner tuning", tt$call, param$id) + + } } ps = NextMethod() ps$tags = map(ps$tags, function(tags) union(tags, "inner_tuning")) @@ -293,7 +305,8 @@ tunetoken_to_ps.RangeTuneToken = function(tt, param, param_set) { } tunetoken_to_ps.ObjectTuneToken = function(tt, param, param_set) { - pslike_to_ps(tt$content, tt$call, param) + x = pslike_to_ps(tt$content, tt$call, param) + return(x) } # Convert something that is `ParamSet`-like (ParamSet or Domain) to a `ParamSet`. diff --git a/tests/testthat/test_domain.R b/tests/testthat/test_domain.R index 48c1b5f0..11ec8774 100644 --- a/tests/testthat/test_domain.R +++ b/tests/testthat/test_domain.R @@ -361,5 +361,6 @@ test_that("inner", { a = p_dbl(1, 10, aggr = function(x) mean(unlist(x)), tags = "inner_tuning") ) param_set$set_values(a = to_tune(inner = TRUE)) + param_set$set_values(a = to_tune(p_fct(1.2), inner = TRUE)) expect_class(param_set$values$a, "InnerTuneToken") }) From 7c8684a1729f273178b19f780ac37b9c23caba6d Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Mon, 22 Apr 2024 18:13:33 +0200 Subject: [PATCH 11/34] cleanup --- R/to_tune.R | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/R/to_tune.R b/R/to_tune.R index 27ae7aaa..b4be91c9 100644 --- a/R/to_tune.R +++ b/R/to_tune.R @@ -305,8 +305,7 @@ tunetoken_to_ps.RangeTuneToken = function(tt, param, param_set) { } tunetoken_to_ps.ObjectTuneToken = function(tt, param, param_set) { - x = pslike_to_ps(tt$content, tt$call, param) - return(x) + pslike_to_ps(tt$content, tt$call, param) } # Convert something that is `ParamSet`-like (ParamSet or Domain) to a `ParamSet`. From 431ba36d5346aa8c9357e67033fae6bea1bc23c2 Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Tue, 23 Apr 2024 09:50:47 +0200 Subject: [PATCH 12/34] more tests --- tests/testthat/test_domain.R | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/testthat/test_domain.R b/tests/testthat/test_domain.R index 11ec8774..1214c6b4 100644 --- a/tests/testthat/test_domain.R +++ b/tests/testthat/test_domain.R @@ -361,6 +361,11 @@ test_that("inner", { a = p_dbl(1, 10, aggr = function(x) mean(unlist(x)), tags = "inner_tuning") ) param_set$set_values(a = to_tune(inner = TRUE)) + expect_class(param_set$values$a, "InnerTuneToken") param_set$set_values(a = to_tune(p_fct(1.2), inner = TRUE)) expect_class(param_set$values$a, "InnerTuneToken") + param_set$set_values(a = to_tune(1.2, 2.3, inner = TRUE)) + expect_class(param_set$values$a, "InnerTuneToken") + param_set$set_values(a = to_tune(1.2, 2.3, logscale = TRUE, inner = TRUE)) + expect_class(param_set$values$a, "InnerTuneToken") }) From beaf845edcb71fa809dcc67a8eb4d61a54d3b19a Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Fri, 26 Apr 2024 15:43:25 +0200 Subject: [PATCH 13/34] ... --- R/Domain.R | 3 +-- R/ParamInt.R | 4 ++- R/ParamSet.R | 7 +++-- R/to_tune.R | 40 ++++++++++++++++++++++------ man/Domain.Rd | 3 +-- man/to_tune.Rd | 11 +++++--- tests/testthat/test_ParamSet.R | 13 --------- tests/testthat/test_to_tune.R | 48 ++++++++++++++++++++++++++++++++++ 8 files changed, 98 insertions(+), 31 deletions(-) diff --git a/R/Domain.R b/R/Domain.R index d9c4a31a..e0bec8ee 100644 --- a/R/Domain.R +++ b/R/Domain.R @@ -65,8 +65,7 @@ #' value upon construction. #' @param aggr (`function`)\cr #' Function with one argument, which is a list of parameter values. -#' The function specifies how a list of parameter values is aggregated to form one parameter value. -#' This is used in the context of inner tuning, where the inner tuned values on the different resampling iterations might differ. +#' This specifies how multiple parameter values are aggregated to form a single value in the context of inner tuning. #' #' @return A `Domain` object. #' diff --git a/R/ParamInt.R b/R/ParamInt.R index 41d06d8c..07a4f9be 100644 --- a/R/ParamInt.R +++ b/R/ParamInt.R @@ -1,9 +1,10 @@ #' @rdname Domain #' @export -p_int = function(lower = -Inf, upper = Inf, special_vals = list(), default = NO_DEF, tags = character(), tolerance = sqrt(.Machine$double.eps), depends = NULL, trafo = NULL, logscale = FALSE, init, aggr = NULL) { +p_int = function(lower = -Inf, upper = Inf, special_vals = list(), default = NO_DEF, tags = character(), tolerance = sqrt(.Machine$double.eps), depends = NULL, trafo = NULL, logscale = FALSE, init, aggr = NULL, translator = NULL) { assert_function(aggr, null.ok = TRUE, nargs = 1L) assert_number(tolerance, lower = 0, upper = 0.5) + assert_function(translator, null.ok = TRUE, args = c("domain", "param_set"), nargs = 2L) # assert_int will stop for `Inf` values, which we explicitly allow as lower / upper bound if (!isTRUE(is.infinite(lower))) assert_int(lower, tol = 1e-300) else assert_number(lower) if (!isTRUE(is.infinite(upper))) assert_int(upper, tol = 1e-300) else assert_number(upper) @@ -27,6 +28,7 @@ p_int = function(lower = -Inf, upper = Inf, special_vals = list(), default = NO_ cargo = list() if (logscale) cargo$logscale = TRUE cargo$aggr = aggr + if (!is.null(translator)) cargo$translator = translator Domain(cls = cls, grouping = cls, lower = real_lower, upper = real_upper, special_vals = special_vals, default = default, tags = tags, tolerance = tolerance, trafo = trafo, storage_type = storage_type, diff --git a/R/ParamSet.R b/R/ParamSet.R index beb8fc38..35c0fc0e 100644 --- a/R/ParamSet.R +++ b/R/ParamSet.R @@ -154,7 +154,8 @@ ParamSet = R6Class("ParamSet", #' @param any_tags (`character()`). See `$ids()`. #' @param type (`character(1)`)\cr #' Return values `"with_token"` (i.e. all values), - # `"without_token"` (all values that are not [`TuneToken`] objects) or `"only_token"` (only [`TuneToken`] objects)? + # `"without_token"` (all values that are not [`TuneToken`] objects), `"only_token"` (only [`TuneToken`] objects) + # or `"with_inner"` (all values that are no not `InnerTuneToken`)? #' @param check_required (`logical(1)`)\cr #' Check if all required parameters are set? #' @param remove_dependencies (`logical(1)`)\cr @@ -162,7 +163,7 @@ ParamSet = R6Class("ParamSet", #' @return Named `list()`. get_values = function(class = NULL, tags = NULL, any_tags = NULL, type = "with_token", check_required = TRUE, remove_dependencies = TRUE) { - assert_choice(type, c("with_token", "without_token", "only_token")) + assert_choice(type, c("with_token", "without_token", "only_token", "inner_or_without_token")) assert_flag(check_required) @@ -173,6 +174,8 @@ ParamSet = R6Class("ParamSet", values = discard(values, is, "TuneToken") } else if (type == "only_token") { values = keep(values, is, "TuneToken") + } else if (type == "with_inner") { + values = keep(values, is, "InnerTuneToken") } if (check_required) { diff --git a/R/to_tune.R b/R/to_tune.R index b4be91c9..a2a9fa9f 100644 --- a/R/to_tune.R +++ b/R/to_tune.R @@ -37,16 +37,21 @@ #' where a single evaluation-time parameter value (e.g. [`p_uty()`]) is constructed from multiple tuner-visible #' parameters (which may not be [`p_uty()`]). If not one-dimensional, the supplied [`ParamSet`] should always contain a `$extra_trafo` function, #' which must then always return a `list` with a single entry. +#' * **`to_tune(..., aggr = , inner = )`: Works like any of the above, but marks the parameter for +#' learner-internal (inner) tuning when `inner = TRUE`. This does not change the behavior of the learner, but does +#' affect the behavior of the tuner, which will then insert the aggregated internally tuned values into the tuning archive. +#' For the [`AutoTuner`][mlr3tuning::AutoTuner], this implies that the final model fit will use the internally +#' optimized values. This e.g. allows to combine XGBoost's earlys stopping with `mlr3`'s tuning packages. #' #' The `TuneToken` object's internals are subject to change and should not be relied upon. `TuneToken` objects should #' only be constructed via `to_tune()`, and should only be used by giving them to `$values` of a [`ParamSet`]. #' @param ... if given, restricts the range to be tuning over, as described above. #' @param aggr (`function`)\cr #' The aggregator function that determines how to aggregate a list of parameter values into a single parameter value. -#' If `NULL`, the default aggregation function of the parameter (if available) will be used. +#' If `NULL`, the default aggregation function of the parameter (if available) is used. #' @param inner (`logical(1)`)\cr -#' Whether to create an inner tuning token, i.e. the value will be optimized using the `Learner`-internal tuning -#' mechanism, such as early stopping for XGBoost. +#' Whether to create an inner tuning token. +#' Is set to `TRUE` by default if `aggr` is provided. #' @return A `TuneToken` object. #' @examples #' params = ps( @@ -150,9 +155,15 @@ to_tune = function(..., inner = !is.null(aggr), aggr = NULL) { } call = sys.call() if (...length() > 3) { - stop("to_tune() must have zero ... arguments (tune entire parameter range), one argument (a Domain/Param, or a vector/list of values to tune over), or up to three arguments (any of `lower`, `upper`, `logscale`) in addition to the inner and aggr arguments.") + stop("to_tune() must have zero ... arguments (tune entire parameter range), one argument (a Domain/Param, or a vector/list of values to tune over), up to three arguments (any of `lower`, `upper`, `logscale`).") } args = list(...) + + if (isTRUE(args$logscale) && inner) { + # we could allow users to give an inverse transformation + stopf("Parameter transformations and inner tuning are currently not supported") + } + if (...length() > 1 || any(names(args) %in% c("lower", "upper"))) { # Two arguments: tune over a range type = "RangeTuneToken" @@ -169,7 +180,7 @@ to_tune = function(..., inner = !is.null(aggr), aggr = NULL) { # one argument: tune over an object. that object can be something # that can be converted to a ParamSet (ParamSet itself, Param, or Domain), # otherwise it must be something that can be converted to a ParamFct Domain. - if (!test_multi_class(content, c("ParamSet", "Param", "Domain"))) { + if (!test_multi_class(content, c("ParamSet", "Domain"))) { assert( check_atomic_vector(content, names = "unnamed"), check_atomic_vector(content, names = "unique"), @@ -185,12 +196,17 @@ to_tune = function(..., inner = !is.null(aggr), aggr = NULL) { } else { if (inherits(content, "Domain")) { bounded = domain_is_bounded(content) + has_trafo = !is.null(content$.trafo) } else { bounded = content$all_bounded + has_trafo = content$has_trafo } if (!bounded) { stop("tuning range must be bounded.") } + if (has_trafo && inner) { + stop("Parameter transformations and inner tuning are currently not supported") + } } type = "ObjectTuneToken" } @@ -200,10 +216,9 @@ to_tune = function(..., inner = !is.null(aggr), aggr = NULL) { content = list(logscale = FALSE) } + # for object tune token, the aggr was already consumed in the p_fct() call above if (!is.null(aggr) && type != "ObjectTuneToken") content$aggr = aggr - if (inner) { - type = c("InnerTuneToken", type) - } + if (inner) type = c("InnerTuneToken", type) set_class(list(content = content, call = deparse1(call)), c(type, "TuneToken")) } @@ -269,8 +284,17 @@ tunetoken_to_ps.FullTuneToken = function(tt, param, param_set) { } if (isTRUE(tt$content$logscale)) { if (!domain_is_number(param)) stop("%s (%s): logscale only valid for numeric / integer parameters.", tt$call, param$id) + tunetoken_to_ps.RangeTuneToken(list(content = list(logscale = tt$content$logscale, aggr = tt$content$aggr), tt$call), param, param_set) } else { + if (!is.null(tt$content$aggr)) { + # https://github.com/Rdatatable/data.table/issues/6104 + param$cargo[[1L]][[1L]] = if (is.null(param$cargo[[1L]])) { + list(aggr = tt$content$aggr) + } else { + insert_named(param$cargo[[1L]], list(aggr = tt$content$aggr)) + } + } pslike_to_ps(param, tt$call, param, param_set) } } diff --git a/man/Domain.Rd b/man/Domain.Rd index c5ba2e17..79b62042 100644 --- a/man/Domain.Rd +++ b/man/Domain.Rd @@ -146,8 +146,7 @@ value upon construction.} \item{aggr}{(\code{function})\cr Function with one argument, which is a list of parameter values. -The function specifies how a list of parameter values is aggregated to form one parameter value. -This is used in the context of inner tuning, where the inner tuned values on the different resampling iterations might differ.} +This specifies how multiple parameter values are aggregated to form a single value in the context of inner tuning.} \item{levels}{(\code{character} | \code{atomic} | \code{list})\cr Allowed categorical values of the parameter. If this is not a \code{character}, then a \code{trafo} is generated that diff --git a/man/to_tune.Rd b/man/to_tune.Rd index e447c05a..b49f2239 100644 --- a/man/to_tune.Rd +++ b/man/to_tune.Rd @@ -11,12 +11,12 @@ to_tune(..., inner = !is.null(aggr), aggr = NULL) \item{...}{if given, restricts the range to be tuning over, as described above.} \item{inner}{(\code{logical(1)})\cr -Whether to create an inner tuning token, i.e. the value will be optimized using the \code{Learner}-internal tuning -mechanism, such as early stopping for XGBoost.} +Whether to create an inner tuning token. +Is set to \code{TRUE} by default if \code{aggr} is provided.} \item{aggr}{(\code{function})\cr The aggregator function that determines how to aggregate a list of parameter values into a single parameter value. -If \code{NULL}, the default aggregation function of the parameter (if available) will be used.} +If \code{NULL}, the default aggregation function of the parameter (if available) is used.} } \value{ A \code{TuneToken} object. @@ -56,6 +56,11 @@ the range which should be tuned over. The supplied \code{trafo} function is used where a single evaluation-time parameter value (e.g. \code{\link[=p_uty]{p_uty()}}) is constructed from multiple tuner-visible parameters (which may not be \code{\link[=p_uty]{p_uty()}}). If not one-dimensional, the supplied \code{\link{ParamSet}} should always contain a \verb{$extra_trafo} function, which must then always return a \code{list} with a single entry. +\item **\verb{to_tune(..., aggr = , inner = )}: Works like any of the above, but marks the parameter for +learner-internal (inner) tuning when \code{inner = TRUE}. This does not change the behavior of the learner, but does +affect the behavior of the tuner, which will then insert the aggregated internally tuned values into the tuning archive. +For the \code{\link[mlr3tuning:AutoTuner]{AutoTuner}}, this implies that the final model fit will use the internally +optimized values. This e.g. allows to combine XGBoost's earlys stopping with \code{mlr3}'s tuning packages. } The \code{TuneToken} object's internals are subject to change and should not be relied upon. \code{TuneToken} objects should diff --git a/tests/testthat/test_ParamSet.R b/tests/testthat/test_ParamSet.R index cf299bbd..5fd4d794 100644 --- a/tests/testthat/test_ParamSet.R +++ b/tests/testthat/test_ParamSet.R @@ -450,16 +450,3 @@ test_that("aggr", { expect_error(param_set$aggr(list(a = list(), b = list(), c = list(), d = list())), "permutation") expect_error(param_set$aggr(list(a = list(), b = list(), c = list(), d = list(), e = list())), "At least one") }) - -test_that("inner", { - param_set = ps(a = p_dbl(lower = 1, upper = 2, tags = "inner_tuning")) - param_set$set_values( - a = to_tune(lower = 1, upper = 2, aggr = function(x) 1.5) - ) - ss = param_set$search_space() - - expect_equal(ss$aggr(list(a = list(1, 2))), list(a = 1.5)) - - param_set1 = ps(a = p_dbl(lower = 1, upper = 2)) - expect_error(param_set1$set_values(a = to_tune(inner = TRUE)), "not eligible") -}) diff --git a/tests/testthat/test_to_tune.R b/tests/testthat/test_to_tune.R index d69cbb74..8d6403b8 100644 --- a/tests/testthat/test_to_tune.R +++ b/tests/testthat/test_to_tune.R @@ -398,3 +398,51 @@ test_that("logscale in tunetoken", { expect_output(print(to_tune(lower = 0, upper = 1, logscale = TRUE)), "range \\[0, 1] \\(log scale\\)") expect_output(print(to_tune(inner = TRUE)), "Inner") }) + + +test_that("inner and aggr", { + # no default aggregation function + param_set = ps(a = p_dbl(lower = 1, upper = 2, tags = "inner_tuning")) + + # correct errors + expect_error(param_set$set_values(a = to_tune(inner = TRUE)), "Provide an aggregation") + expect_error(param_set$set_values(a = to_tune(inner = FALSE, aggr = function(x) 1))) + + + # full tune token + inner + expect_equal( + param_set$set_values(a = to_tune(aggr = function(x) -99))$search_space()$aggr(list(a = list(1, 2, 3))), + list(a = -99) + ) + + # logscale + inner: now allowed + expect_error( + param_set$set_values(a = to_tune(logscale = TRUE, aggr = function(x) -99)), + "inner tuning" + ) + + # other trafos + inner: not allowed + expect_error( + param_set$set_values(a = to_tune(ps(a = p_dbl(0, 1), .extra_trafo = function(x) 1L), aggr = function(x) -99)), + "inner tuning" + ) + + # param set + inner + + # range + inner + param_set$set_values(a = to_tune(lower = 1, upper = 2, aggr = function(x) 1.5)) + expect_equal(s$search_space()$aggr(list(a = list(1, 2))), list(a = 1.5)) + + # full + inner + + # domain + inner + + + ## with default aggregation function + + # default aggregation function is used when not overwritten + + + # can overwrite existing aggregation function + # check all cases +}) From f5dbe0d7a484d82954d88aa324b9e7dfcf2e09c9 Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Sat, 4 May 2024 08:30:38 +0200 Subject: [PATCH 14/34] add in_tune_fn --- R/Domain.R | 5 -- R/ParamDbl.R | 7 +- R/ParamFct.R | 10 ++- R/ParamInt.R | 9 ++- R/ParamLgl.R | 11 ++- R/ParamSet.R | 27 +++++++- R/ParamUty.R | 7 +- R/to_tune.R | 118 +++++++++++---------------------- man/Domain.Rd | 22 +++--- man/ParamSet.Rd | 12 ++++ man/ParamSetCollection.Rd | 1 + man/to_tune.Rd | 22 ++---- tests/testthat/test_ParamSet.R | 25 +++++++ tests/testthat/test_domain.R | 21 +++--- tests/testthat/test_to_tune.R | 10 +-- 15 files changed, 168 insertions(+), 139 deletions(-) diff --git a/R/Domain.R b/R/Domain.R index e0bec8ee..81af4c5a 100644 --- a/R/Domain.R +++ b/R/Domain.R @@ -63,11 +63,6 @@ #' @param init (`any`)\cr #' Initial value. When this is given, then the corresponding entry in `ParamSet$values` is initialized with this #' value upon construction. -#' @param aggr (`function`)\cr -#' Function with one argument, which is a list of parameter values. -#' This specifies how multiple parameter values are aggregated to form a single value in the context of inner tuning. -#' -#' @return A `Domain` object. #' #' @details #' Although the `levels` values of a constructed `p_fct()` will always be `character`-valued, the `p_fct` function admits diff --git a/R/ParamDbl.R b/R/ParamDbl.R index 495da692..c757e358 100644 --- a/R/ParamDbl.R +++ b/R/ParamDbl.R @@ -1,7 +1,11 @@ #' @rdname Domain #' @export -p_dbl = function(lower = -Inf, upper = Inf, special_vals = list(), default = NO_DEF, tags = character(), tolerance = sqrt(.Machine$double.eps), depends = NULL, trafo = NULL, logscale = FALSE, init, aggr = NULL) { +p_dbl = function(lower = -Inf, upper = Inf, special_vals = list(), default = NO_DEF, tags = character(), tolerance = sqrt(.Machine$double.eps), depends = NULL, trafo = NULL, logscale = FALSE, init, aggr = NULL, in_tune_fn = NULL) { assert_function(aggr, null.ok = TRUE, nargs = 1L) + assert_function(in_tune_fn, null.ok = "inner_tuning" %nin% tags, args = c("domain", "param_set"), nargs = 2L) + if ("inner_tuning" %nin% tags && !is.null(in_tune_fn)) { + stopf("Cannot only provide 'in_tune_fn' when parameter is tagged with 'inner_tuning'") + } assert_number(tolerance, lower = 0) assert_number(lower) assert_number(upper) @@ -21,6 +25,7 @@ p_dbl = function(lower = -Inf, upper = Inf, special_vals = list(), default = NO_ cargo = list() if (logscale) cargo$logscale = TRUE cargo$aggr = aggr + if (!is.null(in_tune_fn)) cargo$in_tune_fn = in_tune_fn Domain(cls = "ParamDbl", grouping = "ParamDbl", lower = real_lower, upper = real_upper, special_vals = special_vals, default = default, tags = tags, tolerance = tolerance, trafo = trafo, storage_type = "numeric", depends_expr = substitute(depends), init = init, cargo = if (length(cargo)) cargo) diff --git a/R/ParamFct.R b/R/ParamFct.R index 13e986c4..d6504fe9 100644 --- a/R/ParamFct.R +++ b/R/ParamFct.R @@ -1,9 +1,13 @@ #' @rdname Domain #' @export -p_fct = function(levels, special_vals = list(), default = NO_DEF, tags = character(), depends = NULL, trafo = NULL, init, aggr = NULL) { +p_fct = function(levels, special_vals = list(), default = NO_DEF, tags = character(), depends = NULL, trafo = NULL, init, aggr = NULL, in_tune_fn = NULL) { assert_function(aggr, null.ok = TRUE, nargs = 1L) constargs = as.list(match.call()[-1]) levels = eval.parent(constargs$levels) + assert_function(in_tune_fn, null.ok = TRUE, args = c("domain", "param_set"), nargs = 2L) + if ("inner_tuning" %nin% tags && !is.null(in_tune_fn)) { + stopf("Cannot only provide 'in_tune_fn' when parameter is tagged with 'inner_tuning'") + } if (!is.character(levels)) { # if the "levels" argument is not a character vector, then # we add a trafo. @@ -22,8 +26,10 @@ p_fct = function(levels, special_vals = list(), default = NO_DEF, tags = charact } # group p_fct by levels, so the group can be checked in a vectorized fashion. # We escape '"' and '\' to '\"' and '\\', respectively. + cargo = c(aggr = aggr, in_tune_fn = in_tune_fn) + cargo = if (length(cargo)) cargo grouping = str_collapse(gsub("([\\\\\"])", "\\\\\\1", sort(real_levels)), quote = '"', sep = ",") - Domain(cls = "ParamFct", grouping = grouping, levels = real_levels, special_vals = special_vals, default = default, tags = tags, trafo = trafo, storage_type = "character", depends_expr = substitute(depends), init = init, cargo = if (!is.null(aggr)) list(aggr = aggr)) + Domain(cls = "ParamFct", grouping = grouping, levels = real_levels, special_vals = special_vals, default = default, tags = tags, trafo = trafo, storage_type = "character", depends_expr = substitute(depends), init = init, cargo = cargo) } #' @export diff --git a/R/ParamInt.R b/R/ParamInt.R index 07a4f9be..3099b39a 100644 --- a/R/ParamInt.R +++ b/R/ParamInt.R @@ -1,10 +1,13 @@ #' @rdname Domain #' @export -p_int = function(lower = -Inf, upper = Inf, special_vals = list(), default = NO_DEF, tags = character(), tolerance = sqrt(.Machine$double.eps), depends = NULL, trafo = NULL, logscale = FALSE, init, aggr = NULL, translator = NULL) { +p_int = function(lower = -Inf, upper = Inf, special_vals = list(), default = NO_DEF, tags = character(), tolerance = sqrt(.Machine$double.eps), depends = NULL, trafo = NULL, logscale = FALSE, init, aggr = NULL, in_tune_fn = NULL) { assert_function(aggr, null.ok = TRUE, nargs = 1L) assert_number(tolerance, lower = 0, upper = 0.5) - assert_function(translator, null.ok = TRUE, args = c("domain", "param_set"), nargs = 2L) + assert_function(in_tune_fn, null.ok = TRUE, args = c("domain", "param_set"), nargs = 2L) + if ("inner_tuning" %nin% tags && !is.null(in_tune_fn)) { + stopf("Cannot only provide 'in_tune_fn' when parameter is tagged with 'inner_tuning'") + } # assert_int will stop for `Inf` values, which we explicitly allow as lower / upper bound if (!isTRUE(is.infinite(lower))) assert_int(lower, tol = 1e-300) else assert_number(lower) if (!isTRUE(is.infinite(upper))) assert_int(upper, tol = 1e-300) else assert_number(upper) @@ -28,7 +31,7 @@ p_int = function(lower = -Inf, upper = Inf, special_vals = list(), default = NO_ cargo = list() if (logscale) cargo$logscale = TRUE cargo$aggr = aggr - if (!is.null(translator)) cargo$translator = translator + if (!is.null(in_tune_fn)) cargo$in_tune_fn = in_tune_fn Domain(cls = cls, grouping = cls, lower = real_lower, upper = real_upper, special_vals = special_vals, default = default, tags = tags, tolerance = tolerance, trafo = trafo, storage_type = storage_type, diff --git a/R/ParamLgl.R b/R/ParamLgl.R index b46d7cb5..7ef9c9d8 100644 --- a/R/ParamLgl.R +++ b/R/ParamLgl.R @@ -1,9 +1,16 @@ #' @rdname Domain #' @export -p_lgl = function(special_vals = list(), default = NO_DEF, tags = character(), depends = NULL, trafo = NULL, init, aggr = NULL) { +p_lgl = function(special_vals = list(), default = NO_DEF, tags = character(), depends = NULL, trafo = NULL, init, aggr = NULL, in_tune_fn = NULL) { assert_function(aggr, null.ok = TRUE, nargs = 1L) + assert_function(in_tune_fn, null.ok = TRUE, args = c("domain", "param_set"), nargs = 2L) + if ("inner_tuning" %nin% tags && !is.null(in_tune_fn)) { + stopf("Cannot only provide 'in_tune_fn' when parameter is tagged with 'inner_tuning'") + } + + cargo = c(aggr = aggr, in_tune_fn = in_tune_fn) + cargo = if (length(cargo)) cargo Domain(cls = "ParamLgl", grouping = "ParamLgl", levels = c(TRUE, FALSE), special_vals = special_vals, default = default, - tags = tags, trafo = trafo, storage_type = "logical", depends_expr = substitute(depends), init = init, cargo = if (!is.null(aggr)) list(aggr = aggr)) + tags = tags, trafo = trafo, storage_type = "logical", depends_expr = substitute(depends), init = init, cargo = cargo) } #' @export diff --git a/R/ParamSet.R b/R/ParamSet.R index 35c0fc0e..bb9cb1d4 100644 --- a/R/ParamSet.R +++ b/R/ParamSet.R @@ -163,7 +163,7 @@ ParamSet = R6Class("ParamSet", #' @return Named `list()`. get_values = function(class = NULL, tags = NULL, any_tags = NULL, type = "with_token", check_required = TRUE, remove_dependencies = TRUE) { - assert_choice(type, c("with_token", "without_token", "only_token", "inner_or_without_token")) + assert_choice(type, c("with_token", "without_token", "only_token", "with_inner")) assert_flag(check_required) @@ -282,6 +282,18 @@ ParamSet = R6Class("ParamSet", }) }, + #' @description + #' Convert all `InnerTuneToken`s to specific parameter values. + #' These transformations are defined by the `in_tune_fn` arguments of the [`Domain`] objects. + convert_inner_tune_tokens = function() { + inner_tune_tokens = self$get_values(type = "with_inner") + inner_tune_ps = private$get_tune_ps(inner_tune_tokens) + + imap(inner_tune_ps$domains, function(token, .id) { + converter = private$.params[list(.id), "cargo", on = "id"][[1L]][[1L]]$in_tune_fn(token) + }) + }, + #' @description #' \pkg{checkmate}-like test-function. Takes a named list. #' Return `FALSE` if the given `$constraint` is not satisfied, `TRUE` otherwise. @@ -350,6 +362,17 @@ ParamSet = R6Class("ParamSet", if (!isTRUE(tunecheck)) return(tunecheck) } + xs_innertune = keep(xs, is, "InnerTuneToken") + walk(names(xs_innertune), function(pid) { + if ("inner_tuning" %nin% self$tags[[pid]]) { + stopf("Trying to assign InnerTuneToken to parameter '%s' which is not tagged with 'inner_tuning'.", pid) + } + if (is.null(xs[[pid]]$content$aggr) && is.null(private$.params[pid, "cargo", on = "id"][[1L]][[1L]]$aggr)) { + stopf("Trying to set parameter '%s' to InnerTuneToken, but no aggregation function is available.", pid) + } + }) + + # check each parameter group's feasibility xs_nontune = discard(xs, inherits, "TuneToken") @@ -910,7 +933,7 @@ ParamSet = R6Class("ParamSet", names(params) = names(values) # package-internal S3 fails if we don't call the function indirectly here - partsets = pmap(list(values, params), function(...) tunetoken_to_ps(..., param_set = self)) + partsets = pmap(list(values, params), function(...) tunetoken_to_ps(...)) pars = ps_union(partsets) # partsets does not have names here, wihch is what we want. diff --git a/R/ParamUty.R b/R/ParamUty.R index fbf48ed0..48b18149 100644 --- a/R/ParamUty.R +++ b/R/ParamUty.R @@ -1,8 +1,12 @@ #' @rdname Domain #' @export -p_uty = function(custom_check = NULL, special_vals = list(), default = NO_DEF, tags = character(), depends = NULL, trafo = NULL, repr = substitute(default), init, aggr = NULL) { +p_uty = function(custom_check = NULL, special_vals = list(), default = NO_DEF, tags = character(), depends = NULL, trafo = NULL, repr = substitute(default), init, aggr = NULL, in_tune_fn = NULL) { assert_function(custom_check, null.ok = TRUE) + assert_function(in_tune_fn, null.ok = TRUE, args = c("domain", "param_set"), nargs = 2L) + if ("inner_tuning" %nin% tags && !is.null(in_tune_fn)) { + stopf("Cannot only provide 'in_tune_fn' when parameter is tagged with 'inner_tuning'") + } assert_function(aggr, null.ok = TRUE, nargs = 1L) if (!is.null(custom_check)) { custom_check_result = custom_check(1) @@ -15,6 +19,7 @@ p_uty = function(custom_check = NULL, special_vals = list(), default = NO_DEF, t } cargo = list(custom_check = custom_check, repr = repr) cargo$aggr = aggr + cargo$inner_tune_fn = in_tune_fn Domain(cls = "ParamUty", grouping = "ParamUty", cargo = cargo, special_vals = special_vals, default = default, tags = tags, trafo = trafo, storage_type = "list", depends_expr = substitute(depends), init = init) } diff --git a/R/to_tune.R b/R/to_tune.R index a2a9fa9f..6e84afbe 100644 --- a/R/to_tune.R +++ b/R/to_tune.R @@ -37,21 +37,17 @@ #' where a single evaluation-time parameter value (e.g. [`p_uty()`]) is constructed from multiple tuner-visible #' parameters (which may not be [`p_uty()`]). If not one-dimensional, the supplied [`ParamSet`] should always contain a `$extra_trafo` function, #' which must then always return a `list` with a single entry. -#' * **`to_tune(..., aggr = , inner = )`: Works like any of the above, but marks the parameter for -#' learner-internal (inner) tuning when `inner = TRUE`. This does not change the behavior of the learner, but does -#' affect the behavior of the tuner, which will then insert the aggregated internally tuned values into the tuning archive. -#' For the [`AutoTuner`][mlr3tuning::AutoTuner], this implies that the final model fit will use the internally -#' optimized values. This e.g. allows to combine XGBoost's earlys stopping with `mlr3`'s tuning packages. #' #' The `TuneToken` object's internals are subject to change and should not be relied upon. `TuneToken` objects should #' only be constructed via `to_tune()`, and should only be used by giving them to `$values` of a [`ParamSet`]. #' @param ... if given, restricts the range to be tuning over, as described above. -#' @param aggr (`function`)\cr -#' The aggregator function that determines how to aggregate a list of parameter values into a single parameter value. -#' If `NULL`, the default aggregation function of the parameter (if available) is used. #' @param inner (`logical(1)`)\cr -#' Whether to create an inner tuning token. -#' Is set to `TRUE` by default if `aggr` is provided. +#' Whether to create an `InnerTuneToken`. +#' This is only available for parameters tagged with `"inner_tuning"`. +#' @param aggr (`function`)\cr +#' Function with one argument, which is a list of parameter values. +#' This specifies how multiple parameter values are aggregated to form a single value in the context of inner tuning. +#' If none specified, the default aggregation function of the parameter will be used. #' @return A `TuneToken` object. #' @examples #' params = ps( @@ -65,8 +61,7 @@ #' uty2 = p_uty(), #' uty3 = p_uty(), #' uty4 = p_uty(), -#' uty5 = p_uty(), -#' p_inner = p_int(tags = "inner_tuning", aggr = function(x) round(mean(unlist(x)))) +#' uty5 = p_uty() #' ) #' #' params$values = list( @@ -113,10 +108,7 @@ #' )), #' #' # not all values need to be tuned! -#' uty5 = 100, -#' -#' # Fix value to 100, but use learner-internal tuning and default aggregation rule -#' p_inner = to_tune(p_fct(100), inner = TRUE) +#' uty5 = 100 #' ) #' #' print(params$values) @@ -148,22 +140,16 @@ #' @aliases TuneToken #' @export to_tune = function(..., inner = !is.null(aggr), aggr = NULL) { - test_function(aggr, nargs = 1L, null.ok = TRUE) assert_flag(inner) if (!is.null(aggr)) { assert_true(inner) } + assert_function(aggr, nargs = 1L, null.ok = TRUE) call = sys.call() if (...length() > 3) { - stop("to_tune() must have zero ... arguments (tune entire parameter range), one argument (a Domain/Param, or a vector/list of values to tune over), up to three arguments (any of `lower`, `upper`, `logscale`).") + stop("to_tune() must have zero arguments (tune entire parameter range), one argument (a Domain/Param, or a vector/list of values to tune over), or up to three arguments (any of `lower`, `upper`, `logscale`).") } args = list(...) - - if (isTRUE(args$logscale) && inner) { - # we could allow users to give an inverse transformation - stopf("Parameter transformations and inner tuning are currently not supported") - } - if (...length() > 1 || any(names(args) %in% c("lower", "upper"))) { # Two arguments: tune over a range type = "RangeTuneToken" @@ -180,33 +166,23 @@ to_tune = function(..., inner = !is.null(aggr), aggr = NULL) { # one argument: tune over an object. that object can be something # that can be converted to a ParamSet (ParamSet itself, Param, or Domain), # otherwise it must be something that can be converted to a ParamFct Domain. - if (!test_multi_class(content, c("ParamSet", "Domain"))) { + if (!test_multi_class(content, c("ParamSet", "Param", "Domain"))) { assert( check_atomic_vector(content, names = "unnamed"), check_atomic_vector(content, names = "unique"), check_list(content, names = "unique"), check_list(content, names = "unnamed") ) - # for the printer - content = if (!is.null(aggr)) { - p_fct(levels = content, aggr = aggr) - } else { - p_fct(levels = content) - } + content = p_fct(levels = content) } else { if (inherits(content, "Domain")) { bounded = domain_is_bounded(content) - has_trafo = !is.null(content$.trafo) } else { bounded = content$all_bounded - has_trafo = content$has_trafo } if (!bounded) { stop("tuning range must be bounded.") } - if (has_trafo && inner) { - stop("Parameter transformations and inner tuning are currently not supported") - } } type = "ObjectTuneToken" } @@ -216,9 +192,16 @@ to_tune = function(..., inner = !is.null(aggr), aggr = NULL) { content = list(logscale = FALSE) } - # for object tune token, the aggr was already consumed in the p_fct() call above - if (!is.null(aggr) && type != "ObjectTuneToken") content$aggr = aggr - if (inner) type = c("InnerTuneToken", type) + if (inner) { + if (type == "ObjectTuneToken") { + stop("Inner tuning can currently not be combined with ParamSet or Domain object.") + } + if (isTRUE(content$logscale)) { + stop("Cannot combine logscale transformation with inner tuning.") + } + type = c("InnerTuneToken", type) + content$aggr = aggr + } set_class(list(content = content, call = deparse1(call)), c(type, "TuneToken")) } @@ -229,6 +212,13 @@ print.FullTuneToken = function(x, ...) { if (isTRUE(x$content$logscale)) " (log scale)" else "") } +#' @export +print.InnerTuneToken = function(x, ...) { + cat("Inner ") + NextMethod() +} + + #' @export print.RangeTuneToken = function(x, ...) { catf("Tuning over:\nrange [%s, %s]%s\n", x$content$lower %??% "...", x$content$upper %??% "...", @@ -241,12 +231,6 @@ print.ObjectTuneToken = function(x, ...) { print(x$content) } -#' @export -print.InnerTuneToken = function(x, ...) { - cat("Inner ") - NextMethod() -} - # tunetoken_to_ps: Convert a `TuneToken` to a `ParamSet` that tunes over this. # Needs the corresponding `Domain` to which the `TuneToken` refers, both to # get the range (e.g. if `to_tune()` was used) and to verify that the `TuneToken` @@ -254,53 +238,27 @@ print.InnerTuneToken = function(x, ...) { # # Makes liberal use to `pslike_to_ps` (converting Param, ParamSet, Domain to ParamSet) # param is a data.table that is potentially modified by reference using data.table set() methods. -tunetoken_to_ps = function(tt, param, param_set) { +tunetoken_to_ps = function(tt, param) { UseMethod("tunetoken_to_ps") } -tunetoken_to_ps.InnerTuneToken = function(tt, param, param_set) { - if (!test_class(tt, "ObjectTuneToken")) { - tt$content$aggr = tt$content$aggr %??% param_set$params[list(param$id), "cargo", on = "id"][[1L]][[1L]]$aggr - if ("inner_tuning" %nin% param_set$tags[[param$id]]) { - stopf("%s (%s): Parameter not eligible for inner tuning", tt$call, param$id) - } - if (is.null(tt$content$aggr)) { - stopf("%s (%s): Provide an aggregation function for inner tuning.", tt$call, param$id) - } - } else { - if ("inner_tuning" %in% tt$content$.tags && "inner_tuning" %nin% param_set$tags[[param$id]]) { - stopf("%s (%s): Parameter not eligible for inner tuning", tt$call, param$id) - - } - } - ps = NextMethod() - ps$tags = map(ps$tags, function(tags) union(tags, "inner_tuning")) - return(ps) -} - -tunetoken_to_ps.FullTuneToken = function(tt, param, param_set) { +tunetoken_to_ps.FullTuneToken = function(tt, param) { if (!domain_is_bounded(param)) { stopf("%s must give a range for unbounded parameter %s.", tt$call, param$id) } if (isTRUE(tt$content$logscale)) { if (!domain_is_number(param)) stop("%s (%s): logscale only valid for numeric / integer parameters.", tt$call, param$id) - - tunetoken_to_ps.RangeTuneToken(list(content = list(logscale = tt$content$logscale, aggr = tt$content$aggr), tt$call), param, param_set) + tunetoken_to_ps.RangeTuneToken(list(content = list(logscale = tt$content$logscale), tt$call), param) } else { if (!is.null(tt$content$aggr)) { # https://github.com/Rdatatable/data.table/issues/6104 - param$cargo[[1L]][[1L]] = if (is.null(param$cargo[[1L]])) { - list(aggr = tt$content$aggr) - } else { - insert_named(param$cargo[[1L]], list(aggr = tt$content$aggr)) - } + param$cargo[[1L]] = list(insert_named(param$cargo[[1L]], list(aggr = tt$content$aggr))) } - pslike_to_ps(param, tt$call, param, param_set) + pslike_to_ps(param, tt$call, param) } } - -tunetoken_to_ps.RangeTuneToken = function(tt, param, param_set) { +tunetoken_to_ps.RangeTuneToken = function(tt, param) { if (!domain_is_number(param)) { stopf("%s for non-numeric param must have zero or one argument.", tt$call) } @@ -322,13 +280,11 @@ tunetoken_to_ps.RangeTuneToken = function(tt, param, param_set) { # create p_int / p_dbl object. Doesn't work if there is a numeric param class that we don't know about :-/ constructor = switch(param$cls, ParamInt = p_int, ParamDbl = p_dbl, stopf("%s: logscale for parameter %s of class %s not supported", tt$call, param$id, param$class)) - content = constructor(lower = bound_lower, upper = bound_upper, logscale = tt$content$logscale, - aggr = tt$content$aggr) - + content = constructor(lower = bound_lower, upper = bound_upper, logscale = tt$content$logscale, aggr = tt$content$aggr) pslike_to_ps(content, tt$call, param) } -tunetoken_to_ps.ObjectTuneToken = function(tt, param, param_set) { +tunetoken_to_ps.ObjectTuneToken = function(tt, param) { pslike_to_ps(tt$content, tt$call, param) } diff --git a/man/Domain.Rd b/man/Domain.Rd index 79b62042..f0123a45 100644 --- a/man/Domain.Rd +++ b/man/Domain.Rd @@ -21,7 +21,8 @@ p_dbl( trafo = NULL, logscale = FALSE, init, - aggr = NULL + aggr = NULL, + in_tune_fn = NULL ) p_fct( @@ -32,7 +33,8 @@ p_fct( depends = NULL, trafo = NULL, init, - aggr = NULL + aggr = NULL, + in_tune_fn = NULL ) p_int( @@ -46,7 +48,8 @@ p_int( trafo = NULL, logscale = FALSE, init, - aggr = NULL + aggr = NULL, + in_tune_fn = NULL ) p_lgl( @@ -56,7 +59,8 @@ p_lgl( depends = NULL, trafo = NULL, init, - aggr = NULL + aggr = NULL, + in_tune_fn = NULL ) p_uty( @@ -68,7 +72,8 @@ p_uty( trafo = NULL, repr = substitute(default), init, - aggr = NULL + aggr = NULL, + in_tune_fn = NULL ) } \arguments{ @@ -144,10 +149,6 @@ defining domains or hyperparameter ranges of learning algorithms, because these Initial value. When this is given, then the corresponding entry in \code{ParamSet$values} is initialized with this value upon construction.} -\item{aggr}{(\code{function})\cr -Function with one argument, which is a list of parameter values. -This specifies how multiple parameter values are aggregated to form a single value in the context of inner tuning.} - \item{levels}{(\code{character} | \code{atomic} | \code{list})\cr Allowed categorical values of the parameter. If this is not a \code{character}, then a \code{trafo} is generated that converts the names (if not given: \code{as.character()} of the values) of the \code{levels} argument to the values. @@ -164,9 +165,6 @@ Defaults to \code{NULL}, which means that no check is performed.} Symbol to use to represent the value given in \code{default}. The \code{deparse()} of this object is used when printing the domain, in some cases.} } -\value{ -A \code{Domain} object. -} \description{ A \code{Domain} object is a representation of a single dimension of a \code{\link{ParamSet}}. \code{Domain} objects are used to construct \code{\link{ParamSet}}s, either through the \code{\link[=ps]{ps()}} short form, through the \code{\link{ParamSet}} constructor itself, diff --git a/man/ParamSet.Rd b/man/ParamSet.Rd index cfc3a121..bbae3627 100644 --- a/man/ParamSet.Rd +++ b/man/ParamSet.Rd @@ -173,6 +173,7 @@ Named with param IDs.} \item \href{#method-ParamSet-set_values}{\code{ParamSet$set_values()}} \item \href{#method-ParamSet-trafo}{\code{ParamSet$trafo()}} \item \href{#method-ParamSet-aggr}{\code{ParamSet$aggr()}} +\item \href{#method-ParamSet-convert_inner_tune_tokens}{\code{ParamSet$convert_inner_tune_tokens()}} \item \href{#method-ParamSet-test_constraint}{\code{ParamSet$test_constraint()}} \item \href{#method-ParamSet-test_constraint_dt}{\code{ParamSet$test_constraint_dt()}} \item \href{#method-ParamSet-check}{\code{ParamSet$check()}} @@ -361,6 +362,17 @@ The aggregation function is selected based on the parameter.} \subsection{Returns}{ (named \code{list()}) } +} +\if{html}{\out{
    }} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ParamSet-convert_inner_tune_tokens}{}}} +\subsection{Method \code{convert_inner_tune_tokens()}}{ +Convert all \code{InnerTuneToken}s to specific parameter values. +These transformations are defined by the \code{in_tune_fn} arguments of the \code{\link{Domain}} objects. +\subsection{Usage}{ +\if{html}{\out{
    }}\preformatted{ParamSet$convert_inner_tune_tokens()}\if{html}{\out{
    }} +} + } \if{html}{\out{
    }} \if{html}{\out{}} diff --git a/man/ParamSetCollection.Rd b/man/ParamSetCollection.Rd index 3a863f06..6ab971bf 100644 --- a/man/ParamSetCollection.Rd +++ b/man/ParamSetCollection.Rd @@ -85,6 +85,7 @@ This field provides direct references to the \code{\link{ParamSet}} objects.}
  • paradox::ParamSet$check()
  • paradox::ParamSet$check_dependencies()
  • paradox::ParamSet$check_dt()
  • +
  • paradox::ParamSet$convert_inner_tune_tokens()
  • paradox::ParamSet$flatten()
  • paradox::ParamSet$format()
  • paradox::ParamSet$get_domain()
  • diff --git a/man/to_tune.Rd b/man/to_tune.Rd index b49f2239..d6fad7c8 100644 --- a/man/to_tune.Rd +++ b/man/to_tune.Rd @@ -11,12 +11,13 @@ to_tune(..., inner = !is.null(aggr), aggr = NULL) \item{...}{if given, restricts the range to be tuning over, as described above.} \item{inner}{(\code{logical(1)})\cr -Whether to create an inner tuning token. -Is set to \code{TRUE} by default if \code{aggr} is provided.} +Whether to create an \code{InnerTuneToken}. +This is only available for parameters tagged with \code{"inner_tuning"}.} \item{aggr}{(\code{function})\cr -The aggregator function that determines how to aggregate a list of parameter values into a single parameter value. -If \code{NULL}, the default aggregation function of the parameter (if available) is used.} +Function with one argument, which is a list of parameter values. +This specifies how multiple parameter values are aggregated to form a single value in the context of inner tuning. +If none specified, the default aggregation function of the parameter will be used.} } \value{ A \code{TuneToken} object. @@ -56,11 +57,6 @@ the range which should be tuned over. The supplied \code{trafo} function is used where a single evaluation-time parameter value (e.g. \code{\link[=p_uty]{p_uty()}}) is constructed from multiple tuner-visible parameters (which may not be \code{\link[=p_uty]{p_uty()}}). If not one-dimensional, the supplied \code{\link{ParamSet}} should always contain a \verb{$extra_trafo} function, which must then always return a \code{list} with a single entry. -\item **\verb{to_tune(..., aggr = , inner = )}: Works like any of the above, but marks the parameter for -learner-internal (inner) tuning when \code{inner = TRUE}. This does not change the behavior of the learner, but does -affect the behavior of the tuner, which will then insert the aggregated internally tuned values into the tuning archive. -For the \code{\link[mlr3tuning:AutoTuner]{AutoTuner}}, this implies that the final model fit will use the internally -optimized values. This e.g. allows to combine XGBoost's earlys stopping with \code{mlr3}'s tuning packages. } The \code{TuneToken} object's internals are subject to change and should not be relied upon. \code{TuneToken} objects should @@ -78,8 +74,7 @@ params = ps( uty2 = p_uty(), uty3 = p_uty(), uty4 = p_uty(), - uty5 = p_uty(), - p_inner = p_int(tags = "inner_tuning", aggr = function(x) round(mean(unlist(x)))) + uty5 = p_uty() ) params$values = list( @@ -126,10 +121,7 @@ params$values = list( )), # not all values need to be tuned! - uty5 = 100, - - # Fix value to 100, but use learner-internal tuning and default aggregation rule - p_inner = to_tune(p_fct(100), inner = TRUE) + uty5 = 100 ) print(params$values) diff --git a/tests/testthat/test_ParamSet.R b/tests/testthat/test_ParamSet.R index 5fd4d794..f6981dff 100644 --- a/tests/testthat/test_ParamSet.R +++ b/tests/testthat/test_ParamSet.R @@ -450,3 +450,28 @@ test_that("aggr", { expect_error(param_set$aggr(list(a = list(), b = list(), c = list(), d = list())), "permutation") expect_error(param_set$aggr(list(a = list(), b = list(), c = list(), d = list(), e = list())), "At least one") }) + +test_that("convert_inner_tune_tokens", { + param_set = ps( + a = p_int(lower = 1, upper = 100, tags = "inner_tuning", in_tune_fn = function(domain, param_set) domain$upper, + aggr = function(x) round(mean(unlist(x)))) + ) + param_set$set_values(a = to_tune(inner = TRUE)) + expect_identical(param_set$convert_inner_tune_tokens(), list(a = 100)) + param_set$set_values(a = to_tune(inner = TRUE, upper = 99)) + expect_identical(param_set$convert_inner_tune_tokens(), list(a = 99)) + + param_set$set_values(a = to_tune(inner = FALSE)) + expect_identical(param_set$convert_inner_tune_tokens(), named_list()) +}) + +test_that("get_values works with inner_tune", { + param_set = ps( + a = p_int(lower = 1, upper = 100, tags = "inner_tuning", in_tune_fn = function(domain, param_set) domain$upper, + aggr = function(x) round(mean(unlist(x)))) + ) + param_set$set_values(a = to_tune(inner = TRUE)) + expect_list(param_set$get_values(type = "with_inner"), len = 1L) + param_set$set_values(a = to_tune()) + expect_list(param_set$get_values(type = "with_inner"), len = 0L) +}) diff --git a/tests/testthat/test_domain.R b/tests/testthat/test_domain.R index 1214c6b4..0ebb7b92 100644 --- a/tests/testthat/test_domain.R +++ b/tests/testthat/test_domain.R @@ -349,23 +349,24 @@ test_that("$extra_trafo flag works", { }) test_that("inner", { - it = to_tune(1, inner = TRUE) + expect_error(to_tune(1, inner = TRUE), "can currently") + it = to_tune(upper = 1, inner = TRUE) expect_class(it, "InnerTuneToken") + expect_class(it, "RangeTuneToken") expect_null(it$aggr) - tt = to_tune(1, inner = TRUE) - expect_equal(it$content, tt$content) + + it1 = to_tune(upper = 1, inner = TRUE) + expect_equal(it1$content, it$content) it1 = to_tune(aggr = function(x) min(unlist(x))) expect_equal(it1$content$aggr(list(1, 2)), 1) param_set = ps( - a = p_dbl(1, 10, aggr = function(x) mean(unlist(x)), tags = "inner_tuning") + a = p_dbl(1, 10, aggr = function(x) mean(unlist(x)), tags = "inner_tuning", in_tune_fn = function(domain, param_set) domain$upper) ) param_set$set_values(a = to_tune(inner = TRUE)) expect_class(param_set$values$a, "InnerTuneToken") - param_set$set_values(a = to_tune(p_fct(1.2), inner = TRUE)) - expect_class(param_set$values$a, "InnerTuneToken") - param_set$set_values(a = to_tune(1.2, 2.3, inner = TRUE)) - expect_class(param_set$values$a, "InnerTuneToken") - param_set$set_values(a = to_tune(1.2, 2.3, logscale = TRUE, inner = TRUE)) - expect_class(param_set$values$a, "InnerTuneToken") + expect_error(param_set$set_values(a = to_tune(inner = TRUE, logscale = TRUE)), "Cannot combine") + + expect_error(p_dbl(lower = 1, upper = 2, tags = "inner_tuning", "in_tune_fn")) }) + diff --git a/tests/testthat/test_to_tune.R b/tests/testthat/test_to_tune.R index 8d6403b8..43c33dad 100644 --- a/tests/testthat/test_to_tune.R +++ b/tests/testthat/test_to_tune.R @@ -402,10 +402,10 @@ test_that("logscale in tunetoken", { test_that("inner and aggr", { # no default aggregation function - param_set = ps(a = p_dbl(lower = 1, upper = 2, tags = "inner_tuning")) + param_set = ps(a = p_dbl(lower = 1, upper = 2, tags = "inner_tuning", in_tune_fn = function(domain, param_set) domain$upper)) # correct errors - expect_error(param_set$set_values(a = to_tune(inner = TRUE)), "Provide an aggregation") + expect_error(param_set$set_values(a = to_tune(inner = TRUE)), "but no aggregation function is available") expect_error(param_set$set_values(a = to_tune(inner = FALSE, aggr = function(x) 1))) @@ -424,14 +424,14 @@ test_that("inner and aggr", { # other trafos + inner: not allowed expect_error( param_set$set_values(a = to_tune(ps(a = p_dbl(0, 1), .extra_trafo = function(x) 1L), aggr = function(x) -99)), - "inner tuning" + "can currently not be combined" ) # param set + inner # range + inner - param_set$set_values(a = to_tune(lower = 1, upper = 2, aggr = function(x) 1.5)) - expect_equal(s$search_space()$aggr(list(a = list(1, 2))), list(a = 1.5)) + param_set$set_values(a = to_tune(lower = 1.2, upper = 1.3, aggr = function(x) 1.5)) + expect_equal(param_set$search_space()$aggr(list(a = list(1, 2))), list(a = 1.5)) # full + inner From 41e7d812ffd5fbba78817dd4fe894fb64001a673 Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Sat, 4 May 2024 09:39:59 +0200 Subject: [PATCH 15/34] cleanup --- NEWS.md | 2 +- R/Domain.R | 21 +++++++++++++++++++++ R/ParamDbl.R | 8 ++++---- R/ParamFct.R | 7 ++++--- R/ParamInt.R | 6 +++++- R/ParamLgl.R | 7 ++++--- R/ParamSet.R | 25 +++++++++++++++---------- R/ParamUty.R | 7 ++++--- R/to_tune.R | 4 ++-- man/Domain.Rd | 24 ++++++++++++++++++++++++ man/ParamSet.Rd | 8 +++++--- man/to_tune.Rd | 2 +- tests/testthat/test_ParamSet.R | 2 +- tests/testthat/test_domain.R | 2 +- tests/testthat/test_to_tune.R | 24 +++++++++++++++++++----- 15 files changed, 111 insertions(+), 38 deletions(-) diff --git a/NEWS.md b/NEWS.md index f3b1f9ab..66d60ef6 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,6 +1,6 @@ # dev -* feat: added support for `aggr`(egation function) which can be used for inner tuning. +* feat: added support for `InnerTuneToken`s # paradox 0.12.0 * Removed `Param` objects. `ParamSet` now uses a `data.table` internally; individual parameters are more like `Domain` objects now. `ParamSets` should be constructed using the `ps()` shorthand and `Domain` objects. This entails the following major changes: diff --git a/R/Domain.R b/R/Domain.R index 81af4c5a..bf1bb00a 100644 --- a/R/Domain.R +++ b/R/Domain.R @@ -63,6 +63,14 @@ #' @param init (`any`)\cr #' Initial value. When this is given, then the corresponding entry in `ParamSet$values` is initialized with this #' value upon construction. +#' @param aggr (`function`)\cr +#' Default aggregation function for a parameter. Can only be given for parameters tagged with `"inner_tuning"`. +#' Function with one argument, which is a list of parameter values and that returns the aggregated parameter value. +#' @param in_tune_fn (`function(domain, param_set)`)\cr +#' Function that converters a `Domain` object into a parameter value. +#' Can onlye be given for parameters tagged with `"inner_tuning"`. +#' +#' @return A `Domain` object. #' #' @details #' Although the `levels` values of a constructed `p_fct()` will always be `character`-valued, the `p_fct` function admits @@ -115,6 +123,19 @@ #' # ... but get transformed to integers. #' print(grid$transpose()) #' +#' +#' # inner tuning +#' +#' param_set = ps( +#' iters = p_int(0, Inf, tags = "inner_tuning", aggr = function(x) round(mean(unlist(x))), +#' in_tune_fn = function(domain, param_set) domain$upper) +#' ) +#' param_set$set_values( +#' iters = to_tune(upper = 100, inner = TRUE) +#' ) +#' param_set$convert_inner_tune_tokens() +#' param_set$aggr(list(iters = list(1, 2, 3))) +#' #' @family ParamSet construction helpers #' @name Domain NULL diff --git a/R/ParamDbl.R b/R/ParamDbl.R index c757e358..7c010cc8 100644 --- a/R/ParamDbl.R +++ b/R/ParamDbl.R @@ -2,9 +2,10 @@ #' @export p_dbl = function(lower = -Inf, upper = Inf, special_vals = list(), default = NO_DEF, tags = character(), tolerance = sqrt(.Machine$double.eps), depends = NULL, trafo = NULL, logscale = FALSE, init, aggr = NULL, in_tune_fn = NULL) { assert_function(aggr, null.ok = TRUE, nargs = 1L) - assert_function(in_tune_fn, null.ok = "inner_tuning" %nin% tags, args = c("domain", "param_set"), nargs = 2L) - if ("inner_tuning" %nin% tags && !is.null(in_tune_fn)) { - stopf("Cannot only provide 'in_tune_fn' when parameter is tagged with 'inner_tuning'") + if ("inner_tuning" %in% tags) { + assert_function(in_tune_fn, null.ok = FALSE, args = c("domain", "param_set"), nargs = 2L) + } else { + assert_true(is.null(in_tune_fn)) } assert_number(tolerance, lower = 0) assert_number(lower) @@ -26,7 +27,6 @@ p_dbl = function(lower = -Inf, upper = Inf, special_vals = list(), default = NO_ if (logscale) cargo$logscale = TRUE cargo$aggr = aggr if (!is.null(in_tune_fn)) cargo$in_tune_fn = in_tune_fn - Domain(cls = "ParamDbl", grouping = "ParamDbl", lower = real_lower, upper = real_upper, special_vals = special_vals, default = default, tags = tags, tolerance = tolerance, trafo = trafo, storage_type = "numeric", depends_expr = substitute(depends), init = init, cargo = if (length(cargo)) cargo) } diff --git a/R/ParamFct.R b/R/ParamFct.R index d6504fe9..5ef4ea6d 100644 --- a/R/ParamFct.R +++ b/R/ParamFct.R @@ -4,9 +4,10 @@ p_fct = function(levels, special_vals = list(), default = NO_DEF, tags = charact assert_function(aggr, null.ok = TRUE, nargs = 1L) constargs = as.list(match.call()[-1]) levels = eval.parent(constargs$levels) - assert_function(in_tune_fn, null.ok = TRUE, args = c("domain", "param_set"), nargs = 2L) - if ("inner_tuning" %nin% tags && !is.null(in_tune_fn)) { - stopf("Cannot only provide 'in_tune_fn' when parameter is tagged with 'inner_tuning'") + if ("inner_tuning" %in% tags) { + assert_function(in_tune_fn, null.ok = FALSE, args = c("domain", "param_set"), nargs = 2L) + } else { + assert_true(is.null(in_tune_fn)) } if (!is.character(levels)) { # if the "levels" argument is not a character vector, then diff --git a/R/ParamInt.R b/R/ParamInt.R index 3099b39a..38487142 100644 --- a/R/ParamInt.R +++ b/R/ParamInt.R @@ -4,7 +4,11 @@ p_int = function(lower = -Inf, upper = Inf, special_vals = list(), default = NO_DEF, tags = character(), tolerance = sqrt(.Machine$double.eps), depends = NULL, trafo = NULL, logscale = FALSE, init, aggr = NULL, in_tune_fn = NULL) { assert_function(aggr, null.ok = TRUE, nargs = 1L) assert_number(tolerance, lower = 0, upper = 0.5) - assert_function(in_tune_fn, null.ok = TRUE, args = c("domain", "param_set"), nargs = 2L) + if ("inner_tuning" %in% tags) { + assert_function(in_tune_fn, null.ok = FALSE, args = c("domain", "param_set"), nargs = 2L) + } else { + assert_true(is.null(in_tune_fn)) + } if ("inner_tuning" %nin% tags && !is.null(in_tune_fn)) { stopf("Cannot only provide 'in_tune_fn' when parameter is tagged with 'inner_tuning'") } diff --git a/R/ParamLgl.R b/R/ParamLgl.R index 7ef9c9d8..fda99dea 100644 --- a/R/ParamLgl.R +++ b/R/ParamLgl.R @@ -2,9 +2,10 @@ #' @export p_lgl = function(special_vals = list(), default = NO_DEF, tags = character(), depends = NULL, trafo = NULL, init, aggr = NULL, in_tune_fn = NULL) { assert_function(aggr, null.ok = TRUE, nargs = 1L) - assert_function(in_tune_fn, null.ok = TRUE, args = c("domain", "param_set"), nargs = 2L) - if ("inner_tuning" %nin% tags && !is.null(in_tune_fn)) { - stopf("Cannot only provide 'in_tune_fn' when parameter is tagged with 'inner_tuning'") + if ("inner_tuning" %in% tags) { + assert_function(in_tune_fn, null.ok = FALSE, args = c("domain", "param_set"), nargs = 2L) + } else { + assert_true(is.null(in_tune_fn)) } cargo = c(aggr = aggr, in_tune_fn = in_tune_fn) diff --git a/R/ParamSet.R b/R/ParamSet.R index bb9cb1d4..54c5e40b 100644 --- a/R/ParamSet.R +++ b/R/ParamSet.R @@ -260,37 +260,42 @@ ParamSet = R6Class("ParamSet", #' @description #' - #' Aggregate parameter values according to the aggregation rules. + #' Aggregate parameter values according to their aggregation rules. #' #' @param x (named `list()` of `list()`s)\cr #' The value(s) to be aggregated. Names are parameter values. #' The aggregation function is selected based on the parameter. + #' #' @return (named `list()`) aggr = function(x) { assert_list(x, types = "list") aggrs = private$.params[map_lgl(get("cargo"), function(cargo) is.function(cargo$aggr)), list(id = get("id"), aggr = map(get("cargo"), "aggr"))] assert_permutation(names(x), aggrs$id) - if (!(length(unique(lengths(x))) == 1L)) { - stopf("The same number of values are required for each parameter") - } - if (nrow(aggrs) && !length(x[[1L]])) { - stopf("At least one value is required to aggregate them") + if (!nrow(aggrs)) { + return(named_list()) } - imap(x, function(value, .id) { + if (!length(value)) { + stopf("Trying to aggregate values of parameters '%s', but there are no values", .id) + } aggr = aggrs[list(.id), "aggr", on = "id"][[1L]][[1L]](value) }) }, #' @description - #' Convert all `InnerTuneToken`s to specific parameter values. - #' These transformations are defined by the `in_tune_fn` arguments of the [`Domain`] objects. + #' Convert all `InnerTuneToken`s to parameter values as is defined by their `in_tune_fn`. + #' + #' @return (named `list()`) convert_inner_tune_tokens = function() { inner_tune_tokens = self$get_values(type = "with_inner") inner_tune_ps = private$get_tune_ps(inner_tune_tokens) imap(inner_tune_ps$domains, function(token, .id) { - converter = private$.params[list(.id), "cargo", on = "id"][[1L]][[1L]]$in_tune_fn(token) + converter = private$.params[list(.id), "cargo", on = "id"][[1L]][[1L]]$in_tune_fn + if (!is.function(converter)) { + stopf("No converter exists for InnerTuneToken of parameters '%s'", .id) + } + converter(token) }) }, diff --git a/R/ParamUty.R b/R/ParamUty.R index 48b18149..896daacf 100644 --- a/R/ParamUty.R +++ b/R/ParamUty.R @@ -3,9 +3,10 @@ #' @export p_uty = function(custom_check = NULL, special_vals = list(), default = NO_DEF, tags = character(), depends = NULL, trafo = NULL, repr = substitute(default), init, aggr = NULL, in_tune_fn = NULL) { assert_function(custom_check, null.ok = TRUE) - assert_function(in_tune_fn, null.ok = TRUE, args = c("domain", "param_set"), nargs = 2L) - if ("inner_tuning" %nin% tags && !is.null(in_tune_fn)) { - stopf("Cannot only provide 'in_tune_fn' when parameter is tagged with 'inner_tuning'") + if ("inner_tuning" %in% tags) { + assert_function(in_tune_fn, null.ok = FALSE, args = c("domain", "param_set"), nargs = 2L) + } else { + assert_true(is.null(in_tune_fn)) } assert_function(aggr, null.ok = TRUE, nargs = 1L) if (!is.null(custom_check)) { diff --git a/R/to_tune.R b/R/to_tune.R index 6e84afbe..00eb43db 100644 --- a/R/to_tune.R +++ b/R/to_tune.R @@ -45,7 +45,7 @@ #' Whether to create an `InnerTuneToken`. #' This is only available for parameters tagged with `"inner_tuning"`. #' @param aggr (`function`)\cr -#' Function with one argument, which is a list of parameter values. +#' Function with one argument, which is a list of parameter values and returns a single aggregated value (e.g. the mean). #' This specifies how multiple parameter values are aggregated to form a single value in the context of inner tuning. #' If none specified, the default aggregation function of the parameter will be used. #' @return A `TuneToken` object. @@ -194,7 +194,7 @@ to_tune = function(..., inner = !is.null(aggr), aggr = NULL) { if (inner) { if (type == "ObjectTuneToken") { - stop("Inner tuning can currently not be combined with ParamSet or Domain object.") + stop("Inner tuning can currently not be combined with ParamSet or Domain object, specify lower and upper bounds, e.g. to_tune(upper = 100)") } if (isTRUE(content$logscale)) { stop("Cannot combine logscale transformation with inner tuning.") diff --git a/man/Domain.Rd b/man/Domain.Rd index f0123a45..7043a77f 100644 --- a/man/Domain.Rd +++ b/man/Domain.Rd @@ -149,6 +149,14 @@ defining domains or hyperparameter ranges of learning algorithms, because these Initial value. When this is given, then the corresponding entry in \code{ParamSet$values} is initialized with this value upon construction.} +\item{aggr}{(\code{function})\cr +Default aggregation function for a parameter. Can only be given for parameters tagged with \code{"inner_tuning"}. +Function with one argument, which is a list of parameter values and that returns the aggregated parameter value.} + +\item{in_tune_fn}{(\verb{function(domain, param_set)})\cr +Function that converters a \code{Domain} object into a parameter value. +Can onlye be given for parameters tagged with \code{"inner_tuning"}.} + \item{levels}{(\code{character} | \code{atomic} | \code{list})\cr Allowed categorical values of the parameter. If this is not a \code{character}, then a \code{trafo} is generated that converts the names (if not given: \code{as.character()} of the values) of the \code{levels} argument to the values. @@ -165,6 +173,9 @@ Defaults to \code{NULL}, which means that no check is performed.} Symbol to use to represent the value given in \code{default}. The \code{deparse()} of this object is used when printing the domain, in some cases.} } +\value{ +A \code{Domain} object. +} \description{ A \code{Domain} object is a representation of a single dimension of a \code{\link{ParamSet}}. \code{Domain} objects are used to construct \code{\link{ParamSet}}s, either through the \code{\link[=ps]{ps()}} short form, through the \code{\link{ParamSet}} constructor itself, @@ -229,6 +240,19 @@ print(grid) # ... but get transformed to integers. print(grid$transpose()) + +# inner tuning + +param_set = ps( + iters = p_int(0, Inf, tags = "inner_tuning", aggr = function(x) round(mean(unlist(x))), + in_tune_fn = function(domain, param_set) domain$upper) +) +param_set$set_values( + iters = to_tune(upper = 100, inner = TRUE) +) +param_set$convert_inner_tune_tokens() +param_set$aggr(list(iters = list(1, 2, 3))) + } \seealso{ Other ParamSet construction helpers: diff --git a/man/ParamSet.Rd b/man/ParamSet.Rd index bbae3627..beb2ca0f 100644 --- a/man/ParamSet.Rd +++ b/man/ParamSet.Rd @@ -345,7 +345,7 @@ In almost all cases, the default \code{param_set = self} should be used.} \if{html}{\out{}} \if{latex}{\out{\hypertarget{method-ParamSet-aggr}{}}} \subsection{Method \code{aggr()}}{ -Aggregate parameter values according to the aggregation rules. +Aggregate parameter values according to their aggregation rules. \subsection{Usage}{ \if{html}{\out{
    }}\preformatted{ParamSet$aggr(x)}\if{html}{\out{
    }} } @@ -367,12 +367,14 @@ The aggregation function is selected based on the parameter.} \if{html}{\out{}} \if{latex}{\out{\hypertarget{method-ParamSet-convert_inner_tune_tokens}{}}} \subsection{Method \code{convert_inner_tune_tokens()}}{ -Convert all \code{InnerTuneToken}s to specific parameter values. -These transformations are defined by the \code{in_tune_fn} arguments of the \code{\link{Domain}} objects. +Convert all \code{InnerTuneToken}s to parameter values as is defined by their \code{in_tune_fn}. \subsection{Usage}{ \if{html}{\out{
    }}\preformatted{ParamSet$convert_inner_tune_tokens()}\if{html}{\out{
    }} } +\subsection{Returns}{ +(named \code{list()}) +} } \if{html}{\out{
    }} \if{html}{\out{}} diff --git a/man/to_tune.Rd b/man/to_tune.Rd index d6fad7c8..fb91d150 100644 --- a/man/to_tune.Rd +++ b/man/to_tune.Rd @@ -15,7 +15,7 @@ Whether to create an \code{InnerTuneToken}. This is only available for parameters tagged with \code{"inner_tuning"}.} \item{aggr}{(\code{function})\cr -Function with one argument, which is a list of parameter values. +Function with one argument, which is a list of parameter values and returns a single aggregated value (e.g. the mean). This specifies how multiple parameter values are aggregated to form a single value in the context of inner tuning. If none specified, the default aggregation function of the parameter will be used.} } diff --git a/tests/testthat/test_ParamSet.R b/tests/testthat/test_ParamSet.R index f6981dff..4e476525 100644 --- a/tests/testthat/test_ParamSet.R +++ b/tests/testthat/test_ParamSet.R @@ -448,7 +448,7 @@ test_that("aggr", { expect_error(param_set$aggr(1), "list") expect_error(param_set$aggr(list(1)), "list") expect_error(param_set$aggr(list(a = list(), b = list(), c = list(), d = list())), "permutation") - expect_error(param_set$aggr(list(a = list(), b = list(), c = list(), d = list(), e = list())), "At least one") + expect_error(param_set$aggr(list(a = list(), b = list(), c = list(), d = list(), e = list())), "but there are no") }) test_that("convert_inner_tune_tokens", { diff --git a/tests/testthat/test_domain.R b/tests/testthat/test_domain.R index 0ebb7b92..9cdeae8d 100644 --- a/tests/testthat/test_domain.R +++ b/tests/testthat/test_domain.R @@ -363,7 +363,7 @@ test_that("inner", { param_set = ps( a = p_dbl(1, 10, aggr = function(x) mean(unlist(x)), tags = "inner_tuning", in_tune_fn = function(domain, param_set) domain$upper) ) - param_set$set_values(a = to_tune(inner = TRUE)) + param_set$set_values(a = to_tune(inner = TRUE, aggr = function(x) round(mean(unlist(x))))) expect_class(param_set$values$a, "InnerTuneToken") expect_error(param_set$set_values(a = to_tune(inner = TRUE, logscale = TRUE)), "Cannot combine") diff --git a/tests/testthat/test_to_tune.R b/tests/testthat/test_to_tune.R index 43c33dad..92ab79a3 100644 --- a/tests/testthat/test_to_tune.R +++ b/tests/testthat/test_to_tune.R @@ -427,22 +427,36 @@ test_that("inner and aggr", { "can currently not be combined" ) - # param set + inner - # range + inner param_set$set_values(a = to_tune(lower = 1.2, upper = 1.3, aggr = function(x) 1.5)) expect_equal(param_set$search_space()$aggr(list(a = list(1, 2))), list(a = 1.5)) + expect_equal(param_set$convert_inner_tune_tokens(), list(a = 1.3)) # full + inner + param_set$set_values(a = to_tune(inner = TRUE, aggr = function(x) 1.5)) + expect_equal(param_set$convert_inner_tune_tokens(), list(a = 2)) # domain + inner - + expect_error( + param_set$set_values(a = to_tune(p_dbl(1.21, 1.22), aggr = function(x) 1.5, inner = TRUE)), + "specify lower and upper" + ) ## with default aggregation function - # default aggregation function is used when not overwritten + # param set + inner + param_set = ps(a = p_int(lower = 1, upper = 10000, tags = "inner_tuning", in_tune_fn = function(domain, param_set) domain$upper, + aggr = function(x) max(unlist(x)))) + # default aggregation function is used when not overwritten + param_set$set_values( + a = to_tune(inner = TRUE) + ) + expect_equal(param_set$search_space()$aggr(list(a = list(1, 2, 3))), list(a = 3)) # can overwrite existing aggregation function - # check all cases + param_set$set_values( + a = to_tune(inner = TRUE, aggr = function(x) -60) + ) + expect_equal(param_set$search_space()$aggr(list(a = list(1, 2, 3))), list(a = -60)) }) From 2a43aab9041bed3db78814e7fa203c33f6cb66bc Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Mon, 6 May 2024 09:25:38 +0200 Subject: [PATCH 16/34] fix: add tags to domains created for inner tuning --- R/to_tune.R | 19 +++++++++++++++---- tests/testthat/test_ParamSet.R | 13 +++++++++++++ tests/testthat/test_to_tune.R | 2 +- 3 files changed, 29 insertions(+), 5 deletions(-) diff --git a/R/to_tune.R b/R/to_tune.R index 00eb43db..f5247a70 100644 --- a/R/to_tune.R +++ b/R/to_tune.R @@ -238,11 +238,11 @@ print.ObjectTuneToken = function(x, ...) { # # Makes liberal use to `pslike_to_ps` (converting Param, ParamSet, Domain to ParamSet) # param is a data.table that is potentially modified by reference using data.table set() methods. -tunetoken_to_ps = function(tt, param) { +tunetoken_to_ps = function(tt, param, ...) { UseMethod("tunetoken_to_ps") } -tunetoken_to_ps.FullTuneToken = function(tt, param) { +tunetoken_to_ps.FullTuneToken = function(tt, param, ...) { if (!domain_is_bounded(param)) { stopf("%s must give a range for unbounded parameter %s.", tt$call, param$id) } @@ -258,7 +258,18 @@ tunetoken_to_ps.FullTuneToken = function(tt, param) { } } -tunetoken_to_ps.RangeTuneToken = function(tt, param) { +tunetoken_to_ps.InnerTuneToken = function(tt, param, ...) { + # Calling NextMethod with additional arguments behaves weirdly, as the InnerTuneToken only works with ranges right now + # we just call it directly + aggr = if (!is.null(tt$content$aggr)) tt$content$aggr else param$cargo[[1L]]$aggr + if (is.null(aggr)) { + stopf("%s must specify a aggregation function for parameter %s", tt$call, param$id) + } + tunetoken_to_ps.RangeTuneToken(tt = tt, param = param, in_tune_fn = param$cargo[[1L]]$in_tune_fn, tags = "inner_tuning", + aggr = aggr) +} + +tunetoken_to_ps.RangeTuneToken = function(tt, param, args = list(), ...) { if (!domain_is_number(param)) { stopf("%s for non-numeric param must have zero or one argument.", tt$call) } @@ -280,7 +291,7 @@ tunetoken_to_ps.RangeTuneToken = function(tt, param) { # create p_int / p_dbl object. Doesn't work if there is a numeric param class that we don't know about :-/ constructor = switch(param$cls, ParamInt = p_int, ParamDbl = p_dbl, stopf("%s: logscale for parameter %s of class %s not supported", tt$call, param$id, param$class)) - content = constructor(lower = bound_lower, upper = bound_upper, logscale = tt$content$logscale, aggr = tt$content$aggr) + content = constructor(lower = bound_lower, upper = bound_upper, logscale = tt$content$logscale, ...) pslike_to_ps(content, tt$call, param) } diff --git a/tests/testthat/test_ParamSet.R b/tests/testthat/test_ParamSet.R index 4e476525..4886036c 100644 --- a/tests/testthat/test_ParamSet.R +++ b/tests/testthat/test_ParamSet.R @@ -475,3 +475,16 @@ test_that("get_values works with inner_tune", { param_set$set_values(a = to_tune()) expect_list(param_set$get_values(type = "with_inner"), len = 0L) }) + +test_that("InnerTuneToken is translated to 'inner_tuning' tag when creating search space", { + param_set = ps( + a = p_int(0, Inf, tags = "inner_tuning", in_tune_fn = function(domain, param_set) domain$upper, aggr = function(x) round(mean(unlist(x)))) + ) + + param_set$set_values( + a = to_tune(upper = 100, inner = TRUE) + ) + + ss = param_set$search_space() + expect_true("inner_tuning" %in% ss$tags$a) +}) diff --git a/tests/testthat/test_to_tune.R b/tests/testthat/test_to_tune.R index 92ab79a3..0c540302 100644 --- a/tests/testthat/test_to_tune.R +++ b/tests/testthat/test_to_tune.R @@ -405,7 +405,7 @@ test_that("inner and aggr", { param_set = ps(a = p_dbl(lower = 1, upper = 2, tags = "inner_tuning", in_tune_fn = function(domain, param_set) domain$upper)) # correct errors - expect_error(param_set$set_values(a = to_tune(inner = TRUE)), "but no aggregation function is available") + expect_error(param_set$set_values(a = to_tune(inner = TRUE)), "aggregation") expect_error(param_set$set_values(a = to_tune(inner = FALSE, aggr = function(x) 1))) From 94f0f18310ee6164def807894050c9f1535566e3 Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Mon, 6 May 2024 11:21:27 +0200 Subject: [PATCH 17/34] rename inner to internal --- NAMESPACE | 2 +- NEWS.md | 2 +- R/Domain.R | 12 +++++----- R/ParamDbl.R | 2 +- R/ParamFct.R | 2 +- R/ParamInt.R | 6 ++--- R/ParamLgl.R | 2 +- R/ParamSet.R | 30 ++++++++++++------------- R/ParamUty.R | 4 ++-- R/to_tune.R | 32 +++++++++++++-------------- man/Domain.Rd | 12 +++++----- man/ParamSet.Rd | 12 +++++----- man/ParamSetCollection.Rd | 2 +- man/to_tune.Rd | 10 ++++----- tests/testthat/test_ParamSet.R | 34 ++++++++++++++--------------- tests/testthat/test_domain.R | 20 ++++++++--------- tests/testthat/test_to_tune.R | 40 +++++++++++++++++----------------- 17 files changed, 112 insertions(+), 112 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index 35a657b1..71371801 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -49,7 +49,7 @@ S3method(format,Condition) S3method(print,Condition) S3method(print,Domain) S3method(print,FullTuneToken) -S3method(print,InnerTuneToken) +S3method(print,InternalTuneToken) S3method(print,ObjectTuneToken) S3method(print,RangeTuneToken) S3method(rd_info,ParamSet) diff --git a/NEWS.md b/NEWS.md index 66d60ef6..d2e26c6f 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,6 +1,6 @@ # dev -* feat: added support for `InnerTuneToken`s +* feat: added support for `InternalTuneToken`s # paradox 0.12.0 * Removed `Param` objects. `ParamSet` now uses a `data.table` internally; individual parameters are more like `Domain` objects now. `ParamSets` should be constructed using the `ps()` shorthand and `Domain` objects. This entails the following major changes: diff --git a/R/Domain.R b/R/Domain.R index bf1bb00a..f53c7ebf 100644 --- a/R/Domain.R +++ b/R/Domain.R @@ -64,11 +64,11 @@ #' Initial value. When this is given, then the corresponding entry in `ParamSet$values` is initialized with this #' value upon construction. #' @param aggr (`function`)\cr -#' Default aggregation function for a parameter. Can only be given for parameters tagged with `"inner_tuning"`. +#' Default aggregation function for a parameter. Can only be given for parameters tagged with `"internal_tuning"`. #' Function with one argument, which is a list of parameter values and that returns the aggregated parameter value. #' @param in_tune_fn (`function(domain, param_set)`)\cr #' Function that converters a `Domain` object into a parameter value. -#' Can onlye be given for parameters tagged with `"inner_tuning"`. +#' Can onlye be given for parameters tagged with `"internal_tuning"`. #' #' @return A `Domain` object. #' @@ -124,16 +124,16 @@ #' print(grid$transpose()) #' #' -#' # inner tuning +#' # internal tuning #' #' param_set = ps( -#' iters = p_int(0, Inf, tags = "inner_tuning", aggr = function(x) round(mean(unlist(x))), +#' iters = p_int(0, Inf, tags = "internal_tuning", aggr = function(x) round(mean(unlist(x))), #' in_tune_fn = function(domain, param_set) domain$upper) #' ) #' param_set$set_values( -#' iters = to_tune(upper = 100, inner = TRUE) +#' iters = to_tune(upper = 100, internal = TRUE) #' ) -#' param_set$convert_inner_tune_tokens() +#' param_set$convert_internal_tune_tokens() #' param_set$aggr(list(iters = list(1, 2, 3))) #' #' @family ParamSet construction helpers diff --git a/R/ParamDbl.R b/R/ParamDbl.R index 7c010cc8..9f3075c6 100644 --- a/R/ParamDbl.R +++ b/R/ParamDbl.R @@ -2,7 +2,7 @@ #' @export p_dbl = function(lower = -Inf, upper = Inf, special_vals = list(), default = NO_DEF, tags = character(), tolerance = sqrt(.Machine$double.eps), depends = NULL, trafo = NULL, logscale = FALSE, init, aggr = NULL, in_tune_fn = NULL) { assert_function(aggr, null.ok = TRUE, nargs = 1L) - if ("inner_tuning" %in% tags) { + if ("internal_tuning" %in% tags) { assert_function(in_tune_fn, null.ok = FALSE, args = c("domain", "param_set"), nargs = 2L) } else { assert_true(is.null(in_tune_fn)) diff --git a/R/ParamFct.R b/R/ParamFct.R index 5ef4ea6d..422bcf61 100644 --- a/R/ParamFct.R +++ b/R/ParamFct.R @@ -4,7 +4,7 @@ p_fct = function(levels, special_vals = list(), default = NO_DEF, tags = charact assert_function(aggr, null.ok = TRUE, nargs = 1L) constargs = as.list(match.call()[-1]) levels = eval.parent(constargs$levels) - if ("inner_tuning" %in% tags) { + if ("internal_tuning" %in% tags) { assert_function(in_tune_fn, null.ok = FALSE, args = c("domain", "param_set"), nargs = 2L) } else { assert_true(is.null(in_tune_fn)) diff --git a/R/ParamInt.R b/R/ParamInt.R index 38487142..91f1c57b 100644 --- a/R/ParamInt.R +++ b/R/ParamInt.R @@ -4,13 +4,13 @@ p_int = function(lower = -Inf, upper = Inf, special_vals = list(), default = NO_DEF, tags = character(), tolerance = sqrt(.Machine$double.eps), depends = NULL, trafo = NULL, logscale = FALSE, init, aggr = NULL, in_tune_fn = NULL) { assert_function(aggr, null.ok = TRUE, nargs = 1L) assert_number(tolerance, lower = 0, upper = 0.5) - if ("inner_tuning" %in% tags) { + if ("internal_tuning" %in% tags) { assert_function(in_tune_fn, null.ok = FALSE, args = c("domain", "param_set"), nargs = 2L) } else { assert_true(is.null(in_tune_fn)) } - if ("inner_tuning" %nin% tags && !is.null(in_tune_fn)) { - stopf("Cannot only provide 'in_tune_fn' when parameter is tagged with 'inner_tuning'") + if ("internal_tuning" %nin% tags && !is.null(in_tune_fn)) { + stopf("Cannot only provide 'in_tune_fn' when parameter is tagged with 'internal_tuning'") } # assert_int will stop for `Inf` values, which we explicitly allow as lower / upper bound if (!isTRUE(is.infinite(lower))) assert_int(lower, tol = 1e-300) else assert_number(lower) diff --git a/R/ParamLgl.R b/R/ParamLgl.R index fda99dea..1a6330af 100644 --- a/R/ParamLgl.R +++ b/R/ParamLgl.R @@ -2,7 +2,7 @@ #' @export p_lgl = function(special_vals = list(), default = NO_DEF, tags = character(), depends = NULL, trafo = NULL, init, aggr = NULL, in_tune_fn = NULL) { assert_function(aggr, null.ok = TRUE, nargs = 1L) - if ("inner_tuning" %in% tags) { + if ("internal_tuning" %in% tags) { assert_function(in_tune_fn, null.ok = FALSE, args = c("domain", "param_set"), nargs = 2L) } else { assert_true(is.null(in_tune_fn)) diff --git a/R/ParamSet.R b/R/ParamSet.R index 54c5e40b..9cf8d0c4 100644 --- a/R/ParamSet.R +++ b/R/ParamSet.R @@ -155,7 +155,7 @@ ParamSet = R6Class("ParamSet", #' @param type (`character(1)`)\cr #' Return values `"with_token"` (i.e. all values), # `"without_token"` (all values that are not [`TuneToken`] objects), `"only_token"` (only [`TuneToken`] objects) - # or `"with_inner"` (all values that are no not `InnerTuneToken`)? + # or `"with_internal"` (all values that are no not `InternalTuneToken`)? #' @param check_required (`logical(1)`)\cr #' Check if all required parameters are set? #' @param remove_dependencies (`logical(1)`)\cr @@ -163,7 +163,7 @@ ParamSet = R6Class("ParamSet", #' @return Named `list()`. get_values = function(class = NULL, tags = NULL, any_tags = NULL, type = "with_token", check_required = TRUE, remove_dependencies = TRUE) { - assert_choice(type, c("with_token", "without_token", "only_token", "with_inner")) + assert_choice(type, c("with_token", "without_token", "only_token", "with_internal")) assert_flag(check_required) @@ -174,8 +174,8 @@ ParamSet = R6Class("ParamSet", values = discard(values, is, "TuneToken") } else if (type == "only_token") { values = keep(values, is, "TuneToken") - } else if (type == "with_inner") { - values = keep(values, is, "InnerTuneToken") + } else if (type == "with_internal") { + values = keep(values, is, "InternalTuneToken") } if (check_required) { @@ -283,17 +283,17 @@ ParamSet = R6Class("ParamSet", }, #' @description - #' Convert all `InnerTuneToken`s to parameter values as is defined by their `in_tune_fn`. + #' Convert all `InternalTuneToken`s to parameter values as is defined by their `in_tune_fn`. #' #' @return (named `list()`) - convert_inner_tune_tokens = function() { - inner_tune_tokens = self$get_values(type = "with_inner") - inner_tune_ps = private$get_tune_ps(inner_tune_tokens) + convert_internal_tune_tokens = function() { + internal_tune_tokens = self$get_values(type = "with_internal") + internal_tune_ps = private$get_tune_ps(internal_tune_tokens) - imap(inner_tune_ps$domains, function(token, .id) { + imap(internal_tune_ps$domains, function(token, .id) { converter = private$.params[list(.id), "cargo", on = "id"][[1L]][[1L]]$in_tune_fn if (!is.function(converter)) { - stopf("No converter exists for InnerTuneToken of parameters '%s'", .id) + stopf("No converter exists for InternalTuneToken of parameters '%s'", .id) } converter(token) }) @@ -367,13 +367,13 @@ ParamSet = R6Class("ParamSet", if (!isTRUE(tunecheck)) return(tunecheck) } - xs_innertune = keep(xs, is, "InnerTuneToken") - walk(names(xs_innertune), function(pid) { - if ("inner_tuning" %nin% self$tags[[pid]]) { - stopf("Trying to assign InnerTuneToken to parameter '%s' which is not tagged with 'inner_tuning'.", pid) + xs_internaltune = keep(xs, is, "InternalTuneToken") + walk(names(xs_internaltune), function(pid) { + if ("internal_tuning" %nin% self$tags[[pid]]) { + stopf("Trying to assign InternalTuneToken to parameter '%s' which is not tagged with 'internal_tuning'.", pid) } if (is.null(xs[[pid]]$content$aggr) && is.null(private$.params[pid, "cargo", on = "id"][[1L]][[1L]]$aggr)) { - stopf("Trying to set parameter '%s' to InnerTuneToken, but no aggregation function is available.", pid) + stopf("Trying to set parameter '%s' to InternalTuneToken, but no aggregation function is available.", pid) } }) diff --git a/R/ParamUty.R b/R/ParamUty.R index 896daacf..0521cb88 100644 --- a/R/ParamUty.R +++ b/R/ParamUty.R @@ -3,7 +3,7 @@ #' @export p_uty = function(custom_check = NULL, special_vals = list(), default = NO_DEF, tags = character(), depends = NULL, trafo = NULL, repr = substitute(default), init, aggr = NULL, in_tune_fn = NULL) { assert_function(custom_check, null.ok = TRUE) - if ("inner_tuning" %in% tags) { + if ("internal_tuning" %in% tags) { assert_function(in_tune_fn, null.ok = FALSE, args = c("domain", "param_set"), nargs = 2L) } else { assert_true(is.null(in_tune_fn)) @@ -20,7 +20,7 @@ p_uty = function(custom_check = NULL, special_vals = list(), default = NO_DEF, t } cargo = list(custom_check = custom_check, repr = repr) cargo$aggr = aggr - cargo$inner_tune_fn = in_tune_fn + cargo$internal_tune_fn = in_tune_fn Domain(cls = "ParamUty", grouping = "ParamUty", cargo = cargo, special_vals = special_vals, default = default, tags = tags, trafo = trafo, storage_type = "list", depends_expr = substitute(depends), init = init) } diff --git a/R/to_tune.R b/R/to_tune.R index f5247a70..ace4023e 100644 --- a/R/to_tune.R +++ b/R/to_tune.R @@ -41,12 +41,12 @@ #' The `TuneToken` object's internals are subject to change and should not be relied upon. `TuneToken` objects should #' only be constructed via `to_tune()`, and should only be used by giving them to `$values` of a [`ParamSet`]. #' @param ... if given, restricts the range to be tuning over, as described above. -#' @param inner (`logical(1)`)\cr -#' Whether to create an `InnerTuneToken`. -#' This is only available for parameters tagged with `"inner_tuning"`. +#' @param internal (`logical(1)`)\cr +#' Whether to create an `InternalTuneToken`. +#' This is only available for parameters tagged with `"internal_tuning"`. #' @param aggr (`function`)\cr #' Function with one argument, which is a list of parameter values and returns a single aggregated value (e.g. the mean). -#' This specifies how multiple parameter values are aggregated to form a single value in the context of inner tuning. +#' This specifies how multiple parameter values are aggregated to form a single value in the context of internal tuning. #' If none specified, the default aggregation function of the parameter will be used. #' @return A `TuneToken` object. #' @examples @@ -139,10 +139,10 @@ #' @family ParamSet construction helpers #' @aliases TuneToken #' @export -to_tune = function(..., inner = !is.null(aggr), aggr = NULL) { - assert_flag(inner) +to_tune = function(..., internal = !is.null(aggr), aggr = NULL) { + assert_flag(internal) if (!is.null(aggr)) { - assert_true(inner) + assert_true(internal) } assert_function(aggr, nargs = 1L, null.ok = TRUE) call = sys.call() @@ -192,14 +192,14 @@ to_tune = function(..., inner = !is.null(aggr), aggr = NULL) { content = list(logscale = FALSE) } - if (inner) { + if (internal) { if (type == "ObjectTuneToken") { - stop("Inner tuning can currently not be combined with ParamSet or Domain object, specify lower and upper bounds, e.g. to_tune(upper = 100)") + stop("Internal tuning can currently not be combined with ParamSet or Domain object, specify lower and upper bounds, e.g. to_tune(upper = 100)") } if (isTRUE(content$logscale)) { - stop("Cannot combine logscale transformation with inner tuning.") + stop("Cannot combine logscale transformation with internal tuning.") } - type = c("InnerTuneToken", type) + type = c("InternalTuneToken", type) content$aggr = aggr } @@ -213,8 +213,8 @@ print.FullTuneToken = function(x, ...) { } #' @export -print.InnerTuneToken = function(x, ...) { - cat("Inner ") +print.InternalTuneToken = function(x, ...) { + cat("Internal ") NextMethod() } @@ -258,14 +258,14 @@ tunetoken_to_ps.FullTuneToken = function(tt, param, ...) { } } -tunetoken_to_ps.InnerTuneToken = function(tt, param, ...) { - # Calling NextMethod with additional arguments behaves weirdly, as the InnerTuneToken only works with ranges right now +tunetoken_to_ps.InternalTuneToken = function(tt, param, ...) { + # Calling NextMethod with additional arguments behaves weirdly, as the InternalTuneToken only works with ranges right now # we just call it directly aggr = if (!is.null(tt$content$aggr)) tt$content$aggr else param$cargo[[1L]]$aggr if (is.null(aggr)) { stopf("%s must specify a aggregation function for parameter %s", tt$call, param$id) } - tunetoken_to_ps.RangeTuneToken(tt = tt, param = param, in_tune_fn = param$cargo[[1L]]$in_tune_fn, tags = "inner_tuning", + tunetoken_to_ps.RangeTuneToken(tt = tt, param = param, in_tune_fn = param$cargo[[1L]]$in_tune_fn, tags = "internal_tuning", aggr = aggr) } diff --git a/man/Domain.Rd b/man/Domain.Rd index 7043a77f..c446e1d4 100644 --- a/man/Domain.Rd +++ b/man/Domain.Rd @@ -150,12 +150,12 @@ Initial value. When this is given, then the corresponding entry in \code{ParamSe value upon construction.} \item{aggr}{(\code{function})\cr -Default aggregation function for a parameter. Can only be given for parameters tagged with \code{"inner_tuning"}. +Default aggregation function for a parameter. Can only be given for parameters tagged with \code{"internal_tuning"}. Function with one argument, which is a list of parameter values and that returns the aggregated parameter value.} \item{in_tune_fn}{(\verb{function(domain, param_set)})\cr Function that converters a \code{Domain} object into a parameter value. -Can onlye be given for parameters tagged with \code{"inner_tuning"}.} +Can onlye be given for parameters tagged with \code{"internal_tuning"}.} \item{levels}{(\code{character} | \code{atomic} | \code{list})\cr Allowed categorical values of the parameter. If this is not a \code{character}, then a \code{trafo} is generated that @@ -241,16 +241,16 @@ print(grid) print(grid$transpose()) -# inner tuning +# internal tuning param_set = ps( - iters = p_int(0, Inf, tags = "inner_tuning", aggr = function(x) round(mean(unlist(x))), + iters = p_int(0, Inf, tags = "internal_tuning", aggr = function(x) round(mean(unlist(x))), in_tune_fn = function(domain, param_set) domain$upper) ) param_set$set_values( - iters = to_tune(upper = 100, inner = TRUE) + iters = to_tune(upper = 100, internal = TRUE) ) -param_set$convert_inner_tune_tokens() +param_set$convert_internal_tune_tokens() param_set$aggr(list(iters = list(1, 2, 3))) } diff --git a/man/ParamSet.Rd b/man/ParamSet.Rd index beb2ca0f..fea27ca5 100644 --- a/man/ParamSet.Rd +++ b/man/ParamSet.Rd @@ -173,7 +173,7 @@ Named with param IDs.} \item \href{#method-ParamSet-set_values}{\code{ParamSet$set_values()}} \item \href{#method-ParamSet-trafo}{\code{ParamSet$trafo()}} \item \href{#method-ParamSet-aggr}{\code{ParamSet$aggr()}} -\item \href{#method-ParamSet-convert_inner_tune_tokens}{\code{ParamSet$convert_inner_tune_tokens()}} +\item \href{#method-ParamSet-convert_internal_tune_tokens}{\code{ParamSet$convert_internal_tune_tokens()}} \item \href{#method-ParamSet-test_constraint}{\code{ParamSet$test_constraint()}} \item \href{#method-ParamSet-test_constraint_dt}{\code{ParamSet$test_constraint_dt()}} \item \href{#method-ParamSet-check}{\code{ParamSet$check()}} @@ -364,12 +364,12 @@ The aggregation function is selected based on the parameter.} } } \if{html}{\out{
    }} -\if{html}{\out{}} -\if{latex}{\out{\hypertarget{method-ParamSet-convert_inner_tune_tokens}{}}} -\subsection{Method \code{convert_inner_tune_tokens()}}{ -Convert all \code{InnerTuneToken}s to parameter values as is defined by their \code{in_tune_fn}. +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ParamSet-convert_internal_tune_tokens}{}}} +\subsection{Method \code{convert_internal_tune_tokens()}}{ +Convert all \code{InternalTuneToken}s to parameter values as is defined by their \code{in_tune_fn}. \subsection{Usage}{ -\if{html}{\out{
    }}\preformatted{ParamSet$convert_inner_tune_tokens()}\if{html}{\out{
    }} +\if{html}{\out{
    }}\preformatted{ParamSet$convert_internal_tune_tokens()}\if{html}{\out{
    }} } \subsection{Returns}{ diff --git a/man/ParamSetCollection.Rd b/man/ParamSetCollection.Rd index 6ab971bf..7cc2f18e 100644 --- a/man/ParamSetCollection.Rd +++ b/man/ParamSetCollection.Rd @@ -85,7 +85,7 @@ This field provides direct references to the \code{\link{ParamSet}} objects.}
  • paradox::ParamSet$check()
  • paradox::ParamSet$check_dependencies()
  • paradox::ParamSet$check_dt()
  • -
  • paradox::ParamSet$convert_inner_tune_tokens()
  • +
  • paradox::ParamSet$convert_internal_tune_tokens()
  • paradox::ParamSet$flatten()
  • paradox::ParamSet$format()
  • paradox::ParamSet$get_domain()
  • diff --git a/man/to_tune.Rd b/man/to_tune.Rd index fb91d150..83ba090c 100644 --- a/man/to_tune.Rd +++ b/man/to_tune.Rd @@ -5,18 +5,18 @@ \alias{TuneToken} \title{Indicate that a Parameter Value should be Tuned} \usage{ -to_tune(..., inner = !is.null(aggr), aggr = NULL) +to_tune(..., internal = !is.null(aggr), aggr = NULL) } \arguments{ \item{...}{if given, restricts the range to be tuning over, as described above.} -\item{inner}{(\code{logical(1)})\cr -Whether to create an \code{InnerTuneToken}. -This is only available for parameters tagged with \code{"inner_tuning"}.} +\item{internal}{(\code{logical(1)})\cr +Whether to create an \code{InternalTuneToken}. +This is only available for parameters tagged with \code{"internal_tuning"}.} \item{aggr}{(\code{function})\cr Function with one argument, which is a list of parameter values and returns a single aggregated value (e.g. the mean). -This specifies how multiple parameter values are aggregated to form a single value in the context of inner tuning. +This specifies how multiple parameter values are aggregated to form a single value in the context of internal tuning. If none specified, the default aggregation function of the parameter will be used.} } \value{ diff --git a/tests/testthat/test_ParamSet.R b/tests/testthat/test_ParamSet.R index 4886036c..5892db6a 100644 --- a/tests/testthat/test_ParamSet.R +++ b/tests/testthat/test_ParamSet.R @@ -451,40 +451,40 @@ test_that("aggr", { expect_error(param_set$aggr(list(a = list(), b = list(), c = list(), d = list(), e = list())), "but there are no") }) -test_that("convert_inner_tune_tokens", { +test_that("convert_internal_tune_tokens", { param_set = ps( - a = p_int(lower = 1, upper = 100, tags = "inner_tuning", in_tune_fn = function(domain, param_set) domain$upper, + a = p_int(lower = 1, upper = 100, tags = "internal_tuning", in_tune_fn = function(domain, param_set) domain$upper, aggr = function(x) round(mean(unlist(x)))) ) - param_set$set_values(a = to_tune(inner = TRUE)) - expect_identical(param_set$convert_inner_tune_tokens(), list(a = 100)) - param_set$set_values(a = to_tune(inner = TRUE, upper = 99)) - expect_identical(param_set$convert_inner_tune_tokens(), list(a = 99)) + param_set$set_values(a = to_tune(internal = TRUE)) + expect_identical(param_set$convert_internal_tune_tokens(), list(a = 100)) + param_set$set_values(a = to_tune(internal = TRUE, upper = 99)) + expect_identical(param_set$convert_internal_tune_tokens(), list(a = 99)) - param_set$set_values(a = to_tune(inner = FALSE)) - expect_identical(param_set$convert_inner_tune_tokens(), named_list()) + param_set$set_values(a = to_tune(internal = FALSE)) + expect_identical(param_set$convert_internal_tune_tokens(), named_list()) }) -test_that("get_values works with inner_tune", { +test_that("get_values works with internal_tune", { param_set = ps( - a = p_int(lower = 1, upper = 100, tags = "inner_tuning", in_tune_fn = function(domain, param_set) domain$upper, + a = p_int(lower = 1, upper = 100, tags = "internal_tuning", in_tune_fn = function(domain, param_set) domain$upper, aggr = function(x) round(mean(unlist(x)))) ) - param_set$set_values(a = to_tune(inner = TRUE)) - expect_list(param_set$get_values(type = "with_inner"), len = 1L) + param_set$set_values(a = to_tune(internal = TRUE)) + expect_list(param_set$get_values(type = "with_internal"), len = 1L) param_set$set_values(a = to_tune()) - expect_list(param_set$get_values(type = "with_inner"), len = 0L) + expect_list(param_set$get_values(type = "with_internal"), len = 0L) }) -test_that("InnerTuneToken is translated to 'inner_tuning' tag when creating search space", { +test_that("InternalTuneToken is translated to 'internal_tuning' tag when creating search space", { param_set = ps( - a = p_int(0, Inf, tags = "inner_tuning", in_tune_fn = function(domain, param_set) domain$upper, aggr = function(x) round(mean(unlist(x)))) + a = p_int(0, Inf, tags = "internal_tuning", in_tune_fn = function(domain, param_set) domain$upper, aggr = function(x) round(mean(unlist(x)))) ) param_set$set_values( - a = to_tune(upper = 100, inner = TRUE) + a = to_tune(upper = 100, internal = TRUE) ) ss = param_set$search_space() - expect_true("inner_tuning" %in% ss$tags$a) + expect_true("internal_tuning" %in% ss$tags$a) }) diff --git a/tests/testthat/test_domain.R b/tests/testthat/test_domain.R index 9cdeae8d..4620a0ea 100644 --- a/tests/testthat/test_domain.R +++ b/tests/testthat/test_domain.R @@ -348,25 +348,25 @@ test_that("$extra_trafo flag works", { expect_false(search_space$has_extra_trafo) }) -test_that("inner", { - expect_error(to_tune(1, inner = TRUE), "can currently") - it = to_tune(upper = 1, inner = TRUE) - expect_class(it, "InnerTuneToken") +test_that("internal", { + expect_error(to_tune(1, internal = TRUE), "can currently") + it = to_tune(upper = 1, internal = TRUE) + expect_class(it, "InternalTuneToken") expect_class(it, "RangeTuneToken") expect_null(it$aggr) - it1 = to_tune(upper = 1, inner = TRUE) + it1 = to_tune(upper = 1, internal = TRUE) expect_equal(it1$content, it$content) it1 = to_tune(aggr = function(x) min(unlist(x))) expect_equal(it1$content$aggr(list(1, 2)), 1) param_set = ps( - a = p_dbl(1, 10, aggr = function(x) mean(unlist(x)), tags = "inner_tuning", in_tune_fn = function(domain, param_set) domain$upper) + a = p_dbl(1, 10, aggr = function(x) mean(unlist(x)), tags = "internal_tuning", in_tune_fn = function(domain, param_set) domain$upper) ) - param_set$set_values(a = to_tune(inner = TRUE, aggr = function(x) round(mean(unlist(x))))) - expect_class(param_set$values$a, "InnerTuneToken") - expect_error(param_set$set_values(a = to_tune(inner = TRUE, logscale = TRUE)), "Cannot combine") + param_set$set_values(a = to_tune(internal = TRUE, aggr = function(x) round(mean(unlist(x))))) + expect_class(param_set$values$a, "InternalTuneToken") + expect_error(param_set$set_values(a = to_tune(internal = TRUE, logscale = TRUE)), "Cannot combine") - expect_error(p_dbl(lower = 1, upper = 2, tags = "inner_tuning", "in_tune_fn")) + expect_error(p_dbl(lower = 1, upper = 2, tags = "internal_tuning", "in_tune_fn")) }) diff --git a/tests/testthat/test_to_tune.R b/tests/testthat/test_to_tune.R index 0c540302..f9a24a2f 100644 --- a/tests/testthat/test_to_tune.R +++ b/tests/testthat/test_to_tune.R @@ -396,67 +396,67 @@ test_that("logscale in tunetoken", { expect_output(print(to_tune(lower = 1, logscale = TRUE)), "range \\[1, \\.\\.\\.] \\(log scale\\)") expect_output(print(to_tune(upper = 1, logscale = TRUE)), "range \\[\\.\\.\\., 1] \\(log scale\\)") expect_output(print(to_tune(lower = 0, upper = 1, logscale = TRUE)), "range \\[0, 1] \\(log scale\\)") - expect_output(print(to_tune(inner = TRUE)), "Inner") + expect_output(print(to_tune(internal = TRUE)), "Internal") }) -test_that("inner and aggr", { +test_that("internal and aggr", { # no default aggregation function - param_set = ps(a = p_dbl(lower = 1, upper = 2, tags = "inner_tuning", in_tune_fn = function(domain, param_set) domain$upper)) + param_set = ps(a = p_dbl(lower = 1, upper = 2, tags = "internal_tuning", in_tune_fn = function(domain, param_set) domain$upper)) # correct errors - expect_error(param_set$set_values(a = to_tune(inner = TRUE)), "aggregation") - expect_error(param_set$set_values(a = to_tune(inner = FALSE, aggr = function(x) 1))) + expect_error(param_set$set_values(a = to_tune(internal = TRUE)), "aggregation") + expect_error(param_set$set_values(a = to_tune(internal = FALSE, aggr = function(x) 1))) - # full tune token + inner + # full tune token + internal expect_equal( param_set$set_values(a = to_tune(aggr = function(x) -99))$search_space()$aggr(list(a = list(1, 2, 3))), list(a = -99) ) - # logscale + inner: now allowed + # logscale + internal: now allowed expect_error( param_set$set_values(a = to_tune(logscale = TRUE, aggr = function(x) -99)), - "inner tuning" + "internal tuning" ) - # other trafos + inner: not allowed + # other trafos + internal: not allowed expect_error( param_set$set_values(a = to_tune(ps(a = p_dbl(0, 1), .extra_trafo = function(x) 1L), aggr = function(x) -99)), "can currently not be combined" ) - # range + inner + # range + internal param_set$set_values(a = to_tune(lower = 1.2, upper = 1.3, aggr = function(x) 1.5)) expect_equal(param_set$search_space()$aggr(list(a = list(1, 2))), list(a = 1.5)) - expect_equal(param_set$convert_inner_tune_tokens(), list(a = 1.3)) + expect_equal(param_set$convert_internal_tune_tokens(), list(a = 1.3)) - # full + inner - param_set$set_values(a = to_tune(inner = TRUE, aggr = function(x) 1.5)) - expect_equal(param_set$convert_inner_tune_tokens(), list(a = 2)) + # full + internal + param_set$set_values(a = to_tune(internal = TRUE, aggr = function(x) 1.5)) + expect_equal(param_set$convert_internal_tune_tokens(), list(a = 2)) - # domain + inner + # domain + internal expect_error( - param_set$set_values(a = to_tune(p_dbl(1.21, 1.22), aggr = function(x) 1.5, inner = TRUE)), + param_set$set_values(a = to_tune(p_dbl(1.21, 1.22), aggr = function(x) 1.5, internal = TRUE)), "specify lower and upper" ) ## with default aggregation function - # param set + inner - param_set = ps(a = p_int(lower = 1, upper = 10000, tags = "inner_tuning", in_tune_fn = function(domain, param_set) domain$upper, + # param set + internal + param_set = ps(a = p_int(lower = 1, upper = 10000, tags = "internal_tuning", in_tune_fn = function(domain, param_set) domain$upper, aggr = function(x) max(unlist(x)))) # default aggregation function is used when not overwritten param_set$set_values( - a = to_tune(inner = TRUE) + a = to_tune(internal = TRUE) ) expect_equal(param_set$search_space()$aggr(list(a = list(1, 2, 3))), list(a = 3)) # can overwrite existing aggregation function param_set$set_values( - a = to_tune(inner = TRUE, aggr = function(x) -60) + a = to_tune(internal = TRUE, aggr = function(x) -60) ) expect_equal(param_set$search_space()$aggr(list(a = list(1, 2, 3))), list(a = -60)) }) From 13dd1543e5536ec12d1d1c239d2a6a6c1e9b6ae9 Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Mon, 6 May 2024 11:24:31 +0200 Subject: [PATCH 18/34] fix bug --- R/ParamUty.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/ParamUty.R b/R/ParamUty.R index 0521cb88..094b6634 100644 --- a/R/ParamUty.R +++ b/R/ParamUty.R @@ -20,7 +20,7 @@ p_uty = function(custom_check = NULL, special_vals = list(), default = NO_DEF, t } cargo = list(custom_check = custom_check, repr = repr) cargo$aggr = aggr - cargo$internal_tune_fn = in_tune_fn + cargo$in_tune_fn = in_tune_fn Domain(cls = "ParamUty", grouping = "ParamUty", cargo = cargo, special_vals = special_vals, default = default, tags = tags, trafo = trafo, storage_type = "list", depends_expr = substitute(depends), init = init) } From 5fdb2a4c8c98806f9699e0ae1911708ec23a7e2c Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Wed, 29 May 2024 16:13:49 +0200 Subject: [PATCH 19/34] support disabling internal tuning --- R/Domain.R | 6 ++++-- R/ParamDbl.R | 6 ++++-- R/ParamFct.R | 4 +++- R/ParamInt.R | 6 ++++-- R/ParamLgl.R | 4 +++- R/ParamSet.R | 17 +++++++++++++++-- R/ParamSetCollection.R | 13 +++++++++++++ R/ParamUty.R | 4 +++- man/Domain.Rd | 21 +++++++++++++++------ man/ParamSet.Rd | 22 ++++++++++++++++++++++ man/ParamSetCollection.Rd | 10 ++++++++++ tests/testthat/test_Param.R | 14 ++++++++++++++ tests/testthat/test_ParamSet.R | 19 ++++++++++++++++--- tests/testthat/test_ParamSetCollection.R | 10 ++++++++++ 14 files changed, 136 insertions(+), 20 deletions(-) diff --git a/R/Domain.R b/R/Domain.R index f53c7ebf..425b5222 100644 --- a/R/Domain.R +++ b/R/Domain.R @@ -69,6 +69,9 @@ #' @param in_tune_fn (`function(domain, param_set)`)\cr #' Function that converters a `Domain` object into a parameter value. #' Can onlye be given for parameters tagged with `"internal_tuning"`. +#' @param disable_in_tune (named `list()`)\cr +#' The parameter values that need to be set in the `ParamSet` to disable the internal tuning for the parameter. +#' For `XGBoost` this would e.g. be `list(early_stopping_rounds = NULL)`. #' #' @return A `Domain` object. #' @@ -128,7 +131,7 @@ #' #' param_set = ps( #' iters = p_int(0, Inf, tags = "internal_tuning", aggr = function(x) round(mean(unlist(x))), -#' in_tune_fn = function(domain, param_set) domain$upper) +#' in_tune_fn = function(domain, param_set) domain$upper, disable_in_tune = list(other_param = FALSE)) #' ) #' param_set$set_values( #' iters = to_tune(upper = 100, internal = TRUE) @@ -244,7 +247,6 @@ print.Domain = function(x, ...) { if (!is.null(repr)) { print(repr) } else { - plural_rows = classes = class(x) if ("Domain" %in% classes) { domainidx = which("Domain" == classes)[[1]] diff --git a/R/ParamDbl.R b/R/ParamDbl.R index 9f3075c6..699450a9 100644 --- a/R/ParamDbl.R +++ b/R/ParamDbl.R @@ -1,7 +1,8 @@ #' @rdname Domain #' @export -p_dbl = function(lower = -Inf, upper = Inf, special_vals = list(), default = NO_DEF, tags = character(), tolerance = sqrt(.Machine$double.eps), depends = NULL, trafo = NULL, logscale = FALSE, init, aggr = NULL, in_tune_fn = NULL) { +p_dbl = function(lower = -Inf, upper = Inf, special_vals = list(), default = NO_DEF, tags = character(), tolerance = sqrt(.Machine$double.eps), depends = NULL, trafo = NULL, logscale = FALSE, init, aggr = NULL, in_tune_fn = NULL, disable_in_tune = NULL) { assert_function(aggr, null.ok = TRUE, nargs = 1L) + assert_list(disable_in_tune, null.ok = TRUE, names = "unique") if ("internal_tuning" %in% tags) { assert_function(in_tune_fn, null.ok = FALSE, args = c("domain", "param_set"), nargs = 2L) } else { @@ -26,7 +27,8 @@ p_dbl = function(lower = -Inf, upper = Inf, special_vals = list(), default = NO_ cargo = list() if (logscale) cargo$logscale = TRUE cargo$aggr = aggr - if (!is.null(in_tune_fn)) cargo$in_tune_fn = in_tune_fn + cargo$in_tune_fn = in_tune_fn + cargo$disable_in_tune = disable_in_tune Domain(cls = "ParamDbl", grouping = "ParamDbl", lower = real_lower, upper = real_upper, special_vals = special_vals, default = default, tags = tags, tolerance = tolerance, trafo = trafo, storage_type = "numeric", depends_expr = substitute(depends), init = init, cargo = if (length(cargo)) cargo) } diff --git a/R/ParamFct.R b/R/ParamFct.R index 422bcf61..e2863eef 100644 --- a/R/ParamFct.R +++ b/R/ParamFct.R @@ -1,9 +1,10 @@ #' @rdname Domain #' @export -p_fct = function(levels, special_vals = list(), default = NO_DEF, tags = character(), depends = NULL, trafo = NULL, init, aggr = NULL, in_tune_fn = NULL) { +p_fct = function(levels, special_vals = list(), default = NO_DEF, tags = character(), depends = NULL, trafo = NULL, init, aggr = NULL, in_tune_fn = NULL, disable_in_tune = NULL) { assert_function(aggr, null.ok = TRUE, nargs = 1L) constargs = as.list(match.call()[-1]) levels = eval.parent(constargs$levels) + assert_list(disable_in_tune, null.ok = TRUE, names = "unique") if ("internal_tuning" %in% tags) { assert_function(in_tune_fn, null.ok = FALSE, args = c("domain", "param_set"), nargs = 2L) } else { @@ -29,6 +30,7 @@ p_fct = function(levels, special_vals = list(), default = NO_DEF, tags = charact # We escape '"' and '\' to '\"' and '\\', respectively. cargo = c(aggr = aggr, in_tune_fn = in_tune_fn) cargo = if (length(cargo)) cargo + cargo$disable_in_tune = disable_in_tune grouping = str_collapse(gsub("([\\\\\"])", "\\\\\\1", sort(real_levels)), quote = '"', sep = ",") Domain(cls = "ParamFct", grouping = grouping, levels = real_levels, special_vals = special_vals, default = default, tags = tags, trafo = trafo, storage_type = "character", depends_expr = substitute(depends), init = init, cargo = cargo) } diff --git a/R/ParamInt.R b/R/ParamInt.R index 91f1c57b..2c9504ab 100644 --- a/R/ParamInt.R +++ b/R/ParamInt.R @@ -1,9 +1,10 @@ #' @rdname Domain #' @export -p_int = function(lower = -Inf, upper = Inf, special_vals = list(), default = NO_DEF, tags = character(), tolerance = sqrt(.Machine$double.eps), depends = NULL, trafo = NULL, logscale = FALSE, init, aggr = NULL, in_tune_fn = NULL) { +p_int = function(lower = -Inf, upper = Inf, special_vals = list(), default = NO_DEF, tags = character(), tolerance = sqrt(.Machine$double.eps), depends = NULL, trafo = NULL, logscale = FALSE, init, aggr = NULL, in_tune_fn = NULL, disable_in_tune = NULL) { assert_function(aggr, null.ok = TRUE, nargs = 1L) assert_number(tolerance, lower = 0, upper = 0.5) + assert_list(disable_in_tune, null.ok = TRUE, names = "unique") if ("internal_tuning" %in% tags) { assert_function(in_tune_fn, null.ok = FALSE, args = c("domain", "param_set"), nargs = 2L) } else { @@ -35,7 +36,8 @@ p_int = function(lower = -Inf, upper = Inf, special_vals = list(), default = NO_ cargo = list() if (logscale) cargo$logscale = TRUE cargo$aggr = aggr - if (!is.null(in_tune_fn)) cargo$in_tune_fn = in_tune_fn + cargo$in_tune_fn = in_tune_fn + cargo$disable_in_tune = disable_in_tune Domain(cls = cls, grouping = cls, lower = real_lower, upper = real_upper, special_vals = special_vals, default = default, tags = tags, tolerance = tolerance, trafo = trafo, storage_type = storage_type, diff --git a/R/ParamLgl.R b/R/ParamLgl.R index 1a6330af..3a53fe8c 100644 --- a/R/ParamLgl.R +++ b/R/ParamLgl.R @@ -1,7 +1,8 @@ #' @rdname Domain #' @export -p_lgl = function(special_vals = list(), default = NO_DEF, tags = character(), depends = NULL, trafo = NULL, init, aggr = NULL, in_tune_fn = NULL) { +p_lgl = function(special_vals = list(), default = NO_DEF, tags = character(), depends = NULL, trafo = NULL, init, aggr = NULL, in_tune_fn = NULL, disable_in_tune = NULL) { assert_function(aggr, null.ok = TRUE, nargs = 1L) + assert_list(disable_in_tune, null.ok = TRUE, names = "unique") if ("internal_tuning" %in% tags) { assert_function(in_tune_fn, null.ok = FALSE, args = c("domain", "param_set"), nargs = 2L) } else { @@ -9,6 +10,7 @@ p_lgl = function(special_vals = list(), default = NO_DEF, tags = character(), de } cargo = c(aggr = aggr, in_tune_fn = in_tune_fn) + cargo$disable_in_tune = disable_in_tune cargo = if (length(cargo)) cargo Domain(cls = "ParamLgl", grouping = "ParamLgl", levels = c(TRUE, FALSE), special_vals = special_vals, default = default, tags = tags, trafo = trafo, storage_type = "logical", depends_expr = substitute(depends), init = init, cargo = cargo) diff --git a/R/ParamSet.R b/R/ParamSet.R index 9cf8d0c4..0f2cd4db 100644 --- a/R/ParamSet.R +++ b/R/ParamSet.R @@ -270,8 +270,8 @@ ParamSet = R6Class("ParamSet", aggr = function(x) { assert_list(x, types = "list") aggrs = private$.params[map_lgl(get("cargo"), function(cargo) is.function(cargo$aggr)), list(id = get("id"), aggr = map(get("cargo"), "aggr"))] - assert_permutation(names(x), aggrs$id) - if (!nrow(aggrs)) { + assert_subset(names(x), aggrs$id) + if (!length(x)) { return(named_list()) } imap(x, function(value, .id) { @@ -282,6 +282,19 @@ ParamSet = R6Class("ParamSet", }) }, + #' @description + #' + #' Get the parameter values that disable internal tuning for those parameters passed as `ids`. + #' + #' @param ids (`character()`)\cr + #' The ids of the parameters for which to disable internal tuning. + #' @return (named `list()`) + disable_internal_tuning = function(ids) { + assert_subset(ids, self$ids()) + pvs = Reduce(c, map(private$.params[ids, "cargo", on = "id"][[1]], "disable_in_tune")) + self$set_values(.values = pvs) + }, + #' @description #' Convert all `InternalTuneToken`s to parameter values as is defined by their `in_tune_fn`. #' diff --git a/R/ParamSetCollection.R b/R/ParamSetCollection.R index 2d604245..9ef1d501 100644 --- a/R/ParamSetCollection.R +++ b/R/ParamSetCollection.R @@ -148,6 +148,19 @@ ParamSetCollection = R6Class("ParamSetCollection", inherit = ParamSet, entry = if (n == "") length(private$.sets) + 1 else n private$.sets[[n]] = p invisible(self) + }, + disable_internal_tuning = function(ids) { + assert_subset(ids, self$ids()) + + pvs = Reduce(c, map(ids, function(id_) { + info = private$.translation[id_, c("original_id", "owner_name"), on = "id"] + xs = get_private(private$.sets[[info$owner_name]])$.params[info$original_id, "cargo", on = "id"][[1L]][[1]]$disable_in_tune + + if (info$owner_name == "" || is.null(xs)) return(xs) + + set_names(xs, paste0(info$owner_name, ".", names(xs))) + })) %??% named_list() + self$set_values(.values = pvs) } ), diff --git a/R/ParamUty.R b/R/ParamUty.R index 094b6634..01a418ce 100644 --- a/R/ParamUty.R +++ b/R/ParamUty.R @@ -1,8 +1,9 @@ #' @rdname Domain #' @export -p_uty = function(custom_check = NULL, special_vals = list(), default = NO_DEF, tags = character(), depends = NULL, trafo = NULL, repr = substitute(default), init, aggr = NULL, in_tune_fn = NULL) { +p_uty = function(custom_check = NULL, special_vals = list(), default = NO_DEF, tags = character(), depends = NULL, trafo = NULL, repr = substitute(default), init, aggr = NULL, in_tune_fn = NULL, disable_in_tune = NULL) { assert_function(custom_check, null.ok = TRUE) + assert_list(disable_in_tune, null.ok = TRUE, names = "unique") if ("internal_tuning" %in% tags) { assert_function(in_tune_fn, null.ok = FALSE, args = c("domain", "param_set"), nargs = 2L) } else { @@ -21,6 +22,7 @@ p_uty = function(custom_check = NULL, special_vals = list(), default = NO_DEF, t cargo = list(custom_check = custom_check, repr = repr) cargo$aggr = aggr cargo$in_tune_fn = in_tune_fn + cargo$disable_in_tune = disable_in_tune Domain(cls = "ParamUty", grouping = "ParamUty", cargo = cargo, special_vals = special_vals, default = default, tags = tags, trafo = trafo, storage_type = "list", depends_expr = substitute(depends), init = init) } diff --git a/man/Domain.Rd b/man/Domain.Rd index c446e1d4..2f716e53 100644 --- a/man/Domain.Rd +++ b/man/Domain.Rd @@ -22,7 +22,8 @@ p_dbl( logscale = FALSE, init, aggr = NULL, - in_tune_fn = NULL + in_tune_fn = NULL, + disable_in_tune = NULL ) p_fct( @@ -34,7 +35,8 @@ p_fct( trafo = NULL, init, aggr = NULL, - in_tune_fn = NULL + in_tune_fn = NULL, + disable_in_tune = NULL ) p_int( @@ -49,7 +51,8 @@ p_int( logscale = FALSE, init, aggr = NULL, - in_tune_fn = NULL + in_tune_fn = NULL, + disable_in_tune = NULL ) p_lgl( @@ -60,7 +63,8 @@ p_lgl( trafo = NULL, init, aggr = NULL, - in_tune_fn = NULL + in_tune_fn = NULL, + disable_in_tune = NULL ) p_uty( @@ -73,7 +77,8 @@ p_uty( repr = substitute(default), init, aggr = NULL, - in_tune_fn = NULL + in_tune_fn = NULL, + disable_in_tune = NULL ) } \arguments{ @@ -157,6 +162,10 @@ Function with one argument, which is a list of parameter values and that returns Function that converters a \code{Domain} object into a parameter value. Can onlye be given for parameters tagged with \code{"internal_tuning"}.} +\item{disable_in_tune}{(named \code{list()})\cr +The parameter values that need to be set in the \code{ParamSet} to disable the internal tuning for the parameter. +For \code{XGBoost} this would e.g. be \code{list(early_stopping_rounds = NULL)}.} + \item{levels}{(\code{character} | \code{atomic} | \code{list})\cr Allowed categorical values of the parameter. If this is not a \code{character}, then a \code{trafo} is generated that converts the names (if not given: \code{as.character()} of the values) of the \code{levels} argument to the values. @@ -245,7 +254,7 @@ print(grid$transpose()) param_set = ps( iters = p_int(0, Inf, tags = "internal_tuning", aggr = function(x) round(mean(unlist(x))), - in_tune_fn = function(domain, param_set) domain$upper) + in_tune_fn = function(domain, param_set) domain$upper, disable_in_tune = list(other_param = FALSE)) ) param_set$set_values( iters = to_tune(upper = 100, internal = TRUE) diff --git a/man/ParamSet.Rd b/man/ParamSet.Rd index fea27ca5..d23ecb2a 100644 --- a/man/ParamSet.Rd +++ b/man/ParamSet.Rd @@ -173,6 +173,7 @@ Named with param IDs.} \item \href{#method-ParamSet-set_values}{\code{ParamSet$set_values()}} \item \href{#method-ParamSet-trafo}{\code{ParamSet$trafo()}} \item \href{#method-ParamSet-aggr}{\code{ParamSet$aggr()}} +\item \href{#method-ParamSet-disable_internal_tuning}{\code{ParamSet$disable_internal_tuning()}} \item \href{#method-ParamSet-convert_internal_tune_tokens}{\code{ParamSet$convert_internal_tune_tokens()}} \item \href{#method-ParamSet-test_constraint}{\code{ParamSet$test_constraint()}} \item \href{#method-ParamSet-test_constraint_dt}{\code{ParamSet$test_constraint_dt()}} @@ -364,6 +365,27 @@ The aggregation function is selected based on the parameter.} } } \if{html}{\out{
    }} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ParamSet-disable_internal_tuning}{}}} +\subsection{Method \code{disable_internal_tuning()}}{ +Get the parameter values that disable internal tuning for those parameters passed as \code{ids}. +\subsection{Usage}{ +\if{html}{\out{
    }}\preformatted{ParamSet$disable_internal_tuning(ids)}\if{html}{\out{
    }} +} + +\subsection{Arguments}{ +\if{html}{\out{
    }} +\describe{ +\item{\code{ids}}{(\code{character()})\cr +The ids of the parameters for which to disable internal tuning.} +} +\if{html}{\out{
    }} +} +\subsection{Returns}{ +(named \code{list()}) +} +} +\if{html}{\out{
    }} \if{html}{\out{}} \if{latex}{\out{\hypertarget{method-ParamSet-convert_internal_tune_tokens}{}}} \subsection{Method \code{convert_internal_tune_tokens()}}{ diff --git a/man/ParamSetCollection.Rd b/man/ParamSetCollection.Rd index 7cc2f18e..5e2cb6f6 100644 --- a/man/ParamSetCollection.Rd +++ b/man/ParamSetCollection.Rd @@ -72,6 +72,7 @@ This field provides direct references to the \code{\link{ParamSet}} objects.} \itemize{ \item \href{#method-ParamSetCollection-new}{\code{ParamSetCollection$new()}} \item \href{#method-ParamSetCollection-add}{\code{ParamSetCollection$add()}} +\item \href{#method-ParamSetCollection-disable_internal_tuning}{\code{ParamSetCollection$disable_internal_tuning()}} \item \href{#method-ParamSetCollection-clone}{\code{ParamSetCollection$clone()}} } } @@ -155,6 +156,15 @@ Whether to add tags of the form \code{"param_"} to each parameter with } \if{html}{\out{}} } +} +\if{html}{\out{
    }} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ParamSetCollection-disable_internal_tuning}{}}} +\subsection{Method \code{disable_internal_tuning()}}{ +\subsection{Usage}{ +\if{html}{\out{
    }}\preformatted{ParamSetCollection$disable_internal_tuning(ids)}\if{html}{\out{
    }} +} + } \if{html}{\out{
    }} \if{html}{\out{}} diff --git a/tests/testthat/test_Param.R b/tests/testthat/test_Param.R index 68bcf90d..24c1d1f0 100644 --- a/tests/testthat/test_Param.R +++ b/tests/testthat/test_Param.R @@ -54,3 +54,17 @@ test_that("special_vals work for all Param subclasses", { test_that("we cannot create Params with non-strict R names", { expect_error(ParamInt$new(id = "$foo"), "does not comply") }) + +test_that("disable_in_tune works for all Param subclassedisable_in_tune$s", { + itfn = function(domain, param_set) 1L + expect_equal( + p_uty(tags = "internal_tuning", in_tune_fn = itfn, disable_in_tune = list(a = 1))$cargo[[1]]$disable_in_tune$a, 1) + expect_equal( + p_lgl(tags = "internal_tuning", in_tune_fn = itfn, disable_in_tune = list(a = 1))$cargo[[1]]$disable_in_tune$a, 1) + expect_equal( + p_int(tags = "internal_tuning", in_tune_fn = itfn, disable_in_tune = list(a = 1))$cargo[[1]]$disable_in_tune$a, 1) + expect_equal( + p_lgl(tags = "internal_tuning", in_tune_fn = itfn, disable_in_tune = list(a = 1))$cargo[[1]]$disable_in_tune$a, 1) + expect_equal( + p_uty(tags = "internal_tuning", in_tune_fn = itfn, disable_in_tune = list(a = 1))$cargo[[1]]$disable_in_tune$a, 1) +}) diff --git a/tests/testthat/test_ParamSet.R b/tests/testthat/test_ParamSet.R index 5892db6a..f4503fe7 100644 --- a/tests/testthat/test_ParamSet.R +++ b/tests/testthat/test_ParamSet.R @@ -447,14 +447,14 @@ test_that("aggr", { expect_error(param_set$aggr(1), "list") expect_error(param_set$aggr(list(1)), "list") - expect_error(param_set$aggr(list(a = list(), b = list(), c = list(), d = list())), "permutation") + expect_error(param_set$aggr(list(y = list())), "subset") expect_error(param_set$aggr(list(a = list(), b = list(), c = list(), d = list(), e = list())), "but there are no") }) test_that("convert_internal_tune_tokens", { param_set = ps( a = p_int(lower = 1, upper = 100, tags = "internal_tuning", in_tune_fn = function(domain, param_set) domain$upper, - aggr = function(x) round(mean(unlist(x)))) + aggr = function(x) round(mean(unlist(x))), disable_in_tune = list(a = 1)) ) param_set$set_values(a = to_tune(internal = TRUE)) expect_identical(param_set$convert_internal_tune_tokens(), list(a = 100)) @@ -468,7 +468,7 @@ test_that("convert_internal_tune_tokens", { test_that("get_values works with internal_tune", { param_set = ps( a = p_int(lower = 1, upper = 100, tags = "internal_tuning", in_tune_fn = function(domain, param_set) domain$upper, - aggr = function(x) round(mean(unlist(x)))) + aggr = function(x) round(mean(unlist(x))), disable_in_tune = list(a = 1)) ) param_set$set_values(a = to_tune(internal = TRUE)) expect_list(param_set$get_values(type = "with_internal"), len = 1L) @@ -488,3 +488,16 @@ test_that("InternalTuneToken is translated to 'internal_tuning' tag when creatin ss = param_set$search_space() expect_true("internal_tuning" %in% ss$tags$a) }) + +test_that("disable internal tuning", { + param_set = ps( + a = p_dbl(tags = "internal_tuning", in_tune_fn = function(domain, param_set) domain$upper, disable_in_tune = list(b = FALSE)), + b = p_lgl() + ) + + expect_equal(param_set$values$b, NULL) + param_set$disable_internal_tuning("a") + expect_equal(param_set$values$b, FALSE) + + expect_error(param_set$disable_internal_tuning("c")) +}) diff --git a/tests/testthat/test_ParamSetCollection.R b/tests/testthat/test_ParamSetCollection.R index 91325115..d2b5be33 100644 --- a/tests/testthat/test_ParamSetCollection.R +++ b/tests/testthat/test_ParamSetCollection.R @@ -236,3 +236,13 @@ test_that("set_id inference in values assignment works now", { expect_error(ParamSetCollection$new(list(a = pscol1, pstest)), "duplicated parameter.* a\\.c\\.paramc") }) + +test_that("disable internal tuning works", { + param_set = psc(prefix = ps( + a = p_dbl(tags = "internal_tuning", in_tune_fn = function(domain, param_set) domain$upper, disable_in_tune = list(b = FALSE)), + b = p_lgl() + )) + + param_set$disable_internal_tuning("prefix.a") + expect_equal(param_set$values$prefix.b, FALSE) +}) From 776b318514e040d0215903e6bbe016a44d0ed721 Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Wed, 29 May 2024 16:26:42 +0200 Subject: [PATCH 20/34] docs, edge cases --- R/ParamSet.R | 66 ++++++++++++++---------- R/ParamSetCollection.R | 9 +++- man/ParamSet.Rd | 4 +- man/ParamSetCollection.Rd | 12 +++++ tests/testthat/test_ParamSet.R | 1 + tests/testthat/test_ParamSetCollection.R | 3 ++ 6 files changed, 65 insertions(+), 30 deletions(-) diff --git a/R/ParamSet.R b/R/ParamSet.R index 0f2cd4db..0d6f98bc 100644 --- a/R/ParamSet.R +++ b/R/ParamSet.R @@ -92,7 +92,7 @@ ParamSet = R6Class("ParamSet", if (".requirements" %in% names(paramtbl)) { requirements = paramtbl$.requirements - private$.params = paramtbl # self$add_dep needs this + private$.params = paramtbl # self$add_dep needs this for (row in seq_len(nrow(paramtbl))) { for (req in requirements[[row]]) { invoke(self$add_dep, id = paramtbl$id[[row]], allow_dangling_dependencies = allow_dangling_dependencies, @@ -107,7 +107,7 @@ ParamSet = R6Class("ParamSet", setindexv(paramtbl, c("id", "cls", "grouping")) - private$.params = paramtbl # I am 99% sure this is not necessary, but maybe set() creates a copy when deleting too many cols? + private$.params = paramtbl # I am 99% sure this is not necessary, but maybe set() creates a copy when deleting too many cols? if (!is.null(initvalues)) self$values = initvalues }, @@ -284,14 +284,14 @@ ParamSet = R6Class("ParamSet", #' @description #' - #' Get the parameter values that disable internal tuning for those parameters passed as `ids`. + #' Set the parameter values so that internal tuning for the selected parameters is disabled. #' #' @param ids (`character()`)\cr #' The ids of the parameters for which to disable internal tuning. - #' @return (named `list()`) + #' @return `Self` disable_internal_tuning = function(ids) { - assert_subset(ids, self$ids()) - pvs = Reduce(c, map(private$.params[ids, "cargo", on = "id"][[1]], "disable_in_tune")) + assert_subset(ids, self$ids(tags = "internal_tuning")) + pvs = Reduce(c, map(private$.params[ids, "cargo", on = "id"][[1]], "disable_in_tune")) %??% named_list() self$set_values(.values = pvs) }, @@ -377,7 +377,9 @@ ParamSet = R6Class("ParamSet", private$get_tune_ps(xs) TRUE }, error = function(e) paste("tune token invalid:", conditionMessage(e))) - if (!isTRUE(tunecheck)) return(tunecheck) + if (!isTRUE(tunecheck)) { + return(tunecheck) + } } xs_internaltune = keep(xs, is, "InternalTuneToken") @@ -412,7 +414,9 @@ ParamSet = R6Class("ParamSet", ## if (length(required) > 0L) { ## return(sprintf("Missing required parameters: %s", str_collapse(required))) ## } - if (!self$test_constraint(xs, assert_value = FALSE)) return(sprintf("Constraint not fulfilled.")) + if (!self$test_constraint(xs, assert_value = FALSE)) { + return(sprintf("Constraint not fulfilled.")) + } return(self$check_dependencies(xs)) } @@ -427,29 +431,37 @@ ParamSet = R6Class("ParamSet", #' @return If successful `TRUE`, if not a string with an error message. check_dependencies = function(xs) { deps = self$deps - if (!nrow(deps)) return(TRUE) + if (!nrow(deps)) { + return(TRUE) + } params = private$.params ns = names(xs) errors = pmap(deps[id %in% ns], function(id, on, cond) { onval = xs[[on]] - if (inherits(xs[[id]], "TuneToken") || inherits(onval, "TuneToken")) return(NULL) + if (inherits(xs[[id]], "TuneToken") || inherits(onval, "TuneToken")) { + return(NULL) + } # we are ONLY ok if: # - if 'id' is there, then 'on' must be there, and cond must be true # - if 'id' is not there. but that is skipped (deps[id %in% ns] filter) - if (on %in% ns && condition_test(cond, onval)) return(NULL) + if (on %in% ns && condition_test(cond, onval)) { + return(NULL) + } msg = sprintf("%s: can only be set if the following condition is met '%s'.", id, condition_as_string(cond, on)) if (is.null(onval)) { msg = sprintf(paste("%s Instead the parameter value for '%s' is not set at all.", - "Try setting '%s' to a value that satisfies the condition"), msg, on, on) + "Try setting '%s' to a value that satisfies the condition"), msg, on, on) } else { msg = sprintf("%s Instead the current parameter value is: %s == %s", msg, on, as_short_string(onval)) } msg }) errors = unlist(errors) - if (!length(errors)) return(TRUE) + if (!length(errors)) { + return(TRUE) + } str_collapse(errors, sep = "\n") }, @@ -480,7 +492,7 @@ ParamSet = R6Class("ParamSet", #' Name of the checked object to print in error messages.\cr #' Defaults to the heuristic implemented in [vname][checkmate::vname]. #' @return If successful `xs` invisibly, if not an error message. - assert = function(xs, check_strict = TRUE, .var.name = vname(xs)) makeAssertion(xs, self$check(xs, check_strict = check_strict), .var.name, NULL), # nolint + assert = function(xs, check_strict = TRUE, .var.name = vname(xs)) makeAssertion(xs, self$check(xs, check_strict = check_strict), .var.name, NULL), # nolint #' @description #' \pkg{checkmate}-like check-function. Takes a [data.table::data.table] @@ -568,10 +580,9 @@ ParamSet = R6Class("ParamSet", paramrow[, `:=`( .tags = list(private$.tags[id, tag, nomatch = 0]), .trafo = private$.trafos[id, trafo], - .requirements = list(if (nrow(depstbl)) transpose_list(depstbl)), # NULL if no deps + .requirements = list(if (nrow(depstbl)) transpose_list(depstbl)), # NULL if no deps .init_given = id %in% names(vals), - .init = unname(vals[id])) - ] + .init = unname(vals[id]))] set_class(paramrow, c(paramrow$cls, "Domain", class(paramrow))) }, @@ -596,7 +607,7 @@ ParamSet = R6Class("ParamSet", pids_not_there = setdiff(parents, ids) if (length(pids_not_there) > 0L) { stopf(paste0("Subsetting so that dependencies on params exist which would be gone: %s.", - "\nIf you still want to subset, set allow_dangling_dependencies to TRUE."), str_collapse(pids_not_there)) + "\nIf you still want to subset, set allow_dangling_dependencies to TRUE."), str_collapse(pids_not_there)) } } result = ParamSet$new() @@ -652,7 +663,7 @@ ParamSet = R6Class("ParamSet", assert_list(values) assert_names(names(values), subset.of = self$ids()) pars = private$get_tune_ps(values) - on = NULL # pacify static code check + on = NULL # pacify static code check dangling_deps = pars$deps[!pars$ids(), on = "on"] if (nrow(dangling_deps)) { stopf("Dangling dependencies not allowed: Dependencies on %s dangling.", str_collapse(dangling_deps$on)) @@ -677,7 +688,7 @@ ParamSet = R6Class("ParamSet", stopf("A param cannot depend on itself!") } - if (on %in% ids) { # not necessarily true when allow_dangling_dependencies + if (on %in% ids) { # not necessarily true when allow_dangling_dependencies feasible_on_values = map_lgl(cond$rhs, function(x) domain_test(self$get_domain(on), list(x))) if (any(!feasible_on_values)) { stopf("Condition has infeasible values for %s: %s", on, str_collapse(cond$rhs[!feasible_on_values])) @@ -745,7 +756,7 @@ ParamSet = R6Class("ParamSet", } if (length(xs) == 0L) { xs = named_list() - } else if (self$assert_values) { # this only makes sense when we have asserts on + } else if (self$assert_values) { # this only makes sense when we have asserts on # convert all integer params really to storage type int, move doubles to within bounds etc. # solves issue #293, #317 nontt = discard(xs, inherits, "TuneToken") @@ -836,7 +847,7 @@ ParamSet = R6Class("ParamSet", assert_character(v$on, any.missing = FALSE) assert_list(v$cond, types = "Condition", any.missing = FALSE) } else { - v = data.table(id = character(0), on = character(0), cond = list()) # make sure we have the right columns + v = data.table(id = character(0), on = character(0), cond = list()) # make sure we have the right columns } private$.deps = v } @@ -943,7 +954,9 @@ ParamSet = R6Class("ParamSet", get_tune_ps = function(values) { values = keep(values, inherits, "TuneToken") - if (!length(values)) return(ParamSet$new()) + if (!length(values)) { + return(ParamSet$new()) + } params = map(names(values), function(pn) { domain = private$.params[pn, on = "id"] set_class(domain, c(domain$cls, "Domain", class(domain))) @@ -953,13 +966,13 @@ ParamSet = R6Class("ParamSet", # package-internal S3 fails if we don't call the function indirectly here partsets = pmap(list(values, params), function(...) tunetoken_to_ps(...)) - pars = ps_union(partsets) # partsets does not have names here, wihch is what we want. + pars = ps_union(partsets) # partsets does not have names here, wihch is what we want. names(partsets) = names(values) idmapping = map(partsets, function(x) x$ids()) # only add the dependencies that are also in the tuning PS - on = id = NULL # pacify static code check + on = id = NULL # pacify static code check pmap(self$deps[id %in% names(idmapping) & on %in% names(partsets), c("on", "id", "cond")], function(on, id, cond) { onpar = partsets[[on]] if (onpar$has_trafo || !identical(onpar$ids(), on)) { @@ -1029,7 +1042,7 @@ rd_info.ParamSet = function(obj, descriptions = character(), ...) { # nolint is_default = map_lgl(params$default, inherits, "NoDefault") is_uty = params$storage_type == "list" set(params, i = which(is_uty & !is_default), j = "default", - value = map(cargo[!is_default & is_uty], function(x) x$repr)) + value = map(cargo[!is_default & is_uty], function(x) x$repr)) set(params, i = which(is_uty), j = "storage_type", value = list("untyped")) set(params, i = which(is_default), j = "default", value = list("-")) @@ -1047,4 +1060,3 @@ rd_info.ParamSet = function(obj, descriptions = character(), ...) { # nolint x = c("", knitr::kable(params, col.names = capitalize(names(params)))) paste(x, collapse = "\n") } - diff --git a/R/ParamSetCollection.R b/R/ParamSetCollection.R index 9ef1d501..4702b3b1 100644 --- a/R/ParamSetCollection.R +++ b/R/ParamSetCollection.R @@ -149,8 +149,15 @@ ParamSetCollection = R6Class("ParamSetCollection", inherit = ParamSet, private$.sets[[n]] = p invisible(self) }, + #' @description + #' + #' Set the parameter values so that internal tuning for the selected parameters is disabled. + #' + #' @param ids (`character()`)\cr + #' The ids of the parameters for which to disable internal tuning. + #' @return `Self` disable_internal_tuning = function(ids) { - assert_subset(ids, self$ids()) + assert_subset(ids, self$ids(tags = "internal_tuning")) pvs = Reduce(c, map(ids, function(id_) { info = private$.translation[id_, c("original_id", "owner_name"), on = "id"] diff --git a/man/ParamSet.Rd b/man/ParamSet.Rd index d23ecb2a..6410ade3 100644 --- a/man/ParamSet.Rd +++ b/man/ParamSet.Rd @@ -368,7 +368,7 @@ The aggregation function is selected based on the parameter.} \if{html}{\out{}} \if{latex}{\out{\hypertarget{method-ParamSet-disable_internal_tuning}{}}} \subsection{Method \code{disable_internal_tuning()}}{ -Get the parameter values that disable internal tuning for those parameters passed as \code{ids}. +Set the parameter values so that internal tuning for the selected parameters is disabled. \subsection{Usage}{ \if{html}{\out{
    }}\preformatted{ParamSet$disable_internal_tuning(ids)}\if{html}{\out{
    }} } @@ -382,7 +382,7 @@ The ids of the parameters for which to disable internal tuning.} \if{html}{\out{}} } \subsection{Returns}{ -(named \code{list()}) +\code{Self} } } \if{html}{\out{
    }} diff --git a/man/ParamSetCollection.Rd b/man/ParamSetCollection.Rd index 5e2cb6f6..a372f08e 100644 --- a/man/ParamSetCollection.Rd +++ b/man/ParamSetCollection.Rd @@ -161,10 +161,22 @@ Whether to add tags of the form \code{"param_"} to each parameter with \if{html}{\out{}} \if{latex}{\out{\hypertarget{method-ParamSetCollection-disable_internal_tuning}{}}} \subsection{Method \code{disable_internal_tuning()}}{ +Set the parameter values so that internal tuning for the selected parameters is disabled. \subsection{Usage}{ \if{html}{\out{
    }}\preformatted{ParamSetCollection$disable_internal_tuning(ids)}\if{html}{\out{
    }} } +\subsection{Arguments}{ +\if{html}{\out{
    }} +\describe{ +\item{\code{ids}}{(\code{character()})\cr +The ids of the parameters for which to disable internal tuning.} +} +\if{html}{\out{
    }} +} +\subsection{Returns}{ +\code{Self} +} } \if{html}{\out{
    }} \if{html}{\out{}} diff --git a/tests/testthat/test_ParamSet.R b/tests/testthat/test_ParamSet.R index f4503fe7..cf821d9e 100644 --- a/tests/testthat/test_ParamSet.R +++ b/tests/testthat/test_ParamSet.R @@ -500,4 +500,5 @@ test_that("disable internal tuning", { expect_equal(param_set$values$b, FALSE) expect_error(param_set$disable_internal_tuning("c")) + expect_error(param_set$disable_internal_tuning("b")) }) diff --git a/tests/testthat/test_ParamSetCollection.R b/tests/testthat/test_ParamSetCollection.R index d2b5be33..426ec5f2 100644 --- a/tests/testthat/test_ParamSetCollection.R +++ b/tests/testthat/test_ParamSetCollection.R @@ -245,4 +245,7 @@ test_that("disable internal tuning works", { param_set$disable_internal_tuning("prefix.a") expect_equal(param_set$values$prefix.b, FALSE) + expect_error(param_set$disable_internal_tuning("b")) + + expect_equal(named_list(), psc(ps())$disable_internal_tuning(character(0))$values) }) From 773d692728fd00041b0bfb11c73bbde515f152da Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Thu, 30 May 2024 08:20:07 +0200 Subject: [PATCH 21/34] bugfix --- R/ParamSet.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/ParamSet.R b/R/ParamSet.R index 0d6f98bc..38735053 100644 --- a/R/ParamSet.R +++ b/R/ParamSet.R @@ -300,7 +300,7 @@ ParamSet = R6Class("ParamSet", #' #' @return (named `list()`) convert_internal_tune_tokens = function() { - internal_tune_tokens = self$get_values(type = "with_internal") + internal_tune_tokens = self$get_values(type = "with_internal", check_required = FALSE) internal_tune_ps = private$get_tune_ps(internal_tune_tokens) imap(internal_tune_ps$domains, function(token, .id) { From 1b391f7d10b5aa59a00317854604314275a145f3 Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Thu, 30 May 2024 10:51:11 +0200 Subject: [PATCH 22/34] some more changes --- R/Domain.R | 5 +++-- R/ParamSet.R | 15 +++++++-------- man/Domain.Rd | 5 +++-- man/ParamSet.Rd | 21 +++++++++++++++------ man/ParamSetCollection.Rd | 2 +- tests/testthat/test_ParamSet.R | 9 +++------ tests/testthat/test_to_tune.R | 4 ++-- 7 files changed, 34 insertions(+), 27 deletions(-) diff --git a/R/Domain.R b/R/Domain.R index 425b5222..c1201c05 100644 --- a/R/Domain.R +++ b/R/Domain.R @@ -131,12 +131,13 @@ #' #' param_set = ps( #' iters = p_int(0, Inf, tags = "internal_tuning", aggr = function(x) round(mean(unlist(x))), -#' in_tune_fn = function(domain, param_set) domain$upper, disable_in_tune = list(other_param = FALSE)) +#' in_tune_fn = function(domain, param_set) domain$upper, +#' disable_in_tune = list(other_param = FALSE)) #' ) #' param_set$set_values( #' iters = to_tune(upper = 100, internal = TRUE) #' ) -#' param_set$convert_internal_tune_tokens() +#' param_set$convert_internal_search_space(param_set$search_space()) #' param_set$aggr(list(iters = list(1, 2, 3))) #' #' @family ParamSet construction helpers diff --git a/R/ParamSet.R b/R/ParamSet.R index 38735053..4fb923cc 100644 --- a/R/ParamSet.R +++ b/R/ParamSet.R @@ -296,17 +296,16 @@ ParamSet = R6Class("ParamSet", }, #' @description - #' Convert all `InternalTuneToken`s to parameter values as is defined by their `in_tune_fn`. - #' + #' Convert all parameters from the search space to parameter values using the transformation given by + #' `in_tune_fn`. + #' @param search_space ([`ParamSet`])\cr + #' The internal search space. #' @return (named `list()`) - convert_internal_tune_tokens = function() { - internal_tune_tokens = self$get_values(type = "with_internal", check_required = FALSE) - internal_tune_ps = private$get_tune_ps(internal_tune_tokens) - - imap(internal_tune_ps$domains, function(token, .id) { + convert_internal_search_space = function(search_space) { + imap(search_space$domains, function(token, .id) { converter = private$.params[list(.id), "cargo", on = "id"][[1L]][[1L]]$in_tune_fn if (!is.function(converter)) { - stopf("No converter exists for InternalTuneToken of parameters '%s'", .id) + stopf("No converter exists for parameter '%s'", .id) } converter(token) }) diff --git a/man/Domain.Rd b/man/Domain.Rd index 2f716e53..e63cc3ec 100644 --- a/man/Domain.Rd +++ b/man/Domain.Rd @@ -254,12 +254,13 @@ print(grid$transpose()) param_set = ps( iters = p_int(0, Inf, tags = "internal_tuning", aggr = function(x) round(mean(unlist(x))), - in_tune_fn = function(domain, param_set) domain$upper, disable_in_tune = list(other_param = FALSE)) + in_tune_fn = function(domain, param_set) domain$upper, + disable_in_tune = list(other_param = FALSE)) ) param_set$set_values( iters = to_tune(upper = 100, internal = TRUE) ) -param_set$convert_internal_tune_tokens() +param_set$convert_internal_search_space(param_set$search_space()) param_set$aggr(list(iters = list(1, 2, 3))) } diff --git a/man/ParamSet.Rd b/man/ParamSet.Rd index 6410ade3..e4cc954f 100644 --- a/man/ParamSet.Rd +++ b/man/ParamSet.Rd @@ -174,7 +174,7 @@ Named with param IDs.} \item \href{#method-ParamSet-trafo}{\code{ParamSet$trafo()}} \item \href{#method-ParamSet-aggr}{\code{ParamSet$aggr()}} \item \href{#method-ParamSet-disable_internal_tuning}{\code{ParamSet$disable_internal_tuning()}} -\item \href{#method-ParamSet-convert_internal_tune_tokens}{\code{ParamSet$convert_internal_tune_tokens()}} +\item \href{#method-ParamSet-convert_internal_search_space}{\code{ParamSet$convert_internal_search_space()}} \item \href{#method-ParamSet-test_constraint}{\code{ParamSet$test_constraint()}} \item \href{#method-ParamSet-test_constraint_dt}{\code{ParamSet$test_constraint_dt()}} \item \href{#method-ParamSet-check}{\code{ParamSet$check()}} @@ -386,14 +386,23 @@ The ids of the parameters for which to disable internal tuning.} } } \if{html}{\out{
    }} -\if{html}{\out{}} -\if{latex}{\out{\hypertarget{method-ParamSet-convert_internal_tune_tokens}{}}} -\subsection{Method \code{convert_internal_tune_tokens()}}{ -Convert all \code{InternalTuneToken}s to parameter values as is defined by their \code{in_tune_fn}. +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ParamSet-convert_internal_search_space}{}}} +\subsection{Method \code{convert_internal_search_space()}}{ +Convert all parameters from the search space to parameter values using the transformation given by +\code{in_tune_fn}. \subsection{Usage}{ -\if{html}{\out{
    }}\preformatted{ParamSet$convert_internal_tune_tokens()}\if{html}{\out{
    }} +\if{html}{\out{
    }}\preformatted{ParamSet$convert_internal_search_space(search_space)}\if{html}{\out{
    }} } +\subsection{Arguments}{ +\if{html}{\out{
    }} +\describe{ +\item{\code{search_space}}{(\code{\link{ParamSet}})\cr +The internal search space.} +} +\if{html}{\out{
    }} +} \subsection{Returns}{ (named \code{list()}) } diff --git a/man/ParamSetCollection.Rd b/man/ParamSetCollection.Rd index a372f08e..a1999df6 100644 --- a/man/ParamSetCollection.Rd +++ b/man/ParamSetCollection.Rd @@ -86,7 +86,7 @@ This field provides direct references to the \code{\link{ParamSet}} objects.}
  • paradox::ParamSet$check()
  • paradox::ParamSet$check_dependencies()
  • paradox::ParamSet$check_dt()
  • -
  • paradox::ParamSet$convert_internal_tune_tokens()
  • +
  • paradox::ParamSet$convert_internal_search_space()
  • paradox::ParamSet$flatten()
  • paradox::ParamSet$format()
  • paradox::ParamSet$get_domain()
  • diff --git a/tests/testthat/test_ParamSet.R b/tests/testthat/test_ParamSet.R index cf821d9e..5759b8d5 100644 --- a/tests/testthat/test_ParamSet.R +++ b/tests/testthat/test_ParamSet.R @@ -451,18 +451,15 @@ test_that("aggr", { expect_error(param_set$aggr(list(a = list(), b = list(), c = list(), d = list(), e = list())), "but there are no") }) -test_that("convert_internal_tune_tokens", { +test_that("convert_internal_search_space", { param_set = ps( a = p_int(lower = 1, upper = 100, tags = "internal_tuning", in_tune_fn = function(domain, param_set) domain$upper, aggr = function(x) round(mean(unlist(x))), disable_in_tune = list(a = 1)) ) param_set$set_values(a = to_tune(internal = TRUE)) - expect_identical(param_set$convert_internal_tune_tokens(), list(a = 100)) + expect_identical(param_set$convert_internal_search_space(param_set$search_space()), list(a = 100)) param_set$set_values(a = to_tune(internal = TRUE, upper = 99)) - expect_identical(param_set$convert_internal_tune_tokens(), list(a = 99)) - - param_set$set_values(a = to_tune(internal = FALSE)) - expect_identical(param_set$convert_internal_tune_tokens(), named_list()) + expect_identical(param_set$convert_internal_search_space(param_set$search_space()), list(a = 99)) }) test_that("get_values works with internal_tune", { diff --git a/tests/testthat/test_to_tune.R b/tests/testthat/test_to_tune.R index f9a24a2f..55ba8e95 100644 --- a/tests/testthat/test_to_tune.R +++ b/tests/testthat/test_to_tune.R @@ -430,11 +430,11 @@ test_that("internal and aggr", { # range + internal param_set$set_values(a = to_tune(lower = 1.2, upper = 1.3, aggr = function(x) 1.5)) expect_equal(param_set$search_space()$aggr(list(a = list(1, 2))), list(a = 1.5)) - expect_equal(param_set$convert_internal_tune_tokens(), list(a = 1.3)) + expect_equal(param_set$convert_internal_search_space(param_set$search_space()), list(a = 1.3)) # full + internal param_set$set_values(a = to_tune(internal = TRUE, aggr = function(x) 1.5)) - expect_equal(param_set$convert_internal_tune_tokens(), list(a = 2)) + expect_equal(param_set$convert_internal_search_space(param_set$search_space()), list(a = 2)) # domain + internal expect_error( From 31784b452ccd5a29bdeae87ac525e84dc6ded206 Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Fri, 31 May 2024 17:26:11 +0200 Subject: [PATCH 23/34] wip [skip ci] --- R/Domain.R | 25 ++++++- R/ParamDbl.R | 7 -- R/ParamFct.R | 15 ++--- R/ParamInt.R | 10 --- R/ParamLgl.R | 11 +-- R/ParamSet.R | 10 +-- R/ParamSetCollection.R | 45 ++++++++++++- R/ParamUty.R | 7 -- R/to_tune.R | 2 +- tests/testthat/test_Param.R | 14 ---- tests/testthat/test_ParamSet.R | 21 +++--- tests/testthat/test_ParamSetCollection.R | 86 +++++++++++++++++++++++- tests/testthat/test_domain.R | 3 +- tests/testthat/test_to_tune.R | 21 +++--- 14 files changed, 188 insertions(+), 89 deletions(-) diff --git a/R/Domain.R b/R/Domain.R index c1201c05..5183af66 100644 --- a/R/Domain.R +++ b/R/Domain.R @@ -66,9 +66,11 @@ #' @param aggr (`function`)\cr #' Default aggregation function for a parameter. Can only be given for parameters tagged with `"internal_tuning"`. #' Function with one argument, which is a list of parameter values and that returns the aggregated parameter value. -#' @param in_tune_fn (`function(domain, param_set)`)\cr +#' @param in_tune_fn (`function(domain, param_vals)`)\cr #' Function that converters a `Domain` object into a parameter value. #' Can onlye be given for parameters tagged with `"internal_tuning"`. +#' This function should also assert that the parameters required to enable internal tuning for the given `domain` are +#' set in `param_vals` (such as `early_stopping_rounds` for `XGBoost`). #' @param disable_in_tune (named `list()`)\cr #' The parameter values that need to be set in the `ParamSet` to disable the internal tuning for the parameter. #' For `XGBoost` this would e.g. be `list(early_stopping_rounds = NULL)`. @@ -131,14 +133,16 @@ #' #' param_set = ps( #' iters = p_int(0, Inf, tags = "internal_tuning", aggr = function(x) round(mean(unlist(x))), -#' in_tune_fn = function(domain, param_set) domain$upper, +#' in_tune_fn = function(domain, param_vals) domain$upper, #' disable_in_tune = list(other_param = FALSE)) #' ) #' param_set$set_values( #' iters = to_tune(upper = 100, internal = TRUE) #' ) #' param_set$convert_internal_search_space(param_set$search_space()) -#' param_set$aggr(list(iters = list(1, 2, 3))) +#' param_set$aggr_internal_tuned_values( +#' list(iters = list(1, 2, 3)) +#' ) #' #' @family ParamSet construction helpers #' @name Domain @@ -159,6 +163,21 @@ Domain = function(cls, grouping, storage_type = "list", init) { + if ("internal_tuning" %in% tags) { + assert_true(!is.null(cargo$aggr), .var.name = "aggregation function exists") + } + assert_list(cargo$disable_in_tune, null.ok = TRUE, names = "unique") + assert_function(cargo$aggr, null.ok = TRUE) + assert_function(cargo$in_tune_fn, null.ok = TRUE) + if ((!is.null(cargo$in_tune_fn) || !is.null(cargo$disable_in_tune)) && "internal_tuning" %nin% tags) { + # we cannot check the reverse, as parameters in the search space can be tagged with 'internal_tuning' + # and not provide in_tune_fn or disable_in_tune + stopf("Arguments in_tune_fn and disable_in_tune require the tag 'internal_tuning' to be present.") + } + if ((is.null(cargo$in_tune_fn) + is.null(cargo$disable_in_tune)) == 1) { + stopf("Arguments in_tune_fn and disable_tune_fn must both be present") + } + assert_string(cls) assert_string(grouping) assert_number(lower, na.ok = TRUE) diff --git a/R/ParamDbl.R b/R/ParamDbl.R index 699450a9..6b355bf5 100644 --- a/R/ParamDbl.R +++ b/R/ParamDbl.R @@ -1,13 +1,6 @@ #' @rdname Domain #' @export p_dbl = function(lower = -Inf, upper = Inf, special_vals = list(), default = NO_DEF, tags = character(), tolerance = sqrt(.Machine$double.eps), depends = NULL, trafo = NULL, logscale = FALSE, init, aggr = NULL, in_tune_fn = NULL, disable_in_tune = NULL) { - assert_function(aggr, null.ok = TRUE, nargs = 1L) - assert_list(disable_in_tune, null.ok = TRUE, names = "unique") - if ("internal_tuning" %in% tags) { - assert_function(in_tune_fn, null.ok = FALSE, args = c("domain", "param_set"), nargs = 2L) - } else { - assert_true(is.null(in_tune_fn)) - } assert_number(tolerance, lower = 0) assert_number(lower) assert_number(upper) diff --git a/R/ParamFct.R b/R/ParamFct.R index e2863eef..3adc69c7 100644 --- a/R/ParamFct.R +++ b/R/ParamFct.R @@ -4,12 +4,6 @@ p_fct = function(levels, special_vals = list(), default = NO_DEF, tags = charact assert_function(aggr, null.ok = TRUE, nargs = 1L) constargs = as.list(match.call()[-1]) levels = eval.parent(constargs$levels) - assert_list(disable_in_tune, null.ok = TRUE, names = "unique") - if ("internal_tuning" %in% tags) { - assert_function(in_tune_fn, null.ok = FALSE, args = c("domain", "param_set"), nargs = 2L) - } else { - assert_true(is.null(in_tune_fn)) - } if (!is.character(levels)) { # if the "levels" argument is not a character vector, then # we add a trafo. @@ -28,11 +22,14 @@ p_fct = function(levels, special_vals = list(), default = NO_DEF, tags = charact } # group p_fct by levels, so the group can be checked in a vectorized fashion. # We escape '"' and '\' to '\"' and '\\', respectively. - cargo = c(aggr = aggr, in_tune_fn = in_tune_fn) - cargo = if (length(cargo)) cargo + cargo = list() cargo$disable_in_tune = disable_in_tune + cargo$aggr = aggr + cargo$in_tune_fn = in_tune_fn grouping = str_collapse(gsub("([\\\\\"])", "\\\\\\1", sort(real_levels)), quote = '"', sep = ",") - Domain(cls = "ParamFct", grouping = grouping, levels = real_levels, special_vals = special_vals, default = default, tags = tags, trafo = trafo, storage_type = "character", depends_expr = substitute(depends), init = init, cargo = cargo) + Domain(cls = "ParamFct", grouping = grouping, levels = real_levels, special_vals = special_vals, + default = default, tags = tags, trafo = trafo, storage_type = "character", + depends_expr = substitute(depends), init = init, cargo = if (length(cargo)) cargo) } #' @export diff --git a/R/ParamInt.R b/R/ParamInt.R index 2c9504ab..b0c644cb 100644 --- a/R/ParamInt.R +++ b/R/ParamInt.R @@ -2,17 +2,7 @@ #' @rdname Domain #' @export p_int = function(lower = -Inf, upper = Inf, special_vals = list(), default = NO_DEF, tags = character(), tolerance = sqrt(.Machine$double.eps), depends = NULL, trafo = NULL, logscale = FALSE, init, aggr = NULL, in_tune_fn = NULL, disable_in_tune = NULL) { - assert_function(aggr, null.ok = TRUE, nargs = 1L) assert_number(tolerance, lower = 0, upper = 0.5) - assert_list(disable_in_tune, null.ok = TRUE, names = "unique") - if ("internal_tuning" %in% tags) { - assert_function(in_tune_fn, null.ok = FALSE, args = c("domain", "param_set"), nargs = 2L) - } else { - assert_true(is.null(in_tune_fn)) - } - if ("internal_tuning" %nin% tags && !is.null(in_tune_fn)) { - stopf("Cannot only provide 'in_tune_fn' when parameter is tagged with 'internal_tuning'") - } # assert_int will stop for `Inf` values, which we explicitly allow as lower / upper bound if (!isTRUE(is.infinite(lower))) assert_int(lower, tol = 1e-300) else assert_number(lower) if (!isTRUE(is.infinite(upper))) assert_int(upper, tol = 1e-300) else assert_number(upper) diff --git a/R/ParamLgl.R b/R/ParamLgl.R index 3a53fe8c..ca555271 100644 --- a/R/ParamLgl.R +++ b/R/ParamLgl.R @@ -1,19 +1,10 @@ #' @rdname Domain #' @export p_lgl = function(special_vals = list(), default = NO_DEF, tags = character(), depends = NULL, trafo = NULL, init, aggr = NULL, in_tune_fn = NULL, disable_in_tune = NULL) { - assert_function(aggr, null.ok = TRUE, nargs = 1L) - assert_list(disable_in_tune, null.ok = TRUE, names = "unique") - if ("internal_tuning" %in% tags) { - assert_function(in_tune_fn, null.ok = FALSE, args = c("domain", "param_set"), nargs = 2L) - } else { - assert_true(is.null(in_tune_fn)) - } - cargo = c(aggr = aggr, in_tune_fn = in_tune_fn) cargo$disable_in_tune = disable_in_tune - cargo = if (length(cargo)) cargo Domain(cls = "ParamLgl", grouping = "ParamLgl", levels = c(TRUE, FALSE), special_vals = special_vals, default = default, - tags = tags, trafo = trafo, storage_type = "logical", depends_expr = substitute(depends), init = init, cargo = cargo) + tags = tags, trafo = trafo, storage_type = "logical", depends_expr = substitute(depends), init = init, cargo = if (length(cargo)) cargo) } #' @export diff --git a/R/ParamSet.R b/R/ParamSet.R index 4fb923cc..2fb83fc4 100644 --- a/R/ParamSet.R +++ b/R/ParamSet.R @@ -267,7 +267,7 @@ ParamSet = R6Class("ParamSet", #' The aggregation function is selected based on the parameter. #' #' @return (named `list()`) - aggr = function(x) { + aggr_internal_tuned_values = function(x) { assert_list(x, types = "list") aggrs = private$.params[map_lgl(get("cargo"), function(cargo) is.function(cargo$aggr)), list(id = get("id"), aggr = map(get("cargo"), "aggr"))] assert_subset(names(x), aggrs$id) @@ -302,12 +302,15 @@ ParamSet = R6Class("ParamSet", #' The internal search space. #' @return (named `list()`) convert_internal_search_space = function(search_space) { + assert_class(search_space, "ParamSet") + param_vals = self$values + imap(search_space$domains, function(token, .id) { converter = private$.params[list(.id), "cargo", on = "id"][[1L]][[1L]]$in_tune_fn if (!is.function(converter)) { stopf("No converter exists for parameter '%s'", .id) } - converter(token) + converter(token, param_vals) }) }, @@ -386,9 +389,6 @@ ParamSet = R6Class("ParamSet", if ("internal_tuning" %nin% self$tags[[pid]]) { stopf("Trying to assign InternalTuneToken to parameter '%s' which is not tagged with 'internal_tuning'.", pid) } - if (is.null(xs[[pid]]$content$aggr) && is.null(private$.params[pid, "cargo", on = "id"][[1L]][[1L]]$aggr)) { - stopf("Trying to set parameter '%s' to InternalTuneToken, but no aggregation function is available.", pid) - } }) diff --git a/R/ParamSetCollection.R b/R/ParamSetCollection.R index 4702b3b1..76f06279 100644 --- a/R/ParamSetCollection.R +++ b/R/ParamSetCollection.R @@ -161,13 +161,56 @@ ParamSetCollection = R6Class("ParamSetCollection", inherit = ParamSet, pvs = Reduce(c, map(ids, function(id_) { info = private$.translation[id_, c("original_id", "owner_name"), on = "id"] - xs = get_private(private$.sets[[info$owner_name]])$.params[info$original_id, "cargo", on = "id"][[1L]][[1]]$disable_in_tune + xs = get_private(private$.sets[[info$owner_name]])$.params[ + info$original_id, "cargo", on = "id"][[1L]][[1]]$disable_in_tune if (info$owner_name == "" || is.null(xs)) return(xs) set_names(xs, paste0(info$owner_name, ".", names(xs))) })) %??% named_list() self$set_values(.values = pvs) + }, + + #' @description + #' Convert all parameters from the search space to parameter values using the transformation given by + #' `in_tune_fn`. + #' @param search_space ([`ParamSet`])\cr + #' The internal search space. + #' @return (named `list()`) + convert_internal_search_space = function(search_space) { + assert_class(search_space, "ParamSet") + # call it on the subsets, but pass the parameter with corrected names and add the prefixes afterwards again + }, + + #' @description + #' Create a `ParamSet` from this `ParamSetCollection`. + flatten = function() { + flatps = super$flatten() + + flatps$.__enclos_env__$private$.params[, let( + cargo = pmap(list(cargo = cargo, id_ = id), function(cargo, id_) { + if (is.null(cargo$disable_in_tune) || !length(cargo$disable_in_tune)) return(cargo) + + set_id = private$.translation[list(id_), "owner_name", on = "id"][[1L]] + if (set_id == "") return(cargo) + + disable_in_tune = cargo$disable_in_tune + cargo$in_tune_fn = crate(function(domain, param_vals) { + param_vals = set_named(param_vals, gsub(sprintf("^\\Q%s.\\E", set_id), "", names(param_vals))) + disabled_vals = disable_in_tune(param_vals) + set_names(disabled_vals, paste0(set_id, ".", names(disabled_vals))) + }, disable_in_tune, set_id) + + cargo$disable_in_tune = set_names( + cargo$disable_in_tune, + paste0(set_id, ".", names(cargo$disable_in_tune)) + ) + + cargo + }) + )] + + flatps } ), diff --git a/R/ParamUty.R b/R/ParamUty.R index 01a418ce..5acd8fad 100644 --- a/R/ParamUty.R +++ b/R/ParamUty.R @@ -3,13 +3,6 @@ #' @export p_uty = function(custom_check = NULL, special_vals = list(), default = NO_DEF, tags = character(), depends = NULL, trafo = NULL, repr = substitute(default), init, aggr = NULL, in_tune_fn = NULL, disable_in_tune = NULL) { assert_function(custom_check, null.ok = TRUE) - assert_list(disable_in_tune, null.ok = TRUE, names = "unique") - if ("internal_tuning" %in% tags) { - assert_function(in_tune_fn, null.ok = FALSE, args = c("domain", "param_set"), nargs = 2L) - } else { - assert_true(is.null(in_tune_fn)) - } - assert_function(aggr, null.ok = TRUE, nargs = 1L) if (!is.null(custom_check)) { custom_check_result = custom_check(1) assert(check_true(custom_check_result), check_string(custom_check_result), .var.name = "The result of 'custom_check()'") diff --git a/R/to_tune.R b/R/to_tune.R index ace4023e..03ef7927 100644 --- a/R/to_tune.R +++ b/R/to_tune.R @@ -265,7 +265,7 @@ tunetoken_to_ps.InternalTuneToken = function(tt, param, ...) { if (is.null(aggr)) { stopf("%s must specify a aggregation function for parameter %s", tt$call, param$id) } - tunetoken_to_ps.RangeTuneToken(tt = tt, param = param, in_tune_fn = param$cargo[[1L]]$in_tune_fn, tags = "internal_tuning", + tunetoken_to_ps.RangeTuneToken(tt = tt, param = param, tags = "internal_tuning", aggr = aggr) } diff --git a/tests/testthat/test_Param.R b/tests/testthat/test_Param.R index 24c1d1f0..68bcf90d 100644 --- a/tests/testthat/test_Param.R +++ b/tests/testthat/test_Param.R @@ -54,17 +54,3 @@ test_that("special_vals work for all Param subclasses", { test_that("we cannot create Params with non-strict R names", { expect_error(ParamInt$new(id = "$foo"), "does not comply") }) - -test_that("disable_in_tune works for all Param subclassedisable_in_tune$s", { - itfn = function(domain, param_set) 1L - expect_equal( - p_uty(tags = "internal_tuning", in_tune_fn = itfn, disable_in_tune = list(a = 1))$cargo[[1]]$disable_in_tune$a, 1) - expect_equal( - p_lgl(tags = "internal_tuning", in_tune_fn = itfn, disable_in_tune = list(a = 1))$cargo[[1]]$disable_in_tune$a, 1) - expect_equal( - p_int(tags = "internal_tuning", in_tune_fn = itfn, disable_in_tune = list(a = 1))$cargo[[1]]$disable_in_tune$a, 1) - expect_equal( - p_lgl(tags = "internal_tuning", in_tune_fn = itfn, disable_in_tune = list(a = 1))$cargo[[1]]$disable_in_tune$a, 1) - expect_equal( - p_uty(tags = "internal_tuning", in_tune_fn = itfn, disable_in_tune = list(a = 1))$cargo[[1]]$disable_in_tune$a, 1) -}) diff --git a/tests/testthat/test_ParamSet.R b/tests/testthat/test_ParamSet.R index 5759b8d5..101e8de8 100644 --- a/tests/testthat/test_ParamSet.R +++ b/tests/testthat/test_ParamSet.R @@ -442,18 +442,20 @@ test_that("aggr", { ) expect_class(param_set, "ParamSet") - vals = param_set$aggr(list(a = list(1), b = list(1), c = list(1), d = list(1), e = list(1))) + vals = param_set$aggr_internal_tuned_values( + list(a = list(1), b = list(1), c = list(1), d = list(1), e = list(1))) expect_equal(vals, list(a = "a", b = "b", c = "c", d = "d", e = "e")) - expect_error(param_set$aggr(1), "list") - expect_error(param_set$aggr(list(1)), "list") - expect_error(param_set$aggr(list(y = list())), "subset") - expect_error(param_set$aggr(list(a = list(), b = list(), c = list(), d = list(), e = list())), "but there are no") + expect_error(param_set$aggr_internal_tuned_values(1), "list") + expect_error(param_set$aggr_internal_tuned_values(list(1)), "list") + expect_error(param_set$aggr_internal_tuned_values(list(y = list())), "subset") + expect_error(param_set$aggr_internal_tuned_values( + list(a = list(), b = list(), c = list(), d = list(), e = list())), "but there are no") }) test_that("convert_internal_search_space", { param_set = ps( - a = p_int(lower = 1, upper = 100, tags = "internal_tuning", in_tune_fn = function(domain, param_set) domain$upper, + a = p_int(lower = 1, upper = 100, tags = "internal_tuning", in_tune_fn = function(domain, param_vals) domain$upper, aggr = function(x) round(mean(unlist(x))), disable_in_tune = list(a = 1)) ) param_set$set_values(a = to_tune(internal = TRUE)) @@ -464,7 +466,7 @@ test_that("convert_internal_search_space", { test_that("get_values works with internal_tune", { param_set = ps( - a = p_int(lower = 1, upper = 100, tags = "internal_tuning", in_tune_fn = function(domain, param_set) domain$upper, + a = p_int(lower = 1, upper = 100, tags = "internal_tuning", in_tune_fn = function(domain, param_vals) domain$upper, aggr = function(x) round(mean(unlist(x))), disable_in_tune = list(a = 1)) ) param_set$set_values(a = to_tune(internal = TRUE)) @@ -475,7 +477,8 @@ test_that("get_values works with internal_tune", { test_that("InternalTuneToken is translated to 'internal_tuning' tag when creating search space", { param_set = ps( - a = p_int(0, Inf, tags = "internal_tuning", in_tune_fn = function(domain, param_set) domain$upper, aggr = function(x) round(mean(unlist(x)))) + a = p_int(0, Inf, tags = "internal_tuning", in_tune_fn = function(domain, param_vals) domain$upper, aggr = function(x) round(mean(unlist(x)), aggr = function(x) 1), + disable_in_tune = list()) ) param_set$set_values( @@ -488,7 +491,7 @@ test_that("InternalTuneToken is translated to 'internal_tuning' tag when creatin test_that("disable internal tuning", { param_set = ps( - a = p_dbl(tags = "internal_tuning", in_tune_fn = function(domain, param_set) domain$upper, disable_in_tune = list(b = FALSE)), + a = p_dbl(tags = "internal_tuning", in_tune_fn = function(domain, param_vals) domain$upper, disable_in_tune = list(b = FALSE), aggr = function(x) 1), b = p_lgl() ) diff --git a/tests/testthat/test_ParamSetCollection.R b/tests/testthat/test_ParamSetCollection.R index 426ec5f2..55de1555 100644 --- a/tests/testthat/test_ParamSetCollection.R +++ b/tests/testthat/test_ParamSetCollection.R @@ -239,7 +239,7 @@ test_that("set_id inference in values assignment works now", { test_that("disable internal tuning works", { param_set = psc(prefix = ps( - a = p_dbl(tags = "internal_tuning", in_tune_fn = function(domain, param_set) domain$upper, disable_in_tune = list(b = FALSE)), + a = p_dbl(aggr = function(x) 1, tags = "internal_tuning", in_tune_fn = function(domain, param_vals) domain$upper, disable_in_tune = list(b = FALSE)), b = p_lgl() )) @@ -249,3 +249,87 @@ test_that("disable internal tuning works", { expect_equal(named_list(), psc(ps())$disable_internal_tuning(character(0))$values) }) + +test_that("convert_internal_search_space: depends on other parameter", { + param_set = psc(a = ps( + b = p_int(tags = "internal_tuning", in_tune_fn = function(domain, param_vals) param_vals$c * domain$upper, + aggr = function(x) 1, disable_in_tune = list()), + c = p_int() + )) + param_set$values$c = -1 + + search_space = ps( + b = p_int(upper = 1000, tags = "internal_tuning", aggr = function(x) 1) + ) + + expect_equal( + param_set$convert_internal_search_space(search_space)$b, + -1000 + ) +}) + +test_that("convert_internal_search_space: nested collections", { + param_set = psc(a = psc(b = ps(param = p_int( + in_tune_fn = function(domain, param_vals) domain$upper, tags = "internal_tuning", disable_in_tune = list(), aggr = function(x) 1 + )))) + + search_space = ps( + a.b.param = p_int(upper = 99, tags = "internal_tuning", aggr = function(x) 1) + ) + + expect_equal( + param_set$convert_internal_search_space(search_space), + list(a.b.param = 99) + ) +}) + +test_that("convert_internal_search_space: flattening", { + param_set = psc(a = psc(b = ps( + param = p_int( + in_tune_fn = function(domain, param_vals) domain$upper * param_vals$other_param, tags = "internal_tuning", + disable_in_tune = list(), aggr = function(x) 1), + other_param = p_int() + ))) + + param_set$values$a.b.other_param = -1 + + search_space = ps( + a.b.param = p_int(upper = 99, tags = "internal_tuning", aggr = function(x) 1) + ) + + expect_equal( + param_set$flatten()$convert_internal_search_space(search_space), + list(a.b.param = -99) + ) +}) + +test_that("disable_in_tune: single collection", { + param_set = psc(a = ps( + b = p_int( + in_tune_fn = function(domain, param_vals) domain$upper, tags = "internal_tuning", + disable_in_tune = list(c = TRUE), aggr = function(x) 1 + ), + c = p_lgl() + )) + + param_set$disable_internal_tuning("a.b") + expect_equal(param_set$values$a.c, TRUE) +}) + +test_that("disable_in_tune: nested collection", { + param_set = ps( + a = p_int( + in_tune_fn = function(domain, param_vals) domain$upper, tags = "internal_tuning", + disable_in_tune = list(), aggr = function(x) 1 + ) + ) +}) + +test_that("disable_in_tune: flattening", { + param_set = ps( + a = p_int( + in_tune_fn = function(domain, param_vals) domain$upper, tags = "internal_tuning", + disable_in_tune = list(), aggr = function(x) 1 + ) + ) +}) diff --git a/tests/testthat/test_domain.R b/tests/testthat/test_domain.R index 4620a0ea..02402c00 100644 --- a/tests/testthat/test_domain.R +++ b/tests/testthat/test_domain.R @@ -361,7 +361,8 @@ test_that("internal", { it1 = to_tune(aggr = function(x) min(unlist(x))) expect_equal(it1$content$aggr(list(1, 2)), 1) param_set = ps( - a = p_dbl(1, 10, aggr = function(x) mean(unlist(x)), tags = "internal_tuning", in_tune_fn = function(domain, param_set) domain$upper) + a = p_dbl(1, 10, aggr = function(x) mean(unlist(x)), tags = "internal_tuning", in_tune_fn = function(domain, param_vals) domain$upper, + disable_in_tune = list()) ) param_set$set_values(a = to_tune(internal = TRUE, aggr = function(x) round(mean(unlist(x))))) expect_class(param_set$values$a, "InternalTuneToken") diff --git a/tests/testthat/test_to_tune.R b/tests/testthat/test_to_tune.R index 55ba8e95..9b797277 100644 --- a/tests/testthat/test_to_tune.R +++ b/tests/testthat/test_to_tune.R @@ -402,16 +402,15 @@ test_that("logscale in tunetoken", { test_that("internal and aggr", { # no default aggregation function - param_set = ps(a = p_dbl(lower = 1, upper = 2, tags = "internal_tuning", in_tune_fn = function(domain, param_set) domain$upper)) - - # correct errors - expect_error(param_set$set_values(a = to_tune(internal = TRUE)), "aggregation") - expect_error(param_set$set_values(a = to_tune(internal = FALSE, aggr = function(x) 1))) + param_set = ps(a = p_dbl(lower = 1, upper = 2, tags = "internal_tuning", in_tune_fn = function(domain, param_vals) domain$upper, + disable_in_tune = list(), aggr = function(x) round(mean(unlist(x)))) + ) # full tune token + internal expect_equal( - param_set$set_values(a = to_tune(aggr = function(x) -99))$search_space()$aggr(list(a = list(1, 2, 3))), + param_set$set_values(a = to_tune(aggr = function(x) -99))$search_space()$aggr_internal_tuned_values( + list(a = list(1, 2, 3))), list(a = -99) ) @@ -429,7 +428,7 @@ test_that("internal and aggr", { # range + internal param_set$set_values(a = to_tune(lower = 1.2, upper = 1.3, aggr = function(x) 1.5)) - expect_equal(param_set$search_space()$aggr(list(a = list(1, 2))), list(a = 1.5)) + expect_equal(param_set$search_space()$aggr_internal_tuned_values(list(a = list(1, 2))), list(a = 1.5)) expect_equal(param_set$convert_internal_search_space(param_set$search_space()), list(a = 1.3)) # full + internal @@ -445,18 +444,18 @@ test_that("internal and aggr", { ## with default aggregation function # param set + internal - param_set = ps(a = p_int(lower = 1, upper = 10000, tags = "internal_tuning", in_tune_fn = function(domain, param_set) domain$upper, - aggr = function(x) max(unlist(x)))) + param_set = ps(a = p_int(lower = 1, upper = 10000, tags = "internal_tuning", in_tune_fn = function(domain, param_vals) domain$upper, + aggr = function(x) max(unlist(x)), disable_in_tune = list())) # default aggregation function is used when not overwritten param_set$set_values( a = to_tune(internal = TRUE) ) - expect_equal(param_set$search_space()$aggr(list(a = list(1, 2, 3))), list(a = 3)) + expect_equal(param_set$search_space()$aggr_internal_tuned_values(list(a = list(1, 2, 3))), list(a = 3)) # can overwrite existing aggregation function param_set$set_values( a = to_tune(internal = TRUE, aggr = function(x) -60) ) - expect_equal(param_set$search_space()$aggr(list(a = list(1, 2, 3))), list(a = -60)) + expect_equal(param_set$search_space()$aggr_internal_tuned_values(list(a = list(1, 2, 3))), list(a = -60)) }) From 84823355c0267d1b461783c63fe643e4827ff365 Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Fri, 31 May 2024 18:02:34 +0200 Subject: [PATCH 24/34] more wip --- R/ParamSetCollection.R | 31 ++++++++++++++++-- tests/testthat/test_ParamSetCollection.R | 41 ++++++++++++++++++------ 2 files changed, 60 insertions(+), 12 deletions(-) diff --git a/R/ParamSetCollection.R b/R/ParamSetCollection.R index 76f06279..cedf6bcd 100644 --- a/R/ParamSetCollection.R +++ b/R/ParamSetCollection.R @@ -171,6 +171,26 @@ ParamSetCollection = R6Class("ParamSetCollection", inherit = ParamSet, self$set_values(.values = pvs) }, + # #' @description + # #' Convert all parameters from the search space to parameter values using the transformation given by + # #' `in_tune_fn`. + # #' @param search_space ([`ParamSet`])\cr + # #' The internal search space. + # #' @return (named `list()`) + # convert_internal_search_space = function(search_space) { + # assert_class(search_space, "ParamSet") + + + + + # param_vals = self$values + + # Reduce(c, imap(private$.sets, function(set, prefix) { + + # set$convert_internal_search_space() + + # })) %??% named_list() + # }, #' @description #' Convert all parameters from the search space to parameter values using the transformation given by #' `in_tune_fn`. @@ -179,7 +199,14 @@ ParamSetCollection = R6Class("ParamSetCollection", inherit = ParamSet, #' @return (named `list()`) convert_internal_search_space = function(search_space) { assert_class(search_space, "ParamSet") - # call it on the subsets, but pass the parameter with corrected names and add the prefixes afterwards again + imap(search_space$domains, function(token, .id) { + converter = private$.params[list(.id), "cargo", on = "id"][[1L]][[1L]]$in_tune_fn + if (!is.function(converter)) { + stopf("No converter exists for parameter '%s'", .id) + } + set_index = private$.translation[list(.id), "owner_ps_index", on = "id"][[1L]] + converter(token, private$.sets[[set_index]]$values) + }) }, #' @description @@ -196,7 +223,7 @@ ParamSetCollection = R6Class("ParamSetCollection", inherit = ParamSet, disable_in_tune = cargo$disable_in_tune cargo$in_tune_fn = crate(function(domain, param_vals) { - param_vals = set_named(param_vals, gsub(sprintf("^\\Q%s.\\E", set_id), "", names(param_vals))) + param_vals = set_names(param_vals, gsub(sprintf("^\\Q%s.\\E", set_id), "", names(param_vals))) disabled_vals = disable_in_tune(param_vals) set_names(disabled_vals, paste0(set_id, ".", names(disabled_vals))) }, disable_in_tune, set_id) diff --git a/tests/testthat/test_ParamSetCollection.R b/tests/testthat/test_ParamSetCollection.R index 55de1555..8d8a30ed 100644 --- a/tests/testthat/test_ParamSetCollection.R +++ b/tests/testthat/test_ParamSetCollection.R @@ -256,14 +256,15 @@ test_that("convert_internal_search_space: depends on other parameter", { aggr = function(x) 1, disable_in_tune = list()), c = p_int() )) - param_set$values$c = -1 + param_set$values$a.c = -1 search_space = ps( - b = p_int(upper = 1000, tags = "internal_tuning", aggr = function(x) 1) + a.b = p_int(upper = 1000, tags = "internal_tuning", aggr = function(x) 1) ) + browser() expect_equal( - param_set$convert_internal_search_space(search_space)$b, + param_set$convert_internal_search_space(search_space)$a.b, -1000 ) }) @@ -303,7 +304,7 @@ test_that("convert_internal_search_space: flattening", { ) }) -test_that("disable_in_tune: single collection", { +test_that("disable internal tuning: single collection", { param_set = psc(a = ps( b = p_int( in_tune_fn = function(domain, param_vals) domain$upper, tags = "internal_tuning", @@ -316,7 +317,7 @@ test_that("disable_in_tune: single collection", { expect_equal(param_set$values$a.c, TRUE) }) -test_that("disable_in_tune: nested collection", { +test_that("disable internal tuning: nested collection", { param_set = ps( a = p_int( in_tune_fn = function(domain, param_vals) domain$upper, tags = "internal_tuning", @@ -325,11 +326,31 @@ test_that("disable_in_tune: nested collection", { ) }) -test_that("disable_in_tune: flattening", { - param_set = ps( - a = p_int( +test_that("disable internal tuning: flattening", { + param_set = psc(a = ps( + b = p_int( in_tune_fn = function(domain, param_vals) domain$upper, tags = "internal_tuning", - disable_in_tune = list(), aggr = function(x) 1 - ) + disable_in_tune = list(c = 1), aggr = function(x) 1 + ), + c = p_int() + ))$flatten() + + expect_equal( + param_set$disable_internal_tuning("a.b")$values$a.c, + 1 + ) + + # now with no set id + param_set = psc(ps( + b = p_int( + in_tune_fn = function(domain, param_vals) domain$upper, tags = "internal_tuning", + disable_in_tune = list(c = 1), aggr = function(x) 1 + ), + c = p_int() + ))$flatten() + + expect_equal( + param_set$disable_internal_tuning("b")$values$c, + 1 ) }) From c65cba454595f330abe143f245c5ce28b31773e6 Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Sat, 1 Jun 2024 07:16:02 +0200 Subject: [PATCH 25/34] hopefully fix final bug --- R/ParamSetCollection.R | 51 +- tests/testthat/test_ParamSetCollection.R | 566 +++++++++++------------ 2 files changed, 319 insertions(+), 298 deletions(-) diff --git a/R/ParamSetCollection.R b/R/ParamSetCollection.R index cedf6bcd..26995f08 100644 --- a/R/ParamSetCollection.R +++ b/R/ParamSetCollection.R @@ -214,25 +214,48 @@ ParamSetCollection = R6Class("ParamSetCollection", inherit = ParamSet, flatten = function() { flatps = super$flatten() + recurse_prefix = function(id_, param_set, prefix = "") { + info = get_private(param_set)$.translation[list(id_), c("owner_name", "owner_ps_index"), on = "id"] + prefix = if (info$owner_name == "") { + prefix + } else if (prefix == "") { + info$owner_name + } else { + paste0(prefix, ".", info$owner_name) + } + subset = get_private(param_set)$.sets[[info$owner_ps_index]] + if (!test_class(subset, "ParamSetCollection")) { + return(list(prefix = prefix, ids = subset$ids())) + } + if (prefix != "") { + id_ = gsub(sprintf("^\\Q%s.\\E", prefix), "", id_) + } + recurse_prefix(id_, get_private(param_set)$.sets[[info$owner_ps_index]], prefix) + } + flatps$.__enclos_env__$private$.params[, let( cargo = pmap(list(cargo = cargo, id_ = id), function(cargo, id_) { - if (is.null(cargo$disable_in_tune) || !length(cargo$disable_in_tune)) return(cargo) - - set_id = private$.translation[list(id_), "owner_name", on = "id"][[1L]] - if (set_id == "") return(cargo) + if (all(map_lgl(cargo[c("disable_in_tune", "in_tune_fn")], is.null))) return(cargo) - disable_in_tune = cargo$disable_in_tune - cargo$in_tune_fn = crate(function(domain, param_vals) { - param_vals = set_names(param_vals, gsub(sprintf("^\\Q%s.\\E", set_id), "", names(param_vals))) - disabled_vals = disable_in_tune(param_vals) - set_names(disabled_vals, paste0(set_id, ".", names(disabled_vals))) - }, disable_in_tune, set_id) + info = recurse_prefix(id_, self) + prefix = info$prefix + if (prefix == "") return(cargo) - cargo$disable_in_tune = set_names( - cargo$disable_in_tune, - paste0(set_id, ".", names(cargo$disable_in_tune)) - ) + in_tune_fn = cargo$in_tune_fn + set_ids = info$ids + cargo$in_tune_fn = crate(function(domain, param_vals) { + param_vals = param_vals[names(param_vals) %in% paste0(prefix, ".", set_ids)] + names(param_vals) = gsub(sprintf("^\\Q%s.\\E", prefix), "", names(param_vals)) + in_tune_fn(domain, param_vals) + }, in_tune_fn, prefix, set_ids) + + if (length(cargo$disable_in_tune)) { + cargo$disable_in_tune = set_names( + cargo$disable_in_tune, + paste0(prefix, ".", names(cargo$disable_in_tune)) + ) + } cargo }) )] diff --git a/tests/testthat/test_ParamSetCollection.R b/tests/testthat/test_ParamSetCollection.R index 8d8a30ed..909592ac 100644 --- a/tests/testthat/test_ParamSetCollection.R +++ b/tests/testthat/test_ParamSetCollection.R @@ -1,288 +1,287 @@ -context("ParamSetCollection") - -test_that("ParamSet basic stuff works", { - ps1 = th_paramset_dbl1() - ps2 = th_paramset_full() - ps3 = th_paramset_dbl1() - psc = ParamSetCollection$new(list(s1 = ps1, s2 = ps2, ps3)) - - ps1clone = ps1$clone(deep = TRUE) - ps2clone = ps2$clone(deep = TRUE) - - my_c = function(xs1, xs2, xs3) { - # littler helper to join to ps-result and prefix names - ns = c(paste0("s1.", names(xs1)), paste0("s2.", names(xs2)), names(xs3)) - set_names(c(xs1, xs2, xs3), ns) - } - - expect_class(psc, "ParamSetCollection") - expect_equal(psc$length, ps1$length + ps2$length + ps3$length) - # check that param internally in collection is constructed correctly - p = psc$params[2L] - p$id = "th_param_int" - - expect_equal(p, ps2$params[1L]) - expect_equal(psc$ids(), c(paste0("s1.", ps1$ids()), paste0("s2.", ps2$ids()), ps3$ids())) - expect_equal(psc$lower, my_c(ps1$lower, ps2$lower, ps3$lower)) - d = as.data.table(psc) - expect_data_table(d, nrows = 6) - expect_false(psc$has_deps) - expect_false(psc$has_trafo) - - d = as.data.table(psc) - expect_equal(d$id, c(paste0("s1.", ps1$ids()), paste0("s2.", ps2$ids()), ps3$ids())) - - expect_true(psc$check(list(s1.th_param_dbl = 1, s2.th_param_int = 2))) - expect_string(psc$check(list(th_param_int = 2)), fixed = "not avail") - expect_true(psc$check(list(th_param_dbl = 1))) - - d = generate_design_random(psc, n = 10L) - expect_data_table(d$data, nrows = 10, ncols = 6L) - - psflat = psc$flatten() - psflat$extra_trafo = function(x, param_set) { - x$s2.th_param_int = 99 # nolint - return(x) - } - expect_true(psflat$has_trafo) - d = generate_design_random(psflat, n = 10L) - expect_data_table(d$data, nrows = 10, ncols = 6L) - xs = d$transpose(trafo = TRUE) - for (i in 1:10) { - x = xs[[i]] - expect_list(x, len = 6) - expect_names(names(x), permutation.of = psc$ids()) - expect_equal(x$s2.th_param_int, 99) - } - - # ps1 and ps2 should not be changed - expect_equal(ps1, ps1clone) - expect_equal(ps2, ps2clone) - - expect_output(print(psc), "s1\\.th_param_dbl.*s2\\.th_param_int.*s2\\.th_param_dbl.*s2\\.th_param_fct.*s2\\.th_param_lgl.*th_param_dbl") # nolint - - # ps1 and ps2 should not be changed by printing - expect_equal(ps1, ps1clone) - expect_equal(ps2, ps2clone) - - # adding a set - ps4 = ParamSet_legacy$new(list(ParamDbl$new("x"))) - psc = psc$add(ps4, n = "s4") - expect_equal(psc$length, ps1$length + ps2$length + ps3$length + ps4$length) - expect_equal(psc$ids(), c(paste0("s1.", ps1$ids()), paste0("s2.", ps2$ids()), ps3$ids(), paste0("s4.", ps4$ids()))) -}) - -test_that("some operations are not allowed", { - ps1 = th_paramset_dbl1() - ps2 = th_paramset_full() - psc = ParamSetCollection$new(list(s1 = ps1, s2 = ps2)) - - expect_error(psc$subset("foo"), "Must be a subset of") -}) - -test_that("deps", { - ps1 = ParamSet_legacy$new(list( - ParamFct$new("f", levels = c("a", "b")), - ParamDbl$new("d") - )) - ps1$add_dep("d", on = "f", CondEqual("a")) - - ps2 = ParamSet_legacy$new(list( - ParamFct$new("f", levels = c("a", "b")), - ParamDbl$new("d") - )) - - ps1clone = ps1$clone(deep = TRUE) - ps2clone = ps2$clone(deep = TRUE) - - psc = ParamSetCollection$new(list(ps1 = ps1, ps2 = ps2)) - d = psc$deps - expect_data_table(d, nrows = 1, ncols = 3) - expect_equal(d$id, c("ps1.d")) - - # check deps across sets - psc$add_dep("ps2.d", on = "ps1.f", CondEqual("a")) - expect_data_table(psc$deps, nrows = 2, ncols = 3) - expect_true(psc$check(list(ps1.f = "a", ps1.d = 0, ps2.d = 0))) - expect_string(psc$check(list(ps2.d = 0), check_strict = TRUE)) - - # ps1 and ps2 should not be changed - expect_equal(ps1clone, ps1) - expect_equal(ps2clone, ps2) -}) - -test_that("values", { - ps1 = ParamSet_legacy$new(list( - ParamFct$new("f", levels = c("a", "b")), - ParamDbl$new("d", lower = 1, upper = 8) - )) - ps2 = ParamSet_legacy$new(list( - ParamFct$new("f", levels = c("a", "b")), - ParamDbl$new("d", lower = 1, upper = 8) - )) - ps3 = ParamSet_legacy$new(list( - ParamDbl$new("x", lower = 1, upper = 8) - )) - ps4 = ParamSet_legacy$new(list( - ParamDbl$new("y", lower = 1, upper = 8) - )) - - ps1clone = ps1$clone(deep = TRUE) - ps2clone = ps2$clone(deep = TRUE) - - pcs = ParamSetCollection$new(list(foo = ps1, bar = ps2, ps3, ps4)) - expect_equal(pcs$values, named_list()) - ps2$values = list(d = 3) - expect_equal(pcs$values, list(bar.d = 3)) - pcs$values = list(foo.d = 8) - expect_equal(pcs$values, list(foo.d = 8)) - expect_equal(ps1$values, list(d = 8)) - expect_equal(ps2$values, named_list()) - pcs$values = list(x = 1) - expect_equal(pcs$values, list(x = 1)) - expect_equal(ps3$values, list(x = 1)) - - ps1clone$values$d = 8 - pcs$values = list(foo.d = 8) - ps2$values = list() - - # data table adds indexes at will and comparisons fail because of that, so we have to remove them here. - setindex(ps1clone$deps, NULL) - setindex(ps2clone$deps, NULL) - setindex(ps1$deps, NULL) - setindex(ps2$deps, NULL) - - expect_equal(ps1clone, ps1) - expect_equal(ps2clone, ps2) - - # resetting pcs values - pcs$values = list() - expect_list(pcs$values, len = 0) -}) - -test_that("empty collections", { - # no paramsets - psc = ParamSetCollection$new(list()) - expect_equal(psc$length, 0L) - expect_equal(psc$subspaces(), named_list()) - expect_equal(psc$ids(), character(0L)) - expect_data_table(as.data.table(psc), nrows = 0L) - - # 1 empty paramset - psc = ParamSetCollection$new(list(ParamSet_legacy$new())) - expect_equal(psc$length, 0L) - expect_equal(psc$subspaces(), named_list()) - expect_equal(psc$ids(), character(0L)) - expect_data_table(as.data.table(psc), nrows = 0L) -}) - - -test_that("no problems if we name the list of sets", { - ps = ParamSet_legacy$new(list(ParamDbl$new("test1"))) - psc = ParamSetCollection$new(list(paramset = ps)) - expect_equal(names(psc$subspaces()), "paramset.test1") -}) - -test_that("no warning in printer, see issue 208", { - ps = ParamSet_legacy$new(list(ParamDbl$new("test1"))) - - psc = ParamSetCollection$new(list(paramset = ps)) - psc$values = list(paramset.test1 = 1) - expect_warning(capture_output(print(ps)), NA) -}) - -test_that("collection allows state-change setting of paramvals, see issue 205", { - ps1 = ParamSet_legacy$new(list(ParamDbl$new("d1"))) - ps2 = ParamSet_legacy$new(list(ParamDbl$new("d2"))) - ps3 = ParamSet_legacy$new(list(ParamDbl$new("d3"))) - - psc = ParamSetCollection$new(list(s1 = ps1, s2 = ps2, ps3)) - expect_equal(psc$values, named_list()) - psc$values$s1.d1 = 1 # nolint - expect_equal(psc$values, list(s1.d1 = 1)) - psc$values$s2.d2 = 2 # nolint - expect_equal(psc$values, list(s1.d1 = 1, s2.d2 = 2)) - psc$values$d3 = 3 - expect_equal(psc$values, list(s1.d1 = 1, s2.d2 = 2, d3 = 3)) -}) - -test_that("set_id inference in values assignment works now", { - psa = ParamSet_legacy$new(list(ParamDbl$new("parama"))) - - psb = ParamSet_legacy$new(list(ParamDbl$new("paramb"))) - - psc = ParamSet_legacy$new(list(ParamDbl$new("paramc"))) +# context("ParamSetCollection") + +# test_that("ParamSet basic stuff works", { +# ps1 = th_paramset_dbl1() +# ps2 = th_paramset_full() +# ps3 = th_paramset_dbl1() +# psc = ParamSetCollection$new(list(s1 = ps1, s2 = ps2, ps3)) + +# ps1clone = ps1$clone(deep = TRUE) +# ps2clone = ps2$clone(deep = TRUE) + +# my_c = function(xs1, xs2, xs3) { +# # littler helper to join to ps-result and prefix names +# ns = c(paste0("s1.", names(xs1)), paste0("s2.", names(xs2)), names(xs3)) +# set_names(c(xs1, xs2, xs3), ns) +# } + +# expect_class(psc, "ParamSetCollection") +# expect_equal(psc$length, ps1$length + ps2$length + ps3$length) +# # check that param internally in collection is constructed correctly +# p = psc$params[2L] +# p$id = "th_param_int" + +# expect_equal(p, ps2$params[1L]) +# expect_equal(psc$ids(), c(paste0("s1.", ps1$ids()), paste0("s2.", ps2$ids()), ps3$ids())) +# expect_equal(psc$lower, my_c(ps1$lower, ps2$lower, ps3$lower)) +# d = as.data.table(psc) +# expect_data_table(d, nrows = 6) +# expect_false(psc$has_deps) +# expect_false(psc$has_trafo) + +# d = as.data.table(psc) +# expect_equal(d$id, c(paste0("s1.", ps1$ids()), paste0("s2.", ps2$ids()), ps3$ids())) + +# expect_true(psc$check(list(s1.th_param_dbl = 1, s2.th_param_int = 2))) +# expect_string(psc$check(list(th_param_int = 2)), fixed = "not avail") +# expect_true(psc$check(list(th_param_dbl = 1))) + +# d = generate_design_random(psc, n = 10L) +# expect_data_table(d$data, nrows = 10, ncols = 6L) + +# psflat = psc$flatten() +# psflat$extra_trafo = function(x, param_set) { +# x$s2.th_param_int = 99 # nolint +# return(x) +# } +# expect_true(psflat$has_trafo) +# d = generate_design_random(psflat, n = 10L) +# expect_data_table(d$data, nrows = 10, ncols = 6L) +# xs = d$transpose(trafo = TRUE) +# for (i in 1:10) { +# x = xs[[i]] +# expect_list(x, len = 6) +# expect_names(names(x), permutation.of = psc$ids()) +# expect_equal(x$s2.th_param_int, 99) +# } + +# # ps1 and ps2 should not be changed +# expect_equal(ps1, ps1clone) +# expect_equal(ps2, ps2clone) + +# expect_output(print(psc), "s1\\.th_param_dbl.*s2\\.th_param_int.*s2\\.th_param_dbl.*s2\\.th_param_fct.*s2\\.th_param_lgl.*th_param_dbl") # nolint + +# # ps1 and ps2 should not be changed by printing +# expect_equal(ps1, ps1clone) +# expect_equal(ps2, ps2clone) + +# # adding a set +# ps4 = ParamSet_legacy$new(list(ParamDbl$new("x"))) +# psc = psc$add(ps4, n = "s4") +# expect_equal(psc$length, ps1$length + ps2$length + ps3$length + ps4$length) +# expect_equal(psc$ids(), c(paste0("s1.", ps1$ids()), paste0("s2.", ps2$ids()), ps3$ids(), paste0("s4.", ps4$ids()))) +# }) + +# test_that("some operations are not allowed", { +# ps1 = th_paramset_dbl1() +# ps2 = th_paramset_full() +# psc = ParamSetCollection$new(list(s1 = ps1, s2 = ps2)) + +# expect_error(psc$subset("foo"), "Must be a subset of") +# }) + +# test_that("deps", { +# ps1 = ParamSet_legacy$new(list( +# ParamFct$new("f", levels = c("a", "b")), +# ParamDbl$new("d") +# )) +# ps1$add_dep("d", on = "f", CondEqual("a")) + +# ps2 = ParamSet_legacy$new(list( +# ParamFct$new("f", levels = c("a", "b")), +# ParamDbl$new("d") +# )) + +# ps1clone = ps1$clone(deep = TRUE) +# ps2clone = ps2$clone(deep = TRUE) + +# psc = ParamSetCollection$new(list(ps1 = ps1, ps2 = ps2)) +# d = psc$deps +# expect_data_table(d, nrows = 1, ncols = 3) +# expect_equal(d$id, c("ps1.d")) + +# # check deps across sets +# psc$add_dep("ps2.d", on = "ps1.f", CondEqual("a")) +# expect_data_table(psc$deps, nrows = 2, ncols = 3) +# expect_true(psc$check(list(ps1.f = "a", ps1.d = 0, ps2.d = 0))) +# expect_string(psc$check(list(ps2.d = 0), check_strict = TRUE)) + +# # ps1 and ps2 should not be changed +# expect_equal(ps1clone, ps1) +# expect_equal(ps2clone, ps2) +# }) + +# test_that("values", { +# ps1 = ParamSet_legacy$new(list( +# ParamFct$new("f", levels = c("a", "b")), +# ParamDbl$new("d", lower = 1, upper = 8) +# )) +# ps2 = ParamSet_legacy$new(list( +# ParamFct$new("f", levels = c("a", "b")), +# ParamDbl$new("d", lower = 1, upper = 8) +# )) +# ps3 = ParamSet_legacy$new(list( +# ParamDbl$new("x", lower = 1, upper = 8) +# )) +# ps4 = ParamSet_legacy$new(list( +# ParamDbl$new("y", lower = 1, upper = 8) +# )) + +# ps1clone = ps1$clone(deep = TRUE) +# ps2clone = ps2$clone(deep = TRUE) + +# pcs = ParamSetCollection$new(list(foo = ps1, bar = ps2, ps3, ps4)) +# expect_equal(pcs$values, named_list()) +# ps2$values = list(d = 3) +# expect_equal(pcs$values, list(bar.d = 3)) +# pcs$values = list(foo.d = 8) +# expect_equal(pcs$values, list(foo.d = 8)) +# expect_equal(ps1$values, list(d = 8)) +# expect_equal(ps2$values, named_list()) +# pcs$values = list(x = 1) +# expect_equal(pcs$values, list(x = 1)) +# expect_equal(ps3$values, list(x = 1)) + +# ps1clone$values$d = 8 +# pcs$values = list(foo.d = 8) +# ps2$values = list() + +# # data table adds indexes at will and comparisons fail because of that, so we have to remove them here. +# setindex(ps1clone$deps, NULL) +# setindex(ps2clone$deps, NULL) +# setindex(ps1$deps, NULL) +# setindex(ps2$deps, NULL) + +# expect_equal(ps1clone, ps1) +# expect_equal(ps2clone, ps2) + +# # resetting pcs values +# pcs$values = list() +# expect_list(pcs$values, len = 0) +# }) + +# test_that("empty collections", { +# # no paramsets +# psc = ParamSetCollection$new(list()) +# expect_equal(psc$length, 0L) +# expect_equal(psc$subspaces(), named_list()) +# expect_equal(psc$ids(), character(0L)) +# expect_data_table(as.data.table(psc), nrows = 0L) + +# # 1 empty paramset +# psc = ParamSetCollection$new(list(ParamSet_legacy$new())) +# expect_equal(psc$length, 0L) +# expect_equal(psc$subspaces(), named_list()) +# expect_equal(psc$ids(), character(0L)) +# expect_data_table(as.data.table(psc), nrows = 0L) +# }) + + +# test_that("no problems if we name the list of sets", { +# ps = ParamSet_legacy$new(list(ParamDbl$new("test1"))) +# psc = ParamSetCollection$new(list(paramset = ps)) +# expect_equal(names(psc$subspaces()), "paramset.test1") +# }) + +# test_that("no warning in printer, see issue 208", { +# ps = ParamSet_legacy$new(list(ParamDbl$new("test1"))) + +# psc = ParamSetCollection$new(list(paramset = ps)) +# psc$values = list(paramset.test1 = 1) +# expect_warning(capture_output(print(ps)), NA) +# }) + +# test_that("collection allows state-change setting of paramvals, see issue 205", { +# ps1 = ParamSet_legacy$new(list(ParamDbl$new("d1"))) +# ps2 = ParamSet_legacy$new(list(ParamDbl$new("d2"))) +# ps3 = ParamSet_legacy$new(list(ParamDbl$new("d3"))) + +# psc = ParamSetCollection$new(list(s1 = ps1, s2 = ps2, ps3)) +# expect_equal(psc$values, named_list()) +# psc$values$s1.d1 = 1 # nolint +# expect_equal(psc$values, list(s1.d1 = 1)) +# psc$values$s2.d2 = 2 # nolint +# expect_equal(psc$values, list(s1.d1 = 1, s2.d2 = 2)) +# psc$values$d3 = 3 +# expect_equal(psc$values, list(s1.d1 = 1, s2.d2 = 2, d3 = 3)) +# }) + +# test_that("set_id inference in values assignment works now", { +# psa = ParamSet_legacy$new(list(ParamDbl$new("parama"))) + +# psb = ParamSet_legacy$new(list(ParamDbl$new("paramb"))) + +# psc = ParamSet_legacy$new(list(ParamDbl$new("paramc"))) + +# pscol1 = ParamSetCollection$new(list(b = psb, c = psc)) + +# pscol2 = ParamSetCollection$new(list(a.b = psa, a = pscol1)) + +# pstest = ParamSet_legacy$new(list(ParamDbl$new("paramc"))) + +# expect_error(pscol2$add(pstest, n = "a.c"), "would lead to nameclashes.*a\\.c\\.paramc") + +# pstest = ParamSet_legacy$new(list(ParamDbl$new("a.c.paramc"))) + +# expect_error(pscol2$add(pstest), "would lead to nameclashes.*a\\.c\\.paramc") + +# pscol2$values = list(a.c.paramc = 3, a.b.parama = 1, a.b.paramb = 2) + +# expect_equal(psa$values, list(parama = 1)) +# expect_equal(psb$values, list(paramb = 2)) +# expect_equal(psc$values, list(paramc = 3)) +# expect_equal(pscol1$values, list(b.paramb = 2, c.paramc = 3)) +# expect_equal(pscol2$values, list(a.b.parama = 1, a.b.paramb = 2, a.c.paramc = 3)) + +# expect_error(ParamSetCollection$new(list(a = pscol1, pstest)), +# "duplicated parameter.* a\\.c\\.paramc") +# }) + +# test_that("disable internal tuning works", { +# param_set = psc(prefix = ps( +# a = p_dbl(aggr = function(x) 1, tags = "internal_tuning", in_tune_fn = function(domain, param_vals) domain$upper, disable_in_tune = list(b = FALSE)), +# b = p_lgl() +# )) + +# param_set$disable_internal_tuning("prefix.a") +# expect_equal(param_set$values$prefix.b, FALSE) +# expect_error(param_set$disable_internal_tuning("b")) + +# expect_equal(named_list(), psc(ps())$disable_internal_tuning(character(0))$values) +# }) + +# test_that("convert_internal_search_space: depends on other parameter", { +# param_set = psc(a = ps( +# b = p_int(tags = "internal_tuning", in_tune_fn = function(domain, param_vals) param_vals$c * domain$upper, +# aggr = function(x) 1, disable_in_tune = list()), +# c = p_int() +# )) +# param_set$values$a.c = -1 + +# search_space = ps( +# a.b = p_int(upper = 1000, tags = "internal_tuning", aggr = function(x) 1) +# ) + +# expect_equal( +# param_set$convert_internal_search_space(search_space)$a.b, +# -1000 +# ) +# }) - pscol1 = ParamSetCollection$new(list(b = psb, c = psc)) +# test_that("convert_internal_search_space: nested collections", { +# param_set = psc(a = psc(b = ps(param = p_int( +# in_tune_fn = function(domain, param_vals) domain$upper, tags = "internal_tuning", disable_in_tune = list(), aggr = function(x) 1 +# )))) - pscol2 = ParamSetCollection$new(list(a.b = psa, a = pscol1)) +# search_space = ps( +# a.b.param = p_int(upper = 99, tags = "internal_tuning", aggr = function(x) 1) +# ) - pstest = ParamSet_legacy$new(list(ParamDbl$new("paramc"))) - - expect_error(pscol2$add(pstest, n = "a.c"), "would lead to nameclashes.*a\\.c\\.paramc") - - pstest = ParamSet_legacy$new(list(ParamDbl$new("a.c.paramc"))) - - expect_error(pscol2$add(pstest), "would lead to nameclashes.*a\\.c\\.paramc") - - pscol2$values = list(a.c.paramc = 3, a.b.parama = 1, a.b.paramb = 2) - - expect_equal(psa$values, list(parama = 1)) - expect_equal(psb$values, list(paramb = 2)) - expect_equal(psc$values, list(paramc = 3)) - expect_equal(pscol1$values, list(b.paramb = 2, c.paramc = 3)) - expect_equal(pscol2$values, list(a.b.parama = 1, a.b.paramb = 2, a.c.paramc = 3)) - - expect_error(ParamSetCollection$new(list(a = pscol1, pstest)), - "duplicated parameter.* a\\.c\\.paramc") -}) - -test_that("disable internal tuning works", { - param_set = psc(prefix = ps( - a = p_dbl(aggr = function(x) 1, tags = "internal_tuning", in_tune_fn = function(domain, param_vals) domain$upper, disable_in_tune = list(b = FALSE)), - b = p_lgl() - )) - - param_set$disable_internal_tuning("prefix.a") - expect_equal(param_set$values$prefix.b, FALSE) - expect_error(param_set$disable_internal_tuning("b")) - - expect_equal(named_list(), psc(ps())$disable_internal_tuning(character(0))$values) -}) - -test_that("convert_internal_search_space: depends on other parameter", { - param_set = psc(a = ps( - b = p_int(tags = "internal_tuning", in_tune_fn = function(domain, param_vals) param_vals$c * domain$upper, - aggr = function(x) 1, disable_in_tune = list()), - c = p_int() - )) - param_set$values$a.c = -1 - - search_space = ps( - a.b = p_int(upper = 1000, tags = "internal_tuning", aggr = function(x) 1) - ) - - browser() - expect_equal( - param_set$convert_internal_search_space(search_space)$a.b, - -1000 - ) -}) - -test_that("convert_internal_search_space: nested collections", { - param_set = psc(a = psc(b = ps(param = p_int( - in_tune_fn = function(domain, param_vals) domain$upper, tags = "internal_tuning", disable_in_tune = list(), aggr = function(x) 1 - )))) - - search_space = ps( - a.b.param = p_int(upper = 99, tags = "internal_tuning", aggr = function(x) 1) - ) - - expect_equal( - param_set$convert_internal_search_space(search_space), - list(a.b.param = 99) - ) -}) +# expect_equal( +# param_set$convert_internal_search_space(search_space), +# list(a.b.param = 99) +# ) +# }) test_that("convert_internal_search_space: flattening", { param_set = psc(a = psc(b = ps( @@ -325,8 +324,7 @@ test_that("disable internal tuning: nested collection", { ) ) }) - -test_that("disable internal tuning: flattening", { +test_that("disable internal tuning: nested flattening", { param_set = psc(a = ps( b = p_int( in_tune_fn = function(domain, param_vals) domain$upper, tags = "internal_tuning", From 715a322abcb155a483603157891ef1bf59c081f8 Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Sat, 1 Jun 2024 07:17:49 +0200 Subject: [PATCH 26/34] uncomment tests --- tests/testthat/test_ParamSetCollection.R | 562 +++++++++++------------ 1 file changed, 281 insertions(+), 281 deletions(-) diff --git a/tests/testthat/test_ParamSetCollection.R b/tests/testthat/test_ParamSetCollection.R index 909592ac..2f84d227 100644 --- a/tests/testthat/test_ParamSetCollection.R +++ b/tests/testthat/test_ParamSetCollection.R @@ -1,287 +1,287 @@ -# context("ParamSetCollection") - -# test_that("ParamSet basic stuff works", { -# ps1 = th_paramset_dbl1() -# ps2 = th_paramset_full() -# ps3 = th_paramset_dbl1() -# psc = ParamSetCollection$new(list(s1 = ps1, s2 = ps2, ps3)) - -# ps1clone = ps1$clone(deep = TRUE) -# ps2clone = ps2$clone(deep = TRUE) - -# my_c = function(xs1, xs2, xs3) { -# # littler helper to join to ps-result and prefix names -# ns = c(paste0("s1.", names(xs1)), paste0("s2.", names(xs2)), names(xs3)) -# set_names(c(xs1, xs2, xs3), ns) -# } - -# expect_class(psc, "ParamSetCollection") -# expect_equal(psc$length, ps1$length + ps2$length + ps3$length) -# # check that param internally in collection is constructed correctly -# p = psc$params[2L] -# p$id = "th_param_int" - -# expect_equal(p, ps2$params[1L]) -# expect_equal(psc$ids(), c(paste0("s1.", ps1$ids()), paste0("s2.", ps2$ids()), ps3$ids())) -# expect_equal(psc$lower, my_c(ps1$lower, ps2$lower, ps3$lower)) -# d = as.data.table(psc) -# expect_data_table(d, nrows = 6) -# expect_false(psc$has_deps) -# expect_false(psc$has_trafo) - -# d = as.data.table(psc) -# expect_equal(d$id, c(paste0("s1.", ps1$ids()), paste0("s2.", ps2$ids()), ps3$ids())) - -# expect_true(psc$check(list(s1.th_param_dbl = 1, s2.th_param_int = 2))) -# expect_string(psc$check(list(th_param_int = 2)), fixed = "not avail") -# expect_true(psc$check(list(th_param_dbl = 1))) - -# d = generate_design_random(psc, n = 10L) -# expect_data_table(d$data, nrows = 10, ncols = 6L) - -# psflat = psc$flatten() -# psflat$extra_trafo = function(x, param_set) { -# x$s2.th_param_int = 99 # nolint -# return(x) -# } -# expect_true(psflat$has_trafo) -# d = generate_design_random(psflat, n = 10L) -# expect_data_table(d$data, nrows = 10, ncols = 6L) -# xs = d$transpose(trafo = TRUE) -# for (i in 1:10) { -# x = xs[[i]] -# expect_list(x, len = 6) -# expect_names(names(x), permutation.of = psc$ids()) -# expect_equal(x$s2.th_param_int, 99) -# } - -# # ps1 and ps2 should not be changed -# expect_equal(ps1, ps1clone) -# expect_equal(ps2, ps2clone) - -# expect_output(print(psc), "s1\\.th_param_dbl.*s2\\.th_param_int.*s2\\.th_param_dbl.*s2\\.th_param_fct.*s2\\.th_param_lgl.*th_param_dbl") # nolint - -# # ps1 and ps2 should not be changed by printing -# expect_equal(ps1, ps1clone) -# expect_equal(ps2, ps2clone) - -# # adding a set -# ps4 = ParamSet_legacy$new(list(ParamDbl$new("x"))) -# psc = psc$add(ps4, n = "s4") -# expect_equal(psc$length, ps1$length + ps2$length + ps3$length + ps4$length) -# expect_equal(psc$ids(), c(paste0("s1.", ps1$ids()), paste0("s2.", ps2$ids()), ps3$ids(), paste0("s4.", ps4$ids()))) -# }) - -# test_that("some operations are not allowed", { -# ps1 = th_paramset_dbl1() -# ps2 = th_paramset_full() -# psc = ParamSetCollection$new(list(s1 = ps1, s2 = ps2)) - -# expect_error(psc$subset("foo"), "Must be a subset of") -# }) - -# test_that("deps", { -# ps1 = ParamSet_legacy$new(list( -# ParamFct$new("f", levels = c("a", "b")), -# ParamDbl$new("d") -# )) -# ps1$add_dep("d", on = "f", CondEqual("a")) - -# ps2 = ParamSet_legacy$new(list( -# ParamFct$new("f", levels = c("a", "b")), -# ParamDbl$new("d") -# )) - -# ps1clone = ps1$clone(deep = TRUE) -# ps2clone = ps2$clone(deep = TRUE) - -# psc = ParamSetCollection$new(list(ps1 = ps1, ps2 = ps2)) -# d = psc$deps -# expect_data_table(d, nrows = 1, ncols = 3) -# expect_equal(d$id, c("ps1.d")) - -# # check deps across sets -# psc$add_dep("ps2.d", on = "ps1.f", CondEqual("a")) -# expect_data_table(psc$deps, nrows = 2, ncols = 3) -# expect_true(psc$check(list(ps1.f = "a", ps1.d = 0, ps2.d = 0))) -# expect_string(psc$check(list(ps2.d = 0), check_strict = TRUE)) - -# # ps1 and ps2 should not be changed -# expect_equal(ps1clone, ps1) -# expect_equal(ps2clone, ps2) -# }) - -# test_that("values", { -# ps1 = ParamSet_legacy$new(list( -# ParamFct$new("f", levels = c("a", "b")), -# ParamDbl$new("d", lower = 1, upper = 8) -# )) -# ps2 = ParamSet_legacy$new(list( -# ParamFct$new("f", levels = c("a", "b")), -# ParamDbl$new("d", lower = 1, upper = 8) -# )) -# ps3 = ParamSet_legacy$new(list( -# ParamDbl$new("x", lower = 1, upper = 8) -# )) -# ps4 = ParamSet_legacy$new(list( -# ParamDbl$new("y", lower = 1, upper = 8) -# )) - -# ps1clone = ps1$clone(deep = TRUE) -# ps2clone = ps2$clone(deep = TRUE) - -# pcs = ParamSetCollection$new(list(foo = ps1, bar = ps2, ps3, ps4)) -# expect_equal(pcs$values, named_list()) -# ps2$values = list(d = 3) -# expect_equal(pcs$values, list(bar.d = 3)) -# pcs$values = list(foo.d = 8) -# expect_equal(pcs$values, list(foo.d = 8)) -# expect_equal(ps1$values, list(d = 8)) -# expect_equal(ps2$values, named_list()) -# pcs$values = list(x = 1) -# expect_equal(pcs$values, list(x = 1)) -# expect_equal(ps3$values, list(x = 1)) - -# ps1clone$values$d = 8 -# pcs$values = list(foo.d = 8) -# ps2$values = list() - -# # data table adds indexes at will and comparisons fail because of that, so we have to remove them here. -# setindex(ps1clone$deps, NULL) -# setindex(ps2clone$deps, NULL) -# setindex(ps1$deps, NULL) -# setindex(ps2$deps, NULL) - -# expect_equal(ps1clone, ps1) -# expect_equal(ps2clone, ps2) - -# # resetting pcs values -# pcs$values = list() -# expect_list(pcs$values, len = 0) -# }) - -# test_that("empty collections", { -# # no paramsets -# psc = ParamSetCollection$new(list()) -# expect_equal(psc$length, 0L) -# expect_equal(psc$subspaces(), named_list()) -# expect_equal(psc$ids(), character(0L)) -# expect_data_table(as.data.table(psc), nrows = 0L) - -# # 1 empty paramset -# psc = ParamSetCollection$new(list(ParamSet_legacy$new())) -# expect_equal(psc$length, 0L) -# expect_equal(psc$subspaces(), named_list()) -# expect_equal(psc$ids(), character(0L)) -# expect_data_table(as.data.table(psc), nrows = 0L) -# }) - - -# test_that("no problems if we name the list of sets", { -# ps = ParamSet_legacy$new(list(ParamDbl$new("test1"))) -# psc = ParamSetCollection$new(list(paramset = ps)) -# expect_equal(names(psc$subspaces()), "paramset.test1") -# }) - -# test_that("no warning in printer, see issue 208", { -# ps = ParamSet_legacy$new(list(ParamDbl$new("test1"))) - -# psc = ParamSetCollection$new(list(paramset = ps)) -# psc$values = list(paramset.test1 = 1) -# expect_warning(capture_output(print(ps)), NA) -# }) - -# test_that("collection allows state-change setting of paramvals, see issue 205", { -# ps1 = ParamSet_legacy$new(list(ParamDbl$new("d1"))) -# ps2 = ParamSet_legacy$new(list(ParamDbl$new("d2"))) -# ps3 = ParamSet_legacy$new(list(ParamDbl$new("d3"))) - -# psc = ParamSetCollection$new(list(s1 = ps1, s2 = ps2, ps3)) -# expect_equal(psc$values, named_list()) -# psc$values$s1.d1 = 1 # nolint -# expect_equal(psc$values, list(s1.d1 = 1)) -# psc$values$s2.d2 = 2 # nolint -# expect_equal(psc$values, list(s1.d1 = 1, s2.d2 = 2)) -# psc$values$d3 = 3 -# expect_equal(psc$values, list(s1.d1 = 1, s2.d2 = 2, d3 = 3)) -# }) - -# test_that("set_id inference in values assignment works now", { -# psa = ParamSet_legacy$new(list(ParamDbl$new("parama"))) - -# psb = ParamSet_legacy$new(list(ParamDbl$new("paramb"))) - -# psc = ParamSet_legacy$new(list(ParamDbl$new("paramc"))) - -# pscol1 = ParamSetCollection$new(list(b = psb, c = psc)) - -# pscol2 = ParamSetCollection$new(list(a.b = psa, a = pscol1)) - -# pstest = ParamSet_legacy$new(list(ParamDbl$new("paramc"))) - -# expect_error(pscol2$add(pstest, n = "a.c"), "would lead to nameclashes.*a\\.c\\.paramc") - -# pstest = ParamSet_legacy$new(list(ParamDbl$new("a.c.paramc"))) - -# expect_error(pscol2$add(pstest), "would lead to nameclashes.*a\\.c\\.paramc") - -# pscol2$values = list(a.c.paramc = 3, a.b.parama = 1, a.b.paramb = 2) - -# expect_equal(psa$values, list(parama = 1)) -# expect_equal(psb$values, list(paramb = 2)) -# expect_equal(psc$values, list(paramc = 3)) -# expect_equal(pscol1$values, list(b.paramb = 2, c.paramc = 3)) -# expect_equal(pscol2$values, list(a.b.parama = 1, a.b.paramb = 2, a.c.paramc = 3)) - -# expect_error(ParamSetCollection$new(list(a = pscol1, pstest)), -# "duplicated parameter.* a\\.c\\.paramc") -# }) - -# test_that("disable internal tuning works", { -# param_set = psc(prefix = ps( -# a = p_dbl(aggr = function(x) 1, tags = "internal_tuning", in_tune_fn = function(domain, param_vals) domain$upper, disable_in_tune = list(b = FALSE)), -# b = p_lgl() -# )) - -# param_set$disable_internal_tuning("prefix.a") -# expect_equal(param_set$values$prefix.b, FALSE) -# expect_error(param_set$disable_internal_tuning("b")) - -# expect_equal(named_list(), psc(ps())$disable_internal_tuning(character(0))$values) -# }) - -# test_that("convert_internal_search_space: depends on other parameter", { -# param_set = psc(a = ps( -# b = p_int(tags = "internal_tuning", in_tune_fn = function(domain, param_vals) param_vals$c * domain$upper, -# aggr = function(x) 1, disable_in_tune = list()), -# c = p_int() -# )) -# param_set$values$a.c = -1 - -# search_space = ps( -# a.b = p_int(upper = 1000, tags = "internal_tuning", aggr = function(x) 1) -# ) - -# expect_equal( -# param_set$convert_internal_search_space(search_space)$a.b, -# -1000 -# ) -# }) +context("ParamSetCollection") + +test_that("ParamSet basic stuff works", { + ps1 = th_paramset_dbl1() + ps2 = th_paramset_full() + ps3 = th_paramset_dbl1() + psc = ParamSetCollection$new(list(s1 = ps1, s2 = ps2, ps3)) + + ps1clone = ps1$clone(deep = TRUE) + ps2clone = ps2$clone(deep = TRUE) + + my_c = function(xs1, xs2, xs3) { + # littler helper to join to ps-result and prefix names + ns = c(paste0("s1.", names(xs1)), paste0("s2.", names(xs2)), names(xs3)) + set_names(c(xs1, xs2, xs3), ns) + } + + expect_class(psc, "ParamSetCollection") + expect_equal(psc$length, ps1$length + ps2$length + ps3$length) + # check that param internally in collection is constructed correctly + p = psc$params[2L] + p$id = "th_param_int" + + expect_equal(p, ps2$params[1L]) + expect_equal(psc$ids(), c(paste0("s1.", ps1$ids()), paste0("s2.", ps2$ids()), ps3$ids())) + expect_equal(psc$lower, my_c(ps1$lower, ps2$lower, ps3$lower)) + d = as.data.table(psc) + expect_data_table(d, nrows = 6) + expect_false(psc$has_deps) + expect_false(psc$has_trafo) + + d = as.data.table(psc) + expect_equal(d$id, c(paste0("s1.", ps1$ids()), paste0("s2.", ps2$ids()), ps3$ids())) + + expect_true(psc$check(list(s1.th_param_dbl = 1, s2.th_param_int = 2))) + expect_string(psc$check(list(th_param_int = 2)), fixed = "not avail") + expect_true(psc$check(list(th_param_dbl = 1))) + + d = generate_design_random(psc, n = 10L) + expect_data_table(d$data, nrows = 10, ncols = 6L) + + psflat = psc$flatten() + psflat$extra_trafo = function(x, param_set) { + x$s2.th_param_int = 99 # nolint + return(x) + } + expect_true(psflat$has_trafo) + d = generate_design_random(psflat, n = 10L) + expect_data_table(d$data, nrows = 10, ncols = 6L) + xs = d$transpose(trafo = TRUE) + for (i in 1:10) { + x = xs[[i]] + expect_list(x, len = 6) + expect_names(names(x), permutation.of = psc$ids()) + expect_equal(x$s2.th_param_int, 99) + } + + # ps1 and ps2 should not be changed + expect_equal(ps1, ps1clone) + expect_equal(ps2, ps2clone) + + expect_output(print(psc), "s1\\.th_param_dbl.*s2\\.th_param_int.*s2\\.th_param_dbl.*s2\\.th_param_fct.*s2\\.th_param_lgl.*th_param_dbl") # nolint + + # ps1 and ps2 should not be changed by printing + expect_equal(ps1, ps1clone) + expect_equal(ps2, ps2clone) + + # adding a set + ps4 = ParamSet_legacy$new(list(ParamDbl$new("x"))) + psc = psc$add(ps4, n = "s4") + expect_equal(psc$length, ps1$length + ps2$length + ps3$length + ps4$length) + expect_equal(psc$ids(), c(paste0("s1.", ps1$ids()), paste0("s2.", ps2$ids()), ps3$ids(), paste0("s4.", ps4$ids()))) +}) + +test_that("some operations are not allowed", { + ps1 = th_paramset_dbl1() + ps2 = th_paramset_full() + psc = ParamSetCollection$new(list(s1 = ps1, s2 = ps2)) + + expect_error(psc$subset("foo"), "Must be a subset of") +}) + +test_that("deps", { + ps1 = ParamSet_legacy$new(list( + ParamFct$new("f", levels = c("a", "b")), + ParamDbl$new("d") + )) + ps1$add_dep("d", on = "f", CondEqual("a")) + + ps2 = ParamSet_legacy$new(list( + ParamFct$new("f", levels = c("a", "b")), + ParamDbl$new("d") + )) + + ps1clone = ps1$clone(deep = TRUE) + ps2clone = ps2$clone(deep = TRUE) + + psc = ParamSetCollection$new(list(ps1 = ps1, ps2 = ps2)) + d = psc$deps + expect_data_table(d, nrows = 1, ncols = 3) + expect_equal(d$id, c("ps1.d")) + + # check deps across sets + psc$add_dep("ps2.d", on = "ps1.f", CondEqual("a")) + expect_data_table(psc$deps, nrows = 2, ncols = 3) + expect_true(psc$check(list(ps1.f = "a", ps1.d = 0, ps2.d = 0))) + expect_string(psc$check(list(ps2.d = 0), check_strict = TRUE)) + + # ps1 and ps2 should not be changed + expect_equal(ps1clone, ps1) + expect_equal(ps2clone, ps2) +}) + +test_that("values", { + ps1 = ParamSet_legacy$new(list( + ParamFct$new("f", levels = c("a", "b")), + ParamDbl$new("d", lower = 1, upper = 8) + )) + ps2 = ParamSet_legacy$new(list( + ParamFct$new("f", levels = c("a", "b")), + ParamDbl$new("d", lower = 1, upper = 8) + )) + ps3 = ParamSet_legacy$new(list( + ParamDbl$new("x", lower = 1, upper = 8) + )) + ps4 = ParamSet_legacy$new(list( + ParamDbl$new("y", lower = 1, upper = 8) + )) + + ps1clone = ps1$clone(deep = TRUE) + ps2clone = ps2$clone(deep = TRUE) + + pcs = ParamSetCollection$new(list(foo = ps1, bar = ps2, ps3, ps4)) + expect_equal(pcs$values, named_list()) + ps2$values = list(d = 3) + expect_equal(pcs$values, list(bar.d = 3)) + pcs$values = list(foo.d = 8) + expect_equal(pcs$values, list(foo.d = 8)) + expect_equal(ps1$values, list(d = 8)) + expect_equal(ps2$values, named_list()) + pcs$values = list(x = 1) + expect_equal(pcs$values, list(x = 1)) + expect_equal(ps3$values, list(x = 1)) + + ps1clone$values$d = 8 + pcs$values = list(foo.d = 8) + ps2$values = list() + + # data table adds indexes at will and comparisons fail because of that, so we have to remove them here. + setindex(ps1clone$deps, NULL) + setindex(ps2clone$deps, NULL) + setindex(ps1$deps, NULL) + setindex(ps2$deps, NULL) + + expect_equal(ps1clone, ps1) + expect_equal(ps2clone, ps2) + + # resetting pcs values + pcs$values = list() + expect_list(pcs$values, len = 0) +}) + +test_that("empty collections", { + # no paramsets + psc = ParamSetCollection$new(list()) + expect_equal(psc$length, 0L) + expect_equal(psc$subspaces(), named_list()) + expect_equal(psc$ids(), character(0L)) + expect_data_table(as.data.table(psc), nrows = 0L) + + # 1 empty paramset + psc = ParamSetCollection$new(list(ParamSet_legacy$new())) + expect_equal(psc$length, 0L) + expect_equal(psc$subspaces(), named_list()) + expect_equal(psc$ids(), character(0L)) + expect_data_table(as.data.table(psc), nrows = 0L) +}) + + +test_that("no problems if we name the list of sets", { + ps = ParamSet_legacy$new(list(ParamDbl$new("test1"))) + psc = ParamSetCollection$new(list(paramset = ps)) + expect_equal(names(psc$subspaces()), "paramset.test1") +}) + +test_that("no warning in printer, see issue 208", { + ps = ParamSet_legacy$new(list(ParamDbl$new("test1"))) + + psc = ParamSetCollection$new(list(paramset = ps)) + psc$values = list(paramset.test1 = 1) + expect_warning(capture_output(print(ps)), NA) +}) + +test_that("collection allows state-change setting of paramvals, see issue 205", { + ps1 = ParamSet_legacy$new(list(ParamDbl$new("d1"))) + ps2 = ParamSet_legacy$new(list(ParamDbl$new("d2"))) + ps3 = ParamSet_legacy$new(list(ParamDbl$new("d3"))) + + psc = ParamSetCollection$new(list(s1 = ps1, s2 = ps2, ps3)) + expect_equal(psc$values, named_list()) + psc$values$s1.d1 = 1 # nolint + expect_equal(psc$values, list(s1.d1 = 1)) + psc$values$s2.d2 = 2 # nolint + expect_equal(psc$values, list(s1.d1 = 1, s2.d2 = 2)) + psc$values$d3 = 3 + expect_equal(psc$values, list(s1.d1 = 1, s2.d2 = 2, d3 = 3)) +}) + +test_that("set_id inference in values assignment works now", { + psa = ParamSet_legacy$new(list(ParamDbl$new("parama"))) + + psb = ParamSet_legacy$new(list(ParamDbl$new("paramb"))) + + psc = ParamSet_legacy$new(list(ParamDbl$new("paramc"))) -# test_that("convert_internal_search_space: nested collections", { -# param_set = psc(a = psc(b = ps(param = p_int( -# in_tune_fn = function(domain, param_vals) domain$upper, tags = "internal_tuning", disable_in_tune = list(), aggr = function(x) 1 -# )))) + pscol1 = ParamSetCollection$new(list(b = psb, c = psc)) -# search_space = ps( -# a.b.param = p_int(upper = 99, tags = "internal_tuning", aggr = function(x) 1) -# ) + pscol2 = ParamSetCollection$new(list(a.b = psa, a = pscol1)) -# expect_equal( -# param_set$convert_internal_search_space(search_space), -# list(a.b.param = 99) -# ) -# }) + pstest = ParamSet_legacy$new(list(ParamDbl$new("paramc"))) + + expect_error(pscol2$add(pstest, n = "a.c"), "would lead to nameclashes.*a\\.c\\.paramc") + + pstest = ParamSet_legacy$new(list(ParamDbl$new("a.c.paramc"))) + + expect_error(pscol2$add(pstest), "would lead to nameclashes.*a\\.c\\.paramc") + + pscol2$values = list(a.c.paramc = 3, a.b.parama = 1, a.b.paramb = 2) + + expect_equal(psa$values, list(parama = 1)) + expect_equal(psb$values, list(paramb = 2)) + expect_equal(psc$values, list(paramc = 3)) + expect_equal(pscol1$values, list(b.paramb = 2, c.paramc = 3)) + expect_equal(pscol2$values, list(a.b.parama = 1, a.b.paramb = 2, a.c.paramc = 3)) + + expect_error(ParamSetCollection$new(list(a = pscol1, pstest)), + "duplicated parameter.* a\\.c\\.paramc") +}) + +test_that("disable internal tuning works", { + param_set = psc(prefix = ps( + a = p_dbl(aggr = function(x) 1, tags = "internal_tuning", in_tune_fn = function(domain, param_vals) domain$upper, disable_in_tune = list(b = FALSE)), + b = p_lgl() + )) + + param_set$disable_internal_tuning("prefix.a") + expect_equal(param_set$values$prefix.b, FALSE) + expect_error(param_set$disable_internal_tuning("b")) + + expect_equal(named_list(), psc(ps())$disable_internal_tuning(character(0))$values) +}) + +test_that("convert_internal_search_space: depends on other parameter", { + param_set = psc(a = ps( + b = p_int(tags = "internal_tuning", in_tune_fn = function(domain, param_vals) param_vals$c * domain$upper, + aggr = function(x) 1, disable_in_tune = list()), + c = p_int() + )) + param_set$values$a.c = -1 + + search_space = ps( + a.b = p_int(upper = 1000, tags = "internal_tuning", aggr = function(x) 1) + ) + + expect_equal( + param_set$convert_internal_search_space(search_space)$a.b, + -1000 + ) +}) + +test_that("convert_internal_search_space: nested collections", { + param_set = psc(a = psc(b = ps(param = p_int( + in_tune_fn = function(domain, param_vals) domain$upper, tags = "internal_tuning", disable_in_tune = list(), aggr = function(x) 1 + )))) + + search_space = ps( + a.b.param = p_int(upper = 99, tags = "internal_tuning", aggr = function(x) 1) + ) + + expect_equal( + param_set$convert_internal_search_space(search_space), + list(a.b.param = 99) + ) +}) test_that("convert_internal_search_space: flattening", { param_set = psc(a = psc(b = ps( From 134d4fb86ad9f9221f0e6208c0275384eb8fb288 Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Sat, 1 Jun 2024 07:18:36 +0200 Subject: [PATCH 27/34] remove dead comments --- R/ParamSetCollection.R | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/R/ParamSetCollection.R b/R/ParamSetCollection.R index 26995f08..79dce66a 100644 --- a/R/ParamSetCollection.R +++ b/R/ParamSetCollection.R @@ -171,26 +171,6 @@ ParamSetCollection = R6Class("ParamSetCollection", inherit = ParamSet, self$set_values(.values = pvs) }, - # #' @description - # #' Convert all parameters from the search space to parameter values using the transformation given by - # #' `in_tune_fn`. - # #' @param search_space ([`ParamSet`])\cr - # #' The internal search space. - # #' @return (named `list()`) - # convert_internal_search_space = function(search_space) { - # assert_class(search_space, "ParamSet") - - - - - # param_vals = self$values - - # Reduce(c, imap(private$.sets, function(set, prefix) { - - # set$convert_internal_search_space() - - # })) %??% named_list() - # }, #' @description #' Convert all parameters from the search space to parameter values using the transformation given by #' `in_tune_fn`. From b75465aabb5e510f28b80de43eb668e1fde4c875 Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Sat, 1 Jun 2024 07:39:21 +0200 Subject: [PATCH 28/34] fix disable in tune for nested psc --- R/ParamSetCollection.R | 24 ++++++++++++++++++------ tests/testthat/test_ParamSetCollection.R | 14 +++++++++----- 2 files changed, 27 insertions(+), 11 deletions(-) diff --git a/R/ParamSetCollection.R b/R/ParamSetCollection.R index 79dce66a..9131aede 100644 --- a/R/ParamSetCollection.R +++ b/R/ParamSetCollection.R @@ -149,6 +149,7 @@ ParamSetCollection = R6Class("ParamSetCollection", inherit = ParamSet, private$.sets[[n]] = p invisible(self) }, + #' @description #' #' Set the parameter values so that internal tuning for the selected parameters is disabled. @@ -159,14 +160,25 @@ ParamSetCollection = R6Class("ParamSetCollection", inherit = ParamSet, disable_internal_tuning = function(ids) { assert_subset(ids, self$ids(tags = "internal_tuning")) - pvs = Reduce(c, map(ids, function(id_) { - info = private$.translation[id_, c("original_id", "owner_name"), on = "id"] - xs = get_private(private$.sets[[info$owner_name]])$.params[ - info$original_id, "cargo", on = "id"][[1L]][[1]]$disable_in_tune + full_prefix = function(param_set, id_, prefix = "") { + info = get_private(param_set)$.translation[id_, c("owner_name", "original_id", "owner_ps_index"), on = "id"] + subset = get_private(param_set)$.sets[[info$owner_ps_index]] + prefix = if (info$owner_name == "") { + prefix + } else if (prefix == "") { + info$owner_name + } else { + paste0(prefix, ".", info$owner_name) + } - if (info$owner_name == "" || is.null(xs)) return(xs) + if (!test_class(subset, "ParamSetCollection")) return(prefix) - set_names(xs, paste0(info$owner_name, ".", names(xs))) + full_prefix(subset, info$original_id, prefix) + } + + pvs = Reduce(c, map(ids, function(id_) { + xs = private$.params[list(id_), "cargo", on = "id"][[1]][[1]]$disable_in_tune + set_names(xs, paste0(full_prefix(self, id_), ".", names(xs))) })) %??% named_list() self$set_values(.values = pvs) }, diff --git a/tests/testthat/test_ParamSetCollection.R b/tests/testthat/test_ParamSetCollection.R index 2f84d227..d75c5a62 100644 --- a/tests/testthat/test_ParamSetCollection.R +++ b/tests/testthat/test_ParamSetCollection.R @@ -317,12 +317,16 @@ test_that("disable internal tuning: single collection", { }) test_that("disable internal tuning: nested collection", { - param_set = ps( - a = p_int( + param_set = psc(alpha = psc(a = ps( + b = p_int( in_tune_fn = function(domain, param_vals) domain$upper, tags = "internal_tuning", - disable_in_tune = list(), aggr = function(x) 1 - ) - ) + disable_in_tune = list(c = TRUE), aggr = function(x) 1 + ), + c = p_lgl() + ))) + + param_set$disable_internal_tuning("alpha.a.b") + expect_equal(param_set$values$alpha.a.c, TRUE) }) test_that("disable internal tuning: nested flattening", { param_set = psc(a = ps( From fd6270a3212cb5d79d126489f6157df7aba2ed5f Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Sat, 1 Jun 2024 07:42:55 +0200 Subject: [PATCH 29/34] make tests less confusing --- tests/testthat/test_ParamSetCollection.R | 14 +++++++------- tests/testthat/test_to_tune.R | 3 --- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/tests/testthat/test_ParamSetCollection.R b/tests/testthat/test_ParamSetCollection.R index d75c5a62..6c959461 100644 --- a/tests/testthat/test_ParamSetCollection.R +++ b/tests/testthat/test_ParamSetCollection.R @@ -239,7 +239,7 @@ test_that("set_id inference in values assignment works now", { test_that("disable internal tuning works", { param_set = psc(prefix = ps( - a = p_dbl(aggr = function(x) 1, tags = "internal_tuning", in_tune_fn = function(domain, param_vals) domain$upper, disable_in_tune = list(b = FALSE)), + a = p_dbl(aggr = function(x) 10, tags = "internal_tuning", in_tune_fn = function(domain, param_vals) domain$upper, disable_in_tune = list(b = FALSE)), b = p_lgl() )) @@ -253,28 +253,28 @@ test_that("disable internal tuning works", { test_that("convert_internal_search_space: depends on other parameter", { param_set = psc(a = ps( b = p_int(tags = "internal_tuning", in_tune_fn = function(domain, param_vals) param_vals$c * domain$upper, - aggr = function(x) 1, disable_in_tune = list()), + aggr = function(x) 10, disable_in_tune = list()), c = p_int() )) param_set$values$a.c = -1 search_space = ps( - a.b = p_int(upper = 1000, tags = "internal_tuning", aggr = function(x) 1) + a.b = p_int(upper = 1000, tags = "internal_tuning", aggr = function(x) 10) ) expect_equal( - param_set$convert_internal_search_space(search_space)$a.b, + param_set$convert_internal_search_space(search_space)$a.b, -1000 ) }) test_that("convert_internal_search_space: nested collections", { param_set = psc(a = psc(b = ps(param = p_int( - in_tune_fn = function(domain, param_vals) domain$upper, tags = "internal_tuning", disable_in_tune = list(), aggr = function(x) 1 + in_tune_fn = function(domain, param_vals) domain$upper, tags = "internal_tuning", disable_in_tune = list(), aggr = function(x) 10 )))) search_space = ps( - a.b.param = p_int(upper = 99, tags = "internal_tuning", aggr = function(x) 1) + a.b.param = p_int(upper = 99, tags = "internal_tuning", aggr = function(x) 10) ) expect_equal( @@ -287,7 +287,7 @@ test_that("convert_internal_search_space: flattening", { param_set = psc(a = psc(b = ps( param = p_int( in_tune_fn = function(domain, param_vals) domain$upper * param_vals$other_param, tags = "internal_tuning", - disable_in_tune = list(), aggr = function(x) 1), + disable_in_tune = list(), aggr = function(x) 10), other_param = p_int() ))) diff --git a/tests/testthat/test_to_tune.R b/tests/testthat/test_to_tune.R index 9b797277..1dca0c35 100644 --- a/tests/testthat/test_to_tune.R +++ b/tests/testthat/test_to_tune.R @@ -401,7 +401,6 @@ test_that("logscale in tunetoken", { test_that("internal and aggr", { - # no default aggregation function param_set = ps(a = p_dbl(lower = 1, upper = 2, tags = "internal_tuning", in_tune_fn = function(domain, param_vals) domain$upper, disable_in_tune = list(), aggr = function(x) round(mean(unlist(x)))) ) @@ -441,8 +440,6 @@ test_that("internal and aggr", { "specify lower and upper" ) - ## with default aggregation function - # param set + internal param_set = ps(a = p_int(lower = 1, upper = 10000, tags = "internal_tuning", in_tune_fn = function(domain, param_vals) domain$upper, aggr = function(x) max(unlist(x)), disable_in_tune = list())) From 6fd11c095d901d9cd832ec3ac5e87db287a55332 Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Sat, 1 Jun 2024 07:52:54 +0200 Subject: [PATCH 30/34] fix one more bug --- R/ParamLgl.R | 4 +++- R/ParamSetCollection.R | 2 ++ tests/testthat/test_ParamSetCollection.R | 13 +++++++++++++ 3 files changed, 18 insertions(+), 1 deletion(-) diff --git a/R/ParamLgl.R b/R/ParamLgl.R index ca555271..016d4a7d 100644 --- a/R/ParamLgl.R +++ b/R/ParamLgl.R @@ -1,7 +1,9 @@ #' @rdname Domain #' @export p_lgl = function(special_vals = list(), default = NO_DEF, tags = character(), depends = NULL, trafo = NULL, init, aggr = NULL, in_tune_fn = NULL, disable_in_tune = NULL) { - cargo = c(aggr = aggr, in_tune_fn = in_tune_fn) + cargo = list() + cargo$aggr = aggr + cargo$in_tune_fn = in_tune_fn cargo$disable_in_tune = disable_in_tune Domain(cls = "ParamLgl", grouping = "ParamLgl", levels = c(TRUE, FALSE), special_vals = special_vals, default = default, tags = tags, trafo = trafo, storage_type = "logical", depends_expr = substitute(depends), init = init, cargo = if (length(cargo)) cargo) diff --git a/R/ParamSetCollection.R b/R/ParamSetCollection.R index 9131aede..48ee2a38 100644 --- a/R/ParamSetCollection.R +++ b/R/ParamSetCollection.R @@ -178,6 +178,8 @@ ParamSetCollection = R6Class("ParamSetCollection", inherit = ParamSet, pvs = Reduce(c, map(ids, function(id_) { xs = private$.params[list(id_), "cargo", on = "id"][[1]][[1]]$disable_in_tune + prefix = full_prefix(self, id_) + if (prefix == "") return(xs) set_names(xs, paste0(full_prefix(self, id_), ".", names(xs))) })) %??% named_list() self$set_values(.values = pvs) diff --git a/tests/testthat/test_ParamSetCollection.R b/tests/testthat/test_ParamSetCollection.R index 6c959461..89f7b713 100644 --- a/tests/testthat/test_ParamSetCollection.R +++ b/tests/testthat/test_ParamSetCollection.R @@ -356,3 +356,16 @@ test_that("disable internal tuning: nested flattening", { 1 ) }) + +test_that("disable internal tuning without set names", { + param_set = psc(ps( + a = p_int( + in_tune_fn = function(domain, param_vals) domain$upper, tags = "internal_tuning", + disable_in_tune = list(b = TRUE), aggr = function(x) 1 + ), + b = p_lgl() + )) + + param_set$disable_internal_tuning("a") + expect_equal(param_set$values$b, TRUE) +}) From f7a6754cb7e14b9cb92cc219530b105625812875 Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Sat, 1 Jun 2024 08:01:29 +0200 Subject: [PATCH 31/34] better example --- R/Domain.R | 16 ++++++++++++---- man/Domain.Rd | 24 ++++++++++++++++++------ man/ParamSet.Rd | 10 +++++----- man/ParamSetCollection.Rd | 38 +++++++++++++++++++++++++++++++++++--- 4 files changed, 70 insertions(+), 18 deletions(-) diff --git a/R/Domain.R b/R/Domain.R index 5183af66..8a6ddd31 100644 --- a/R/Domain.R +++ b/R/Domain.R @@ -68,7 +68,7 @@ #' Function with one argument, which is a list of parameter values and that returns the aggregated parameter value. #' @param in_tune_fn (`function(domain, param_vals)`)\cr #' Function that converters a `Domain` object into a parameter value. -#' Can onlye be given for parameters tagged with `"internal_tuning"`. +#' Can only be given for parameters tagged with `"internal_tuning"`. #' This function should also assert that the parameters required to enable internal tuning for the given `domain` are #' set in `param_vals` (such as `early_stopping_rounds` for `XGBoost`). #' @param disable_in_tune (named `list()`)\cr @@ -133,17 +133,25 @@ #' #' param_set = ps( #' iters = p_int(0, Inf, tags = "internal_tuning", aggr = function(x) round(mean(unlist(x))), -#' in_tune_fn = function(domain, param_vals) domain$upper, -#' disable_in_tune = list(other_param = FALSE)) +#' in_tune_fn = function(domain, param_vals) { +#' stopifnot(domain$lower <= 1) +#' stopifnot(param_vals$early_stopping == TRUE) +#' domain$upper +#' }, +#' disable_in_tune = list(early_stopping = FALSE)), +#' early_stopping = p_lgl() #' ) #' param_set$set_values( -#' iters = to_tune(upper = 100, internal = TRUE) +#' iters = to_tune(upper = 100, internal = TRUE), +#' early_stopping = TRUE #' ) #' param_set$convert_internal_search_space(param_set$search_space()) #' param_set$aggr_internal_tuned_values( #' list(iters = list(1, 2, 3)) #' ) #' +#' param_set$disable_internal_tuning("iters") +#' param_set$values$early_stopping #' @family ParamSet construction helpers #' @name Domain NULL diff --git a/man/Domain.Rd b/man/Domain.Rd index e63cc3ec..48c4e42f 100644 --- a/man/Domain.Rd +++ b/man/Domain.Rd @@ -158,9 +158,11 @@ value upon construction.} Default aggregation function for a parameter. Can only be given for parameters tagged with \code{"internal_tuning"}. Function with one argument, which is a list of parameter values and that returns the aggregated parameter value.} -\item{in_tune_fn}{(\verb{function(domain, param_set)})\cr +\item{in_tune_fn}{(\verb{function(domain, param_vals)})\cr Function that converters a \code{Domain} object into a parameter value. -Can onlye be given for parameters tagged with \code{"internal_tuning"}.} +Can only be given for parameters tagged with \code{"internal_tuning"}. +This function should also assert that the parameters required to enable internal tuning for the given \code{domain} are +set in \code{param_vals} (such as \code{early_stopping_rounds} for \code{XGBoost}).} \item{disable_in_tune}{(named \code{list()})\cr The parameter values that need to be set in the \code{ParamSet} to disable the internal tuning for the parameter. @@ -254,15 +256,25 @@ print(grid$transpose()) param_set = ps( iters = p_int(0, Inf, tags = "internal_tuning", aggr = function(x) round(mean(unlist(x))), - in_tune_fn = function(domain, param_set) domain$upper, - disable_in_tune = list(other_param = FALSE)) + in_tune_fn = function(domain, param_vals) { + stopifnot(domain$lower <= 1) + stopifnot(param_vals$early_stopping == TRUE) + domain$upper + }, + disable_in_tune = list(early_stopping = FALSE)), + early_stopping = p_lgl() ) param_set$set_values( - iters = to_tune(upper = 100, internal = TRUE) + iters = to_tune(upper = 100, internal = TRUE), + early_stopping = TRUE ) param_set$convert_internal_search_space(param_set$search_space()) -param_set$aggr(list(iters = list(1, 2, 3))) +param_set$aggr_internal_tuned_values( + list(iters = list(1, 2, 3)) +) +param_set$disable_internal_tuning("iters") +param_set$values$early_stopping } \seealso{ Other ParamSet construction helpers: diff --git a/man/ParamSet.Rd b/man/ParamSet.Rd index e4cc954f..a4ad34e7 100644 --- a/man/ParamSet.Rd +++ b/man/ParamSet.Rd @@ -172,7 +172,7 @@ Named with param IDs.} \item \href{#method-ParamSet-get_values}{\code{ParamSet$get_values()}} \item \href{#method-ParamSet-set_values}{\code{ParamSet$set_values()}} \item \href{#method-ParamSet-trafo}{\code{ParamSet$trafo()}} -\item \href{#method-ParamSet-aggr}{\code{ParamSet$aggr()}} +\item \href{#method-ParamSet-aggr_internal_tuned_values}{\code{ParamSet$aggr_internal_tuned_values()}} \item \href{#method-ParamSet-disable_internal_tuning}{\code{ParamSet$disable_internal_tuning()}} \item \href{#method-ParamSet-convert_internal_search_space}{\code{ParamSet$convert_internal_search_space()}} \item \href{#method-ParamSet-test_constraint}{\code{ParamSet$test_constraint()}} @@ -343,12 +343,12 @@ In almost all cases, the default \code{param_set = self} should be used.} } } \if{html}{\out{
    }} -\if{html}{\out{}} -\if{latex}{\out{\hypertarget{method-ParamSet-aggr}{}}} -\subsection{Method \code{aggr()}}{ +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ParamSet-aggr_internal_tuned_values}{}}} +\subsection{Method \code{aggr_internal_tuned_values()}}{ Aggregate parameter values according to their aggregation rules. \subsection{Usage}{ -\if{html}{\out{
    }}\preformatted{ParamSet$aggr(x)}\if{html}{\out{
    }} +\if{html}{\out{
    }}\preformatted{ParamSet$aggr_internal_tuned_values(x)}\if{html}{\out{
    }} } \subsection{Arguments}{ diff --git a/man/ParamSetCollection.Rd b/man/ParamSetCollection.Rd index a1999df6..01c10e7f 100644 --- a/man/ParamSetCollection.Rd +++ b/man/ParamSetCollection.Rd @@ -73,6 +73,8 @@ This field provides direct references to the \code{\link{ParamSet}} objects.} \item \href{#method-ParamSetCollection-new}{\code{ParamSetCollection$new()}} \item \href{#method-ParamSetCollection-add}{\code{ParamSetCollection$add()}} \item \href{#method-ParamSetCollection-disable_internal_tuning}{\code{ParamSetCollection$disable_internal_tuning()}} +\item \href{#method-ParamSetCollection-convert_internal_search_space}{\code{ParamSetCollection$convert_internal_search_space()}} +\item \href{#method-ParamSetCollection-flatten}{\code{ParamSetCollection$flatten()}} \item \href{#method-ParamSetCollection-clone}{\code{ParamSetCollection$clone()}} } } @@ -80,14 +82,12 @@ This field provides direct references to the \code{\link{ParamSet}} objects.}
    Inherited methods
    • paradox::ParamSet$add_dep()
    • -
    • paradox::ParamSet$aggr()
    • +
    • paradox::ParamSet$aggr_internal_tuned_values()
    • paradox::ParamSet$assert()
    • paradox::ParamSet$assert_dt()
    • paradox::ParamSet$check()
    • paradox::ParamSet$check_dependencies()
    • paradox::ParamSet$check_dt()
    • -
    • paradox::ParamSet$convert_internal_search_space()
    • -
    • paradox::ParamSet$flatten()
    • paradox::ParamSet$format()
    • paradox::ParamSet$get_domain()
    • paradox::ParamSet$get_values()
    • @@ -177,6 +177,38 @@ The ids of the parameters for which to disable internal tuning.} \subsection{Returns}{ \code{Self} } +} +\if{html}{\out{
      }} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ParamSetCollection-convert_internal_search_space}{}}} +\subsection{Method \code{convert_internal_search_space()}}{ +Convert all parameters from the search space to parameter values using the transformation given by +\code{in_tune_fn}. +\subsection{Usage}{ +\if{html}{\out{
      }}\preformatted{ParamSetCollection$convert_internal_search_space(search_space)}\if{html}{\out{
      }} +} + +\subsection{Arguments}{ +\if{html}{\out{
      }} +\describe{ +\item{\code{search_space}}{(\code{\link{ParamSet}})\cr +The internal search space.} +} +\if{html}{\out{
      }} +} +\subsection{Returns}{ +(named \code{list()}) +} +} +\if{html}{\out{
      }} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ParamSetCollection-flatten}{}}} +\subsection{Method \code{flatten()}}{ +Create a \code{ParamSet} from this \code{ParamSetCollection}. +\subsection{Usage}{ +\if{html}{\out{
      }}\preformatted{ParamSetCollection$flatten()}\if{html}{\out{
      }} +} + } \if{html}{\out{
      }} \if{html}{\out{}} From 949bbe84ca8dac27317292f8df54a1193f2ffbe7 Mon Sep 17 00:00:00 2001 From: mb706 Date: Mon, 10 Jun 2024 19:52:38 +0200 Subject: [PATCH 32/34] NEWS.md --- NEWS.md | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/NEWS.md b/NEWS.md index d2e26c6f..6ce734e4 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,8 +1,5 @@ -# dev +# paradox 1.0.0 -* feat: added support for `InternalTuneToken`s - -# paradox 0.12.0 * Removed `Param` objects. `ParamSet` now uses a `data.table` internally; individual parameters are more like `Domain` objects now. `ParamSets` should be constructed using the `ps()` shorthand and `Domain` objects. This entails the following major changes: * `ParamSet` now supports `extra_trafo` natively; it behaves like `.extra_trafo` of the `ps()` call. * `ParamSet` has `$constraint` @@ -11,6 +8,7 @@ * `Condition` objects are now S3 objects and can be constructed with `CondEqual()` and `CondAnyOf()`, instead of `CondXyz$new()`. (It is recommended to use the `Domain` interface for conditions, which has not changed) * `ParamSet` has new fields `$is_logscale`, `$has_trafo_param` (per-param), and `$has_trafo_param` (scalar for the whole set). * Added a vignette which was previously a chapter in the `mlr3book` +* feat: added support for `InternalTuneToken`s # paradox 0.11.1 From 18e40888455f06e2a7b02f617ace46c39d4a9423 Mon Sep 17 00:00:00 2001 From: mb706 Date: Mon, 10 Jun 2024 20:04:27 +0200 Subject: [PATCH 33/34] keep diff small --- R/Design.R | 1 - R/Domain.R | 2 ++ R/ParamSet.R | 57 ++++++++++++++++++++-------------------------------- 3 files changed, 24 insertions(+), 36 deletions(-) diff --git a/R/Design.R b/R/Design.R index 3ca3b465..41a6b425 100644 --- a/R/Design.R +++ b/R/Design.R @@ -34,7 +34,6 @@ Design = R6Class("Design", # set fixed param vals to their constant values # FIXME: this might also be problematic for LHS # do we still create an LHS like this? - imap(param_set$values, function(v, n) set(data, j = n, value = v)) self$data = data if (param_set$has_deps) { diff --git a/R/Domain.R b/R/Domain.R index 8a6ddd31..b7dee1e8 100644 --- a/R/Domain.R +++ b/R/Domain.R @@ -197,6 +197,7 @@ Domain = function(cls, grouping, assert_character(tags, any.missing = FALSE, unique = TRUE) assert_function(trafo, null.ok = TRUE) + # depends may be an expression, but may also be quote() or expression() if (length(depends_expr) == 1) { depends_expr = eval(depends_expr, envir = parent.frame(2)) @@ -217,6 +218,7 @@ Domain = function(cls, grouping, .tags = list(tags), .trafo = list(trafo), .requirements = list(parse_depends(depends_expr, parent.frame(2))), + .init_given = !missing(init), .init = list(if (!missing(init)) init) ) diff --git a/R/ParamSet.R b/R/ParamSet.R index 9d6bb8fc..c7003361 100644 --- a/R/ParamSet.R +++ b/R/ParamSet.R @@ -92,7 +92,7 @@ ParamSet = R6Class("ParamSet", if (".requirements" %in% names(paramtbl)) { requirements = paramtbl$.requirements - private$.params = paramtbl # self$add_dep needs this + private$.params = paramtbl # self$add_dep needs this for (row in seq_len(nrow(paramtbl))) { for (req in requirements[[row]]) { invoke(self$add_dep, id = paramtbl$id[[row]], allow_dangling_dependencies = allow_dangling_dependencies, @@ -107,7 +107,7 @@ ParamSet = R6Class("ParamSet", setindexv(paramtbl, c("id", "cls", "grouping")) - private$.params = paramtbl # I am 99% sure this is not necessary, but maybe set() creates a copy when deleting too many cols? + private$.params = paramtbl # I am 99% sure this is not necessary, but maybe set() creates a copy when deleting too many cols? if (!is.null(initvalues)) self$values = initvalues }, @@ -379,9 +379,7 @@ ParamSet = R6Class("ParamSet", private$get_tune_ps(xs) TRUE }, error = function(e) paste("tune token invalid:", conditionMessage(e))) - if (!isTRUE(tunecheck)) { - return(tunecheck) - } + if (!isTRUE(tunecheck)) return(tunecheck) } xs_internaltune = keep(xs, is, "InternalTuneToken") @@ -413,9 +411,7 @@ ParamSet = R6Class("ParamSet", ## if (length(required) > 0L) { ## return(sprintf("Missing required parameters: %s", str_collapse(required))) ## } - if (!self$test_constraint(xs, assert_value = FALSE)) { - return(sprintf("Constraint not fulfilled.")) - } + if (!self$test_constraint(xs, assert_value = FALSE)) return(sprintf("Constraint not fulfilled.")) return(self$check_dependencies(xs)) } @@ -430,37 +426,29 @@ ParamSet = R6Class("ParamSet", #' @return If successful `TRUE`, if not a string with an error message. check_dependencies = function(xs) { deps = self$deps - if (!nrow(deps)) { - return(TRUE) - } + if (!nrow(deps)) return(TRUE) params = private$.params ns = names(xs) errors = pmap(deps[id %in% ns], function(id, on, cond) { onval = xs[[on]] - if (inherits(xs[[id]], "TuneToken") || inherits(onval, "TuneToken")) { - return(NULL) - } + if (inherits(xs[[id]], "TuneToken") || inherits(onval, "TuneToken")) return(NULL) # we are ONLY ok if: # - if 'id' is there, then 'on' must be there, and cond must be true # - if 'id' is not there. but that is skipped (deps[id %in% ns] filter) - if (on %in% ns && condition_test(cond, onval)) { - return(NULL) - } + if (on %in% ns && condition_test(cond, onval)) return(NULL) msg = sprintf("%s: can only be set if the following condition is met '%s'.", id, condition_as_string(cond, on)) if (is.null(onval)) { msg = sprintf(paste("%s Instead the parameter value for '%s' is not set at all.", - "Try setting '%s' to a value that satisfies the condition"), msg, on, on) + "Try setting '%s' to a value that satisfies the condition"), msg, on, on) } else { msg = sprintf("%s Instead the current parameter value is: %s == %s", msg, on, as_short_string(onval)) } msg }) errors = unlist(errors) - if (!length(errors)) { - return(TRUE) - } + if (!length(errors)) return(TRUE) str_collapse(errors, sep = "\n") }, @@ -491,7 +479,7 @@ ParamSet = R6Class("ParamSet", #' Name of the checked object to print in error messages.\cr #' Defaults to the heuristic implemented in [vname][checkmate::vname]. #' @return If successful `xs` invisibly, if not an error message. - assert = function(xs, check_strict = TRUE, .var.name = vname(xs)) makeAssertion(xs, self$check(xs, check_strict = check_strict), .var.name, NULL), # nolint + assert = function(xs, check_strict = TRUE, .var.name = vname(xs)) makeAssertion(xs, self$check(xs, check_strict = check_strict), .var.name, NULL), # nolint #' @description #' \pkg{checkmate}-like check-function. Takes a [data.table::data.table] @@ -579,9 +567,10 @@ ParamSet = R6Class("ParamSet", paramrow[, `:=`( .tags = list(private$.tags[id, tag, nomatch = 0]), .trafo = private$.trafos[id, trafo], - .requirements = list(if (nrow(depstbl)) transpose_list(depstbl)), # NULL if no deps + .requirements = list(if (nrow(depstbl)) transpose_list(depstbl)), # NULL if no deps .init_given = id %in% names(vals), - .init = unname(vals[id]))] + .init = unname(vals[id])) + ] set_class(paramrow, c(paramrow$cls, "Domain", class(paramrow))) }, @@ -606,7 +595,7 @@ ParamSet = R6Class("ParamSet", pids_not_there = setdiff(parents, ids) if (length(pids_not_there) > 0L) { stopf(paste0("Subsetting so that dependencies on params exist which would be gone: %s.", - "\nIf you still want to subset, set allow_dangling_dependencies to TRUE."), str_collapse(pids_not_there)) + "\nIf you still want to subset, set allow_dangling_dependencies to TRUE."), str_collapse(pids_not_there)) } } result = ParamSet$new() @@ -662,7 +651,7 @@ ParamSet = R6Class("ParamSet", assert_list(values) assert_names(names(values), subset.of = self$ids()) pars = private$get_tune_ps(values) - on = NULL # pacify static code check + on = NULL # pacify static code check dangling_deps = pars$deps[!pars$ids(), on = "on"] if (nrow(dangling_deps)) { stopf("Dangling dependencies not allowed: Dependencies on %s dangling.", str_collapse(dangling_deps$on)) @@ -687,7 +676,7 @@ ParamSet = R6Class("ParamSet", stopf("A param cannot depend on itself!") } - if (on %in% ids) { # not necessarily true when allow_dangling_dependencies + if (on %in% ids) { # not necessarily true when allow_dangling_dependencies feasible_on_values = map_lgl(cond$rhs, function(x) domain_test(self$get_domain(on), list(x))) if (any(!feasible_on_values)) { stopf("Condition has infeasible values for %s: %s", on, str_collapse(cond$rhs[!feasible_on_values])) @@ -847,7 +836,7 @@ ParamSet = R6Class("ParamSet", assert_character(v$on, any.missing = FALSE) assert_list(v$cond, types = "Condition", any.missing = FALSE) } else { - v = data.table(id = character(0), on = character(0), cond = list()) # make sure we have the right columns + v = data.table(id = character(0), on = character(0), cond = list()) # make sure we have the right columns } private$.deps = v } @@ -954,9 +943,7 @@ ParamSet = R6Class("ParamSet", get_tune_ps = function(values) { values = keep(values, inherits, "TuneToken") - if (!length(values)) { - return(ParamSet$new()) - } + if (!length(values)) return(ParamSet$new()) params = map(names(values), function(pn) { domain = private$.params[pn, on = "id"] set_class(domain, c(domain$cls, "Domain", class(domain))) @@ -965,14 +952,13 @@ ParamSet = R6Class("ParamSet", # package-internal S3 fails if we don't call the function indirectly here partsets = pmap(list(values, params), function(...) tunetoken_to_ps(...)) - - pars = ps_union(partsets) # partsets does not have names here, wihch is what we want. + pars = ps_union(partsets) # partsets does not have names here, wihch is what we want. names(partsets) = names(values) idmapping = map(partsets, function(x) x$ids()) # only add the dependencies that are also in the tuning PS - on = id = NULL # pacify static code check + on = id = NULL # pacify static code check pmap(self$deps[id %in% names(idmapping) & on %in% names(partsets), c("on", "id", "cond")], function(on, id, cond) { onpar = partsets[[on]] if (onpar$has_trafo || !identical(onpar$ids(), on)) { @@ -1042,7 +1028,7 @@ rd_info.ParamSet = function(obj, descriptions = character(), ...) { # nolint is_default = map_lgl(params$default, inherits, "NoDefault") is_uty = params$storage_type == "list" set(params, i = which(is_uty & !is_default), j = "default", - value = map(cargo[!is_default & is_uty], function(x) x$repr)) + value = map(cargo[!is_default & is_uty], function(x) x$repr)) set(params, i = which(is_uty), j = "storage_type", value = list("untyped")) set(params, i = which(is_default), j = "default", value = list("-")) @@ -1060,3 +1046,4 @@ rd_info.ParamSet = function(obj, descriptions = character(), ...) { # nolint x = c("", knitr::kable(params, col.names = capitalize(names(params)))) paste(x, collapse = "\n") } + From a5e7ef8fdb901575d88458335efd1123972712c9 Mon Sep 17 00:00:00 2001 From: mb706 Date: Mon, 10 Jun 2024 20:05:59 +0200 Subject: [PATCH 34/34] keep diff small II --- R/ParamSet.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/ParamSet.R b/R/ParamSet.R index c7003361..a8ef87f0 100644 --- a/R/ParamSet.R +++ b/R/ParamSet.R @@ -744,7 +744,7 @@ ParamSet = R6Class("ParamSet", } if (length(xs) == 0L) { xs = named_list() - } else if (self$assert_values) { # this only makes sense when we have asserts on + } else if (self$assert_values) { # this only makes sense when we have asserts on # convert all integer params really to storage type int, move doubles to within bounds etc. # solves issue #293, #317 nontt = discard(xs, inherits, "TuneToken")