-
Notifications
You must be signed in to change notification settings - Fork 7
Open
Description
Hi all -
I am tuning the catboost model, (separate from the mtry discussion), and receive an error relating to multi_predict.
################################################################
# cat cross-val
################################################################
library(catboost)
library(parsnip)
library(bonsai)
library(rsample)
library(recipes)
#> Loading required package: dplyr
#>
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#>
#> filter, lag
#> The following objects are masked from 'package:base':
#>
#> intersect, setdiff, setequal, union
#>
#> Attaching package: 'recipes'
#> The following object is masked from 'package:stats':
#>
#> step
library(stacks)
#> Registered S3 method overwritten by 'butcher':
#> method from
#> as.character.dev_topic generics
library(yardstick)
library(workflows)
library(tune)
# taken from stacks --------------------------------
# https://stacks.tidymodels.org/articles/basics.html
data("tree_frogs")
# subset the data
tree_frogs <- tree_frogs |>
filter(!is.na(latency)) |>
select(-c(clutch, hatched))
# some setup: resampling and a basic recipe
set.seed(1)
tree_frogs_split <- initial_split(tree_frogs)
tree_frogs_train <- training(tree_frogs_split)
tree_frogs_test <- testing(tree_frogs_split)
set.seed(1)
folds <- rsample::vfold_cv(tree_frogs_train, v = 5)
tree_frogs_rec <-
recipe(latency ~ ., data = tree_frogs_train)
metric <- metric_set(rmse)
ctrl_grid <- control_stack_grid()
ctrl_res <- control_stack_resamples()
### ~~
### change the knn to catboost, test mtry tuning:
cat_spec <-
boost_tree(
mode = "regression",
learn_rate = tune(),
trees = tune()
) |>
set_engine("catboost")
# extend the recipe
#
tree_frogs_rec <-
recipe(latency ~ ., data = tree_frogs_train)
cat_rec <-
tree_frogs_rec |>
step_dummy(all_nominal_predictors()) |>
step_zv(all_predictors()) |>
step_impute_mean(all_numeric_predictors()) |>
step_normalize(all_numeric_predictors())
#cat_rec
cat_wflow <-
workflow() |>
add_model(cat_spec) |>
add_recipe(cat_rec)
# create my own grid
tuning_grid <- expand.grid(
learn_rate = c(.99, .01, .004),
trees = c(3, 4, 100, 200)
)
# tune eta, iterations and fit to the 5-fold cv
set.seed(2020)
cat_res <-
tune_grid(
cat_wflow,
resamples = folds,
metrics = metric,
grid = tuning_grid,
control = ctrl_grid
)
#> → A | error: No `multi_predict()` method exists for objects with classes
#> <_catboost.Model/model_fit>.
#> There were issues with some computations A: x1
#> There were issues with some computations A: x15
#>
#> Warning: All models failed. Run `show_notes(.Last.tune.result)` for more
#> information.
tune::show_best(cat_res)
#> Warning in tune::show_best(cat_res): No value of `metric` was given; "rmse"
#> will be used.
#> Error in `estimate_tune_results()`:
#> ! All models failed. Run `show_notes(.Last.tune.result)` for more information.Created on 2025-08-04 with reprex v2.1.1
zecojls
Metadata
Metadata
Assignees
Labels
No labels