diff --git a/NAMESPACE b/NAMESPACE index 4a0ffd9..e58b2f8 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -2,6 +2,7 @@ S3method(augment,orbital_class) S3method(orbital,default) +S3method(orbital,glm) S3method(orbital,last_fit) S3method(orbital,model_fit) S3method(orbital,model_spec) diff --git a/R/model-glm.R b/R/model-glm.R new file mode 100644 index 0000000..a45f0be --- /dev/null +++ b/R/model-glm.R @@ -0,0 +1,17 @@ +#' @export +orbital.glm <- function(x, ..., mode = c("classification", "regression")) { + mode <- rlang::arg_match(mode) + + if (mode == "classification") { + levels <- levels(x$model$Species) + levels <- glue::double_quote(levels) + res <- tidypredict::tidypredict_fit(x) + res <- deparse1(res) + res <- glue::glue("dplyr::case_when({res} < 0.5 ~ {levels[1]}, .default = {levels[2]})") + } + + if (mode == "regression") { + res <- tidypredict::tidypredict_fit(x) + } + res +} diff --git a/R/parsnip.R b/R/parsnip.R index 5784c11..e263058 100644 --- a/R/parsnip.R +++ b/R/parsnip.R @@ -1,23 +1,39 @@ #' @export orbital.model_fit <- function(x, ..., prefix = ".pred") { - res <- tryCatch( - tidypredict::tidypredict_fit(x), - error = function(cnd) { - if (grepl("no applicable method for", cnd$message)) { - cls <- class(x) - cls <- setdiff(cls, "model_fit") - cls <- gsub("^_", "", cls) - - cli::cli_abort( - "A model of class {.cls {cls}} is not supported.", - call = rlang::call2("orbital") - ) + mode <- x$spec$mode + + check_mode(mode) + + res <- try(orbital(x$fit, mode = mode), silent = TRUE) + + if (inherits(res, "try-error")) { + res <- tryCatch( + tidypredict::tidypredict_fit(x), + error = function(cnd) { + if (grepl("no applicable method for", cnd$message)) { + cls <- class(x) + cls <- setdiff(cls, "model_fit") + cls <- gsub("^_", "", cls) + + cli::cli_abort( + "A model of class {.cls {cls}} is not supported.", + call = rlang::call2("orbital") + ) + } + stop(cnd) } - stop(cnd) - } - ) + ) + } + + if (mode == "classification") { + prefix <- paste0(prefix, "_class") + } - res <- stats::setNames(deparse1(res), prefix) + if (is.language(res)) { + res <- deparse1(res) + } + + res <- stats::setNames(res, prefix) new_orbital_class(res) } @@ -26,3 +42,15 @@ orbital.model_fit <- function(x, ..., prefix = ".pred") { orbital.model_spec <- function(x, ...) { cli::cli_abort("{.arg x} must be fitted model.") } + +check_mode <- function(mode, call = rlang::caller_env()) { + supported_modes <- c("regression", "classification") + + if (!(mode %in% supported_modes)) { + cli::cli_abort( + "Only models with modes {.val {supported_modes}} are supported. + Not {.val {mode}}.", + call = call + ) + } +}