Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

switch to {cli} in check_args() functions #1093

Merged
merged 25 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
8ad9cd0
pass calls around for check_args()
EmilHvitfeldt Apr 1, 2024
2bdd04a
switch to {cli} in all check_args() methods
EmilHvitfeldt Apr 1, 2024
7e95fc3
devtools::document()
EmilHvitfeldt Apr 1, 2024
ecafc3f
update snapshots for check_args()
EmilHvitfeldt Apr 5, 2024
e51bea7
Merged origin/main into cli-check_args
EmilHvitfeldt Apr 6, 2024
3ae0eeb
revert changes
EmilHvitfeldt Apr 9, 2024
fd2a8c1
delete unreachable code
EmilHvitfeldt Apr 9, 2024
8a83a42
pass call argument through form_xy()
EmilHvitfeldt Apr 9, 2024
3c361b9
fix typo
EmilHvitfeldt Apr 9, 2024
f67f14b
add all tests for check_args()
EmilHvitfeldt Apr 9, 2024
9130aac
use skip_if_not_installed()
EmilHvitfeldt Apr 9, 2024
eeb2a81
increase package version
EmilHvitfeldt Apr 9, 2024
4191c3c
pass call in check_args.C5_rules()
EmilHvitfeldt Apr 9, 2024
2ac8d17
pass calls in check_args.cubist_rules()
EmilHvitfeldt Apr 10, 2024
d9e9f06
add tests for more models
EmilHvitfeldt Apr 10, 2024
6db51f9
use arg argument in check_* functions
EmilHvitfeldt Apr 10, 2024
dbc7839
Update R/cubist_rules.R
EmilHvitfeldt Apr 10, 2024
3277863
use more check_* functions
EmilHvitfeldt Apr 10, 2024
02be9dc
use check_ functions for penalty
EmilHvitfeldt Apr 10, 2024
61c452d
break up message into multiple lines
EmilHvitfeldt Apr 10, 2024
b1e2f5a
better LiblineaR specific penalty error
EmilHvitfeldt Apr 10, 2024
eceac2d
move data inside test_that()
EmilHvitfeldt Apr 10, 2024
c9a5ba5
adding newline at end of file
EmilHvitfeldt Apr 10, 2024
fa40f0a
be more specific in testing of LiblineaR penalty
EmilHvitfeldt Apr 10, 2024
21c0e91
fix space typo
EmilHvitfeldt Apr 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: parsnip
Title: A Common API to Modeling and Analysis Functions
Version: 1.2.1.9000
EmilHvitfeldt marked this conversation as resolved.
Show resolved Hide resolved
Version: 1.2.1.9001
Authors@R: c(
person("Max", "Kuhn", , "[email protected]", role = c("aut", "cre")),
person("Davis", "Vaughan", , "[email protected]", role = "aut"),
Expand Down
4 changes: 1 addition & 3 deletions R/bag_tree.R
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,7 @@ update.bag_tree <-
# ------------------------------------------------------------------------------

#' @export
check_args.bag_tree <- function(object) {
if (object$engine == "C5.0" && object$mode == "regression")
stop("C5.0 is classification only.", call. = FALSE)
simonpcouch marked this conversation as resolved.
Show resolved Hide resolved
check_args.bag_tree <- function(object, call = rlang::caller_env()) {
invisible(object)
}

Expand Down
20 changes: 6 additions & 14 deletions R/boost_tree.R
Original file line number Diff line number Diff line change
Expand Up @@ -164,23 +164,15 @@ translate.boost_tree <- function(x, engine = x$engine, ...) {
# ------------------------------------------------------------------------------

#' @export
check_args.boost_tree <- function(object) {
check_args.boost_tree <- function(object, call = rlang::caller_env()) {

args <- lapply(object$args, rlang::eval_tidy)

if (is.numeric(args$trees) && args$trees < 0) {
rlang::abort("`trees` should be >= 1.")
}
if (is.numeric(args$sample_size) && (args$sample_size < 0 | args$sample_size > 1)) {
rlang::abort("`sample_size` should be within [0,1].")
}
if (is.numeric(args$tree_depth) && args$tree_depth < 0) {
rlang::abort("`tree_depth` should be >= 1.")
}
if (is.numeric(args$min_n) && args$min_n < 0) {
rlang::abort("`min_n` should be >= 1.")
}

check_number_whole(args$trees, min = 0, allow_null = TRUE, call = call, arg = "trees")
check_number_decimal(args$sample_size, min = 0, max = 1, allow_null = TRUE, call = call, arg = "sample_size")
check_number_whole(args$tree_depth, min = 0, allow_null = TRUE, call = call, arg = "tree_depth")
check_number_whole(args$min_n, min = 0, allow_null = TRUE, call = call, arg = "min_n")

invisible(object)
}

Expand Down
31 changes: 11 additions & 20 deletions R/c5_rules.R
Original file line number Diff line number Diff line change
Expand Up @@ -111,32 +111,23 @@ update.C5_rules <-
# make work in different places

#' @export
check_args.C5_rules <- function(object) {
check_args.C5_rules <- function(object, call = rlang::caller_env()) {

args <- lapply(object$args, rlang::eval_tidy)

if (is.numeric(args$trees)) {
if (length(args$trees) > 1) {
rlang::abort("Only a single value of `trees` is used.")
}
msg <- "The number of trees should be >= 1 and <= 100. Truncating the value."
if (args$trees > 100) {
object$args$trees <-
rlang::new_quosure(100L, env = rlang::empty_env())
rlang::warn(msg)
}
if (args$trees < 1) {
object$args$trees <-
rlang::new_quosure(1L, env = rlang::empty_env())
rlang::warn(msg)
}
check_number_whole(args$min_n, allow_null = TRUE, call = call, arg = "min_n")
check_number_whole(args$tree, allow_null = TRUE, call = call, arg = "tree")

msg <- "The number of trees should be {.code >= 1} and {.code <= 100}"
if (!(is.null(args$trees)) && args$trees > 100) {
object$args$trees <- rlang::new_quosure(100L, env = rlang::empty_env())
cli::cli_warn(c(msg, "Truncating to 100."))
}
if (is.numeric(args$min_n)) {
if (length(args$min_n) > 1) {
rlang::abort("Only a single `min_n`` value is used.")
}
if (!(is.null(args$trees)) && args$trees < 1) {
object$args$trees <- rlang::new_quosure(1L, env = rlang::empty_env())
cli::cli_warn(c(msg, "Truncating to 1."))
}
EmilHvitfeldt marked this conversation as resolved.
Show resolved Hide resolved

invisible(object)
}

Expand Down
54 changes: 23 additions & 31 deletions R/cubist_rules.R
Original file line number Diff line number Diff line change
Expand Up @@ -135,44 +135,36 @@ update.cubist_rules <-
# make work in different places

#' @export
check_args.cubist_rules <- function(object) {
check_args.cubist_rules <- function(object, call = rlang::caller_env()) {

args <- lapply(object$args, rlang::eval_tidy)

if (is.numeric(args$committees)) {
if (length(args$committees) > 1) {
rlang::abort("Only a single committee member is used.")
}
msg <- "The number of committees should be >= 1 and <= 100. Truncating the value."
if (args$committees > 100) {
object$args$committees <-
rlang::new_quosure(100L, env = rlang::empty_env())
rlang::warn(msg)
}
if (args$committees < 1) {
object$args$committees <-
rlang::new_quosure(1L, env = rlang::empty_env())
rlang::warn(msg)
}
check_number_whole(args$committees, allow_null = TRUE, call = call, arg = "committees")

}
if (is.numeric(args$neighbors)) {
if (length(args$neighbors) > 1) {
rlang::abort("Only a single neighbors value is used.")
}
msg <- "The number of neighbors should be >= 0 and <= 9. Truncating the value."
if (args$neighbors > 9) {
object$args$neighbors <-
rlang::new_quosure(9L, env = rlang::empty_env())
rlang::warn(msg)
}
if (args$neighbors < 0) {
object$args$neighbors <-
rlang::new_quosure(0L, env = rlang::empty_env())
rlang::warn(msg)
msg <- "The number of committees should be {.code >= 1} and {.code <= 100}."
if (!(is.null(args$committees)) && args$committees > 100) {
object$args$committees <-
rlang::new_quosure(100L, env = rlang::empty_env())
cli::cli_warn(c(msg, "Truncating to 100."))
}
if (!(is.null(args$committees)) && args$committees < 1) {
object$args$committees <-
rlang::new_quosure(1L, env = rlang::empty_env())
cli::cli_warn(c(msg, "Truncating to 1."))
}

check_number_whole(args$neighbors, allow_null = TRUE, call = call, arg = "neighbors")

msg <- "The number of neighbors should be {.code >= 0} and {.code <= 9}."
if (!(is.null(args$neighbors)) && args$neighbors > 9) {
object$args$neighbors <- rlang::new_quosure(9L, env = rlang::empty_env())
cli::cli_warn(c(msg, "Truncating to 9."))
}
if (!(is.null(args$neighbors)) && args$neighbors < 0) {
object$args$neighbors <- rlang::new_quosure(0L, env = rlang::empty_env())
cli::cli_warn(c(msg, "Truncating to 0."))
}

invisible(object)
}

Expand Down
4 changes: 1 addition & 3 deletions R/decision_tree.R
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,7 @@ translate.decision_tree <- function(x, engine = x$engine, ...) {
# ------------------------------------------------------------------------------

#' @export
check_args.decision_tree <- function(object) {
if (object$engine == "C5.0" && object$mode == "regression")
rlang::abort("C5.0 is classification only.")
EmilHvitfeldt marked this conversation as resolved.
Show resolved Hide resolved
check_args.decision_tree <- function(object, call = rlang::caller_env()) {
invisible(object)
}

Expand Down
17 changes: 5 additions & 12 deletions R/discrim_flexible.R
Original file line number Diff line number Diff line change
Expand Up @@ -85,21 +85,14 @@ update.discrim_flexible <-
# ------------------------------------------------------------------------------

#' @export
check_args.discrim_flexible <- function(object) {
check_args.discrim_flexible <- function(object, call = rlang::caller_env()) {

args <- lapply(object$args, rlang::eval_tidy)

if (is.numeric(args$prod_degree) && args$prod_degree < 0)
stop("`prod_degree` should be >= 1", call. = FALSE)

if (is.numeric(args$num_terms) && args$num_terms < 0)
stop("`num_terms` should be >= 1", call. = FALSE)

if (!is.character(args$prune_method) &&
!is.null(args$prune_method) &&
!is.character(args$prune_method))
stop("`prune_method` should be a single string value", call. = FALSE)

check_number_whole(args$prod_degree, min = 1, allow_null = TRUE, call = call, arg = "prod_degree")
check_number_whole(args$num_terms, min = 1, allow_null = TRUE, call = call, arg = "num_terms")
check_string(args$prune_method, allow_empty = FALSE, allow_null = TRUE, call = call, arg = "prune_method")

invisible(object)
}

Expand Down
6 changes: 2 additions & 4 deletions R/discrim_linear.R
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,11 @@ update.discrim_linear <-
# ------------------------------------------------------------------------------

#' @export
check_args.discrim_linear <- function(object) {
check_args.discrim_linear <- function(object, call = rlang::caller_env()) {

args <- lapply(object$args, rlang::eval_tidy)

if (all(is.numeric(args$penalty)) && any(args$penalty < 0)) {
stop("The amount of regularization should be >= 0", call. = FALSE)
}
check_number_decimal(args$penalty, min = 0, allow_null = TRUE, call = call, arg = "penalty")

invisible(object)
}
Expand Down
13 changes: 4 additions & 9 deletions R/discrim_regularized.R
Original file line number Diff line number Diff line change
Expand Up @@ -95,18 +95,13 @@ update.discrim_regularized <-
# ------------------------------------------------------------------------------

#' @export
check_args.discrim_regularized <- function(object) {
check_args.discrim_regularized <- function(object, call = rlang::caller_env()) {

args <- lapply(object$args, rlang::eval_tidy)

if (is.numeric(args$frac_common_cov) &&
(args$frac_common_cov < 0 | args$frac_common_cov > 1)) {
stop("The common covariance fraction should be between zero and one", call. = FALSE)
}
if (is.numeric(args$frac_identity) &&
(args$frac_identity < 0 | args$frac_identity > 1)) {
stop("The identity matrix fraction should be between zero and one", call. = FALSE)
}
check_number_decimal(args$frac_common_cov, min = 0, max = 1, allow_null = TRUE, call = call, arg = "frac_common_cov")
check_number_decimal(args$frac_identity, min = 0, max = 1, allow_null = TRUE, call = call, arg = "frac_identity")

invisible(object)
}

Expand Down
18 changes: 12 additions & 6 deletions R/fit_helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# data to formula/data objects and so on.

form_form <-
function(object, control, env, ...) {
function(object, control, env, ..., call = rlang::caller_env()) {

if (inherits(env$data, "data.frame")) {
check_outcome(eval_tidy(rlang::f_lhs(env$formula), env$data), object)
Expand Down Expand Up @@ -32,7 +32,7 @@ form_form <-
}

# evaluate quoted args once here to check them
object <- check_args(object)
object <- check_args(object, call = call)

# sub in arguments to actual syntax for corresponding engine
object <- translate(object, engine = object$engine)
Expand Down Expand Up @@ -60,7 +60,12 @@ form_form <-
res
}

xy_xy <- function(object, env, control, target = "none", ...) {
xy_xy <- function(object,
env,
control,
target = "none",
...,
call = rlang::caller_env()) {

if (inherits(env$x, "tbl_spark") | inherits(env$y, "tbl_spark"))
rlang::abort("spark objects can only be used with the formula interface to `fit()`")
Expand All @@ -83,7 +88,7 @@ xy_xy <- function(object, env, control, target = "none", ...) {
}

# evaluate quoted args once here to check them
object <- check_args(object)
object <- check_args(object, call = call)

# sub in arguments to actual syntax for corresponding engine
object <- translate(object, engine = object$engine)
Expand Down Expand Up @@ -114,7 +119,7 @@ xy_xy <- function(object, env, control, target = "none", ...) {
}

form_xy <- function(object, control, env,
target = "none", ...) {
target = "none", ..., call = rlang::caller_env()) {

encoding_info <-
get_encoding(class(object)[1]) %>%
Expand All @@ -138,7 +143,8 @@ form_xy <- function(object, control, env,
object = object,
env = env, #weights!
control = control,
target = target
target = target,
call = call
)
data_obj$y_var <- all.vars(rlang::f_lhs(env$formula))
data_obj$x <- NULL
Expand Down
10 changes: 3 additions & 7 deletions R/linear_reg.R
Original file line number Diff line number Diff line change
Expand Up @@ -106,16 +106,12 @@ update.linear_reg <-
# ------------------------------------------------------------------------------

#' @export
check_args.linear_reg <- function(object) {
check_args.linear_reg <- function(object, call = rlang::caller_env()) {

args <- lapply(object$args, rlang::eval_tidy)

if (all(is.numeric(args$penalty)) && any(args$penalty < 0))
rlang::abort("The amount of regularization should be >= 0.")
if (is.numeric(args$mixture) && (args$mixture < 0 | args$mixture > 1))
rlang::abort("The mixture proportion should be within [0,1].")
if (is.numeric(args$mixture) && length(args$mixture) > 1)
rlang::abort("Only one value of `mixture` is allowed.")
check_number_decimal(args$mixture, min = 0, max = 1, allow_null = TRUE, call = call, arg = "mixture")
check_number_decimal(args$penalty, min = 0, allow_null = TRUE, call = call, arg = "penalty")

invisible(object)
}
35 changes: 21 additions & 14 deletions R/logistic_reg.R
Original file line number Diff line number Diff line change
Expand Up @@ -135,25 +135,32 @@ update.logistic_reg <-
# ------------------------------------------------------------------------------

#' @export
check_args.logistic_reg <- function(object) {
check_args.logistic_reg <- function(object, call = rlang::caller_env()) {

args <- lapply(object$args, rlang::eval_tidy)

if (all(is.numeric(args$penalty)) && any(args$penalty < 0))
rlang::abort("The amount of regularization should be >= 0.")
if (is.numeric(args$mixture) && (args$mixture < 0 | args$mixture > 1))
rlang::abort("The mixture proportion should be within [0,1].")
if (is.numeric(args$mixture) && length(args$mixture) > 1)
rlang::abort("Only one value of `mixture` is allowed.")
check_number_decimal(args$mixture, min = 0, max = 1, allow_null = TRUE, call = call, arg = "mixture")
check_number_decimal(args$penalty, min = 0, allow_null = TRUE, call = call, arg = "penalty")

if (object$engine == "LiblineaR") {
if(is.numeric(args$mixture) && !args$mixture %in% 0:1)
rlang::abort(c("For the LiblineaR engine, mixture must be 0 or 1.",
"Choose a pure ridge model with `mixture = 0`.",
"Choose a pure lasso model with `mixture = 1`.",
"The Liblinear engine does not support other values."))
if(all(is.numeric(args$penalty)) && !all(args$penalty > 0))
rlang::abort("For the LiblineaR engine, penalty must be > 0.")
if (is.numeric(args$mixture) && !args$mixture %in% 0:1) {
cli::cli_abort(
EmilHvitfeldt marked this conversation as resolved.
Show resolved Hide resolved
c("x" = "For the {.pkg LiblineaR} engine, mixture must be 0 or 1, \\
not {args$mixture}.",
"i" = "Choose a pure ridge model with {.code mixture = 0} or \\
a pure lasso model with {.code mixture = 1}.",
"!" = "The {.pkg Liblinear} engine does not support other values."),
call = call
)
}

if ((!is.null(args$penalty)) && args$penalty == 0) {
cli::cli_abort(
"For the {.pkg LiblineaR} engine, {.arg penalty} must be {.code > 0}, \\
not 0.",
call = call
)
}
EmilHvitfeldt marked this conversation as resolved.
Show resolved Hide resolved
}

invisible(object)
Expand Down
15 changes: 4 additions & 11 deletions R/mars.R
Original file line number Diff line number Diff line change
Expand Up @@ -105,20 +105,13 @@ translate.mars <- function(x, engine = x$engine, ...) {
# ------------------------------------------------------------------------------

#' @export
check_args.mars <- function(object) {
check_args.mars <- function(object, call = rlang::caller_env()) {

args <- lapply(object$args, rlang::eval_tidy)

if (is.numeric(args$prod_degree) && args$prod_degree < 0)
rlang::abort("`prod_degree` should be >= 1.")

if (is.numeric(args$num_terms) && args$num_terms < 0)
rlang::abort("`num_terms` should be >= 1.")

if (!is_varying(args$prune_method) &&
!is.null(args$prune_method) &&
!is.character(args$prune_method))
rlang::abort("`prune_method` should be a single string value.")
check_number_whole(args$prod_degree, min = 1, allow_null = TRUE, call = call, arg = "prod_degree")
check_number_whole(args$num_terms, min = 1, allow_null = TRUE, call = call, arg = "num_terms")
check_string(args$prune_method, allow_empty = FALSE, allow_null = TRUE, call = call, arg = "prune_method")

invisible(object)
}
Expand Down
Loading
Loading