-
Notifications
You must be signed in to change notification settings - Fork 21
Open
Description
The problem
I'm a bit puzzled as I cannot reproduce a clustering outcome using {stats}.
It seems that predict() and augment() give a different answer with the same data: no preprocessing at all, same RNG seed, same distance function and same linkage.
It's not an issue of label switching, the clusters are actually different.
- ✅
extract_cluster_assignment()matches base approach. - ❌
predict()andaugment()give the same answer, but different from base. - ❌ This also affects predictions from a fitted workflow
Reproducible example
library(tidyclust)
n_clusters <- 3
linkage_method <- "average"
distance_method <- "euclidean"
# make sure the distance function is exactly the same
dist_common <- \(x) dist(x, method = distance_method)
hclust_spec <- hier_clust(
num_clusters = n_clusters,
mode = "partition",
engine = "stats",
linkage_method = linkage_method
) |>
tidyclust::set_args(
dist_fun = dist_common
)
# Base
set.seed(123)
h_stats <- mtcars |>
dist_common() |>
as.dist() |>
hclust(method = linkage_method)
cl_stats <- h_stats |>
cutree(k = n_clusters) |>
as.character()
# Via {parsnip}-like interface
set.seed(123)
hclust_fit <- hclust_spec |> fit(~., data = mtcars)
set.seed(123)
cl_predict <- hclust_fit |>
predict(new_data = mtcars, prefix = "") |>
dplyr::pull(.pred_cluster) |>
as.character()
# predict() does NOT match base
testthat::expect_equal(cl_stats, cl_predict)
#> Error: `cl_stats` not equal to `cl_predict`.
#> 2/32 mismatches
#> x[6]: "2"
#> y[6]: "1"
#>
#> x[29]: "2"
#> y[29]: "3"
table(cl_stats, cl_predict)
#> cl_predict
#> cl_stats 1 2 3
#> 1 16 0 0
#> 2 1 13 1
#> 3 0 0 1
# Via {tidyclust} tools
cl_tidyclust <- hclust_fit |>
tidyclust::extract_cluster_assignment(prefix = "") |>
dplyr::pull(.cluster) |>
as.character()
# extract_cluster_assignment() matches base
testthat::expect_equal(cl_stats, cl_tidyclust)
# predict() does NOT match base
testthat::expect_equal(cl_predict, cl_tidyclust)
#> Error: `cl_predict` not equal to `cl_tidyclust`.
#> 2/32 mismatches
#> x[6]: "1"
#> y[6]: "2"
#>
#> x[29]: "3"
#> y[29]: "2"
# Via {workflows}
wf_spec <- workflows::workflow() |>
workflows::add_recipe(recipes::recipe(~., data = mtcars)) |>
workflows::add_model(hclust_spec)
wf_fit <- fit(wf_spec, data = mtcars)
# class: cluster_fit
h_fit <- wf_fit |> workflows::extract_fit_parsnip()
cl_wf_predict <- predict(wf_fit, new_data = mtcars, prefix = "") |>
dplyr::pull(.pred_cluster) |>
as.character()
# matches predict(), does NOT match base
testthat::expect_equal(cl_wf_predict, cl_predict)
# compare data: identical
testthat::expect_equal(
attr(h_fit$fit, "training_data"),
mtcars |> tibble::remove_rownames() |> as.matrix()
)Created on 2025-06-13 with reprex v2.1.1
Metadata
Metadata
Assignees
Labels
No labels