From 539e5a7430028a9e3225aada281df61d069252a6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98topepo=E2=80=99?= <‘mxkuhn@gmail.com’> Date: Wed, 25 Sep 2024 17:52:32 -0400 Subject: [PATCH] quantile -> quantile_levels for #1203 --- NEWS.md | 3 +++ R/predict.R | 2 +- R/predict_quantile.R | 21 ++++++++++++++++----- man/other_predict.Rd | 5 ++--- man/set_args.Rd | 2 +- tests/testthat/test-linear_reg_quantreg.R | 5 +++++ 6 files changed, 28 insertions(+), 10 deletions(-) diff --git a/NEWS.md b/NEWS.md index e2a63b619..004da1ef5 100644 --- a/NEWS.md +++ b/NEWS.md @@ -11,6 +11,9 @@ * New `extract_fit_time()` method has been added that returns the time it took to train the model (#853). +## Breaking Change + +* For quantile prediction, the `predict()` argument has been changed from `quantile` to `quantile_levels` for consistency. This does not affect models with mode `"quantile regression"`. # parsnip 1.2.1 diff --git a/R/predict.R b/R/predict.R index 3a2681048..397b92112 100644 --- a/R/predict.R +++ b/R/predict.R @@ -344,7 +344,7 @@ check_pred_type_dots <- function(object, type, ..., call = rlang::caller_env()) # ---------------------------------------------------------------------------- - other_args <- c("interval", "level", "std_error", "quantile", + other_args <- c("interval", "level", "std_error", "quantile_levels", "time", "eval_time", "increasing") is_pred_arg <- names(the_dots) %in% other_args if (any(!is_pred_arg)) { diff --git a/R/predict_quantile.R b/R/predict_quantile.R index fc2d91b15..a3950a75c 100644 --- a/R/predict_quantile.R +++ b/R/predict_quantile.R @@ -1,14 +1,13 @@ #' @keywords internal #' @rdname other_predict -#' @param quantile A vector of numbers between 0 and 1 for the quantile being -#' predicted. +#' @param quantile_levels A vector of values between zero and one. #' @inheritParams predict.model_fit #' @method predict_quantile model_fit #' @export predict_quantile.model_fit #' @export predict_quantile.model_fit <- function(object, new_data, - quantile = (1:9)/10, + quantile_levels = NULL, interval = "none", level = 0.95, ...) { @@ -20,6 +19,20 @@ predict_quantile.model_fit <- function(object, return(NULL) } + if (object$spec$mode != "quantile regression") { + if (is.null(quantile_levels)) { + quantile_levels <- (1:9)/10 + } + hardhat::check_quantile_levels(quantile_levels) + # Pass some extra arguments to be used in post-processor + object$quantile_levels <- quantile_levels + } else { + if (!is.null(quantile_levels)) { + cli::cli_abort("{.arg quantile_levels} are specified by {.fn set_mode} + when the mode is {.val quantile regression}.") + } + } + new_data <- prepare_data(object, new_data) # preprocess data @@ -27,8 +40,6 @@ predict_quantile.model_fit <- function(object, new_data <- object$spec$method$pred$quantile$pre(new_data, object) } - # Pass some extra arguments to be used in post-processor - object$spec$method$pred$quantile$args$p <- quantile pred_call <- make_pred_call(object$spec$method$pred$quantile) res <- eval_tidy(pred_call) diff --git a/man/other_predict.Rd b/man/other_predict.Rd index 6c997e28d..d1342d87f 100644 --- a/man/other_predict.Rd +++ b/man/other_predict.Rd @@ -49,7 +49,7 @@ predict_numeric(object, ...) \method{predict_quantile}{model_fit}( object, new_data, - quantile = (1:9)/10, + quantile_levels = NULL, interval = "none", level = 0.95, ... @@ -103,8 +103,7 @@ interval estimates.} \item{std_error}{A single logical for whether the standard error should be returned (assuming that the model can compute it).} -\item{quantile}{A vector of numbers between 0 and 1 for the quantile being -predicted.} +\item{quantile_levels}{A vector of values between zero and one.} } \description{ These are internal functions not meant to be directly called by the user. diff --git a/man/set_args.Rd b/man/set_args.Rd index 6d3b60f3d..b31e4ad4c 100644 --- a/man/set_args.Rd +++ b/man/set_args.Rd @@ -21,7 +21,7 @@ set_mode(object, mode, ...) "regression")} \item{quantile_levels}{A vector of values between zero and one (only for the -\verb{quantile regression } mode); otherwise, it is \code{NULL}. The model uses these +\code{"quantile regression"} mode); otherwise, it is \code{NULL}. The model uses these values to appropriately train quantile regression models to make predictions for these values (e.g., \code{quantile_levels = 0.5} is the median).} } diff --git a/tests/testthat/test-linear_reg_quantreg.R b/tests/testthat/test-linear_reg_quantreg.R index 7edc7c3a5..2cce1f6ae 100644 --- a/tests/testthat/test-linear_reg_quantreg.R +++ b/tests/testthat/test-linear_reg_quantreg.R @@ -83,6 +83,11 @@ test_that('linear quantile regression via quantreg - multiple quantiles', { expect_named(ten_quant_df, c(".pred_quantile", ".quantile_levels", ".row")) expect_true(nrow(ten_quant_df) == nrow(sac_test) * 10) + expect_snapshot( + ten_quant_pred <- predict(ten_quant, new_data = sac_test), + error = TRUE + ) + ### ten_quant_one_row <- predict(ten_quant, new_data = sac_test[1,])