Skip to content

Commit

Permalink
Merge pull request #60 from tidymodels/prefix
Browse files Browse the repository at this point in the history
add prefix argument to orbital
  • Loading branch information
EmilHvitfeldt authored Oct 29, 2024
2 parents 93a1826 + 05f829b commit 055df81
Show file tree
Hide file tree
Showing 10 changed files with 77 additions and 12 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

* `augment()` method for orbital object have been added. (#55)

* `orbital()` gained `prefix` argument to allow for renaming of prediction columns. (#59)

# orbital 0.2.0

* Support for `step_dummy()`, `step_impute_mean()`, `step_impute_median()`, `step_impute_mode()`, `step_unknown()`, `step_novel()`, `step_other()`, `step_BoxCox()`, `step_inverse()`, `step_mutate()`, `step_sqrt()`, `step_indicate_na()`, `step_range()`, `step_intercept()`, `step_ratio()`, `step_lag()`, `step_log()`, `step_rename()` has been added. (#17)
Expand Down
6 changes: 5 additions & 1 deletion R/orbital.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
#'
#' @param x A fitted workflow, parsnip, or recipes object.
#' @param ... Not currently used.
#' @param prefix A single string, specifies the prediction naming scheme.
#' If `x` produces a prediction, tidymodels standards dictate that the
#' predictions will start with `.pred`. This is not a valid name for
#' some data bases.
#'
#' @returns An [orbital] object.
#'
Expand Down Expand Up @@ -54,7 +58,7 @@
#' orbital()
#'
#' @export
orbital <- function(x, ...) {
orbital <- function(x, ..., prefix = ".pred") {
UseMethod("orbital")
}

Expand Down
4 changes: 2 additions & 2 deletions R/parsnip.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#' @export
orbital.model_fit <- function(x, ...) {
orbital.model_fit <- function(x, ..., prefix = ".pred") {
res <- tryCatch(
tidypredict::tidypredict_fit(x),
error = function(cnd) {
Expand All @@ -17,7 +17,7 @@ orbital.model_fit <- function(x, ...) {
}
)

res <- c(".pred" = deparse1(res))
res <- stats::setNames(deparse1(res), prefix)

new_orbital_class(res)
}
Expand Down
9 changes: 7 additions & 2 deletions R/recipes.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

#' @export
orbital.recipe <- function(x, eqs = NULL, ...) {
orbital.recipe <- function(x, eqs = NULL, ..., prefix = ".pred") {
rlang::check_installed("glue")
if (!recipes::fully_trained(x)) {
cli::cli_abort("recipe must be fully trained.")
Expand All @@ -15,7 +15,12 @@ orbital.recipe <- function(x, eqs = NULL, ...) {

n_steps <- length(x$steps)

out <- c(.pred = unname(eqs))
if (is.null(eqs)) {
out <- c()
} else {
out <- stats::setNames(unname(eqs), prefix)
}

for (step in rev(x$steps)) {
if (step$skip) {
next
Expand Down
8 changes: 4 additions & 4 deletions R/workflows.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#' @export
orbital.workflow <- function(x, ...) {
orbital.workflow <- function(x, ..., prefix = ".pred") {
if (!workflows::is_trained_workflow(x)) {
cli::cli_abort("{.arg x} must be a fully trained {.cls workflow}.")
}
Expand All @@ -9,15 +9,15 @@ orbital.workflow <- function(x, ...) {
}

model_fit <- workflows::extract_fit_parsnip(x)
out <- orbital(model_fit)
out <- orbital(model_fit, prefix = prefix)

preprocessor <- workflows::extract_preprocessor(x)

if (inherits(preprocessor, "recipe")) {
recipe_fit <- workflows::extract_recipe(x)

out <- orbital(recipe_fit, out)
out <- orbital(recipe_fit, out, prefix = prefix)
}

new_orbital_class(out)
}
}
7 changes: 6 additions & 1 deletion man/orbital.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion tests/testthat/_snaps/parsnip.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# multiplication works
# normal usage works works

Code
orbital(wf_fit)
Expand Down
21 changes: 21 additions & 0 deletions tests/testthat/test-orbital.R
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,24 @@ test_that("orbital printing works", {
print(orbital(wf_fit), truncate = FALSE)
)
})

test_that("prefix argument works", {
skip_if_not_installed("recipes")
skip_if_not_installed("parsnip")
skip_if_not_installed("workflows")
skip_if_not_installed("tidypredict")

rec_spec <- recipes::recipe(mpg ~ ., data = mtcars) %>%
recipes::step_normalize(recipes::all_numeric_predictors())

lm_spec <- parsnip::linear_reg()

wf_spec <- workflows::workflow(rec_spec, lm_spec)

wf_fit <- parsnip::fit(wf_spec, mtcars)

orb_obj <- orbital(wf_fit, prefix = "pred")

expect_true("pred" %in% names(orb_obj))
expect_false(".pred" %in% names(orb_obj))
})
16 changes: 15 additions & 1 deletion tests/testthat/test-parsnip.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
test_that("multiplication works", {
test_that("normal usage works works", {
skip_if_not_installed("recipes")
skip_if_not_installed("parsnip")
skip_if_not_installed("workflows")
Expand All @@ -19,3 +19,17 @@ test_that("multiplication works", {
orbital(wf_fit)
)
})

test_that("prefix argument works", {
skip_if_not_installed("parsnip")
skip_if_not_installed("tidypredict")

lm_spec <- parsnip::linear_reg()

lm_fit <- parsnip::fit(lm_spec, mpg ~ ., mtcars)

orb_obj <- orbital(lm_fit, prefix = "pred")

expect_true("pred" %in% names(orb_obj))
expect_false(".pred" %in% names(orb_obj))
})
14 changes: 14 additions & 0 deletions tests/testthat/test-recipes.R
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,17 @@ test_that("recipe works with skip argument", {

expect_equal(res, exp)
})

test_that("prefix argument works", {
skip_if_not_installed("recipes")

rec_spec <- recipes::recipe(mpg ~ ., data = mtcars) %>%
recipes::step_normalize(recipes::all_numeric_predictors())

rec_fit <- recipes::prep(rec_spec)

orb_obj <- orbital(rec_fit, prefix = "pred")

expect_false("pred" %in% names(orb_obj))
expect_false(".pred" %in% names(orb_obj))
})

0 comments on commit 055df81

Please sign in to comment.