Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

detached extra_trafo #411

Merged
merged 6 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# paradox 1.0.1-9000

* `ParamSetCollection$flatten()` now detaches `$extra_trafo` completely from original ParamSetCollection.

# paradox 1.0.1

* Performance improvements.
Expand Down
1 change: 0 additions & 1 deletion R/ParamSet.R
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,6 @@ ParamSet = R6Class("ParamSet",
result$assert_values = FALSE
result$deps = deps[ids, on = "id", nomatch = NULL]
if (keep_constraint) result$constraint = self$constraint
# TODO: ParamSetCollection trafo currently drags along the entire original paramset in its environment
result$extra_trafo = self$extra_trafo
# restrict to ids already in pvals
values = self$values
Expand Down
152 changes: 110 additions & 42 deletions R/ParamSetCollection.R
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,6 @@ ParamSetCollection = R6Class("ParamSetCollection", inherit = ParamSet,
setindexv(paramtbl, c("id", "cls", "grouping"))
private$.params = paramtbl

private$.children_with_trafos = which(!map_lgl(map(sets, "extra_trafo"), is.null))
private$.children_with_constraints = which(!map_lgl(map(sets, "constraint"), is.null))

private$.sets = sets
},

Expand Down Expand Up @@ -188,21 +185,28 @@ ParamSetCollection = R6Class("ParamSetCollection", inherit = ParamSet,
private$.params = rbind(private$.params, paramtbl)
setindexv(private$.params, c("id", "cls", "grouping"))

if (!is.null(p$extra_trafo)) {
entry = if (n == "") length(private$.children_with_trafos) + 1 else n
private$.children_with_trafos[[entry]] = new_index
}

if (!is.null(p$constraint)) {
entry = if (n == "") length(private$.children_with_constraints) + 1 else n
private$.children_with_constraints[[entry]] = new_index
}

entry = if (n == "") length(private$.sets) + 1 else n
private$.sets[[n]] = p
invisible(self)
},

#' @description
#' Create a new `ParamSet` restricted to the passed IDs.
#' @param ids (`character()`).
#' @param allow_dangling_dependencies (`logical(1)`)\cr
#' Whether to allow subsets that cut across parameter dependencies.
#' Dependencies that point to dropped parameters are kept (but will be "dangling", i.e. their `"on"` will not be present).
#' @param keep_constraint (`logical(1)`)\cr
#' Whether to keep the `$constraint` function.
#' @return `ParamSet`.
subset = function(ids, allow_dangling_dependencies = FALSE, keep_constraint = TRUE) {
# need to take care of extra_trafo and constraint.
result = super$subset(ids, allow_dangling_dependencies = allow_dangling_dependencies, keep_constraint = keep_constraint)
if (keep_constraint) result$constraint = private$.get_constraint_detached(ids)
result$extra_trafo = private$.get_extra_trafo_detached(ids)
result
},

#' @description
#'
#' Set the parameter values so that internal tuning for the selected parameters is disabled.
Expand Down Expand Up @@ -261,6 +265,8 @@ ParamSetCollection = R6Class("ParamSetCollection", inherit = ParamSet,
flatten = function() {
flatps = super$flatten()

# This function is a mistake. It should not have been written. Sorry for allowing it to be merged.

recurse_prefix = function(id_, param_set, prefix = "") {
info = get_private(param_set)$.translation[list(id_), c("owner_name", "owner_ps_index"), on = "id"]
prefix = if (info$owner_name == "") {
Expand Down Expand Up @@ -334,15 +340,16 @@ ParamSetCollection = R6Class("ParamSetCollection", inherit = ParamSet,
#' @template field_extra_trafo
extra_trafo = function(f) {
if (!missing(f)) stop("extra_trafo is read-only in ParamSetCollection.")
if (!length(private$.children_with_trafos)) return(NULL)
private$.extra_trafo_explicit
if (!length(private$.children_with_trafos())) return(NULL)

# The reason why we don't crate a function here is that the extra_trafo of private$.sets could change.
private$.extra_trafo_explicit
},

#' @template field_constraint
constraint = function(f) {
if (!missing(f)) stop("constraint is read-only in ParamSetCollection.")
if (!length(private$.children_with_constraints)) return(NULL)
if (!length(private$.children_with_constraints())) return(NULL)
private$.constraint_explicit
},

Expand Down Expand Up @@ -376,36 +383,47 @@ ParamSetCollection = R6Class("ParamSetCollection", inherit = ParamSet,
},
.sets = NULL,
.translation = data.table(id = character(0), original_id = character(0), owner_ps_index = integer(0), owner_name = character(0), key = "id"),
.children_with_trafos = NULL,
.children_with_constraints = NULL,
.children_with_trafos = function() {
which(!map_lgl(map(private$.sets, "extra_trafo"), is.null))
},
.children_with_constraints = function() {
which(!map_lgl(map(private$.sets, "constraint"), is.null))
},
.extra_trafo_explicit = function(x) {
changed = unlist(lapply(private$.children_with_trafos, function(set_index) {
changing_ids = private$.translation[J(set_index), id, on = "owner_ps_index"]
trafo = private$.sets[[set_index]]$extra_trafo
changing_values_in = x[names(x) %in% changing_ids]
names(changing_values_in) = private$.translation[names(changing_values_in), original_id]
# input of trafo() must not be changed after the call; otherwise the trafo would have to `force()` it in
# some circumstances.
changing_values = trafo(changing_values_in)
prefix = names(private$.sets)[[set_index]]
if (prefix != "") {
names(changing_values) = sprintf("%s.%s", prefix, names(changing_values))
}
changing_values
}), recursive = FALSE)
unchanged_ids = private$.translation[!J(private$.children_with_trafos), id, on = "owner_ps_index"]
unchanged = x[names(x) %in% unchanged_ids]
c(unchanged, changed)
children_with_trafos = private$.children_with_trafos()
sets_with_trafos = private$.sets[children_with_trafos]
translation = private$.translation
psc_extra_trafo(x, children_with_trafos, sets_with_trafos, translation)
},
# get an extra_trafo function that does not have any references to the PSC object or any of its contained sets.
# This is used for flattening.
# `ids`: subset of params to consider
.get_extra_trafo_detached = function(ids = NULL) {
translation = if (is.null(ids)) copy(private$.translation) else private$.translation[id %in% ids]
children_with_trafos = private$.children_with_trafos() # just an integer vector, no need to worry here
if (!is.null(ids)) {
children_with_trafos = intersect(children_with_trafos, translation$owner_ps_index)
}
if (!length(children_with_trafos)) return(NULL)
sets_with_trafos = lapply(private$.sets[children_with_trafos], function(x) x$clone(deep = TRUE)) # get new objects that are detached from PSC
crate(function(x) psc_extra_trafo(x, children_with_trafos, sets_with_trafos, translation), children_with_trafos, sets_with_trafos, translation, psc_extra_trafo)
},
.constraint_explicit = function(x) {
for (set_index in private$.children_with_constraints) {
constraining_ids = private$.translation[J(set_index), id, on = "owner_ps_index"]
constraint = private$.sets[[set_index]]$constraint
constraining_values = x[names(x) %in% constraining_ids]
names(constraining_values) = private$.translation[names(constraining_values), original_id]
if (!constraint(x)) return(FALSE)
children_with_constraints = private$.children_with_constraints()
sets_with_constraints = private$.sets[children_with_constraints]
translation = private$.translation
psc_constraint(x, children_with_constraints, sets_with_constraints, translation)
},
# same as with extra_trafo above
.get_constraint_detached = function(ids = NULL) {
translation = if (is.null(ids)) copy(private$.translation) else private$.translation[id %in% ids]
children_with_constraints = private$.children_with_constraints()
if (!is.null(ids)) {
children_with_constraints = intersect(children_with_constraints, translation$owner_ps_index)
}
TRUE
if (!length(children_with_constraints)) return(NULL)
sets_with_constraints = lapply(private$.sets[children_with_constraints], function(x) x$clone(deep = TRUE))
crate(function(x) psc_constraint(x, children_with_constraints, sets_with_constraints, translation), children_with_constraints, sets_with_constraints, translation, psc_constraint)
},
deep_clone = function(name, value) {
switch(name,
Expand All @@ -418,3 +436,53 @@ ParamSetCollection = R6Class("ParamSetCollection", inherit = ParamSet,
}
)
)

# extra_trafo function for ParamSetCollection
# This function is used as extra_trafo for ParamSetCollection, in the case that any of its children has an extra_trafo.
# Arguments:
# - children_with_trafos: set-indices (i.e. index inside PSC's private$.sets, and inside `translation`) of children with extra_trafo
# - sets_with_trafos: subset of PSC's private$.sets of children with extra_trafo
# - translation: PSC's private$.translation
#
# We have this functoin outside of the ParamSetCollection class, because we anticipate that PSC can be "flattened", i.e. turned into
# a normal ParamSet. In that case, the resulting ParamSet's extra_trafo should be a function that can stand on its own, without
# referring to private$<anything>.
psc_extra_trafo = function(x, children_with_trafos, sets_with_trafos, translation) {
changed = unlist(lapply(seq_along(children_with_trafos), function(i) {
set_index = children_with_trafos[[i]]
changing_ids = translation[J(set_index), id, on = "owner_ps_index"]
trafo = sets_with_trafos[[i]]$extra_trafo
changing_values_in = x[names(x) %in% changing_ids]
names(changing_values_in) = translation[names(changing_values_in), original_id]
# input of trafo() must not be changed after the call; otherwise the trafo would have to `force()` it in
# some circumstances.
if (test_function(trafo, args = c("x", "param_set"))) {
changing_values = trafo(x = changing_values_in, param_set = sets_with_trafos[[i]])
} else {
changing_values = trafo(changing_values_in)
}
changing_values = trafo(changing_values_in)
prefix = names(sets_with_trafos)[[i]]
if (prefix != "") {
names(changing_values) = sprintf("%s.%s", prefix, names(changing_values))
}
changing_values
}), recursive = FALSE)
unchanged_ids = translation[!J(children_with_trafos), id, on = "owner_ps_index"]
unchanged = x[names(x) %in% unchanged_ids]
c(unchanged, changed)
}

psc_constraint = function(x, children_with_constraints, sets_with_constraints, translation) {
for (i in seq_along(children_with_constraints)) {
set_index = children_with_constraints[[i]]
constraining_ids = translation[J(set_index), id, on = "owner_ps_index"]
constraint = sets_with_constraints[[i]]$constraint
constraining_values = x[names(x) %in% constraining_ids]
names(constraining_values) = translation[names(constraining_values), original_id]
if (!constraint(x)) return(FALSE)
}
TRUE
}


33 changes: 32 additions & 1 deletion man/ParamSetCollection.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading