Skip to content

Commit

Permalink
Add predict parameters as argument
Browse files Browse the repository at this point in the history
  • Loading branch information
fouodo committed Nov 5, 2024
1 parent 7040c0d commit 2e56455
Showing 1 changed file with 29 additions and 11 deletions.
40 changes: 29 additions & 11 deletions R/Lrner.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@ Lrner <- R6Class("Lrner",

#' @param package (`character(1)`) \cr
#' Package that implements the learn function. If NULL, the
#' learn function is called from the current environment.
#' @param lrn_fct (`character(1)`) \cr
#' Learn function name.
#' @param param (`ParamLrner(1)`) \cr
#' learn function is called from the current environment.
#' @param param_train_list \cr
#' List of parameter for training.
#' @param param_pred_list \cr
#' List of parameter for testing.
#' Learn parameters.
#' @param train_layer (`TrainLayer(1)`) \cr
#' Layer on which the learner is stored.
Expand All @@ -31,13 +33,15 @@ Lrner <- R6Class("Lrner",
initialize = function (id,
package = NULL,
lrn_fct,
param,
param_train_list,
param_pred_list = list(),
train_layer,
na_rm = TRUE) {
private$id = id
private$package = package
private$lrn_fct = lrn_fct
private$param = param
private$param_train = param_train_list
private$param_pred = param_pred_list
if (!any(c("TrainLayer", "TrainMetaLayer") %in% class(train_layer))) {
stop("A Lrner can only belong to a TrainLayer or a TrainMetaLayer object.")
}
Expand Down Expand Up @@ -67,7 +71,6 @@ Lrner <- R6Class("Lrner",
cat(sprintf("TrainLayer : %s\n", private$train_layer$getId()))
cat(sprintf("Package : %s\n", private$package))
cat(sprintf("Learn function : %s\n", private$lrn_fct))
cat(sprintf("Param id : %s\n", private$param$id))
},
#' @description
#' Printer
Expand All @@ -78,7 +81,12 @@ Lrner <- R6Class("Lrner",
cat(sprintf(" TrainLayer : %s\n", private$train_layer$getId()))
cat(sprintf(" Package : %s\n", private$package))
cat(sprintf(" Learn function : %s\n", private$lrn_fct))
cat(sprintf(" Param id : %s\n", private$param$id))
cat("Predicting parameter\n")
print(expand.grid(private$param_train))
if (!length(private$param_pred)) {
cat("Predicting parameter\n")
print(expand.grid(private$param_pred))
}
},
#' @description
#' Tains the current learner (from class [Lrner]) on the current training data (from class [TrainData]).
Expand Down Expand Up @@ -109,8 +117,7 @@ Lrner <- R6Class("Lrner",
} else {
lrn = sprintf('%s::%s', private$package, private$lrn_fct)
}
lrn_param = private$param$getParamLrner()[1L, ]
lrn_param = as.list(lrn_param)
lrn_param = private$param_train
# Prepare training dataset: extract individual subset
if (!is.null(ind_subset)) {
train_data = train_data$getIndSubset(
Expand Down Expand Up @@ -208,6 +215,15 @@ Lrner <- R6Class("Lrner",
#'
getVarSubset = function () {
return(private$ind_subset)
},
#' @description
#' Getter predicting parameter list.
#'
#' @return
#' The list of predicting parameters.
#'
getParamPred = function () {
return(private$param_pred)
}
),
private = list(
Expand All @@ -217,8 +233,10 @@ Lrner <- R6Class("Lrner",
package = NULL,
# Learn function name (like \code{ranger}).
lrn_fct = NULL,
# Parameters (from class [Param]) of the learn function.
param = NULL,
# Parameters of the learn function.
param_train = list(0L),
# Parameters of the predict function.
param_pred = list(0L),
na_rm = NULL,
# Training layer (from class [TainLayer] or [TrainMetaLayer]) of the current learner.
train_layer = NULL,
Expand Down

0 comments on commit 2e56455

Please sign in to comment.