diff --git a/datafusion/spark/src/function/datetime/next_day.rs b/datafusion/spark/src/function/datetime/next_day.rs index 72a0c830ffb25..2acd295f8f142 100644 --- a/datafusion/spark/src/function/datetime/next_day.rs +++ b/datafusion/spark/src/function/datetime/next_day.rs @@ -19,11 +19,12 @@ use std::any::Any; use std::sync::Arc; use arrow::array::{ArrayRef, AsArray, Date32Array, StringArrayType, new_null_array}; -use arrow::datatypes::{DataType, Date32Type}; +use arrow::datatypes::{DataType, Date32Type, Field, FieldRef}; use chrono::{Datelike, Duration, Weekday}; -use datafusion_common::{Result, ScalarValue, exec_err}; +use datafusion_common::{Result, ScalarValue, exec_err, internal_err}; use datafusion_expr::{ - ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; /// @@ -63,7 +64,13 @@ impl ScalarUDFImpl for SparkNextDay { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(DataType::Date32) + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args(&self, _args: ReturnFieldArgs) -> Result { + // Spark marks next_day as always nullable because invalid day_of_week values + // can yield NULL even when inputs are non-null. + Ok(Arc::new(Field::new(self.name(), DataType::Date32, true))) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { @@ -245,3 +252,40 @@ fn spark_next_day(days: i32, day_of_week: &str) -> Option { None } } + +#[cfg(test)] +mod tests { + use super::*; + use datafusion_expr::ReturnFieldArgs; + + #[test] + fn return_type_is_not_used() { + let func = SparkNextDay::new(); + let err = func + .return_type(&[DataType::Date32, DataType::Utf8]) + .unwrap_err(); + assert!( + err.to_string() + .contains("return_field_from_args should be used instead") + ); + } + + #[test] + fn next_day_is_always_nullable() { + let func = SparkNextDay::new(); + let date_field: FieldRef = + Arc::new(Field::new("start_date", DataType::Date32, false)); + let day_field: FieldRef = + Arc::new(Field::new("day_of_week", DataType::Utf8, false)); + + let field = func + .return_field_from_args(ReturnFieldArgs { + arg_fields: &[Arc::clone(&date_field), Arc::clone(&day_field)], + scalar_arguments: &[None, None], + }) + .unwrap(); + + assert_eq!(field.data_type(), &DataType::Date32); + assert!(field.is_nullable()); + } +}