From 2e45dcee6b63fddf0ecd2d992fd3e62f6b27e3a3 Mon Sep 17 00:00:00 2001
From: Matthew Kim <38759997+friendlymatthew@users.noreply.github.com>
Date: Thu, 20 Nov 2025 15:13:17 -0500
Subject: [PATCH] compare union to opaque
---
arrow-ord/src/ord.rs | 200 ++++++++++++++++++++++++++++++++++++++++++-
1 file changed, 198 insertions(+), 2 deletions(-)
diff --git a/arrow-ord/src/ord.rs b/arrow-ord/src/ord.rs
index b12a06732d42..26878cf6ab10 100644
--- a/arrow-ord/src/ord.rs
+++ b/arrow-ord/src/ord.rs
@@ -36,9 +36,10 @@ fn child_opts(opts: SortOptions) -> SortOptions {
}
}
-fn compare(l: &A, r: &A, opts: SortOptions, cmp: F) -> DynComparator
+fn compare(l: &A, r: &B, opts: SortOptions, cmp: F) -> DynComparator
where
- A: Array + Clone,
+ A: Array + ?Sized,
+ B: Array + ?Sized,
F: Fn(usize, usize) -> Ordering + Send + Sync + 'static,
{
let l = l.logical_nulls().filter(|x| x.null_count() > 0);
@@ -368,6 +369,52 @@ fn compare_union(
Ok(f)
}
+fn compare_union_to_opaque(
+ union_array: &dyn Array,
+ opaque_array: &dyn Array,
+ opts: SortOptions,
+) -> Result {
+ let union_array = union_array.as_union();
+
+ let DataType::Union(union_fields, _) = union_array.data_type() else {
+ unreachable!()
+ };
+
+ let opaque_type_id = union_fields
+ .iter()
+ .find_map(|(i, f)| (f.data_type() == opaque_array.data_type()).then_some(i))
+ .ok_or_else(|| {
+ ArrowError::InvalidArgumentError(format!(
+ "cannot compare union with {} array: type not found in union fields",
+ opaque_array.data_type(),
+ ))
+ })?;
+
+ let c_opts = child_opts(opts);
+
+ let opaque_field_comparator = {
+ let union_child = union_array.child(opaque_type_id);
+ make_comparator(union_child.as_ref(), opaque_array, c_opts)?
+ };
+
+ let union_type_ids = union_array.type_ids().clone();
+ let union_offsets = union_array.offsets().cloned();
+
+ let f = compare(union_array, opaque_array, opts, move |i, j| {
+ let union_type_id = union_type_ids[i];
+
+ match union_type_id.cmp(&opaque_type_id) {
+ Ordering::Equal => {
+ let union_offset = union_offsets.as_ref().map(|o| o[i] as usize).unwrap_or(i);
+ opaque_field_comparator(union_offset, j)
+ }
+ other => other,
+ }
+ });
+
+ Ok(f)
+}
+
/// Returns a comparison function that compares two values at two different positions
/// between the two arrays.
///
@@ -485,6 +532,8 @@ pub fn make_comparator(
},
(Map(_, _), Map(_, _)) => compare_map(left, right, opts),
(Union(_, _), Union(_, _)) => compare_union(left, right, opts),
+ (Union(_, _), _) => compare_union_to_opaque(left, right, opts),
+ (_, Union(_, _)) => compare_union_to_opaque(right, left, opts),
(lhs, rhs) => Err(ArrowError::InvalidArgumentError(match lhs == rhs {
true => format!("The data type type {lhs:?} has no natural order"),
false => "Can't compare arrays of different types".to_string(),
@@ -1501,4 +1550,151 @@ mod tests {
"Cannot compare UnionArrays with different modes: left=Dense, right=Sparse"
);
}
+
+ #[test]
+ fn test_union_to_opaque_int32() {
+ let int_array = Int32Array::from(vec![1, 2, 3]);
+ let str_array = StringArray::from(vec!["a", "b"]);
+ let type_ids = [0, 1, 0, 1, 0].into_iter().collect::>();
+ let offsets = [0, 0, 1, 1, 2].into_iter().collect::>();
+ let union_fields = [
+ (0, Arc::new(Field::new("ints", DataType::Int32, false))),
+ (1, Arc::new(Field::new("strings", DataType::Utf8, false))),
+ ]
+ .into_iter()
+ .collect::();
+ let children = vec![Arc::new(int_array) as ArrayRef, Arc::new(str_array)];
+
+ // union: [1, "a", 2, "b", 3], opaque: [2, 6, 7]
+ let union_array =
+ UnionArray::try_new(union_fields, type_ids, Some(offsets), children).unwrap();
+ let opaque_array = Int32Array::from(vec![2, 6, 7]);
+ let opts = SortOptions::default();
+
+ let cmp = make_comparator(&union_array, &opaque_array, opts).unwrap();
+
+ // 1 < 2
+ assert_eq!(cmp(0, 0), Ordering::Less);
+ // type_id 1 > 0
+ assert_eq!(cmp(1, 0), Ordering::Greater);
+ // 2 == 2
+ assert_eq!(cmp(2, 0), Ordering::Equal);
+ // 3 > 6
+ assert_eq!(cmp(4, 1), Ordering::Less);
+ }
+
+ #[test]
+ fn test_union_to_opaque_string() {
+ let str_array = StringArray::from(vec![Some("apple"), None, Some("pork")]);
+ let int_array = Int32Array::from(vec![None, Some(67), None]);
+ let type_ids = [1, 0, 1].into_iter().collect::>();
+ let union_fields = [
+ (0, Arc::new(Field::new("ints", DataType::Int32, false))),
+ (1, Arc::new(Field::new("strings", DataType::Utf8, false))),
+ ]
+ .into_iter()
+ .collect::();
+ let children = vec![Arc::new(int_array) as ArrayRef, Arc::new(str_array)];
+
+ // sparse union: ["apple", 67, "pork"], opaque: ["howdy", "john", "pork"]
+ let union_array = UnionArray::try_new(union_fields, type_ids, None, children).unwrap();
+ let opaque_array = StringArray::from(vec!["howdy", "john", "pork"]);
+ let opts = SortOptions::default();
+
+ let cmp = make_comparator(&union_array, &opaque_array, opts).unwrap();
+
+ // apple < howdy
+ assert_eq!(cmp(0, 0), Ordering::Less);
+ // type_id < 1
+ assert_eq!(cmp(1, 0), Ordering::Less);
+ // "pork" == "pork"
+ assert_eq!(cmp(2, 2), Ordering::Equal);
+ }
+
+ #[test]
+ fn test_union_to_opaque_with_nulls() {
+ let int_array = Int32Array::from(vec![Some(1), None, Some(3)]);
+ let str_array = StringArray::from(vec![Some("a"), Some("b")]);
+ let type_ids = [0, 1, 0, 1, 0].into_iter().collect::>();
+ let offsets = [0, 0, 1, 1, 2].into_iter().collect::>();
+ let union_fields = [
+ (0, Arc::new(Field::new("ints", DataType::Int32, false))),
+ (1, Arc::new(Field::new("strings", DataType::Utf8, false))),
+ ]
+ .into_iter()
+ .collect::();
+ let children = vec![Arc::new(int_array) as ArrayRef, Arc::new(str_array)];
+
+ // union: [1, "a", null, "b", 3], opaque: [2, null, 1]
+ let union_array =
+ UnionArray::try_new(union_fields, type_ids, Some(offsets), children).unwrap();
+ let opaque_array = Int32Array::from(vec![Some(2), None, Some(1)]);
+ let opts = SortOptions {
+ descending: false,
+ nulls_first: true,
+ };
+ let cmp = make_comparator(&union_array, &opaque_array, opts).unwrap();
+
+ // 1 > null
+ assert_eq!(cmp(0, 1), Ordering::Greater);
+ // null < 2
+ assert_eq!(cmp(2, 0), Ordering::Less);
+ // null == null
+ assert_eq!(cmp(2, 1), Ordering::Equal);
+ }
+
+ #[test]
+ fn test_union_to_opaque_descending() {
+ let int_array = Int32Array::from(vec![1, 2, 3]);
+ let str_array = StringArray::from(vec!["a", "b"]);
+ let type_ids = [0, 1, 0].into_iter().collect::>();
+ let offsets = [0, 0, 1].into_iter().collect::>();
+ let union_fields = [
+ (0, Arc::new(Field::new("ints", DataType::Int32, false))),
+ (1, Arc::new(Field::new("strings", DataType::Utf8, false))),
+ ]
+ .into_iter()
+ .collect::();
+ let children = vec![Arc::new(int_array) as ArrayRef, Arc::new(str_array)];
+
+ // union: [1, "a", 2], opaque: [2, 1]
+ let union_array =
+ UnionArray::try_new(union_fields, type_ids, Some(offsets), children).unwrap();
+ let opaque_array = Int32Array::from(vec![2, 1]);
+ let opts = SortOptions {
+ descending: true,
+ nulls_first: false,
+ };
+ let cmp = make_comparator(&union_array, &opaque_array, opts).unwrap();
+ // 1 > 2 (descending)
+ assert_eq!(cmp(0, 0), Ordering::Greater);
+ // 2 == 2
+ assert_eq!(cmp(2, 0), Ordering::Equal);
+ // 1 == 1
+ assert_eq!(cmp(0, 1), Ordering::Equal);
+ }
+
+ #[test]
+ fn test_union_to_opaque_incompatible_type() {
+ let int_array = Int32Array::from(vec![1, 2]);
+ let str_array = StringArray::from(vec!["a", "b"]);
+ let type_ids = [0, 1].into_iter().collect::>();
+ let offsets = [0, 0].into_iter().collect::>();
+ let union_fields = [
+ (0, Arc::new(Field::new("ints", DataType::Int32, false))),
+ (1, Arc::new(Field::new("strings", DataType::Utf8, false))),
+ ]
+ .into_iter()
+ .collect::();
+ let children = vec![Arc::new(int_array) as ArrayRef, Arc::new(str_array)];
+ let union_array =
+ UnionArray::try_new(union_fields, type_ids, Some(offsets), children).unwrap();
+ let opaque_array = Float64Array::from(vec![1.0, 2.0]);
+ let opts = SortOptions::default();
+ let Err(err) = make_comparator(&union_array, &opaque_array, opts) else {
+ panic!("expected err");
+ };
+
+ assert!(err.to_string().contains("cannot compare union with"));
+ }
}