@@ -36,9 +36,10 @@ fn child_opts(opts: SortOptions) -> SortOptions {
3636 }
3737}
3838
39- fn compare < A , F > ( l : & A , r : & A , opts : SortOptions , cmp : F ) -> DynComparator
39+ fn compare < A , B , F > ( l : & A , r : & B , opts : SortOptions , cmp : F ) -> DynComparator
4040where
41- A : Array + Clone ,
41+ A : Array + ?Sized ,
42+ B : Array + ?Sized ,
4243 F : Fn ( usize , usize ) -> Ordering + Send + Sync + ' static ,
4344{
4445 let l = l. logical_nulls ( ) . filter ( |x| x. null_count ( ) > 0 ) ;
@@ -368,6 +369,52 @@ fn compare_union(
368369 Ok ( f)
369370}
370371
372+ fn compare_union_to_opaque (
373+ union_array : & dyn Array ,
374+ opaque_array : & dyn Array ,
375+ opts : SortOptions ,
376+ ) -> Result < DynComparator , ArrowError > {
377+ let union_array = union_array. as_union ( ) ;
378+
379+ let DataType :: Union ( union_fields, _) = union_array. data_type ( ) else {
380+ unreachable ! ( )
381+ } ;
382+
383+ let opaque_type_id = union_fields
384+ . iter ( )
385+ . find_map ( |( i, f) | ( f. data_type ( ) == opaque_array. data_type ( ) ) . then_some ( i) )
386+ . ok_or_else ( || {
387+ ArrowError :: InvalidArgumentError ( format ! (
388+ "cannot compare union with {} array: type not found in union fields" ,
389+ opaque_array. data_type( ) ,
390+ ) )
391+ } ) ?;
392+
393+ let c_opts = child_opts ( opts) ;
394+
395+ let opaque_field_comparator = {
396+ let union_child = union_array. child ( opaque_type_id) ;
397+ make_comparator ( union_child. as_ref ( ) , opaque_array, c_opts) ?
398+ } ;
399+
400+ let union_type_ids = union_array. type_ids ( ) . clone ( ) ;
401+ let union_offsets = union_array. offsets ( ) . cloned ( ) ;
402+
403+ let f = compare ( union_array, opaque_array, opts, move |i, j| {
404+ let union_type_id = union_type_ids[ i] ;
405+
406+ match union_type_id. cmp ( & opaque_type_id) {
407+ Ordering :: Equal => {
408+ let union_offset = union_offsets. as_ref ( ) . map ( |o| o[ i] as usize ) . unwrap_or ( i) ;
409+ opaque_field_comparator ( union_offset, j)
410+ }
411+ other => other,
412+ }
413+ } ) ;
414+
415+ Ok ( f)
416+ }
417+
371418/// Returns a comparison function that compares two values at two different positions
372419/// between the two arrays.
373420///
@@ -485,6 +532,8 @@ pub fn make_comparator(
485532 } ,
486533 ( Map ( _, _) , Map ( _, _) ) => compare_map( left, right, opts) ,
487534 ( Union ( _, _) , Union ( _, _) ) => compare_union( left, right, opts) ,
535+ ( Union ( _, _) , _) => compare_union_to_opaque( left, right, opts) ,
536+ ( _, Union ( _, _) ) => compare_union_to_opaque( right, left, opts) ,
488537 ( lhs, rhs) => Err ( ArrowError :: InvalidArgumentError ( match lhs == rhs {
489538 true => format!( "The data type type {lhs:?} has no natural order" ) ,
490539 false => "Can't compare arrays of different types" . to_string( ) ,
@@ -1501,4 +1550,151 @@ mod tests {
15011550 "Cannot compare UnionArrays with different modes: left=Dense, right=Sparse"
15021551 ) ;
15031552 }
1553+
1554+ #[ test]
1555+ fn test_union_to_opaque_int32 ( ) {
1556+ let int_array = Int32Array :: from ( vec ! [ 1 , 2 , 3 ] ) ;
1557+ let str_array = StringArray :: from ( vec ! [ "a" , "b" ] ) ;
1558+ let type_ids = [ 0 , 1 , 0 , 1 , 0 ] . into_iter ( ) . collect :: < ScalarBuffer < i8 > > ( ) ;
1559+ let offsets = [ 0 , 0 , 1 , 1 , 2 ] . into_iter ( ) . collect :: < ScalarBuffer < i32 > > ( ) ;
1560+ let union_fields = [
1561+ ( 0 , Arc :: new ( Field :: new ( "ints" , DataType :: Int32 , false ) ) ) ,
1562+ ( 1 , Arc :: new ( Field :: new ( "strings" , DataType :: Utf8 , false ) ) ) ,
1563+ ]
1564+ . into_iter ( )
1565+ . collect :: < UnionFields > ( ) ;
1566+ let children = vec ! [ Arc :: new( int_array) as ArrayRef , Arc :: new( str_array) ] ;
1567+
1568+ // union: [1, "a", 2, "b", 3], opaque: [2, 6, 7]
1569+ let union_array =
1570+ UnionArray :: try_new ( union_fields, type_ids, Some ( offsets) , children) . unwrap ( ) ;
1571+ let opaque_array = Int32Array :: from ( vec ! [ 2 , 6 , 7 ] ) ;
1572+ let opts = SortOptions :: default ( ) ;
1573+
1574+ let cmp = make_comparator ( & union_array, & opaque_array, opts) . unwrap ( ) ;
1575+
1576+ // 1 < 2
1577+ assert_eq ! ( cmp( 0 , 0 ) , Ordering :: Less ) ;
1578+ // type_id 1 > 0
1579+ assert_eq ! ( cmp( 1 , 0 ) , Ordering :: Greater ) ;
1580+ // 2 == 2
1581+ assert_eq ! ( cmp( 2 , 0 ) , Ordering :: Equal ) ;
1582+ // 3 > 6
1583+ assert_eq ! ( cmp( 4 , 1 ) , Ordering :: Less ) ;
1584+ }
1585+
1586+ #[ test]
1587+ fn test_union_to_opaque_string ( ) {
1588+ let str_array = StringArray :: from ( vec ! [ Some ( "apple" ) , None , Some ( "pork" ) ] ) ;
1589+ let int_array = Int32Array :: from ( vec ! [ None , Some ( 67 ) , None ] ) ;
1590+ let type_ids = [ 1 , 0 , 1 ] . into_iter ( ) . collect :: < ScalarBuffer < i8 > > ( ) ;
1591+ let union_fields = [
1592+ ( 0 , Arc :: new ( Field :: new ( "ints" , DataType :: Int32 , false ) ) ) ,
1593+ ( 1 , Arc :: new ( Field :: new ( "strings" , DataType :: Utf8 , false ) ) ) ,
1594+ ]
1595+ . into_iter ( )
1596+ . collect :: < UnionFields > ( ) ;
1597+ let children = vec ! [ Arc :: new( int_array) as ArrayRef , Arc :: new( str_array) ] ;
1598+
1599+ // sparse union: ["apple", 67, "pork"], opaque: ["howdy", "john", "pork"]
1600+ let union_array = UnionArray :: try_new ( union_fields, type_ids, None , children) . unwrap ( ) ;
1601+ let opaque_array = StringArray :: from ( vec ! [ "howdy" , "john" , "pork" ] ) ;
1602+ let opts = SortOptions :: default ( ) ;
1603+
1604+ let cmp = make_comparator ( & union_array, & opaque_array, opts) . unwrap ( ) ;
1605+
1606+ // apple < howdy
1607+ assert_eq ! ( cmp( 0 , 0 ) , Ordering :: Less ) ;
1608+ // type_id < 1
1609+ assert_eq ! ( cmp( 1 , 0 ) , Ordering :: Less ) ;
1610+ // "pork" == "pork"
1611+ assert_eq ! ( cmp( 2 , 2 ) , Ordering :: Equal ) ;
1612+ }
1613+
1614+ #[ test]
1615+ fn test_union_to_opaque_with_nulls ( ) {
1616+ let int_array = Int32Array :: from ( vec ! [ Some ( 1 ) , None , Some ( 3 ) ] ) ;
1617+ let str_array = StringArray :: from ( vec ! [ Some ( "a" ) , Some ( "b" ) ] ) ;
1618+ let type_ids = [ 0 , 1 , 0 , 1 , 0 ] . into_iter ( ) . collect :: < ScalarBuffer < i8 > > ( ) ;
1619+ let offsets = [ 0 , 0 , 1 , 1 , 2 ] . into_iter ( ) . collect :: < ScalarBuffer < i32 > > ( ) ;
1620+ let union_fields = [
1621+ ( 0 , Arc :: new ( Field :: new ( "ints" , DataType :: Int32 , false ) ) ) ,
1622+ ( 1 , Arc :: new ( Field :: new ( "strings" , DataType :: Utf8 , false ) ) ) ,
1623+ ]
1624+ . into_iter ( )
1625+ . collect :: < UnionFields > ( ) ;
1626+ let children = vec ! [ Arc :: new( int_array) as ArrayRef , Arc :: new( str_array) ] ;
1627+
1628+ // union: [1, "a", null, "b", 3], opaque: [2, null, 1]
1629+ let union_array =
1630+ UnionArray :: try_new ( union_fields, type_ids, Some ( offsets) , children) . unwrap ( ) ;
1631+ let opaque_array = Int32Array :: from ( vec ! [ Some ( 2 ) , None , Some ( 1 ) ] ) ;
1632+ let opts = SortOptions {
1633+ descending : false ,
1634+ nulls_first : true ,
1635+ } ;
1636+ let cmp = make_comparator ( & union_array, & opaque_array, opts) . unwrap ( ) ;
1637+
1638+ // 1 > null
1639+ assert_eq ! ( cmp( 0 , 1 ) , Ordering :: Greater ) ;
1640+ // null < 2
1641+ assert_eq ! ( cmp( 2 , 0 ) , Ordering :: Less ) ;
1642+ // null == null
1643+ assert_eq ! ( cmp( 2 , 1 ) , Ordering :: Equal ) ;
1644+ }
1645+
1646+ #[ test]
1647+ fn test_union_to_opaque_descending ( ) {
1648+ let int_array = Int32Array :: from ( vec ! [ 1 , 2 , 3 ] ) ;
1649+ let str_array = StringArray :: from ( vec ! [ "a" , "b" ] ) ;
1650+ let type_ids = [ 0 , 1 , 0 ] . into_iter ( ) . collect :: < ScalarBuffer < i8 > > ( ) ;
1651+ let offsets = [ 0 , 0 , 1 ] . into_iter ( ) . collect :: < ScalarBuffer < i32 > > ( ) ;
1652+ let union_fields = [
1653+ ( 0 , Arc :: new ( Field :: new ( "ints" , DataType :: Int32 , false ) ) ) ,
1654+ ( 1 , Arc :: new ( Field :: new ( "strings" , DataType :: Utf8 , false ) ) ) ,
1655+ ]
1656+ . into_iter ( )
1657+ . collect :: < UnionFields > ( ) ;
1658+ let children = vec ! [ Arc :: new( int_array) as ArrayRef , Arc :: new( str_array) ] ;
1659+
1660+ // union: [1, "a", 2], opaque: [2, 1]
1661+ let union_array =
1662+ UnionArray :: try_new ( union_fields, type_ids, Some ( offsets) , children) . unwrap ( ) ;
1663+ let opaque_array = Int32Array :: from ( vec ! [ 2 , 1 ] ) ;
1664+ let opts = SortOptions {
1665+ descending : true ,
1666+ nulls_first : false ,
1667+ } ;
1668+ let cmp = make_comparator ( & union_array, & opaque_array, opts) . unwrap ( ) ;
1669+ // 1 > 2 (descending)
1670+ assert_eq ! ( cmp( 0 , 0 ) , Ordering :: Greater ) ;
1671+ // 2 == 2
1672+ assert_eq ! ( cmp( 2 , 0 ) , Ordering :: Equal ) ;
1673+ // 1 == 1
1674+ assert_eq ! ( cmp( 0 , 1 ) , Ordering :: Equal ) ;
1675+ }
1676+
1677+ #[ test]
1678+ fn test_union_to_opaque_incompatible_type ( ) {
1679+ let int_array = Int32Array :: from ( vec ! [ 1 , 2 ] ) ;
1680+ let str_array = StringArray :: from ( vec ! [ "a" , "b" ] ) ;
1681+ let type_ids = [ 0 , 1 ] . into_iter ( ) . collect :: < ScalarBuffer < i8 > > ( ) ;
1682+ let offsets = [ 0 , 0 ] . into_iter ( ) . collect :: < ScalarBuffer < i32 > > ( ) ;
1683+ let union_fields = [
1684+ ( 0 , Arc :: new ( Field :: new ( "ints" , DataType :: Int32 , false ) ) ) ,
1685+ ( 1 , Arc :: new ( Field :: new ( "strings" , DataType :: Utf8 , false ) ) ) ,
1686+ ]
1687+ . into_iter ( )
1688+ . collect :: < UnionFields > ( ) ;
1689+ let children = vec ! [ Arc :: new( int_array) as ArrayRef , Arc :: new( str_array) ] ;
1690+ let union_array =
1691+ UnionArray :: try_new ( union_fields, type_ids, Some ( offsets) , children) . unwrap ( ) ;
1692+ let opaque_array = Float64Array :: from ( vec ! [ 1.0 , 2.0 ] ) ;
1693+ let opts = SortOptions :: default ( ) ;
1694+ let Err ( err) = make_comparator ( & union_array, & opaque_array, opts) else {
1695+ panic ! ( "expected err" ) ;
1696+ } ;
1697+
1698+ assert ! ( err. to_string( ) . contains( "cannot compare union with" ) ) ;
1699+ }
15041700}
0 commit comments