Skip to content

Commit

Permalink
Add weighted option
Browse files Browse the repository at this point in the history
  • Loading branch information
fouodo committed Nov 19, 2024
1 parent 0a5e1be commit c30dd68
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 4 deletions.
2 changes: 2 additions & 0 deletions R/BestSpecificLearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ bestSpecificLearner = function (x, y, perf = NULL) {
perf_values = unlist(perf_values)
}
} else {
# nocov start
if (is.function(perf)) {
arg_names <- names(formals(perf))
if (arg_names %in% c("observed", "predicted")) {
Expand All @@ -62,6 +63,7 @@ bestSpecificLearner = function (x, y, perf = NULL) {
stop("Arguments of the perf function must be 'observed' and 'predicted'.")
}
}
# nocov end
weights_values = (1L / perf_values) / sum((1L / perf_values))
max_index = which.max(weights_values)
weights_values = rep(0L, length(weights_values))
Expand Down
26 changes: 26 additions & 0 deletions R/Target.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,27 +24,37 @@ Target <- R6Class("Target",
data_frame,
training) {
if (!any(c("Training") %in% class(training))) {
# nocov start
stop("A Target can belong only to a Training object.\n")
# nocov end
}
ind_col = training$getIndCol()
target = training$getTarget()
if (!all(c(ind_col, target) %in% colnames(data_frame))) {
# nocov start
stop("Individual column ID or target variable not found in the provided data.frame.\n")
# nocov end
}
if (training$checkTargetExist()) {
# Remove TrainData if already existing
# nocov start
key_class = train_layer$getKeyClass()
key = key_class[key_class$class == "Target", "key"]
training$removeFromHashTable(key = key)
# nocov end
}
private$training = training
missing_target = is.na(data_frame[ , target])
if (any(missing_target)) {
# nocov start
data_frame = data_frame[!missing_target, ]
# nocov end
}
missing_id = is.na(data_frame[ , ind_col])
if (any(missing_id)) {
# nocov start
data_frame = data_frame[!missing_id, ]
# nocov end
}
super$initialize(id = id,
ind_col = training$getIndCol(),
Expand All @@ -55,36 +65,44 @@ Target <- R6Class("Target",
value = self,
.class = "Target")
if (any(missing_target)) {
# nocov start
warning(sprintf("%s individual(s) with missing target value(s) recognized and removed.\n",
sum(missing_target)))
# nocov end
}
if (any(missing_id)) {
# nocov start
warning(sprintf("%s individual(s) with missing ID value(s) recognized and removed.\n",
sum(missing_id)))
# nocov end
}
},
#' @description
#' Printer
#' @param ... (any) \cr
#'
# nocov start
print = function (...) {
cat(sprintf("Training : %s\n", private$training$getId()))
cat(sprintf("ind. id. : %s\n", private$ind_col))
cat(sprintf("target : %s\n", private$target))
cat(sprintf("n : %s\n", nrow(private$data_frame)))
cat(sprintf("Missing : %s\n", sum(!complete.cases(private$data_frame))))
},
# nocov end
#' @description
#' Summary
#' @param ... (any) \cr
#'
# nocov start
summary = function (...) {
cat(sprintf(" Layer : %s\n", private$training$getId()))
cat(sprintf(" Ind. id. : %s\n", private$ind_col))
cat(sprintf(" Target : %s\n", private$target))
cat(sprintf(" n : %s\n", nrow(private$data_frame)))
cat(sprintf(" Missing : %s\n", sum(!complete.cases(private$data_frame))))
},
# nocov end
#' @description
#' Getter of the current \code{data.frame} wihtout individual
#' ID nor target variables.
Expand All @@ -93,26 +111,32 @@ Target <- R6Class("Target",
#' The \code{data.frame} without individual ID nor target variables is returned.
#' @export
#'
# nocov start
getData = function () {
return(private$data_frame)
},
# # nocov end
#' @description
#' Getter of target values stored on the current training layer.
#'
#' @return
#' The observed target values stored on the current training layer are returned.
#' @export
#'
# nocov start
getTargetValues = function () {
return(private$data_frame[[private$target]])
},
# nocov end
#' @description
#' Getter of the target variable name.
#'
#' @export
#'
getTargetName = function () {
# nocov start
return(private$target)
# nocov end
},
#' @description
#' Getter of the current training object.
Expand All @@ -123,7 +147,9 @@ Target <- R6Class("Target",
#' @export
#'
getTraining = function () {
# nocov start
return(private$training)
# nocov end
}
),
private = list(
Expand Down
2 changes: 1 addition & 1 deletion R/predict.bestSpecificLearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,6 @@ predict.bestSpecificLearner = function (object, data, na.rm = TRUE) {
})
return(list(predictions = unlist(pred)))
} else {
stop("Names of weights do not match with name columns in data")
stop("Names of weights do not match with name columns in data.")
}
}
10 changes: 8 additions & 2 deletions R/weightedMeanLearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#' \code{data.frame} of predictors.
#' @param y `vector(1)` \cr
#' Target observations. Either binary or two level factor variable.
#' @param weighted \cr
#' If TRUE, the weighted sum is computed.
#'
#' @return
#' A model object of class \code{weightedMeanLeaner}.
Expand All @@ -19,7 +21,7 @@
#' y = sample(x = 0L:1L, size = 50L, replace = TRUE)
#' my_model = weightedMeanLearner(x = x, y = y)
#'
weightedMeanLearner = function (x, y) {
weightedMeanLearner = function (x, y, weighted = TRUE) {
# y must be binomial. If dichotomy, first category (case) = 1 and
# second (control) = 0
if ((length(unique(y)) > 2) | is.character(y)) {
Expand All @@ -39,7 +41,11 @@ weightedMeanLearner = function (x, y) {
})
brier_values = unlist(brier_values)
# weights_values = (1 - brier_values) / sum((1 - brier_values))
weights_values = (1 / brier_values) / sum((1 / brier_values))
if (weighted) {
weights_values = (1 / brier_values) / sum((1 / brier_values))
} else {
weights_values <- 1 / length(brier_values)
}
names(weights_values) = names(x)
class(weights_values) = "weightedMeanLearner"
return(weights_values)
Expand Down
5 changes: 4 additions & 1 deletion man/weightedMeanLearner.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit c30dd68

Please sign in to comment.