Skip to content

multi_predict() doesn't support type = "raw" predictions for {lightgbm} classification models #45

@jameslamb

Description

@jameslamb

There is some code in {bonsai} that looks like it was intended to support multi_predict(..., type = "raw") for {lightgbm} classification models.

bonsai/R/lightgbm_data.R

Lines 146 to 158 in 6c090e1

parsnip::set_pred(
model = "boost_tree",
eng = "lightgbm",
mode = "classification",
type = "raw",
value = parsnip::pred_value_template(
pre = NULL,
post = NULL,
func = c(pkg = "bonsai", fun = "predict_lightgbm_classification_raw"),
object = quote(object),
new_data = quote(new_data)
)
)

However, I don't believe {bonsai} actually respects type = "raw" for multi_predict().

Reproducible Example

See the following coded for evidence of this claim. I saw this behavior with both {lightgbm} v3.3.2 installed from CRAN and with the latest development version (microsoft/LightGBM@c7102e5).

sessionInfo() (click me)
R version 4.1.0 (2021-05-18)
Platform: x86_64-apple-darwin17.0 (64-bit)
Running under: macOS 12.2.1

Matrix products: default
LAPACK: /Library/Frameworks/R.framework/Versions/4.1/Resources/lib/libRlapack.dylib

locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
[1] modeldata_1.0.0   lightgbm_3.3.2    R6_2.5.1          dplyr_1.0.9       bonsai_0.1.0.9000
[6] parsnip_1.0.0    

loaded via a namespace (and not attached):
 [1] rstudioapi_0.13   magrittr_2.0.3    tidyselect_1.1.2  munsell_0.5.0     lattice_0.20-45  
 [6] colorspace_2.0-3  rlang_1.0.4       fansi_1.0.3       tools_4.1.0       hardhat_1.2.0    
[11] grid_4.1.0        data.table_1.14.2 gtable_0.3.0      utf8_1.2.2        cli_3.3.0        
[16] withr_2.5.0       ellipsis_0.3.2    tibble_3.1.7      lifecycle_1.0.1   crayon_1.5.1     
[21] Matrix_1.4-0      purrr_0.3.4       ggplot2_3.3.6     tidyr_1.2.0       vctrs_0.4.1      
[26] glue_1.6.2        compiler_4.1.0    pillar_1.8.0      dials_1.0.0       generics_0.1.3   
[31] scales_1.2.0      jsonlite_1.8.0    DiceDesign_1.9    pkgconfig_2.0.3
library(bonsai)
library(dplyr)
library(lightgbm)
library(modeldata)
library(parsnip)

data("penguins", package = "modeldata")
penguins <- penguins[complete.cases(penguins),]

penguins_subset <- penguins[1:10,]
penguins_subset_numeric <-
    penguins_subset %>%
    mutate(across(where(is.character), ~as.factor(.x))) %>%
    mutate(across(where(is.factor), ~as.integer(.x) - 1))

clf_multiclass_fit <-
    boost_tree(trees = 5) %>%
    set_engine("lightgbm") %>%
    set_mode("classification") %>%
    fit(species ~ ., data = penguins)

new_data <-
    penguins_subset_numeric %>%
    select(-species) %>%
    as.matrix()

preds_bonsai_raw <-
    multi_predict(
        clf_multiclass_fit
        , new_data = new_data[1, , drop = FALSE]
        , trees = seq_len(4)
        , type = "raw"
    )

preds_lgb_raw <-
    t(sapply(
        X = seq_len(4)
        , FUN = function(booster, new_data, num_iteration) {
            booster$predict(new_data, num_iteration = num_iteration, rawscore = TRUE)
        }
        , booster = clf_multiclass_fit$fit
        , new_data = new_data[1, , drop = FALSE]
    ))

preds_bonsai_prob <-
    multi_predict(
        clf_multiclass_fit
        , new_data = new_data[1, , drop = FALSE]
        , trees = seq_len(4)
        , type = "prob"
    )

The predictions from multi_predict(..., type = "raw") look like probabilities (between 0 and 1, sum to 1) and don't match {lightgbm}'s output for raw predictions.

preds_bonsai_raw[[".pred"]][[1]]
# A tibble: 4 × 4
#  trees .pred_Adelie .pred_Chinstrap .pred_Gentoo
#  <int>        <dbl>           <dbl>        <dbl>
#      1        0.500           0.184        0.316
#      2        0.556           0.165        0.279
#      3        0.607           0.147        0.246
#      4        0.652           0.131        0.217

preds_lgb_raw
#            [,1]      [,2]      [,3]
# [1,] -0.6724811 -1.672408 -1.132757
# [2,] -0.5392134 -1.754103 -1.230182
# [3,] -0.4193116 -1.834036 -1.322633
# [4,] -0.3093926 -1.912255 -1.411070

type = "prob" predictions look correct, and like probabilities.

preds_bonsai_prob[[".pred"]][[1]]
# A tibble: 4 × 4
#   trees .pred_Adelie .pred_Chinstrap .pred_Gentoo
#  <int>        <dbl>           <dbl>        <dbl>
# 1     1        0.500           0.184        0.316
# 2     2        0.556           0.165        0.279
# 3     3        0.607           0.147        0.246
# 4     4        0.652           0.131        0.217

I observed the same thing for binary classification models. This doesn't matter for regression models, because "raw" predictions are the default for {lightgbm} regression models using built-in objectives.

Notes for Maintainers

I believe the issue is that this block does not contain an if (type == "raw") condition:

bonsai/R/lightgbm.R

Lines 366 to 375 in 6c090e1

} else {
if (type == "class") {
pred <- predict_lightgbm_classification_class(object, new_data, num_iteration = tree)
pred <- tibble::tibble(.pred_class = factor(pred, levels = object$lvl))
} else {
pred <- predict_lightgbm_classification_prob(object, new_data, num_iteration = tree)
names(pred) <- paste0(".pred_", names(pred))
}

Is it expected that {bonsai} supports multi_predict(..., type = "raw") for {lightgbm} classification models? If so, would you be open to me putting up a pull request to add this support?

Thanks for your time and consideration.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugan unexpected problem or unintended behavior

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions