Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support internal tuning #399

Merged
merged 36 commits into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
441d754
feat: aggr function for inner tuning
sebffischer Apr 16, 2024
8fc9f5f
add aggr function to ParamSet, tests
sebffischer Apr 16, 2024
cc5c828
...
sebffischer Apr 16, 2024
874cd7f
fix bug
sebffischer Apr 16, 2024
d83abfb
fix bug
sebffischer Apr 16, 2024
5fd31c2
...
sebffischer Apr 17, 2024
b6fee52
aggr is now part of cargo
sebffischer Apr 22, 2024
4fa12b1
cleanup
sebffischer Apr 22, 2024
8f7cd36
fix example
sebffischer Apr 22, 2024
f9c9ff1
fix bug in objecttunetoken
sebffischer Apr 22, 2024
7c8684a
cleanup
sebffischer Apr 22, 2024
431ba36
more tests
sebffischer Apr 23, 2024
beaf845
...
sebffischer Apr 26, 2024
f5dbe0d
add in_tune_fn
sebffischer May 4, 2024
41e7d81
cleanup
sebffischer May 4, 2024
2a43aab
fix: add tags to domains created for inner tuning
sebffischer May 6, 2024
94f0f18
rename inner to internal
sebffischer May 6, 2024
13dd154
fix bug
sebffischer May 6, 2024
abb8856
Merge branch 'main' into feat/inner_valid
sebffischer May 29, 2024
5fdb2a4
support disabling internal tuning
sebffischer May 29, 2024
776b318
docs, edge cases
sebffischer May 29, 2024
773d692
bugfix
sebffischer May 30, 2024
1b391f7
some more changes
sebffischer May 30, 2024
31784b4
wip [skip ci]
sebffischer May 31, 2024
8482335
more wip
sebffischer May 31, 2024
ebb46f4
Merge branch 'main' into feat/inner_valid
sebffischer May 31, 2024
c65cba4
hopefully fix final bug
sebffischer Jun 1, 2024
715a322
uncomment tests
sebffischer Jun 1, 2024
134d4fb
remove dead comments
sebffischer Jun 1, 2024
b75465a
fix disable in tune for nested psc
sebffischer Jun 1, 2024
fd6270a
make tests less confusing
sebffischer Jun 1, 2024
6fd11c0
fix one more bug
sebffischer Jun 1, 2024
f7a6754
better example
sebffischer Jun 1, 2024
949bbe8
NEWS.md
mb706 Jun 10, 2024
18e4088
keep diff small
mb706 Jun 10, 2024
a5e7ef8
keep diff small II
mb706 Jun 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ Config/testthat/edition: 3
Config/testthat/parallel: false
NeedsCompilation: no
Roxygen: list(markdown = TRUE, r6 = TRUE)
RoxygenNote: 7.2.3
RoxygenNote: 7.3.1
VignetteBuilder: knitr
Collate:
'Condition.R'
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ S3method(format,Condition)
S3method(print,Condition)
S3method(print,Domain)
S3method(print,FullTuneToken)
S3method(print,InternalTuneToken)
S3method(print,ObjectTuneToken)
S3method(print,RangeTuneToken)
S3method(rd_info,ParamSet)
Expand Down
4 changes: 3 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# paradox 0.12.0
# paradox 1.0.0

* Removed `Param` objects. `ParamSet` now uses a `data.table` internally; individual parameters are more like `Domain` objects now. `ParamSets` should be constructed using the `ps()` shorthand and `Domain` objects. This entails the following major changes:
* `ParamSet` now supports `extra_trafo` natively; it behaves like `.extra_trafo` of the `ps()` call.
* `ParamSet` has `$constraint`
Expand All @@ -7,6 +8,7 @@
* `Condition` objects are now S3 objects and can be constructed with `CondEqual()` and `CondAnyOf()`, instead of `CondXyz$new()`. (It is recommended to use the `Domain` interface for conditions, which has not changed)
* `ParamSet` has new fields `$is_logscale`, `$has_trafo_param` (per-param), and `$has_trafo_param` (scalar for the whole set).
* Added a vignette which was previously a chapter in the `mlr3book`
* feat: added support for `InternalTuneToken`s

