From 8fc9f5f79db805d0f8bc1dc4afec51d297d12814 Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Tue, 16 Apr 2024 15:11:26 +0200 Subject: [PATCH] add aggr function to ParamSet, tests --- R/Domain.R | 2 +- R/ParamSet.R | 22 ++++++++++++++++++++++ R/to_tune.R | 7 +++---- man/Domain.Rd | 2 +- man/ParamSet.Rd | 23 +++++++++++++++++++++++ man/ParamSetCollection.Rd | 1 + man/in_tune.Rd | 4 ++-- tests/testthat/test_ParamSet.R | 19 +++++++++++++++++++ tests/testthat/test_domain.R | 13 +++++++++++++ 9 files changed, 85 insertions(+), 8 deletions(-) diff --git a/R/Domain.R b/R/Domain.R index e9224a63..07feae10 100644 --- a/R/Domain.R +++ b/R/Domain.R @@ -65,7 +65,7 @@ #' value upon construction. #' @param aggr (`function`)\cr #' Function with one argument, which is a list of parameter values. -#' The function specifies how this list of parameter values is aggregated to form one parameter value. +#' The function specifies how a list of parameter values is aggregated to form one parameter value. #' This is used in the context of inner tuning. The default is to aggregate the values. #' #' @return A `Domain` object. diff --git a/R/ParamSet.R b/R/ParamSet.R index 3c289f48..588091c9 100644 --- a/R/ParamSet.R +++ b/R/ParamSet.R @@ -259,6 +259,28 @@ ParamSet = R6Class("ParamSet", x }, + #' @description + #' + #' Aggregate parameter values according to the aggregation rules. + #' + #' @param x (named `list()` of `list()`s)\cr + #' The value(s) to be aggregated. Names are parameter values. + #' The aggregation function is selected accordingly for each parameter. + #' @return (named `list()`) + aggr = function(x) { + assert_list(x, types = "list") + assert_permutation(names(x), private$.aggrs$id) + if (!(length(unique(lengths(x))) == 1L)) { + stopf("The same number of values are required for each parameter") + } + if (nrow(private$.aggrs) && !length(x[[1L]])) { + stopf("More than one value is required to aggregate them") + } + imap(x, function(value, .id) { + aggr = private$.aggrs[list(.id), "aggr", on = "id"][[1L]][[1L]](value) + }) + }, + #' @description #' \pkg{checkmate}-like test-function. Takes a named list. #' Return `FALSE` if the given `$constraint` is not satisfied, `TRUE` otherwise. diff --git a/R/to_tune.R b/R/to_tune.R index a069d0fb..4eb3955a 100644 --- a/R/to_tune.R +++ b/R/to_tune.R @@ -189,8 +189,8 @@ to_tune = function(...) { #' See [`mlr3::Learner`] for more information. #' @inheritParams to_tune #' @param aggr (`function`)\cr -#' The aggregator function that determines how to aggregate a list of parameter values into one value. -#' a single parameter value. The default is to average them. +#' The aggregator function that determines how to aggregate a list of parameter values into a single parameter value. +#' The default is to average the values and round them up. #' @export in_tune = function(..., aggr = NULL) { if (is.null(aggr)) { @@ -241,8 +241,7 @@ tunetoken_to_ps = function(tt, param) { tunetoken_to_ps.InnerTuneToken = function(tt, params) { ps = NextMethod() - browser() - ps$tags = map(ps$tags, function(tags) union(tags, "inner_tune")) + ps$tags = map(ps$tags, function(tags) union(tags, "inner_tuning")) return(ps) } diff --git a/man/Domain.Rd b/man/Domain.Rd index 553d4cd1..1af85d2d 100644 --- a/man/Domain.Rd +++ b/man/Domain.Rd @@ -146,7 +146,7 @@ value upon construction.} \item{aggr}{(\code{function})\cr Function with one argument, which is a list of parameter values. -The function specifies how this list of parameter values is aggregated to form one parameter value. +The function specifies how a list of parameter values is aggregated to form one parameter value. This is used in the context of inner tuning. The default is to aggregate the values.} \item{levels}{(\code{character} | \code{atomic} | \code{list})\cr diff --git a/man/ParamSet.Rd b/man/ParamSet.Rd index b74daebc..165905cd 100644 --- a/man/ParamSet.Rd +++ b/man/ParamSet.Rd @@ -172,6 +172,7 @@ Named with param IDs.} \item \href{#method-ParamSet-get_values}{\code{ParamSet$get_values()}} \item \href{#method-ParamSet-set_values}{\code{ParamSet$set_values()}} \item \href{#method-ParamSet-trafo}{\code{ParamSet$trafo()}} +\item \href{#method-ParamSet-aggr}{\code{ParamSet$aggr()}} \item \href{#method-ParamSet-test_constraint}{\code{ParamSet$test_constraint()}} \item \href{#method-ParamSet-test_constraint_dt}{\code{ParamSet$test_constraint_dt()}} \item \href{#method-ParamSet-check}{\code{ParamSet$check()}} @@ -340,6 +341,28 @@ In almost all cases, the default \code{param_set = self} should be used.} } } \if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ParamSet-aggr}{}}} +\subsection{Method \code{aggr()}}{ +Aggregate parameter values according to the aggregation rules. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ParamSet$aggr(x)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{x}}{(named \code{list()} of \code{list()}s)\cr +The value(s) to be aggregated. Names are parameter values. +The aggregation function is selected accordingly for each parameter.} +} +\if{html}{\out{
}} +} +\subsection{Returns}{ +(named \code{list()}) +} +} +\if{html}{\out{
}} \if{html}{\out{}} \if{latex}{\out{\hypertarget{method-ParamSet-test_constraint}{}}} \subsection{Method \code{test_constraint()}}{ diff --git a/man/ParamSetCollection.Rd b/man/ParamSetCollection.Rd index 93e4c93e..3a863f06 100644 --- a/man/ParamSetCollection.Rd +++ b/man/ParamSetCollection.Rd @@ -79,6 +79,7 @@ This field provides direct references to the \code{\link{ParamSet}} objects.}
Inherited methods