Skip to content

Commit 9b30e9a

Browse files
compare union to opaque
1 parent d92d584 commit 9b30e9a

File tree

1 file changed

+198
-2
lines changed

1 file changed

+198
-2
lines changed

arrow-ord/src/ord.rs

Lines changed: 198 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,10 @@ fn child_opts(opts: SortOptions) -> SortOptions {
3636
}
3737
}
3838

39-
fn compare<A, F>(l: &A, r: &A, opts: SortOptions, cmp: F) -> DynComparator
39+
fn compare<A, B, F>(l: &A, r: &B, opts: SortOptions, cmp: F) -> DynComparator
4040
where
41-
A: Array + Clone,
41+
A: Array + ?Sized,
42+
B: Array + ?Sized,
4243
F: Fn(usize, usize) -> Ordering + Send + Sync + 'static,
4344
{
4445
let l = l.logical_nulls().filter(|x| x.null_count() > 0);
@@ -368,6 +369,52 @@ fn compare_union(
368369
Ok(f)
369370
}
370371

372+
fn compare_union_to_opaque(
373+
union_array: &dyn Array,
374+
opaque_array: &dyn Array,
375+
opts: SortOptions,
376+
) -> Result<DynComparator, ArrowError> {
377+
let union_array = union_array.as_union();
378+
379+
let DataType::Union(union_fields, _) = union_array.data_type() else {
380+
unreachable!()
381+
};
382+
383+
let opaque_type_id = union_fields
384+
.iter()
385+
.find_map(|(i, f)| (f.data_type() == opaque_array.data_type()).then_some(i))
386+
.ok_or_else(|| {
387+
ArrowError::InvalidArgumentError(format!(
388+
"cannot compare union with {} array: type not found in union fields",
389+
opaque_array.data_type(),
390+
))
391+
})?;
392+
393+
let c_opts = child_opts(opts);
394+
395+
let opaque_field_comparator = {
396+
let union_child = union_array.child(opaque_type_id);
397+
make_comparator(union_child.as_ref(), opaque_array, c_opts)?
398+
};
399+
400+
let union_type_ids = union_array.type_ids().clone();
401+
let union_offsets = union_array.offsets().cloned();
402+
403+
let f = compare(union_array, opaque_array, opts, move |i, j| {
404+
let union_type_id = union_type_ids[i];
405+
406+
match union_type_id.cmp(&opaque_type_id) {
407+
Ordering::Equal => {
408+
let union_offset = union_offsets.as_ref().map(|o| o[i] as usize).unwrap_or(i);
409+
opaque_field_comparator(union_offset, j)
410+
}
411+
other => other,
412+
}
413+
});
414+
415+
Ok(f)
416+
}
417+
371418
/// Returns a comparison function that compares two values at two different positions
372419
/// between the two arrays.
373420
///
@@ -485,6 +532,8 @@ pub fn make_comparator(
485532
},
486533
(Map(_, _), Map(_, _)) => compare_map(left, right, opts),
487534
(Union(_, _), Union(_, _)) => compare_union(left, right, opts),
535+
(Union(_, _), _) => compare_union_to_opaque(left, right, opts),
536+
(_, Union(_, _)) => compare_union_to_opaque(right, left, opts),
488537
(lhs, rhs) => Err(ArrowError::InvalidArgumentError(match lhs == rhs {
489538
true => format!("The data type type {lhs:?} has no natural order"),
490539
false => "Can't compare arrays of different types".to_string(),
@@ -1501,4 +1550,151 @@ mod tests {
15011550
"Cannot compare UnionArrays with different modes: left=Dense, right=Sparse"
15021551
);
15031552
}
1553+
1554+
#[test]
1555+
fn test_union_to_opaque_int32() {
1556+
let int_array = Int32Array::from(vec![1, 2, 3]);
1557+
let str_array = StringArray::from(vec!["a", "b"]);
1558+
let type_ids = [0, 1, 0, 1, 0].into_iter().collect::<ScalarBuffer<i8>>();
1559+
let offsets = [0, 0, 1, 1, 2].into_iter().collect::<ScalarBuffer<i32>>();
1560+
let union_fields = [
1561+
(0, Arc::new(Field::new("ints", DataType::Int32, false))),
1562+
(1, Arc::new(Field::new("strings", DataType::Utf8, false))),
1563+
]
1564+
.into_iter()
1565+
.collect::<UnionFields>();
1566+
let children = vec![Arc::new(int_array) as ArrayRef, Arc::new(str_array)];
1567+
1568+
// union: [1, "a", 2, "b", 3], opaque: [2, 6, 7]
1569+
let union_array =
1570+
UnionArray::try_new(union_fields, type_ids, Some(offsets), children).unwrap();
1571+
let opaque_array = Int32Array::from(vec![2, 6, 7]);
1572+
let opts = SortOptions::default();
1573+
1574+
let cmp = make_comparator(&union_array, &opaque_array, opts).unwrap();
1575+
1576+
// 1 < 2
1577+
assert_eq!(cmp(0, 0), Ordering::Less);
1578+
// type_id 1 > 0
1579+
assert_eq!(cmp(1, 0), Ordering::Greater);
1580+
// 2 == 2
1581+
assert_eq!(cmp(2, 0), Ordering::Equal);
1582+
// 3 > 6
1583+
assert_eq!(cmp(4, 1), Ordering::Less);
1584+
}
1585+
1586+
#[test]
1587+
fn test_union_to_opaque_string() {
1588+
let str_array = StringArray::from(vec![Some("apple"), None, Some("pork")]);
1589+
let int_array = Int32Array::from(vec![None, Some(67), None]);
1590+
let type_ids = [1, 0, 1].into_iter().collect::<ScalarBuffer<i8>>();
1591+
let union_fields = [
1592+
(0, Arc::new(Field::new("ints", DataType::Int32, false))),
1593+
(1, Arc::new(Field::new("strings", DataType::Utf8, false))),
1594+
]
1595+
.into_iter()
1596+
.collect::<UnionFields>();
1597+
let children = vec![Arc::new(int_array) as ArrayRef, Arc::new(str_array)];
1598+
1599+
// sparse union: ["apple", 67, "pork"], opaque: ["howdy", "john", "pork"]
1600+
let union_array = UnionArray::try_new(union_fields, type_ids, None, children).unwrap();
1601+
let opaque_array = StringArray::from(vec!["howdy", "john", "pork"]);
1602+
let opts = SortOptions::default();
1603+
1604+
let cmp = make_comparator(&union_array, &opaque_array, opts).unwrap();
1605+
1606+
// apple < howdy
1607+
assert_eq!(cmp(0, 0), Ordering::Less);
1608+
// type_id < 1
1609+
assert_eq!(cmp(1, 0), Ordering::Less);
1610+
// "pork" == "pork"
1611+
assert_eq!(cmp(2, 2), Ordering::Equal);
1612+
}
1613+
1614+
#[test]
1615+
fn test_union_to_opaque_with_nulls() {
1616+
let int_array = Int32Array::from(vec![Some(1), None, Some(3)]);
1617+
let str_array = StringArray::from(vec![Some("a"), Some("b")]);
1618+
let type_ids = [0, 1, 0, 1, 0].into_iter().collect::<ScalarBuffer<i8>>();
1619+
let offsets = [0, 0, 1, 1, 2].into_iter().collect::<ScalarBuffer<i32>>();
1620+
let union_fields = [
1621+
(0, Arc::new(Field::new("ints", DataType::Int32, false))),
1622+
(1, Arc::new(Field::new("strings", DataType::Utf8, false))),
1623+
]
1624+
.into_iter()
1625+
.collect::<UnionFields>();
1626+
let children = vec![Arc::new(int_array) as ArrayRef, Arc::new(str_array)];
1627+
1628+
// union: [1, "a", null, "b", 3], opaque: [2, null, 1]
1629+
let union_array =
1630+
UnionArray::try_new(union_fields, type_ids, Some(offsets), children).unwrap();
1631+
let opaque_array = Int32Array::from(vec![Some(2), None, Some(1)]);
1632+
let opts = SortOptions {
1633+
descending: false,
1634+
nulls_first: true,
1635+
};
1636+
let cmp = make_comparator(&union_array, &opaque_array, opts).unwrap();
1637+
1638+
// 1 > null
1639+
assert_eq!(cmp(0, 1), Ordering::Greater);
1640+
// null < 2
1641+
assert_eq!(cmp(2, 0), Ordering::Less);
1642+
// null == null
1643+
assert_eq!(cmp(2, 1), Ordering::Equal);
1644+
}
1645+
1646+
#[test]
1647+
fn test_union_to_opaque_descending() {
1648+
let int_array = Int32Array::from(vec![1, 2, 3]);
1649+
let str_array = StringArray::from(vec!["a", "b"]);
1650+
let type_ids = [0, 1, 0].into_iter().collect::<ScalarBuffer<i8>>();
1651+
let offsets = [0, 0, 1].into_iter().collect::<ScalarBuffer<i32>>();
1652+
let union_fields = [
1653+
(0, Arc::new(Field::new("ints", DataType::Int32, false))),
1654+
(1, Arc::new(Field::new("strings", DataType::Utf8, false))),
1655+
]
1656+
.into_iter()
1657+
.collect::<UnionFields>();
1658+
let children = vec![Arc::new(int_array) as ArrayRef, Arc::new(str_array)];
1659+
1660+
// union: [1, "a", 2], opaque: [2, 1]
1661+
let union_array =
1662+
UnionArray::try_new(union_fields, type_ids, Some(offsets), children).unwrap();
1663+
let opaque_array = Int32Array::from(vec![2, 1]);
1664+
let opts = SortOptions {
1665+
descending: true,
1666+
nulls_first: false,
1667+
};
1668+
let cmp = make_comparator(&union_array, &opaque_array, opts).unwrap();
1669+
// 1 > 2 (descending)
1670+
assert_eq!(cmp(0, 0), Ordering::Greater);
1671+
// 2 == 2
1672+
assert_eq!(cmp(2, 0), Ordering::Equal);
1673+
// 1 == 1
1674+
assert_eq!(cmp(0, 1), Ordering::Equal);
1675+
}
1676+
1677+
#[test]
1678+
fn test_union_to_opaque_incompatible_type() {
1679+
let int_array = Int32Array::from(vec![1, 2]);
1680+
let str_array = StringArray::from(vec!["a", "b"]);
1681+
let type_ids = [0, 1].into_iter().collect::<ScalarBuffer<i8>>();
1682+
let offsets = [0, 0].into_iter().collect::<ScalarBuffer<i32>>();
1683+
let union_fields = [
1684+
(0, Arc::new(Field::new("ints", DataType::Int32, false))),
1685+
(1, Arc::new(Field::new("strings", DataType::Utf8, false))),
1686+
]
1687+
.into_iter()
1688+
.collect::<UnionFields>();
1689+
let children = vec![Arc::new(int_array) as ArrayRef, Arc::new(str_array)];
1690+
let union_array =
1691+
UnionArray::try_new(union_fields, type_ids, Some(offsets), children).unwrap();
1692+
let opaque_array = Float64Array::from(vec![1.0, 2.0]);
1693+
let opts = SortOptions::default();
1694+
let Err(err) = make_comparator(&union_array, &opaque_array, opts) else {
1695+
panic!("expected err");
1696+
};
1697+
1698+
assert!(err.to_string().contains("cannot compare union with"));
1699+
}
15041700
}

0 commit comments

Comments
 (0)