Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

multi_predict() doesn't support type = "raw" predictions for {lightgbm} classification models #45

Open
jameslamb opened this issue Aug 7, 2022 · 1 comment
Labels
bug an unexpected problem or unintended behavior

Comments

@jameslamb
Copy link
Contributor

There is some code in {bonsai} that looks like it was intended to support multi_predict(..., type = "raw") for {lightgbm} classification models.

bonsai/R/lightgbm_data.R

Lines 146 to 158 in 6c090e1

parsnip::set_pred(
model = "boost_tree",
eng = "lightgbm",
mode = "classification",
type = "raw",
value = parsnip::pred_value_template(
pre = NULL,
post = NULL,
func = c(pkg = "bonsai", fun = "predict_lightgbm_classification_raw"),
object = quote(object),
new_data = quote(new_data)
)
)

However, I don't believe {bonsai} actually respects type = "raw" for multi_predict().

Reproducible Example

See the following coded for evidence of this claim. I saw this behavior with both {lightgbm} v3.3.2 installed from CRAN and with the latest development version (microsoft/LightGBM@c7102e5).

sessionInfo() (click me)
R version 4.1.0 (2021-05-18)
Platform: x86_64-apple-darwin17.0 (64-bit)
Running under: macOS 12.2.1

Matrix products: default
LAPACK: /Library/Frameworks/R.framework/Versions/4.1/Resources/lib/libRlapack.dylib

locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
[1] modeldata_1.0.0   lightgbm_3.3.2    R6_2.5.1          dplyr_1.0.9       bonsai_0.1.0.9000
[6] parsnip_1.0.0    

loaded via a namespace (and not attached):
 [1] rstudioapi_0.13   magrittr_2.0.3    tidyselect_1.1.2  munsell_0.5.0     lattice_0.20-45  
 [6] colorspace_2.0-3  rlang_1.0.4       fansi_1.0.3       tools_4.1.0       hardhat_1.2.0    
[11] grid_4.1.0        data.table_1.14.2 gtable_0.3.0      utf8_1.2.2        cli_3.3.0        
[16] withr_2.5.0       ellipsis_0.3.2    tibble_3.1.7      lifecycle_1.0.1   crayon_1.5.1     
[21] Matrix_1.4-0      purrr_0.3.4       ggplot2_3.3.6     tidyr_1.2.0       vctrs_0.4.1      
[26] glue_1.6.2        compiler_4.1.0    pillar_1.8.0      dials_1.0.0       generics_0.1.3   
[31] scales_1.2.0      jsonlite_1.8.0    DiceDesign_1.9    pkgconfig_2.0.3
library(bonsai)
library(dplyr)
library(lightgbm)
library(modeldata)
library(parsnip)

data("penguins", package = "modeldata")
penguins <- penguins[complete.cases(penguins),]

penguins_subset <- penguins[1:10,]
penguins_subset_numeric <-
    penguins_subset %>%
    mutate(across(where(is.character), ~as.factor(.x))) %>%
    mutate(across(where(is.factor), ~as.integer(.x) - 1))

clf_multiclass_fit <-
    boost_tree(trees = 5) %>%
    set_engine("lightgbm") %>%
    set_mode("classification") %>%
    fit(species ~ ., data = penguins)

new_data <-
    penguins_subset_numeric %>%
    select(-species) %>%
    as.matrix()

preds_bonsai_raw <-
    multi_predict(
        clf_multiclass_fit
        , new_data = new_data[1, , drop = FALSE]
        , trees = seq_len(4)
        , type = "raw"
    )

preds_lgb_raw <-
    t(sapply(
        X = seq_len(4)
        , FUN = function(booster, new_data, num_iteration) {
            booster$predict(new_data, num_iteration = num_iteration, rawscore = TRUE)
        }
        , booster = clf_multiclass_fit$fit
        , new_data = new_data[1, , drop = FALSE]
    ))

preds_bonsai_prob <-
    multi_predict(
        clf_multiclass_fit
        , new_data = new_data[1, , drop = FALSE]
        , trees = seq_len(4)
        , type = "prob"
    )

The predictions from multi_predict(..., type = "raw") look like probabilities (between 0 and 1, sum to 1) and don't match {lightgbm}'s output for raw predictions.

preds_bonsai_raw[[".pred"]][[1]]
# A tibble: 4 × 4
#  trees .pred_Adelie .pred_Chinstrap .pred_Gentoo
#  <int>        <dbl>           <dbl>        <dbl>
#      1        0.500           0.184        0.316
#      2        0.556           0.165        0.279
#      3        0.607           0.147        0.246
#      4        0.652           0.131        0.217

preds_lgb_raw
#            [,1]      [,2]      [,3]
# [1,] -0.6724811 -1.672408 -1.132757
# [2,] -0.5392134 -1.754103 -1.230182
# [3,] -0.4193116 -1.834036 -1.322633
# [4,] -0.3093926 -1.912255 -1.411070

type = "prob" predictions look correct, and like probabilities.

preds_bonsai_prob[[".pred"]][[1]]
# A tibble: 4 × 4
#   trees .pred_Adelie .pred_Chinstrap .pred_Gentoo
#  <int>        <dbl>           <dbl>        <dbl>
# 1     1        0.500           0.184        0.316
# 2     2        0.556           0.165        0.279
# 3     3        0.607           0.147        0.246
# 4     4        0.652           0.131        0.217

I observed the same thing for binary classification models. This doesn't matter for regression models, because "raw" predictions are the default for {lightgbm} regression models using built-in objectives.

Notes for Maintainers

I believe the issue is that this block does not contain an if (type == "raw") condition:

bonsai/R/lightgbm.R

Lines 366 to 375 in 6c090e1

} else {
if (type == "class") {
pred <- predict_lightgbm_classification_class(object, new_data, num_iteration = tree)
pred <- tibble::tibble(.pred_class = factor(pred, levels = object$lvl))
} else {
pred <- predict_lightgbm_classification_prob(object, new_data, num_iteration = tree)
names(pred) <- paste0(".pred_", names(pred))
}

Is it expected that {bonsai} supports multi_predict(..., type = "raw") for {lightgbm} classification models? If so, would you be open to me putting up a pull request to add this support?

Thanks for your time and consideration.

@simonpcouch
Copy link
Contributor

Just wanted to drop a note here and let you know this hasn't fallen off my radar!

I'm hoping to spend some time with our multi_predict methods and put together some more unified machinery for dispatch and testing, and will return to this PR after then.👍

@simonpcouch simonpcouch added the bug an unexpected problem or unintended behavior label Oct 31, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug an unexpected problem or unintended behavior
Projects
None yet
Development

No branches or pull requests

2 participants