From 132abc965910626ad3cfe6b602a1451cf7b944a4 Mon Sep 17 00:00:00 2001 From: simonpcouch Date: Fri, 5 Apr 2024 10:26:04 -0500 Subject: [PATCH 1/3] test tunable output --- tests/testthat/_snaps/tunable.md | 410 +++++++++++++++++++++++++++++++ tests/testthat/test-tunable.R | 105 ++++++++ 2 files changed, 515 insertions(+) create mode 100644 tests/testthat/_snaps/tunable.md create mode 100644 tests/testthat/test-tunable.R diff --git a/tests/testthat/_snaps/tunable.md b/tests/testthat/_snaps/tunable.md new file mode 100644 index 000000000..e75281f08 --- /dev/null +++ b/tests/testthat/_snaps/tunable.md @@ -0,0 +1,410 @@ +# tunable.linear_reg() + + Code + tunable(spec) + Output + # A tibble: 0 x 5 + # i 5 variables: name , call_info , source , component , + # component_id + +--- + + Code + tunable(spec %>% set_engine("lm")) + Output + # A tibble: 0 x 5 + # i 5 variables: name , call_info , source , component , + # component_id + +--- + + Code + tunable(spec %>% set_engine("glmnet")) + Output + # A tibble: 2 x 5 + name call_info source component component_id + + 1 penalty model_spec linear_reg main + 2 mixture model_spec linear_reg main + +--- + + Code + tunable(spec %>% set_engine("brulee")) + Output + # A tibble: 2 x 5 + name call_info source component component_id + + 1 penalty model_spec linear_reg main + 2 mixture model_spec linear_reg main + +# tunable.logistic_reg() + + Code + tunable(spec) + Output + # A tibble: 0 x 5 + # i 5 variables: name , call_info , source , component , + # component_id + +--- + + Code + tunable(spec %>% set_engine("glm")) + Output + # A tibble: 0 x 5 + # i 5 variables: name , call_info , source , component , + # component_id + +--- + + Code + tunable(spec %>% set_engine("glmnet")) + Output + # A tibble: 2 x 5 + name call_info source component component_id + + 1 penalty model_spec logistic_reg main + 2 mixture model_spec logistic_reg main + +--- + + Code + tunable(spec %>% set_engine("brulee")) + Output + # A tibble: 2 x 5 + name call_info source component component_id + + 1 penalty model_spec logistic_reg main + 2 mixture model_spec logistic_reg main + +# tunable.multinom_reg() + + Code + tunable(spec) + Output + # A tibble: 1 x 5 + name call_info source component component_id + + 1 penalty model_spec multinom_reg main + +--- + + Code + tunable(spec %>% set_engine("glmnet")) + Output + # A tibble: 2 x 5 + name call_info source component component_id + + 1 penalty model_spec multinom_reg main + 2 mixture model_spec multinom_reg main + +--- + + Code + tunable(spec %>% set_engine("spark")) + Output + # A tibble: 2 x 5 + name call_info source component component_id + + 1 penalty model_spec multinom_reg main + 2 mixture model_spec multinom_reg main + +--- + + Code + tunable(spec %>% set_engine("keras")) + Output + # A tibble: 1 x 5 + name call_info source component component_id + + 1 penalty model_spec multinom_reg main + +--- + + Code + tunable(spec %>% set_engine("nnet")) + Output + # A tibble: 1 x 5 + name call_info source component component_id + + 1 penalty model_spec multinom_reg main + +--- + + Code + tunable(spec %>% set_engine("brulee")) + Output + # A tibble: 2 x 5 + name call_info source component component_id + + 1 penalty model_spec multinom_reg main + 2 mixture model_spec multinom_reg main + +# tunable.boost_tree() + + Code + tunable(spec) + Output + # A tibble: 8 x 5 + name call_info source component component_id + + 1 tree_depth model_spec boost_tree main + 2 trees model_spec boost_tree main + 3 learn_rate model_spec boost_tree main + 4 mtry model_spec boost_tree main + 5 min_n model_spec boost_tree main + 6 loss_reduction model_spec boost_tree main + 7 sample_size model_spec boost_tree main + 8 stop_iter model_spec boost_tree main + +--- + + Code + tunable(spec %>% set_engine("xgboost")) + Output + # A tibble: 8 x 5 + name call_info source component component_id + + 1 tree_depth model_spec boost_tree main + 2 trees model_spec boost_tree main + 3 learn_rate model_spec boost_tree main + 4 mtry model_spec boost_tree main + 5 min_n model_spec boost_tree main + 6 loss_reduction model_spec boost_tree main + 7 sample_size model_spec boost_tree main + 8 stop_iter model_spec boost_tree main + +--- + + Code + tunable(spec %>% set_engine("C5.0")) + Output + # A tibble: 3 x 5 + name call_info source component component_id + + 1 trees model_spec boost_tree main + 2 min_n model_spec boost_tree main + 3 sample_size model_spec boost_tree main + +--- + + Code + tunable(spec %>% set_engine("spark")) + Output + # A tibble: 7 x 5 + name call_info source component component_id + + 1 tree_depth model_spec boost_tree main + 2 trees model_spec boost_tree main + 3 learn_rate model_spec boost_tree main + 4 mtry model_spec boost_tree main + 5 min_n model_spec boost_tree main + 6 loss_reduction model_spec boost_tree main + 7 sample_size model_spec boost_tree main + +# tunable.rand_forest() + + Code + tunable(spec) + Output + # A tibble: 3 x 5 + name call_info source component component_id + + 1 mtry model_spec rand_forest main + 2 trees model_spec rand_forest main + 3 min_n model_spec rand_forest main + +--- + + Code + tunable(spec %>% set_engine("ranger")) + Output + # A tibble: 3 x 5 + name call_info source component component_id + + 1 mtry model_spec rand_forest main + 2 trees model_spec rand_forest main + 3 min_n model_spec rand_forest main + +--- + + Code + tunable(spec %>% set_engine("randomForest")) + Output + # A tibble: 3 x 5 + name call_info source component component_id + + 1 mtry model_spec rand_forest main + 2 trees model_spec rand_forest main + 3 min_n model_spec rand_forest main + +--- + + Code + tunable(spec %>% set_engine("spark")) + Output + # A tibble: 3 x 5 + name call_info source component component_id + + 1 mtry model_spec rand_forest main + 2 trees model_spec rand_forest main + 3 min_n model_spec rand_forest main + +# tunable.mars() + + Code + tunable(spec) + Output + # A tibble: 3 x 5 + name call_info source component component_id + + 1 num_terms model_spec mars main + 2 prod_degree model_spec mars main + 3 prune_method model_spec mars main + +--- + + Code + tunable(spec %>% set_engine("earth")) + Output + # A tibble: 3 x 5 + name call_info source component component_id + + 1 num_terms model_spec mars main + 2 prod_degree model_spec mars main + 3 prune_method model_spec mars main + +# tunable.decision_tree() + + Code + tunable(spec) + Output + # A tibble: 3 x 5 + name call_info source component component_id + + 1 tree_depth model_spec decision_tree main + 2 min_n model_spec decision_tree main + 3 cost_complexity model_spec decision_tree main + +--- + + Code + tunable(spec %>% set_engine("rpart")) + Output + # A tibble: 3 x 5 + name call_info source component component_id + + 1 tree_depth model_spec decision_tree main + 2 min_n model_spec decision_tree main + 3 cost_complexity model_spec decision_tree main + +--- + + Code + tunable(spec %>% set_engine("C5.0")) + Output + # A tibble: 1 x 5 + name call_info source component component_id + + 1 min_n model_spec decision_tree main + +--- + + Code + tunable(spec %>% set_engine("spark")) + Output + # A tibble: 2 x 5 + name call_info source component component_id + + 1 tree_depth model_spec decision_tree main + 2 min_n model_spec decision_tree main + +# tunable.svm_poly() + + Code + tunable(spec) + Output + # A tibble: 4 x 5 + name call_info source component component_id + + 1 cost model_spec svm_poly main + 2 degree model_spec svm_poly main + 3 scale_factor model_spec svm_poly main + 4 margin model_spec svm_poly main + +--- + + Code + tunable(spec %>% set_engine("kernlab")) + Output + # A tibble: 4 x 5 + name call_info source component component_id + + 1 cost model_spec svm_poly main + 2 degree model_spec svm_poly main + 3 scale_factor model_spec svm_poly main + 4 margin model_spec svm_poly main + +# tunable.mlp() + + Code + tunable(spec) + Output + # A tibble: 3 x 5 + name call_info source component component_id + + 1 hidden_units model_spec mlp main + 2 penalty model_spec mlp main + 3 epochs model_spec mlp main + +--- + + Code + tunable(spec %>% set_engine("keras")) + Output + # A tibble: 5 x 5 + name call_info source component component_id + + 1 hidden_units model_spec mlp main + 2 penalty model_spec mlp main + 3 dropout model_spec mlp main + 4 epochs model_spec mlp main + 5 activation model_spec mlp main + +--- + + Code + tunable(spec %>% set_engine("nnet")) + Output + # A tibble: 3 x 5 + name call_info source component component_id + + 1 hidden_units model_spec mlp main + 2 penalty model_spec mlp main + 3 epochs model_spec mlp main + +--- + + Code + tunable(spec %>% set_engine("brulee")) + Output + # A tibble: 6 x 5 + name call_info source component component_id + + 1 hidden_units model_spec mlp main + 2 penalty model_spec mlp main + 3 epochs model_spec mlp main + 4 dropout model_spec mlp main + 5 learn_rate model_spec mlp main + 6 activation model_spec mlp main + +# tunable.survival_reg() + + Code + tunable(spec) + Output + # A tibble: 0 x 5 + # i 5 variables: name , call_info , source , component , + # component_id + diff --git a/tests/testthat/test-tunable.R b/tests/testthat/test-tunable.R new file mode 100644 index 000000000..8f799594c --- /dev/null +++ b/tests/testthat/test-tunable.R @@ -0,0 +1,105 @@ +# general pattern, for each tunable method: +# define `spec`, run `show_engines()` with only parsnip loaded, +# snapshot test `tunable()` output for each unique engine. +# +# note that, as implemented, parsnip can return `tunable()` information +# for engines that it cannot fit without first loading an extension package. +# +# the specific contents of call_info are just hard-coded tibbles in the +# source, so snapshot testing only for their presence rather than contents. + +test_that("tunable.linear_reg()", { + spec <- linear_reg() + expect_snapshot(tunable(spec)) + expect_snapshot(tunable(spec %>% set_engine("lm"))) + expect_snapshot(tunable(spec %>% set_engine("glmnet"))) + expect_snapshot(tunable(spec %>% set_engine("brulee"))) + + # tests for call_info in tidymodels/extratests +}) + +test_that("tunable.logistic_reg()", { + spec <- logistic_reg() + expect_snapshot(tunable(spec)) + expect_snapshot(tunable(spec %>% set_engine("glm"))) + expect_snapshot(tunable(spec %>% set_engine("glmnet"))) + expect_snapshot(tunable(spec %>% set_engine("brulee"))) + + # tests for call_info and additional engines in tidymodels/extratests +}) + +test_that("tunable.multinom_reg()", { + spec <- multinom_reg() + expect_snapshot(tunable(spec)) + expect_snapshot(tunable(spec %>% set_engine("glmnet"))) + expect_snapshot(tunable(spec %>% set_engine("spark"))) + expect_snapshot(tunable(spec %>% set_engine("keras"))) + expect_snapshot(tunable(spec %>% set_engine("nnet"))) + expect_snapshot(tunable(spec %>% set_engine("brulee"))) + + # tests for call_info and additional engines in tidymodels/extratests +}) + +test_that("tunable.boost_tree()", { + spec <- boost_tree() + expect_snapshot(tunable(spec)) + expect_snapshot(tunable(spec %>% set_engine("xgboost"))) + expect_snapshot(tunable(spec %>% set_engine("C5.0"))) + expect_snapshot(tunable(spec %>% set_engine("spark"))) + + # tests for call_info and additional engines in tidymodels/extratests +}) + +test_that("tunable.rand_forest()", { + spec <- rand_forest() + expect_snapshot(tunable(spec)) + expect_snapshot(tunable(spec %>% set_engine("ranger"))) + expect_snapshot(tunable(spec %>% set_engine("randomForest"))) + expect_snapshot(tunable(spec %>% set_engine("spark"))) + + # tests for call_info and additional engines in tidymodels/extratests +}) + +test_that("tunable.mars()", { + spec <- mars() + expect_snapshot(tunable(spec)) + expect_snapshot(tunable(spec %>% set_engine("earth"))) + + # tests for call_info and additional engines in tidymodels/extratests +}) + +test_that("tunable.decision_tree()", { + spec <- decision_tree() + expect_snapshot(tunable(spec)) + expect_snapshot(tunable(spec %>% set_engine("rpart"))) + expect_snapshot(tunable(spec %>% set_engine("C5.0"))) + expect_snapshot(tunable(spec %>% set_engine("spark"))) + + # tests for call_info and additional engines in tidymodels/extratests +}) + +test_that("tunable.svm_poly()", { + spec <- svm_poly() + expect_snapshot(tunable(spec)) + expect_snapshot(tunable(spec %>% set_engine("kernlab"))) + + # tests for call_info and additional engines in tidymodels/extratests +}) + +test_that("tunable.mlp()", { + spec <- mlp() + expect_snapshot(tunable(spec)) + expect_snapshot(tunable(spec %>% set_engine("keras"))) + expect_snapshot(tunable(spec %>% set_engine("nnet"))) + expect_snapshot(tunable(spec %>% set_engine("brulee"))) + + # tests for call_info and additional engines in tidymodels/extratests +}) + + +test_that("tunable.survival_reg()", { + spec <- survival_reg() + expect_snapshot(tunable(spec)) + + # tests for call_info and additional engines in tidymodels/extratests +}) From 6e8106ec9443c95c06e727037d4d4bbc14745611 Mon Sep 17 00:00:00 2001 From: simonpcouch Date: Fri, 5 Apr 2024 10:42:08 -0500 Subject: [PATCH 2/3] demonstrate 1104 --- tests/testthat/_snaps/tunable.md | 141 +++++++++++++++++++++++++++++++ tests/testthat/test-tunable.R | 31 ++++--- 2 files changed, 162 insertions(+), 10 deletions(-) diff --git a/tests/testthat/_snaps/tunable.md b/tests/testthat/_snaps/tunable.md index e75281f08..e2ec2c208 100644 --- a/tests/testthat/_snaps/tunable.md +++ b/tests/testthat/_snaps/tunable.md @@ -38,6 +38,18 @@ 1 penalty model_spec linear_reg main 2 mixture model_spec linear_reg main +--- + + Code + tunable(spec %>% set_engine("glmnet", dfmax = tune())) + Output + # A tibble: 3 x 5 + name call_info source component component_id + + 1 penalty model_spec linear_reg main + 2 mixture model_spec linear_reg main + 3 dfmax model_spec linear_reg engine + # tunable.logistic_reg() Code @@ -78,6 +90,18 @@ 1 penalty model_spec logistic_reg main 2 mixture model_spec logistic_reg main +--- + + Code + tunable(spec %>% set_engine("glmnet", dfmax = tune())) + Output + # A tibble: 3 x 5 + name call_info source component component_id + + 1 penalty model_spec logistic_reg main + 2 mixture model_spec logistic_reg main + 3 dfmax model_spec logistic_reg engine + # tunable.multinom_reg() Code @@ -141,6 +165,18 @@ 1 penalty model_spec multinom_reg main 2 mixture model_spec multinom_reg main +--- + + Code + tunable(spec %>% set_engine("glmnet", dfmax = tune())) + Output + # A tibble: 3 x 5 + name call_info source component component_id + + 1 penalty model_spec multinom_reg main + 2 mixture model_spec multinom_reg main + 3 dfmax model_spec multinom_reg engine + # tunable.boost_tree() Code @@ -203,6 +239,24 @@ 6 loss_reduction model_spec boost_tree main 7 sample_size model_spec boost_tree main +--- + + Code + tunable(spec %>% set_engine("xgboost", feval = tune())) + Output + # A tibble: 9 x 5 + name call_info source component component_id + + 1 tree_depth model_spec boost_tree main + 2 trees model_spec boost_tree main + 3 learn_rate model_spec boost_tree main + 4 mtry model_spec boost_tree main + 5 min_n model_spec boost_tree main + 6 loss_reduction model_spec boost_tree main + 7 sample_size model_spec boost_tree main + 8 stop_iter model_spec boost_tree main + 9 feval model_spec boost_tree engine + # tunable.rand_forest() Code @@ -251,6 +305,19 @@ 2 trees model_spec rand_forest main 3 min_n model_spec rand_forest main +--- + + Code + tunable(spec %>% set_engine("ranger", min.bucket = tune())) + Output + # A tibble: 4 x 5 + name call_info source component component_id + + 1 mtry model_spec rand_forest main + 2 trees model_spec rand_forest main + 3 min_n model_spec rand_forest main + 4 min.bucket model_spec rand_forest engine + # tunable.mars() Code @@ -275,6 +342,19 @@ 2 prod_degree model_spec mars main 3 prune_method model_spec mars main +--- + + Code + tunable(spec %>% set_engine("earth", minspan = tune())) + Output + # A tibble: 4 x 5 + name call_info source component component_id + + 1 num_terms model_spec mars main + 2 prod_degree model_spec mars main + 3 prune_method model_spec mars main + 4 minspan model_spec mars engine + # tunable.decision_tree() Code @@ -320,6 +400,19 @@ 1 tree_depth model_spec decision_tree main 2 min_n model_spec decision_tree main +--- + + Code + tunable(spec %>% set_engine("rpart", parms = tune())) + Output + # A tibble: 4 x 5 + name call_info source component component_id + + 1 tree_depth model_spec decision_tree main + 2 min_n model_spec decision_tree main + 3 cost_complexity model_spec decision_tree main + 4 parms model_spec decision_tree engine + # tunable.svm_poly() Code @@ -346,6 +439,20 @@ 3 scale_factor model_spec svm_poly main 4 margin model_spec svm_poly main +--- + + Code + tunable(spec %>% set_engine("kernlab", tol = tune())) + Output + # A tibble: 5 x 5 + name call_info source component component_id + + 1 cost model_spec svm_poly main + 2 degree model_spec svm_poly main + 3 scale_factor model_spec svm_poly main + 4 margin model_spec svm_poly main + 5 tol model_spec svm_poly engine + # tunable.mlp() Code @@ -399,6 +506,21 @@ 5 learn_rate model_spec mlp main 6 activation model_spec mlp main +--- + + Code + tunable(spec %>% set_engine("keras", ragged = tune())) + Output + # A tibble: 6 x 5 + name call_info source component component_id + + 1 hidden_units model_spec mlp main + 2 penalty model_spec mlp main + 3 dropout model_spec mlp main + 4 epochs model_spec mlp main + 5 activation model_spec mlp main + 6 ragged model_spec mlp engine + # tunable.survival_reg() Code @@ -408,3 +530,22 @@ # i 5 variables: name , call_info , source , component , # component_id +--- + + Code + tunable(spec %>% set_engine("survival")) + Output + # A tibble: 0 x 5 + # i 5 variables: name , call_info , source , component , + # component_id + +--- + + Code + tunable(spec %>% set_engine("survival", parms = tune())) + Output + # A tibble: 1 x 5 + name call_info source component component_id + + 1 parms model_spec survival_reg engine + diff --git a/tests/testthat/test-tunable.R b/tests/testthat/test-tunable.R index 8f799594c..8e8dd6ab7 100644 --- a/tests/testthat/test-tunable.R +++ b/tests/testthat/test-tunable.R @@ -15,7 +15,8 @@ test_that("tunable.linear_reg()", { expect_snapshot(tunable(spec %>% set_engine("glmnet"))) expect_snapshot(tunable(spec %>% set_engine("brulee"))) - # tests for call_info in tidymodels/extratests + # don't include rows for non-tunable args marked with tune() (#1104) + expect_snapshot(tunable(spec %>% set_engine("glmnet", dfmax = tune()))) }) test_that("tunable.logistic_reg()", { @@ -25,7 +26,8 @@ test_that("tunable.logistic_reg()", { expect_snapshot(tunable(spec %>% set_engine("glmnet"))) expect_snapshot(tunable(spec %>% set_engine("brulee"))) - # tests for call_info and additional engines in tidymodels/extratests + # don't include rows for non-tunable args marked with tune() (#1104) + expect_snapshot(tunable(spec %>% set_engine("glmnet", dfmax = tune()))) }) test_that("tunable.multinom_reg()", { @@ -37,7 +39,8 @@ test_that("tunable.multinom_reg()", { expect_snapshot(tunable(spec %>% set_engine("nnet"))) expect_snapshot(tunable(spec %>% set_engine("brulee"))) - # tests for call_info and additional engines in tidymodels/extratests + # don't include rows for non-tunable args marked with tune() (#1104) + expect_snapshot(tunable(spec %>% set_engine("glmnet", dfmax = tune()))) }) test_that("tunable.boost_tree()", { @@ -47,7 +50,8 @@ test_that("tunable.boost_tree()", { expect_snapshot(tunable(spec %>% set_engine("C5.0"))) expect_snapshot(tunable(spec %>% set_engine("spark"))) - # tests for call_info and additional engines in tidymodels/extratests + # don't include rows for non-tunable args marked with tune() (#1104) + expect_snapshot(tunable(spec %>% set_engine("xgboost", feval = tune()))) }) test_that("tunable.rand_forest()", { @@ -57,7 +61,8 @@ test_that("tunable.rand_forest()", { expect_snapshot(tunable(spec %>% set_engine("randomForest"))) expect_snapshot(tunable(spec %>% set_engine("spark"))) - # tests for call_info and additional engines in tidymodels/extratests + # don't include rows for non-tunable args marked with tune() (#1104) + expect_snapshot(tunable(spec %>% set_engine("ranger", min.bucket = tune()))) }) test_that("tunable.mars()", { @@ -65,7 +70,8 @@ test_that("tunable.mars()", { expect_snapshot(tunable(spec)) expect_snapshot(tunable(spec %>% set_engine("earth"))) - # tests for call_info and additional engines in tidymodels/extratests + # don't include rows for non-tunable args marked with tune() (#1104) + expect_snapshot(tunable(spec %>% set_engine("earth", minspan = tune()))) }) test_that("tunable.decision_tree()", { @@ -75,7 +81,8 @@ test_that("tunable.decision_tree()", { expect_snapshot(tunable(spec %>% set_engine("C5.0"))) expect_snapshot(tunable(spec %>% set_engine("spark"))) - # tests for call_info and additional engines in tidymodels/extratests + # don't include rows for non-tunable args marked with tune() (#1104) + expect_snapshot(tunable(spec %>% set_engine("rpart", parms = tune()))) }) test_that("tunable.svm_poly()", { @@ -83,7 +90,8 @@ test_that("tunable.svm_poly()", { expect_snapshot(tunable(spec)) expect_snapshot(tunable(spec %>% set_engine("kernlab"))) - # tests for call_info and additional engines in tidymodels/extratests + # don't include rows for non-tunable args marked with tune() (#1104) + expect_snapshot(tunable(spec %>% set_engine("kernlab", tol = tune()))) }) test_that("tunable.mlp()", { @@ -93,13 +101,16 @@ test_that("tunable.mlp()", { expect_snapshot(tunable(spec %>% set_engine("nnet"))) expect_snapshot(tunable(spec %>% set_engine("brulee"))) - # tests for call_info and additional engines in tidymodels/extratests + # don't include rows for non-tunable args marked with tune() (#1104) + expect_snapshot(tunable(spec %>% set_engine("keras", ragged = tune()))) }) test_that("tunable.survival_reg()", { spec <- survival_reg() expect_snapshot(tunable(spec)) + expect_snapshot(tunable(spec %>% set_engine("survival"))) - # tests for call_info and additional engines in tidymodels/extratests + # don't include rows for non-tunable args marked with tune() (#1104) + expect_snapshot(tunable(spec %>% set_engine("survival", parms = tune()))) }) From 21a56c85da2fff272d3f594eaa6e39acabe6e99f Mon Sep 17 00:00:00 2001 From: simonpcouch Date: Fri, 5 Apr 2024 10:48:16 -0500 Subject: [PATCH 3/3] exclude non-tunable engine arguments in `tunable()` closes #1104 --- R/tunable.R | 20 ++++++++-------- tests/testthat/_snaps/tunable.md | 41 +++++++++++++------------------- 2 files changed, 26 insertions(+), 35 deletions(-) diff --git a/R/tunable.R b/R/tunable.R index 85c8bff29..271c470e2 100644 --- a/R/tunable.R +++ b/R/tunable.R @@ -248,7 +248,7 @@ tunable.linear_reg <- function(x, ...) { } else if (x$engine == "brulee") { res <- add_engine_parameters(res, brulee_linear_engine_args) } - res + res[!vapply(res$call_info, is.null, logical(1)), ] } #' @export @@ -260,7 +260,7 @@ tunable.logistic_reg <- function(x, ...) { } else if (x$engine == "brulee") { res <- add_engine_parameters(res, brulee_logistic_engine_args) } - res + res[!vapply(res$call_info, is.null, logical(1)), ] } #' @export @@ -272,7 +272,7 @@ tunable.multinomial_reg <- function(x, ...) { } else if (x$engine == "brulee") { res <- add_engine_parameters(res, brulee_multinomial_engine_args) } - res + res[!vapply(res$call_info, is.null, logical(1)), ] } #' @export @@ -295,7 +295,7 @@ tunable.boost_tree <- function(x, ...) { res$call_info[res$name == "sample_size"] <- list(list(pkg = "dials", fun = "sample_prop")) } - res + res[!vapply(res$call_info, is.null, logical(1)), ] } #' @export @@ -310,7 +310,7 @@ tunable.rand_forest <- function(x, ...) { } else if (x$engine == "aorsf") { res <- add_engine_parameters(res, aorsf_engine_args) } - res + res[!vapply(res$call_info, is.null, logical(1)), ] } #' @export @@ -319,7 +319,7 @@ tunable.mars <- function(x, ...) { if (x$engine == "earth") { res <- add_engine_parameters(res, earth_engine_args) } - res + res[!vapply(res$call_info, is.null, logical(1)), ] } #' @export @@ -333,7 +333,7 @@ tunable.decision_tree <- function(x, ...) { partykit_engine_args %>% dplyr::mutate(component = "decision_tree")) } - res + res[!vapply(res$call_info, is.null, logical(1)), ] } #' @export @@ -343,7 +343,7 @@ tunable.svm_poly <- function(x, ...) { res$call_info[res$name == "degree"] <- list(list(pkg = "dials", fun = "prod_degree", range = c(1L, 3L))) } - res + res[!vapply(res$call_info, is.null, logical(1)), ] } @@ -357,7 +357,7 @@ tunable.mlp <- function(x, ...) { res$call_info[res$name == "epochs"] <- list(list(pkg = "dials", fun = "epochs", range = c(5L, 500L))) } - res + res[!vapply(res$call_info, is.null, logical(1)), ] } #' @export @@ -366,7 +366,7 @@ tunable.survival_reg <- function(x, ...) { if (x$engine == "flexsurvspline") { res <- add_engine_parameters(res, flexsurvspline_engine_args) } - res + res[!vapply(res$call_info, is.null, logical(1)), ] } # nocov end diff --git a/tests/testthat/_snaps/tunable.md b/tests/testthat/_snaps/tunable.md index e2ec2c208..9f8b6ba3b 100644 --- a/tests/testthat/_snaps/tunable.md +++ b/tests/testthat/_snaps/tunable.md @@ -43,12 +43,11 @@ Code tunable(spec %>% set_engine("glmnet", dfmax = tune())) Output - # A tibble: 3 x 5 + # A tibble: 2 x 5 name call_info source component component_id 1 penalty model_spec linear_reg main 2 mixture model_spec linear_reg main - 3 dfmax model_spec linear_reg engine # tunable.logistic_reg() @@ -95,12 +94,11 @@ Code tunable(spec %>% set_engine("glmnet", dfmax = tune())) Output - # A tibble: 3 x 5 + # A tibble: 2 x 5 name call_info source component component_id 1 penalty model_spec logistic_reg main 2 mixture model_spec logistic_reg main - 3 dfmax model_spec logistic_reg engine # tunable.multinom_reg() @@ -244,7 +242,7 @@ Code tunable(spec %>% set_engine("xgboost", feval = tune())) Output - # A tibble: 9 x 5 + # A tibble: 8 x 5 name call_info source component component_id 1 tree_depth model_spec boost_tree main @@ -255,7 +253,6 @@ 6 loss_reduction model_spec boost_tree main 7 sample_size model_spec boost_tree main 8 stop_iter model_spec boost_tree main - 9 feval model_spec boost_tree engine # tunable.rand_forest() @@ -310,13 +307,12 @@ Code tunable(spec %>% set_engine("ranger", min.bucket = tune())) Output - # A tibble: 4 x 5 - name call_info source component component_id - - 1 mtry model_spec rand_forest main - 2 trees model_spec rand_forest main - 3 min_n model_spec rand_forest main - 4 min.bucket model_spec rand_forest engine + # A tibble: 3 x 5 + name call_info source component component_id + + 1 mtry model_spec rand_forest main + 2 trees model_spec rand_forest main + 3 min_n model_spec rand_forest main # tunable.mars() @@ -347,13 +343,12 @@ Code tunable(spec %>% set_engine("earth", minspan = tune())) Output - # A tibble: 4 x 5 + # A tibble: 3 x 5 name call_info source component component_id 1 num_terms model_spec mars main 2 prod_degree model_spec mars main 3 prune_method model_spec mars main - 4 minspan model_spec mars engine # tunable.decision_tree() @@ -405,13 +400,12 @@ Code tunable(spec %>% set_engine("rpart", parms = tune())) Output - # A tibble: 4 x 5 + # A tibble: 3 x 5 name call_info source component component_id 1 tree_depth model_spec decision_tree main 2 min_n model_spec decision_tree main 3 cost_complexity model_spec decision_tree main - 4 parms model_spec decision_tree engine # tunable.svm_poly() @@ -444,14 +438,13 @@ Code tunable(spec %>% set_engine("kernlab", tol = tune())) Output - # A tibble: 5 x 5 + # A tibble: 4 x 5 name call_info source component component_id 1 cost model_spec svm_poly main 2 degree model_spec svm_poly main 3 scale_factor model_spec svm_poly main 4 margin model_spec svm_poly main - 5 tol model_spec svm_poly engine # tunable.mlp() @@ -511,7 +504,7 @@ Code tunable(spec %>% set_engine("keras", ragged = tune())) Output - # A tibble: 6 x 5 + # A tibble: 5 x 5 name call_info source component component_id 1 hidden_units model_spec mlp main @@ -519,7 +512,6 @@ 3 dropout model_spec mlp main 4 epochs model_spec mlp main 5 activation model_spec mlp main - 6 ragged model_spec mlp engine # tunable.survival_reg() @@ -544,8 +536,7 @@ Code tunable(spec %>% set_engine("survival", parms = tune())) Output - # A tibble: 1 x 5 - name call_info source component component_id - - 1 parms model_spec survival_reg engine + # A tibble: 0 x 5 + # i 5 variables: name , call_info , source , component , + # component_id