# paradox 0.11.1

Expand Down
51 changes: 50 additions & 1 deletion R/Domain.R
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,17 @@
#' @param init (`any`)\cr
#' Initial value. When this is given, then the corresponding entry in `ParamSet$values` is initialized with this
#' value upon construction.
#' @param aggr (`function`)\cr
#' Default aggregation function for a parameter. Can only be given for parameters tagged with `"internal_tuning"`.
#' Function with one argument, which is a list of parameter values and that returns the aggregated parameter value.
#' @param in_tune_fn (`function(domain, param_vals)`)\cr
#' Function that converters a `Domain` object into a parameter value.
#' Can only be given for parameters tagged with `"internal_tuning"`.
#' This function should also assert that the parameters required to enable internal tuning for the given `domain` are
#' set in `param_vals` (such as `early_stopping_rounds` for `XGBoost`).
#' @param disable_in_tune (named `list()`)\cr
#' The parameter values that need to be set in the `ParamSet` to disable the internal tuning for the parameter.
#' For `XGBoost` this would e.g. be `list(early_stopping_rounds = NULL)`.
#'
#' @return A `Domain` object.
#'
Expand Down Expand Up @@ -117,6 +128,30 @@
#' # ... but get transformed to integers.
#' print(grid$transpose())
#'
#'
#' # internal tuning
#'
#' param_set = ps(
#' iters = p_int(0, Inf, tags = "internal_tuning", aggr = function(x) round(mean(unlist(x))),
#' in_tune_fn = function(domain, param_vals) {
#' stopifnot(domain$lower <= 1)
#' stopifnot(param_vals$early_stopping == TRUE)
#' domain$upper
#' },
#' disable_in_tune = list(early_stopping = FALSE)),
#' early_stopping = p_lgl()
#' )
#' param_set$set_values(
#' iters = to_tune(upper = 100, internal = TRUE),
#' early_stopping = TRUE
#' )
#' param_set$convert_internal_search_space(param_set$search_space())
#' param_set$aggr_internal_tuned_values(
#' list(iters = list(1, 2, 3))
#' )
#'
#' param_set$disable_internal_tuning("iters")
#' param_set$values$early_stopping
#' @family ParamSet construction helpers
#' @name Domain
NULL
Expand All @@ -136,6 +171,21 @@ Domain = function(cls, grouping,
storage_type = "list",
init) {

if ("internal_tuning" %in% tags) {
assert_true(!is.null(cargo$aggr), .var.name = "aggregation function exists")
}
assert_list(cargo$disable_in_tune, null.ok = TRUE, names = "unique")
assert_function(cargo$aggr, null.ok = TRUE)
assert_function(cargo$in_tune_fn, null.ok = TRUE)
if ((!is.null(cargo$in_tune_fn) || !is.null(cargo$disable_in_tune)) && "internal_tuning" %nin% tags) {
# we cannot check the reverse, as parameters in the search space can be tagged with 'internal_tuning'
# and not provide in_tune_fn or disable_in_tune
stopf("Arguments in_tune_fn and disable_in_tune require the tag 'internal_tuning' to be present.")
}
if ((is.null(cargo$in_tune_fn) + is.null(cargo$disable_in_tune)) == 1) {
stopf("Arguments in_tune_fn and disable_tune_fn must both be present")
}

assert_string(cls)
assert_string(grouping)
assert_number(lower, na.ok = TRUE)
Expand Down Expand Up @@ -227,7 +277,6 @@ print.Domain = function(x, ...) {
if (!is.null(repr)) {
print(repr)
} else {
plural_rows =
classes = class(x)
if ("Domain" %in% classes) {
domainidx = which("Domain" == classes)[[1]]
Expand Down
9 changes: 7 additions & 2 deletions R/ParamDbl.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#' @rdname Domain
#' @export
p_dbl = function(lower = -Inf, upper = Inf, special_vals = list(), default = NO_DEF, tags = character(), tolerance = sqrt(.Machine$double.eps), depends = NULL, trafo = NULL, logscale = FALSE, init) {
p_dbl = function(lower = -Inf, upper = Inf, special_vals = list(), default = NO_DEF, tags = character(), tolerance = sqrt(.Machine$double.eps), depends = NULL, trafo = NULL, logscale = FALSE, init, aggr = NULL, in_tune_fn = NULL, disable_in_tune = NULL) {
assert_number(tolerance, lower = 0)
assert_number(lower)
assert_number(upper)
Expand All @@ -17,8 +17,13 @@ p_dbl = function(lower = -Inf, upper = Inf, special_vals = list(), default = NO_
real_upper = upper
}

cargo = list()
if (logscale) cargo$logscale = TRUE
cargo$aggr = aggr
cargo$in_tune_fn = in_tune_fn
cargo$disable_in_tune = disable_in_tune
Domain(cls = "ParamDbl", grouping = "ParamDbl", lower = real_lower, upper = real_upper, special_vals = special_vals, default = default, tags = tags, tolerance = tolerance, trafo = trafo, storage_type = "numeric",
depends_expr = substitute(depends), init = init, cargo = if (logscale) "logscale")
depends_expr = substitute(depends), init = init, cargo = if (length(cargo)) cargo)
}

#' @export
Expand Down
11 changes: 9 additions & 2 deletions R/ParamFct.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#' @rdname Domain
#' @export
p_fct = function(levels, special_vals = list(), default = NO_DEF, tags = character(), depends = NULL, trafo = NULL, init) {
p_fct = function(levels, special_vals = list(), default = NO_DEF, tags = character(), depends = NULL, trafo = NULL, init, aggr = NULL, in_tune_fn = NULL, disable_in_tune = NULL) {
assert_function(aggr, null.ok = TRUE, nargs = 1L)
constargs = as.list(match.call()[-1])
levels = eval.parent(constargs$levels)
if (!is.character(levels)) {
Expand All @@ -21,8 +22,14 @@ p_fct = function(levels, special_vals = list(), default = NO_DEF, tags = charact
}
# group p_fct by levels, so the group can be checked in a vectorized fashion.
# We escape '"' and '\' to '\"' and '\\', respectively.
cargo = list()
cargo$disable_in_tune = disable_in_tune
cargo$aggr = aggr
cargo$in_tune_fn = in_tune_fn
grouping = str_collapse(gsub("([\\\\\"])", "\\\\\\1", sort(real_levels)), quote = '"', sep = ",")
Domain(cls = "ParamFct", grouping = grouping, levels = real_levels, special_vals = special_vals, default = default, tags = tags, trafo = trafo, storage_type = "character", depends_expr = substitute(depends), init = init)
Domain(cls = "ParamFct", grouping = grouping, levels = real_levels, special_vals = special_vals,
default = default, tags = tags, trafo = trafo, storage_type = "character",
depends_expr = substitute(depends), init = init, cargo = if (length(cargo)) cargo)
}

#' @export
Expand Down
10 changes: 8 additions & 2 deletions R/ParamInt.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

#' @rdname Domain
#' @export
p_int = function(lower = -Inf, upper = Inf, special_vals = list(), default = NO_DEF, tags = character(), tolerance = sqrt(.Machine$double.eps), depends = NULL, trafo = NULL, logscale = FALSE, init) {
p_int = function(lower = -Inf, upper = Inf, special_vals = list(), default = NO_DEF, tags = character(), tolerance = sqrt(.Machine$double.eps), depends = NULL, trafo = NULL, logscale = FALSE, init, aggr = NULL, in_tune_fn = NULL, disable_in_tune = NULL) {
assert_number(tolerance, lower = 0, upper = 0.5)
# assert_int will stop for `Inf` values, which we explicitly allow as lower / upper bound
if (!isTRUE(is.infinite(lower))) assert_int(lower, tol = 1e-300) else assert_number(lower)
Expand All @@ -23,9 +23,15 @@ p_int = function(lower = -Inf, upper = Inf, special_vals = list(), default = NO_
real_upper = upper
}

cargo = list()
if (logscale) cargo$logscale = TRUE
cargo$aggr = aggr
cargo$in_tune_fn = in_tune_fn
cargo$disable_in_tune = disable_in_tune

Domain(cls = cls, grouping = cls, lower = real_lower, upper = real_upper, special_vals = special_vals, default = default, tags = tags, tolerance = tolerance, trafo = trafo,
storage_type = storage_type,
depends_expr = substitute(depends), init = init, cargo = if (logscale) "logscale")
depends_expr = substitute(depends), init = init, cargo = if (length(cargo)) cargo)
}

#' @export
Expand Down
8 changes: 6 additions & 2 deletions R/ParamLgl.R
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
#' @rdname Domain
#' @export
p_lgl = function(special_vals = list(), default = NO_DEF, tags = character(), depends = NULL, trafo = NULL, init) {
p_lgl = function(special_vals = list(), default = NO_DEF, tags = character(), depends = NULL, trafo = NULL, init, aggr = NULL, in_tune_fn = NULL, disable_in_tune = NULL) {
cargo = list()
cargo$aggr = aggr
cargo$in_tune_fn = in_tune_fn
cargo$disable_in_tune = disable_in_tune
Domain(cls = "ParamLgl", grouping = "ParamLgl", levels = c(TRUE, FALSE), special_vals = special_vals, default = default,
tags = tags, trafo = trafo, storage_type = "logical", depends_expr = substitute(depends), init = init)
tags = tags, trafo = trafo, storage_type = "logical", depends_expr = substitute(depends), init = init, cargo = if (length(cargo)) cargo)
}

#' @export
Expand Down
75 changes: 71 additions & 4 deletions R/ParamSet.R
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ ParamSet = R6Class("ParamSet",
if (".requirements" %in% names(paramtbl)) {
requirements = paramtbl$.requirements
private$.params = paramtbl # self$add_dep needs this
for (row in seq_len(nrow(paramtbl))) {
for (row in seq_len(nrow(paramtbl))) {
for (req in requirements[[row]]) {
invoke(self$add_dep, id = paramtbl$id[[row]], allow_dangling_dependencies = allow_dangling_dependencies,
.args = req)
Expand Down Expand Up @@ -154,15 +154,16 @@ ParamSet = R6Class("ParamSet",
#' @param any_tags (`character()`). See `$ids()`.
#' @param type (`character(1)`)\cr
#' Return values `"with_token"` (i.e. all values),
# `"without_token"` (all values that are not [`TuneToken`] objects) or `"only_token"` (only [`TuneToken`] objects)?
# `"without_token"` (all values that are not [`TuneToken`] objects), `"only_token"` (only [`TuneToken`] objects)
# or `"with_internal"` (all values that are no not `InternalTuneToken`)?
#' @param check_required (`logical(1)`)\cr
#' Check if all required parameters are set?
#' @param remove_dependencies (`logical(1)`)\cr
#' If `TRUE`, set values with dependencies that are not fulfilled to `NULL`.
#' @return Named `list()`.
get_values = function(class = NULL, tags = NULL, any_tags = NULL,
type = "with_token", check_required = TRUE, remove_dependencies = TRUE) {
assert_choice(type, c("with_token", "without_token", "only_token"))
assert_choice(type, c("with_token", "without_token", "only_token", "with_internal"))

assert_flag(check_required)

Expand All @@ -173,6 +174,8 @@ ParamSet = R6Class("ParamSet",
values = discard(values, is, "TuneToken")
} else if (type == "only_token") {
values = keep(values, is, "TuneToken")
} else if (type == "with_internal") {
values = keep(values, is, "InternalTuneToken")
}

if (check_required) {
Expand Down Expand Up @@ -255,6 +258,62 @@ ParamSet = R6Class("ParamSet",
x
},

#' @description
#'
#' Aggregate parameter values according to their aggregation rules.
#'
#' @param x (named `list()` of `list()`s)\cr
#' The value(s) to be aggregated. Names are parameter values.
#' The aggregation function is selected based on the parameter.
#'
#' @return (named `list()`)
aggr_internal_tuned_values = function(x) {
assert_list(x, types = "list")
aggrs = private$.params[map_lgl(get("cargo"), function(cargo) is.function(cargo$aggr)), list(id = get("id"), aggr = map(get("cargo"), "aggr"))]
assert_subset(names(x), aggrs$id)
if (!length(x)) {
return(named_list())
}
imap(x, function(value, .id) {
if (!length(value)) {
stopf("Trying to aggregate values of parameters '%s', but there are no values", .id)
}
aggr = aggrs[list(.id), "aggr", on = "id"][[1L]][[1L]](value)
})
},

#' @description
#'
#' Set the parameter values so that internal tuning for the selected parameters is disabled.
#'
#' @param ids (`character()`)\cr
#' The ids of the parameters for which to disable internal tuning.
#' @return `Self`
disable_internal_tuning = function(ids) {
assert_subset(ids, self$ids(tags = "internal_tuning"))
pvs = Reduce(c, map(private$.params[ids, "cargo", on = "id"][[1]], "disable_in_tune")) %??% named_list()
self$set_values(.values = pvs)
},

#' @description
#' Convert all parameters from the search space to parameter values using the transformation given by
#' `in_tune_fn`.
#' @param search_space ([`ParamSet`])\cr
#' The internal search space.
#' @return (named `list()`)
convert_internal_search_space = function(search_space) {
assert_class(search_space, "ParamSet")
param_vals = self$values

imap(search_space$domains, function(token, .id) {
converter = private$.params[list(.id), "cargo", on = "id"][[1L]][[1L]]$in_tune_fn
if (!is.function(converter)) {
stopf("No converter exists for parameter '%s'", .id)
}
converter(token, param_vals)
})
},

#' @description
#' \pkg{checkmate}-like test-function. Takes a named list.
#' Return `FALSE` if the given `$constraint` is not satisfied, `TRUE` otherwise.
Expand Down Expand Up @@ -323,6 +382,14 @@ ParamSet = R6Class("ParamSet",
if (!isTRUE(tunecheck)) return(tunecheck)
}

xs_internaltune = keep(xs, is, "InternalTuneToken")
walk(names(xs_internaltune), function(pid) {
if ("internal_tuning" %nin% self$tags[[pid]]) {
stopf("Trying to assign InternalTuneToken to parameter '%s' which is not tagged with 'internal_tuning'.", pid)
}
})


# check each parameter group's feasibility
xs_nontune = discard(xs, inherits, "TuneToken")

Expand Down Expand Up @@ -822,7 +889,7 @@ ParamSet = R6Class("ParamSet",
#' Note that this only refers to the `logscale` flag set during construction, e.g. `p_dbl(logscale = TRUE)`.
#' If the parameter was set to logscale manually, e.g. through `p_dbl(trafo = exp)`,
#' this `is_logscale` will be `FALSE`.
is_logscale = function() with(private$.params, set_names(cls %in% c("ParamDbl", "ParamInt") & cargo == "logscale", id)),
is_logscale = function() with(private$.params, set_names(cls %in% c("ParamDbl", "ParamInt") & map_lgl(cargo, function(x) isTRUE(x$logscale)), id)),

############################
# Per-Parameter class properties (S3 method call)
Expand Down
Loading
Loading