From 2e45dcee6b63fddf0ecd2d992fd3e62f6b27e3a3 Mon Sep 17 00:00:00 2001 From: Matthew Kim <38759997+friendlymatthew@users.noreply.github.com> Date: Thu, 20 Nov 2025 15:13:17 -0500 Subject: [PATCH] compare union to opaque --- arrow-ord/src/ord.rs | 200 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 198 insertions(+), 2 deletions(-) diff --git a/arrow-ord/src/ord.rs b/arrow-ord/src/ord.rs index b12a06732d42..26878cf6ab10 100644 --- a/arrow-ord/src/ord.rs +++ b/arrow-ord/src/ord.rs @@ -36,9 +36,10 @@ fn child_opts(opts: SortOptions) -> SortOptions { } } -fn compare(l: &A, r: &A, opts: SortOptions, cmp: F) -> DynComparator +fn compare(l: &A, r: &B, opts: SortOptions, cmp: F) -> DynComparator where - A: Array + Clone, + A: Array + ?Sized, + B: Array + ?Sized, F: Fn(usize, usize) -> Ordering + Send + Sync + 'static, { let l = l.logical_nulls().filter(|x| x.null_count() > 0); @@ -368,6 +369,52 @@ fn compare_union( Ok(f) } +fn compare_union_to_opaque( + union_array: &dyn Array, + opaque_array: &dyn Array, + opts: SortOptions, +) -> Result { + let union_array = union_array.as_union(); + + let DataType::Union(union_fields, _) = union_array.data_type() else { + unreachable!() + }; + + let opaque_type_id = union_fields + .iter() + .find_map(|(i, f)| (f.data_type() == opaque_array.data_type()).then_some(i)) + .ok_or_else(|| { + ArrowError::InvalidArgumentError(format!( + "cannot compare union with {} array: type not found in union fields", + opaque_array.data_type(), + )) + })?; + + let c_opts = child_opts(opts); + + let opaque_field_comparator = { + let union_child = union_array.child(opaque_type_id); + make_comparator(union_child.as_ref(), opaque_array, c_opts)? + }; + + let union_type_ids = union_array.type_ids().clone(); + let union_offsets = union_array.offsets().cloned(); + + let f = compare(union_array, opaque_array, opts, move |i, j| { + let union_type_id = union_type_ids[i]; + + match union_type_id.cmp(&opaque_type_id) { + Ordering::Equal => { + let union_offset = union_offsets.as_ref().map(|o| o[i] as usize).unwrap_or(i); + opaque_field_comparator(union_offset, j) + } + other => other, + } + }); + + Ok(f) +} + /// Returns a comparison function that compares two values at two different positions /// between the two arrays. /// @@ -485,6 +532,8 @@ pub fn make_comparator( }, (Map(_, _), Map(_, _)) => compare_map(left, right, opts), (Union(_, _), Union(_, _)) => compare_union(left, right, opts), + (Union(_, _), _) => compare_union_to_opaque(left, right, opts), + (_, Union(_, _)) => compare_union_to_opaque(right, left, opts), (lhs, rhs) => Err(ArrowError::InvalidArgumentError(match lhs == rhs { true => format!("The data type type {lhs:?} has no natural order"), false => "Can't compare arrays of different types".to_string(), @@ -1501,4 +1550,151 @@ mod tests { "Cannot compare UnionArrays with different modes: left=Dense, right=Sparse" ); } + + #[test] + fn test_union_to_opaque_int32() { + let int_array = Int32Array::from(vec![1, 2, 3]); + let str_array = StringArray::from(vec!["a", "b"]); + let type_ids = [0, 1, 0, 1, 0].into_iter().collect::>(); + let offsets = [0, 0, 1, 1, 2].into_iter().collect::>(); + let union_fields = [ + (0, Arc::new(Field::new("ints", DataType::Int32, false))), + (1, Arc::new(Field::new("strings", DataType::Utf8, false))), + ] + .into_iter() + .collect::(); + let children = vec![Arc::new(int_array) as ArrayRef, Arc::new(str_array)]; + + // union: [1, "a", 2, "b", 3], opaque: [2, 6, 7] + let union_array = + UnionArray::try_new(union_fields, type_ids, Some(offsets), children).unwrap(); + let opaque_array = Int32Array::from(vec![2, 6, 7]); + let opts = SortOptions::default(); + + let cmp = make_comparator(&union_array, &opaque_array, opts).unwrap(); + + // 1 < 2 + assert_eq!(cmp(0, 0), Ordering::Less); + // type_id 1 > 0 + assert_eq!(cmp(1, 0), Ordering::Greater); + // 2 == 2 + assert_eq!(cmp(2, 0), Ordering::Equal); + // 3 > 6 + assert_eq!(cmp(4, 1), Ordering::Less); + } + + #[test] + fn test_union_to_opaque_string() { + let str_array = StringArray::from(vec![Some("apple"), None, Some("pork")]); + let int_array = Int32Array::from(vec![None, Some(67), None]); + let type_ids = [1, 0, 1].into_iter().collect::>(); + let union_fields = [ + (0, Arc::new(Field::new("ints", DataType::Int32, false))), + (1, Arc::new(Field::new("strings", DataType::Utf8, false))), + ] + .into_iter() + .collect::(); + let children = vec![Arc::new(int_array) as ArrayRef, Arc::new(str_array)]; + + // sparse union: ["apple", 67, "pork"], opaque: ["howdy", "john", "pork"] + let union_array = UnionArray::try_new(union_fields, type_ids, None, children).unwrap(); + let opaque_array = StringArray::from(vec!["howdy", "john", "pork"]); + let opts = SortOptions::default(); + + let cmp = make_comparator(&union_array, &opaque_array, opts).unwrap(); + + // apple < howdy + assert_eq!(cmp(0, 0), Ordering::Less); + // type_id < 1 + assert_eq!(cmp(1, 0), Ordering::Less); + // "pork" == "pork" + assert_eq!(cmp(2, 2), Ordering::Equal); + } + + #[test] + fn test_union_to_opaque_with_nulls() { + let int_array = Int32Array::from(vec![Some(1), None, Some(3)]); + let str_array = StringArray::from(vec![Some("a"), Some("b")]); + let type_ids = [0, 1, 0, 1, 0].into_iter().collect::>(); + let offsets = [0, 0, 1, 1, 2].into_iter().collect::>(); + let union_fields = [ + (0, Arc::new(Field::new("ints", DataType::Int32, false))), + (1, Arc::new(Field::new("strings", DataType::Utf8, false))), + ] + .into_iter() + .collect::(); + let children = vec![Arc::new(int_array) as ArrayRef, Arc::new(str_array)]; + + // union: [1, "a", null, "b", 3], opaque: [2, null, 1] + let union_array = + UnionArray::try_new(union_fields, type_ids, Some(offsets), children).unwrap(); + let opaque_array = Int32Array::from(vec![Some(2), None, Some(1)]); + let opts = SortOptions { + descending: false, + nulls_first: true, + }; + let cmp = make_comparator(&union_array, &opaque_array, opts).unwrap(); + + // 1 > null + assert_eq!(cmp(0, 1), Ordering::Greater); + // null < 2 + assert_eq!(cmp(2, 0), Ordering::Less); + // null == null + assert_eq!(cmp(2, 1), Ordering::Equal); + } + + #[test] + fn test_union_to_opaque_descending() { + let int_array = Int32Array::from(vec![1, 2, 3]); + let str_array = StringArray::from(vec!["a", "b"]); + let type_ids = [0, 1, 0].into_iter().collect::>(); + let offsets = [0, 0, 1].into_iter().collect::>(); + let union_fields = [ + (0, Arc::new(Field::new("ints", DataType::Int32, false))), + (1, Arc::new(Field::new("strings", DataType::Utf8, false))), + ] + .into_iter() + .collect::(); + let children = vec![Arc::new(int_array) as ArrayRef, Arc::new(str_array)]; + + // union: [1, "a", 2], opaque: [2, 1] + let union_array = + UnionArray::try_new(union_fields, type_ids, Some(offsets), children).unwrap(); + let opaque_array = Int32Array::from(vec![2, 1]); + let opts = SortOptions { + descending: true, + nulls_first: false, + }; + let cmp = make_comparator(&union_array, &opaque_array, opts).unwrap(); + // 1 > 2 (descending) + assert_eq!(cmp(0, 0), Ordering::Greater); + // 2 == 2 + assert_eq!(cmp(2, 0), Ordering::Equal); + // 1 == 1 + assert_eq!(cmp(0, 1), Ordering::Equal); + } + + #[test] + fn test_union_to_opaque_incompatible_type() { + let int_array = Int32Array::from(vec![1, 2]); + let str_array = StringArray::from(vec!["a", "b"]); + let type_ids = [0, 1].into_iter().collect::>(); + let offsets = [0, 0].into_iter().collect::>(); + let union_fields = [ + (0, Arc::new(Field::new("ints", DataType::Int32, false))), + (1, Arc::new(Field::new("strings", DataType::Utf8, false))), + ] + .into_iter() + .collect::(); + let children = vec![Arc::new(int_array) as ArrayRef, Arc::new(str_array)]; + let union_array = + UnionArray::try_new(union_fields, type_ids, Some(offsets), children).unwrap(); + let opaque_array = Float64Array::from(vec![1.0, 2.0]); + let opts = SortOptions::default(); + let Err(err) = make_comparator(&union_array, &opaque_array, opts) else { + panic!("expected err"); + }; + + assert!(err.to_string().contains("cannot compare union with")); + } }