From 05f829b3a9b023975ac2c0b637fb4ccee13af481 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Mon, 28 Oct 2024 23:26:20 -0700 Subject: [PATCH] add prefix argument to orbital --- NEWS.md | 2 ++ R/orbital.R | 6 +++++- R/parsnip.R | 4 ++-- R/recipes.R | 9 +++++++-- R/workflows.R | 8 ++++---- man/orbital.Rd | 7 ++++++- tests/testthat/_snaps/parsnip.md | 2 +- tests/testthat/test-orbital.R | 21 +++++++++++++++++++++ tests/testthat/test-parsnip.R | 16 +++++++++++++++- tests/testthat/test-recipes.R | 14 ++++++++++++++ 10 files changed, 77 insertions(+), 12 deletions(-) diff --git a/NEWS.md b/NEWS.md index 47f73d3..db72f35 100644 --- a/NEWS.md +++ b/NEWS.md @@ -2,6 +2,8 @@ * `augment()` method for orbital object have been added. (#55) +* `orbital()` gained `prefix` argument to allow for renaming of prediction columns. (#59) + # orbital 0.2.0 * Support for `step_dummy()`, `step_impute_mean()`, `step_impute_median()`, `step_impute_mode()`, `step_unknown()`, `step_novel()`, `step_other()`, `step_BoxCox()`, `step_inverse()`, `step_mutate()`, `step_sqrt()`, `step_indicate_na()`, `step_range()`, `step_intercept()`, `step_ratio()`, `step_lag()`, `step_log()`, `step_rename()` has been added. (#17) diff --git a/R/orbital.R b/R/orbital.R index 50696be..2ed5970 100644 --- a/R/orbital.R +++ b/R/orbital.R @@ -6,6 +6,10 @@ #' #' @param x A fitted workflow, parsnip, or recipes object. #' @param ... Not currently used. +#' @param prefix A single string, specifies the prediction naming scheme. +#' If `x` produces a prediction, tidymodels standards dictate that the +#' predictions will start with `.pred`. This is not a valid name for +#' some data bases. #' #' @returns An [orbital] object. #' @@ -54,7 +58,7 @@ #' orbital() #' #' @export -orbital <- function(x, ...) { +orbital <- function(x, ..., prefix = ".pred") { UseMethod("orbital") } diff --git a/R/parsnip.R b/R/parsnip.R index 921737b..5784c11 100644 --- a/R/parsnip.R +++ b/R/parsnip.R @@ -1,5 +1,5 @@ #' @export -orbital.model_fit <- function(x, ...) { +orbital.model_fit <- function(x, ..., prefix = ".pred") { res <- tryCatch( tidypredict::tidypredict_fit(x), error = function(cnd) { @@ -17,7 +17,7 @@ orbital.model_fit <- function(x, ...) { } ) - res <- c(".pred" = deparse1(res)) + res <- stats::setNames(deparse1(res), prefix) new_orbital_class(res) } diff --git a/R/recipes.R b/R/recipes.R index 74de526..0027b44 100644 --- a/R/recipes.R +++ b/R/recipes.R @@ -1,6 +1,6 @@ #' @export -orbital.recipe <- function(x, eqs = NULL, ...) { +orbital.recipe <- function(x, eqs = NULL, ..., prefix = ".pred") { rlang::check_installed("glue") if (!recipes::fully_trained(x)) { cli::cli_abort("recipe must be fully trained.") @@ -15,7 +15,12 @@ orbital.recipe <- function(x, eqs = NULL, ...) { n_steps <- length(x$steps) - out <- c(.pred = unname(eqs)) + if (is.null(eqs)) { + out <- c() + } else { + out <- stats::setNames(unname(eqs), prefix) + } + for (step in rev(x$steps)) { if (step$skip) { next diff --git a/R/workflows.R b/R/workflows.R index a7f94fa..862d79f 100644 --- a/R/workflows.R +++ b/R/workflows.R @@ -1,5 +1,5 @@ #' @export -orbital.workflow <- function(x, ...) { +orbital.workflow <- function(x, ..., prefix = ".pred") { if (!workflows::is_trained_workflow(x)) { cli::cli_abort("{.arg x} must be a fully trained {.cls workflow}.") } @@ -9,15 +9,15 @@ orbital.workflow <- function(x, ...) { } model_fit <- workflows::extract_fit_parsnip(x) - out <- orbital(model_fit) + out <- orbital(model_fit, prefix = prefix) preprocessor <- workflows::extract_preprocessor(x) if (inherits(preprocessor, "recipe")) { recipe_fit <- workflows::extract_recipe(x) - out <- orbital(recipe_fit, out) + out <- orbital(recipe_fit, out, prefix = prefix) } new_orbital_class(out) -} \ No newline at end of file +} diff --git a/man/orbital.Rd b/man/orbital.Rd index 973bb45..e25bc7a 100644 --- a/man/orbital.Rd +++ b/man/orbital.Rd @@ -4,12 +4,17 @@ \alias{orbital} \title{Turn tidymodels objects into orbital objects} \usage{ -orbital(x, ...) +orbital(x, ..., prefix = ".pred") } \arguments{ \item{x}{A fitted workflow, parsnip, or recipes object.} \item{...}{Not currently used.} + +\item{prefix}{A single string, specifies the prediction naming scheme. +If \code{x} produces a prediction, tidymodels standards dictate that the +predictions will start with \code{.pred}. This is not a valid name for +some data bases.} } \value{ An \link{orbital} object. diff --git a/tests/testthat/_snaps/parsnip.md b/tests/testthat/_snaps/parsnip.md index 7d54614..9e63dc2 100644 --- a/tests/testthat/_snaps/parsnip.md +++ b/tests/testthat/_snaps/parsnip.md @@ -1,4 +1,4 @@ -# multiplication works +# normal usage works works Code orbital(wf_fit) diff --git a/tests/testthat/test-orbital.R b/tests/testthat/test-orbital.R index cf45df0..4246a19 100644 --- a/tests/testthat/test-orbital.R +++ b/tests/testthat/test-orbital.R @@ -175,3 +175,24 @@ test_that("orbital printing works", { print(orbital(wf_fit), truncate = FALSE) ) }) + +test_that("prefix argument works", { + skip_if_not_installed("recipes") + skip_if_not_installed("parsnip") + skip_if_not_installed("workflows") + skip_if_not_installed("tidypredict") + + rec_spec <- recipes::recipe(mpg ~ ., data = mtcars) %>% + recipes::step_normalize(recipes::all_numeric_predictors()) + + lm_spec <- parsnip::linear_reg() + + wf_spec <- workflows::workflow(rec_spec, lm_spec) + + wf_fit <- parsnip::fit(wf_spec, mtcars) + + orb_obj <- orbital(wf_fit, prefix = "pred") + + expect_true("pred" %in% names(orb_obj)) + expect_false(".pred" %in% names(orb_obj)) +}) diff --git a/tests/testthat/test-parsnip.R b/tests/testthat/test-parsnip.R index bff9988..b4b96b1 100644 --- a/tests/testthat/test-parsnip.R +++ b/tests/testthat/test-parsnip.R @@ -1,4 +1,4 @@ -test_that("multiplication works", { +test_that("normal usage works works", { skip_if_not_installed("recipes") skip_if_not_installed("parsnip") skip_if_not_installed("workflows") @@ -19,3 +19,17 @@ test_that("multiplication works", { orbital(wf_fit) ) }) + +test_that("prefix argument works", { + skip_if_not_installed("parsnip") + skip_if_not_installed("tidypredict") + + lm_spec <- parsnip::linear_reg() + + lm_fit <- parsnip::fit(lm_spec, mpg ~ ., mtcars) + + orb_obj <- orbital(lm_fit, prefix = "pred") + + expect_true("pred" %in% names(orb_obj)) + expect_false(".pred" %in% names(orb_obj)) +}) diff --git a/tests/testthat/test-recipes.R b/tests/testthat/test-recipes.R index 993170b..a63b4cc 100644 --- a/tests/testthat/test-recipes.R +++ b/tests/testthat/test-recipes.R @@ -39,3 +39,17 @@ test_that("recipe works with skip argument", { expect_equal(res, exp) }) + +test_that("prefix argument works", { + skip_if_not_installed("recipes") + + rec_spec <- recipes::recipe(mpg ~ ., data = mtcars) %>% + recipes::step_normalize(recipes::all_numeric_predictors()) + + rec_fit <- recipes::prep(rec_spec) + + orb_obj <- orbital(rec_fit, prefix = "pred") + + expect_false("pred" %in% names(orb_obj)) + expect_false(".pred" %in% names(orb_obj)) +})