diff --git a/DESCRIPTION b/DESCRIPTION index f8d2bd7f..85f415ba 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -47,6 +47,7 @@ Imports: ggplot2 Suggests: testthat (>= 3.1.0), + CVXR, patchwork, rpart, ranger, diff --git a/R/learner_fairml_classif_fairzlrm.R b/R/learner_fairml_classif_fairzlrm.R index cf132447..d92b910c 100644 --- a/R/learner_fairml_classif_fairzlrm.R +++ b/R/learner_fairml_classif_fairzlrm.R @@ -1,10 +1,10 @@ #' @title Classification Fair Logistic Regression With Covariance Constraints Learner #' @author pfistfl -#' @details +#' @details #' Generalized fair regression model from Zafar et al., 2019 implemented via package `fairml`. #' The 'unfairness' parameter is set to 0.05 as a default. #' The optimized fairness metric is statistical parity. -#' +#' #' @name mlr_learners_classif.fairzlrm #' #' @template class_learner @@ -13,7 +13,7 @@ #' #' @references #' `r format_bib("zafar19a")` -#' +#' #' @template seealso_learner #' @template example #' @export @@ -31,7 +31,7 @@ LearnerClassifFairzlrm = R6Class("LearnerClassifFairzlrm", ps$values = list(unfairness = 0.05, intersect = FALSE) super$initialize( id = "classif.fairzlrm", - packages = "fairml", + packages = c("fairml", "CVXR"), feature_types = c("integer", "numeric", "factor", "ordered"), predict_types = c("response", "prob"), properties = "twoclass", diff --git a/R/learner_fairml_regr_fairzlm.R b/R/learner_fairml_regr_fairzlm.R index 9ed5ec0a..bb84eafe 100644 --- a/R/learner_fairml_regr_fairzlm.R +++ b/R/learner_fairml_regr_fairzlm.R @@ -1,10 +1,10 @@ #' @title Regression Fair Regression With Covariance Constraints Learner #' @author pfistfl -#' @details +#' @details #' Fair regression model from Zafar et al., 2019 implemented via package `fairml`. #' The 'unfairness' parameter is set to 0.05 as a default. #' The optimized fairness metric is statistical parity. -#' +#' #' @name mlr_learners_regr.fairzlm #' #' @template class_learner @@ -31,7 +31,7 @@ LearnerRegrFairzlm = R6Class("LearnerRegrFairzlm", ps$values = list(unfairness = 0.05, intersect = FALSE) super$initialize( id = "regr.fairzlm", - packages = "fairml", + packages = c("fairml", "CVXR"), feature_types = c("integer", "numeric", "factor", "ordered"), predict_types = c("response"), param_set = ps, diff --git a/tests/testthat/test_learners_fairml.R b/tests/testthat/test_learners_fairml.R index 4f811fac..f6c166d7 100644 --- a/tests/testthat/test_learners_fairml.R +++ b/tests/testthat/test_learners_fairml.R @@ -24,9 +24,10 @@ test_that("regr.fairfrrm", { test_that("regr.fairzlm", { skip_on_cran() skip_if_not_installed("fairml") + skip_if_not_installed("CVXR") learner = lrn("regr.fairzlm", unfairness = 0.5) out = expect_learner(learner) - + task = TaskRegr$new("long", fairml::national.longitudinal.survey, target = "income06") task$col_roles$pta = "gender" simple_autotest(learner, task) @@ -38,6 +39,7 @@ test_that("regr.fairzlm", { test_that("classif.fairzlrm", { skip_on_cran() skip_if_not_installed("fairml") + skip_if_not_installed("CVXR") learner = lrn("classif.fairzlrm", unfairness = 0.2) out = expect_learner(learner)