Skip to content

hier_clust(): predict() does not match base approach or extract_cluster_assignment() #208

@lgaborini

Description

@lgaborini

The problem

I'm a bit puzzled as I cannot reproduce a clustering outcome using {stats}.
It seems that predict() and augment() give a different answer with the same data: no preprocessing at all, same RNG seed, same distance function and same linkage.

It's not an issue of label switching, the clusters are actually different.

  • extract_cluster_assignment() matches base approach.
  • predict() and augment() give the same answer, but different from base.
  • ❌ This also affects predictions from a fitted workflow

Reproducible example

library(tidyclust)

n_clusters <- 3
linkage_method <- "average"
distance_method <- "euclidean"

# make sure the distance function is exactly the same

dist_common <- \(x) dist(x, method = distance_method)

hclust_spec <- hier_clust(
  num_clusters = n_clusters,
  mode = "partition",
  engine = "stats",
  linkage_method = linkage_method
) |>
  tidyclust::set_args(
    dist_fun = dist_common
  )

# Base

set.seed(123)

h_stats <- mtcars |>
  dist_common() |>
  as.dist() |>
  hclust(method = linkage_method)

cl_stats <- h_stats |>
  cutree(k = n_clusters) |>
  as.character()

# Via {parsnip}-like interface

set.seed(123)

hclust_fit <- hclust_spec |> fit(~., data = mtcars)

set.seed(123)

cl_predict <- hclust_fit |>
  predict(new_data = mtcars, prefix = "") |>
  dplyr::pull(.pred_cluster) |>
  as.character()

# predict() does NOT match base
testthat::expect_equal(cl_stats, cl_predict)
#> Error: `cl_stats` not equal to `cl_predict`.
#> 2/32 mismatches
#> x[6]: "2"
#> y[6]: "1"
#> 
#> x[29]: "2"
#> y[29]: "3"

table(cl_stats, cl_predict)
#>         cl_predict
#> cl_stats  1  2  3
#>        1 16  0  0
#>        2  1 13  1
#>        3  0  0  1

# Via {tidyclust} tools

cl_tidyclust <- hclust_fit |>
  tidyclust::extract_cluster_assignment(prefix = "") |>
  dplyr::pull(.cluster) |>
  as.character()

# extract_cluster_assignment() matches base
testthat::expect_equal(cl_stats, cl_tidyclust)
# predict() does NOT match base
testthat::expect_equal(cl_predict, cl_tidyclust)
#> Error: `cl_predict` not equal to `cl_tidyclust`.
#> 2/32 mismatches
#> x[6]: "1"
#> y[6]: "2"
#> 
#> x[29]: "3"
#> y[29]: "2"

# Via {workflows}

wf_spec <- workflows::workflow() |>
  workflows::add_recipe(recipes::recipe(~., data = mtcars)) |>
  workflows::add_model(hclust_spec)

wf_fit <- fit(wf_spec, data = mtcars)

# class: cluster_fit
h_fit <- wf_fit |> workflows::extract_fit_parsnip()

cl_wf_predict <- predict(wf_fit, new_data = mtcars, prefix = "") |>
  dplyr::pull(.pred_cluster) |>
  as.character()

# matches predict(), does NOT match base
testthat::expect_equal(cl_wf_predict, cl_predict)

# compare data: identical
testthat::expect_equal(
  attr(h_fit$fit, "training_data"),
  mtcars |> tibble::remove_rownames() |> as.matrix()
)

Created on 2025-06-13 with reprex v2.1.1

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions