Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 218 additions & 4 deletions datafusion/functions-nested/src/array_has.rs
Original file line number Diff line number Diff line change
Expand Up @@ -318,8 +318,8 @@ impl<'a> ArrayWrapper<'a> {
}
}

fn array_has_dispatch_for_array<'a>(
haystack: ArrayWrapper<'a>,
fn array_has_dispatch_for_array(
haystack: ArrayWrapper,
needle: &ArrayRef,
) -> Result<ArrayRef> {
let mut boolean_builder = BooleanArray::builder(haystack.len());
Expand All @@ -332,7 +332,17 @@ fn array_has_dispatch_for_array<'a>(
let is_nested = arr.data_type().is_nested();
let needle_row = Scalar::new(needle.slice(i, 1));
let eq_array = compare_with_eq(&arr, &needle_row, is_nested)?;
boolean_builder.append_value(eq_array.true_count() > 0);

let value_to_append = if eq_array.true_count() > 0 {
Some(true)
} else if eq_array.null_count() > 0 {
// If there are nulls in the comparison result and no true matches,
// the result is null
None
} else {
Some(false)
};
boolean_builder.append_option(value_to_append);
}

Ok(Arc::new(boolean_builder.finish()))
Expand Down Expand Up @@ -378,7 +388,14 @@ fn array_has_dispatch_for_scalar(
final_contained[i] = Some(false); // empty array -> false
} else {
let sliced_array = eq_array.slice(start, length);
final_contained[i] = Some(sliced_array.true_count() > 0);
let result_value = if sliced_array.true_count() > 0 {
Some(true)
} else if sliced_array.null_count() > 0 {
None
} else {
Some(false)
};
final_contained[i] = result_value;
}
}

Expand Down Expand Up @@ -859,4 +876,201 @@ mod tests {

Ok(())
}

#[test]
fn test_array_has_with_nulls() -> Result<(), DataFusionError> {
use arrow::array::{Int32Array, ListArray};
use arrow::buffer::OffsetBuffer;
use std::sync::Arc;

// Test data matching the Scala test cases
let test_cases = [
(vec![Some(1), Some(2), Some(3)], Some(2), Some(true)),
// Row(Seq(1, null, 3), 2) - should return null
(vec![Some(1), None, Some(3)], Some(2), None),
// Row(Seq(1, null, 3), null) - should return null
(vec![Some(1), None, Some(3)], None, None),
// Row(Seq(), 1) - empty array, should return false
(vec![], Some(1), Some(false)),
// Row(null, 1) - null array, should return null (handled by null validity)
(vec![], Some(1), None), // Will set array as null below
// Row(Seq(null, null), null) - should return null
(vec![None, None], None, None),
];

let haystack_field = Arc::new(Field::new_list(
"haystack",
Field::new("item", DataType::Int32, true),
true,
));
let needle_field = Arc::new(Field::new("needle", DataType::Int32, true));
let return_field = Arc::new(Field::new("return", DataType::Boolean, true));

// Build the haystack array
let mut offsets = vec![0i32];
let mut values = Vec::new();
let mut array_validity = Vec::new();

for (i, (arr, _, _)) in test_cases.iter().enumerate() {
if i == 4 {
// Case 5: null array
array_validity.push(false);
offsets.push(offsets.last().unwrap() + 0);
} else {
array_validity.push(true);
values.extend(arr.iter().copied());
offsets.push(offsets.last().unwrap() + arr.len() as i32);
}
}

let values_array = Arc::new(Int32Array::from(values)) as ArrayRef;
let haystack = ListArray::new(
Field::new("item", DataType::Int32, true).into(),
OffsetBuffer::new(offsets.into()),
values_array,
Some(array_validity.into()),
);

// Build the needle array
let needle_values: Vec<Option<i32>> =
test_cases.iter().map(|(_, n, _)| *n).collect();
let needle = Arc::new(Int32Array::from(needle_values)) as ArrayRef;

// Execute array_has
let haystack_columnar = ColumnarValue::Array(Arc::new(haystack));
let needle_columnar = ColumnarValue::Array(needle);

let result = ArrayHas::new().invoke_with_args(ScalarFunctionArgs {
args: vec![haystack_columnar, needle_columnar],
arg_fields: vec![haystack_field, needle_field],
number_rows: test_cases.len(),
return_field,
config_options: Arc::new(ConfigOptions::default()),
})?;

let output = result.into_array(test_cases.len())?;
let output = output.as_boolean();

// Verify results
assert_eq!(output.len(), test_cases.len());
for (i, (_, _, expected)) in test_cases.iter().enumerate() {
match expected {
Some(expected_val) => {
assert!(!output.is_null(i), "Expected non-null at index {i}");
assert_eq!(
output.value(i),
*expected_val,
"Mismatch at index {}: expected {:?}, got {:?}",
i,
expected_val,
output.value(i)
);
}
None => {
assert!(output.is_null(i), "Expected null at index {i}");
}
}
}

Ok(())
}

#[test]
fn test_array_has_scalar_needle_with_nulls() -> Result<(), DataFusionError> {
use arrow::array::{Int32Array, ListArray};
use arrow::buffer::OffsetBuffer;
use std::sync::Arc;

// Test cases with scalar needle (not array needle)
// All tests search for the scalar value 2
let test_cases = [
// Row(Seq(1, 2, 3), 2) - should return true (element found)
(vec![Some(1), Some(2), Some(3)], Some(true)),
// Row(Seq(1, null, 3), 2) - should return null (has nulls, element not found)
(vec![Some(1), None, Some(3)], None),
// Row(Seq(4, 5, 6), 2) - should return false (no nulls, element not found)
(vec![Some(4), Some(5), Some(6)], Some(false)),
// Row(Seq(), 2) - empty array, should return false
(vec![], Some(false)),
// Row(null, 2) - null array, should return null
(vec![], None), // Will set array as null below
// Row(Seq(null, null), 2) - should return null (only nulls, element unknown)
(vec![None, None], None),
// Row(Seq(2, null), 2) - should return true (element found even with nulls)
(vec![Some(2), None], Some(true)),
];

let haystack_field = Arc::new(Field::new_list(
"haystack",
Field::new("item", DataType::Int32, true),
true,
));
let needle_field = Arc::new(Field::new("needle", DataType::Int32, true));
let return_field = Arc::new(Field::new("return", DataType::Boolean, true));

// Build the haystack array
let mut offsets = vec![0i32];
let mut values = Vec::new();
let mut array_validity = Vec::new();

for (i, (arr, _)) in test_cases.iter().enumerate() {
if i == 4 {
// Case 5: null array
array_validity.push(false);
offsets.push(*offsets.last().unwrap());
} else {
array_validity.push(true);
values.extend(arr.iter().copied());
offsets.push(offsets.last().unwrap() + arr.len() as i32);
}
}

let values_array = Arc::new(Int32Array::from(values)) as ArrayRef;
let haystack = ListArray::new(
Field::new("item", DataType::Int32, true).into(),
OffsetBuffer::new(offsets.into()),
values_array,
Some(array_validity.into()),
);

// Use a SCALAR needle (this is the key difference from the other test)
let needle = ColumnarValue::Scalar(ScalarValue::Int32(Some(2)));

// Execute array_has
let haystack_columnar = ColumnarValue::Array(Arc::new(haystack));

let result = ArrayHas::new().invoke_with_args(ScalarFunctionArgs {
args: vec![haystack_columnar, needle],
arg_fields: vec![haystack_field, needle_field],
number_rows: test_cases.len(),
return_field,
config_options: Arc::new(ConfigOptions::default()),
})?;

let output = result.into_array(test_cases.len())?;
let output = output.as_boolean();

// Verify results
assert_eq!(output.len(), test_cases.len());
for (i, (_, expected)) in test_cases.iter().enumerate() {
match expected {
Some(expected_val) => {
assert!(!output.is_null(i), "Expected non-null at index {i}");
assert_eq!(
output.value(i),
*expected_val,
"Mismatch at index {}: expected {:?}, got {:?}",
i,
expected_val,
output.value(i)
);
}
None => {
assert!(output.is_null(i), "Expected null at index {i}");
}
}
}

Ok(())
}
}
Loading