diff --git a/datafusion/spark/src/function/math/modulus.rs b/datafusion/spark/src/function/math/modulus.rs index 49657e2cb8ce..f8a30938bd19 100644 --- a/datafusion/spark/src/function/math/modulus.rs +++ b/datafusion/spark/src/function/math/modulus.rs @@ -17,12 +17,14 @@ use arrow::compute::kernels::numeric::add; use arrow::compute::kernels::{cmp::lt, numeric::rem, zip::zip}; -use arrow::datatypes::DataType; -use datafusion_common::{Result, ScalarValue, assert_eq_or_internal_err}; +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::{Result, ScalarValue, assert_eq_or_internal_err, internal_err}; use datafusion_expr::{ - ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; use std::any::Any; +use std::sync::Arc; /// Spark-compatible `mod` function /// This function directly uses Arrow's arithmetic_op function for modulo operations @@ -82,16 +84,13 @@ impl ScalarUDFImpl for SparkMod { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - assert_eq_or_internal_err!( - arg_types.len(), - 2, - "mod expects exactly two arguments" - ); + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") + } - // Return the same type as the first argument for simplicity - // Arrow's rem function handles type promotion internally - Ok(arg_types[0].clone()) + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let data_type = args.arg_fields[0].data_type().clone(); + Ok(Arc::new(Field::new(self.name(), data_type, true))) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { @@ -99,7 +98,7 @@ impl ScalarUDFImpl for SparkMod { } } -/// SparkMod implements the Spark-compatible modulo function +/// SparkPMod implements the Spark-compatible modulo function #[derive(Debug, PartialEq, Eq, Hash)] pub struct SparkPmod { signature: Signature, @@ -132,16 +131,13 @@ impl ScalarUDFImpl for SparkPmod { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - assert_eq_or_internal_err!( - arg_types.len(), - 2, - "pmod expects exactly two arguments" - ); + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") + } - // Return the same type as the first argument for simplicity - // Arrow's rem function handles type promotion internally - Ok(arg_types[0].clone()) + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let data_type = args.arg_fields[0].data_type().clone(); + Ok(Arc::new(Field::new(self.name(), data_type, true))) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { @@ -606,4 +602,88 @@ mod test { panic!("Expected array result"); } } + + #[test] + fn test_mod_return_type_error() { + let mod_func = SparkMod::new(); + let result = mod_func.return_type(&[DataType::Int32, DataType::Int32]); + assert!(result.is_err()); + } + + #[test] + fn test_mod_return_field_nullability() { + let mod_func = SparkMod::new(); + + let args = ReturnFieldArgs { + arg_fields: &[ + Arc::new(Field::new("a", DataType::Int32, false)), + Arc::new(Field::new("b", DataType::Int32, false)), + ], + scalar_arguments: &[], + }; + let field = mod_func.return_field_from_args(args).unwrap(); + assert!(field.is_nullable()); + + let args = ReturnFieldArgs { + arg_fields: &[ + Arc::new(Field::new("a", DataType::Int32, false)), + Arc::new(Field::new("b", DataType::Int32, true)), + ], + scalar_arguments: &[], + }; + let field = mod_func.return_field_from_args(args).unwrap(); + assert!(field.is_nullable()); + + let args = ReturnFieldArgs { + arg_fields: &[ + Arc::new(Field::new("a", DataType::Int32, true)), + Arc::new(Field::new("b", DataType::Int32, true)), + ], + scalar_arguments: &[], + }; + let field = mod_func.return_field_from_args(args).unwrap(); + assert!(field.is_nullable()); + } + + #[test] + fn test_pmod_return_type_error() { + let pmod_func = SparkPmod::new(); + let result = pmod_func.return_type(&[DataType::Int32, DataType::Int32]); + assert!(result.is_err()); + } + + #[test] + fn test_pmod_return_field_nullability() { + let pmod_func = SparkPmod::new(); + + let args = ReturnFieldArgs { + arg_fields: &[ + Arc::new(Field::new("a", DataType::Int32, false)), + Arc::new(Field::new("b", DataType::Int32, false)), + ], + scalar_arguments: &[], + }; + let field = pmod_func.return_field_from_args(args).unwrap(); + assert!(field.is_nullable()); + + let args = ReturnFieldArgs { + arg_fields: &[ + Arc::new(Field::new("a", DataType::Int32, false)), + Arc::new(Field::new("b", DataType::Int32, true)), + ], + scalar_arguments: &[], + }; + let field = pmod_func.return_field_from_args(args).unwrap(); + assert!(field.is_nullable()); + + let args = ReturnFieldArgs { + arg_fields: &[ + Arc::new(Field::new("a", DataType::Int32, true)), + Arc::new(Field::new("b", DataType::Int32, true)), + ], + scalar_arguments: &[], + }; + let field = pmod_func.return_field_from_args(args).unwrap(); + assert!(field.is_nullable()); + } }