From 037132daafb196303ab04458f7517494775dbb83 Mon Sep 17 00:00:00 2001 From: Justin O'Dwyer Date: Wed, 10 Dec 2025 15:54:36 -0500 Subject: [PATCH 1/4] Add return_field_from_args. --- datafusion/spark/src/function/math/modulus.rs | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/datafusion/spark/src/function/math/modulus.rs b/datafusion/spark/src/function/math/modulus.rs index 60d45baa7f38..08b1cc238911 100644 --- a/datafusion/spark/src/function/math/modulus.rs +++ b/datafusion/spark/src/function/math/modulus.rs @@ -83,15 +83,13 @@ impl ScalarUDFImpl for SparkMod { } fn return_type(&self, arg_types: &[DataType]) -> Result { - assert_eq_or_internal_err!( - arg_types.len(), - 2, - "mod expects exactly two arguments" - ); + internal_err("return_field_from_args should be used instead") + } - // Return the same type as the first argument for simplicity - // Arrow's rem function handles type promotion internally - Ok(arg_types[0].clone()) + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let any_nullable = args.input_fields.iter().any(|f| f.is_nullable()); + let data_type = args.input_fields[0].data_type().clone(); + Ok(Arc::new(Field::new(args.name, data_type, any_nullable))) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { From e43e7447fa807c6455283eca3926abd26d96f249 Mon Sep 17 00:00:00 2001 From: Justin O'Dwyer Date: Fri, 12 Dec 2025 13:54:42 -0500 Subject: [PATCH 2/4] Adding tests, imports, comments. --- datafusion/spark/src/function/math/modulus.rs | 106 +++++++++++++++--- 1 file changed, 89 insertions(+), 17 deletions(-) diff --git a/datafusion/spark/src/function/math/modulus.rs b/datafusion/spark/src/function/math/modulus.rs index 08b1cc238911..521d54b33b6a 100644 --- a/datafusion/spark/src/function/math/modulus.rs +++ b/datafusion/spark/src/function/math/modulus.rs @@ -17,12 +17,14 @@ use arrow::compute::kernels::numeric::add; use arrow::compute::kernels::{cmp::lt, numeric::rem, zip::zip}; -use arrow::datatypes::DataType; -use datafusion_common::{assert_eq_or_internal_err, Result, ScalarValue}; +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::{assert_eq_or_internal_err, internal_err, Result, ScalarValue}; use datafusion_expr::{ - ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; use std::any::Any; +use std::sync::Arc; /// Spark-compatible `mod` function /// This function directly uses Arrow's arithmetic_op function for modulo operations @@ -82,14 +84,16 @@ impl ScalarUDFImpl for SparkMod { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - internal_err("return_field_from_args should be used instead") + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") } fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { - let any_nullable = args.input_fields.iter().any(|f| f.is_nullable()); - let data_type = args.input_fields[0].data_type().clone(); - Ok(Arc::new(Field::new(args.name, data_type, any_nullable))) + // The mod function output is nullable only in the case that the input is nullable + // (notably, a mod 0 returns an error, not null). Thus this check is sufficient. + let any_nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + let data_type = args.arg_fields[0].data_type().clone(); + Ok(Arc::new(Field::new(self.name(), data_type, any_nullable))) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { @@ -130,16 +134,16 @@ impl ScalarUDFImpl for SparkPmod { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - assert_eq_or_internal_err!( - arg_types.len(), - 2, - "pmod expects exactly two arguments" - ); + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") + } - // Return the same type as the first argument for simplicity - // Arrow's rem function handles type promotion internally - Ok(arg_types[0].clone()) + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + // The mod function output is nullable only in the case that the input is nullable + // (notably, a mod 0 returns an error, not null). Thus this check is sufficient. + let any_nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + let data_type = args.arg_fields[0].data_type().clone(); + Ok(Arc::new(Field::new(self.name(), data_type, any_nullable))) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { @@ -604,4 +608,72 @@ mod test { panic!("Expected array result"); } } + + #[test] + fn test_mod_return_type_error() { + let mod_func = SparkMod::new(); + let result = mod_func.return_type(&[DataType::Int32, DataType::Int32]); + assert!(result.is_err()); + } + + #[test] + fn test_mod_return_field_nullability() { + let mod_func = SparkMod::new(); + + // Non-nullable inputs -> non-nullable output. + let args = ReturnFieldArgs { + arg_fields: &[ + Arc::new(Field::new("a", DataType::Int32, false)), + Arc::new(Field::new("b", DataType::Int32, false)), + ], + scalar_arguments: &[], + }; + let field = mod_func.return_field_from_args(args).unwrap(); + assert!(!field.is_nullable()); + + // Nullable input -> nullable output. + let args = ReturnFieldArgs { + arg_fields: &[ + Arc::new(Field::new("a", DataType::Int32, true)), + Arc::new(Field::new("b", DataType::Int32, false)), + ], + scalar_arguments: &[], + }; + let field = mod_func.return_field_from_args(args).unwrap(); + assert!(field.is_nullable()); + } + + #[test] + fn test_pmod_return_type_error() { + let pmod_func = SparkPmod::new(); + let result = pmod_func.return_type(&[DataType::Int32, DataType::Int32]); + assert!(result.is_err()); + } + + #[test] + fn test_pmod_return_field_nullability() { + let pmod_func = SparkPmod::new(); + + // Non-nullable inputs -> non-nullable output. + let args = ReturnFieldArgs { + arg_fields: &[ + Arc::new(Field::new("a", DataType::Int32, false)), + Arc::new(Field::new("b", DataType::Int32, false)), + ], + scalar_arguments: &[], + }; + let field = pmod_func.return_field_from_args(args).unwrap(); + assert!(!field.is_nullable()); + + // Nullable input -> nullable output. + let args = ReturnFieldArgs { + arg_fields: &[ + Arc::new(Field::new("a", DataType::Int32, true)), + Arc::new(Field::new("b", DataType::Int32, false)), + ], + scalar_arguments: &[], + }; + let field = pmod_func.return_field_from_args(args).unwrap(); + assert!(field.is_nullable()); + } } From 5c3a4679b957af1f2e5641454019f725f7d8a531 Mon Sep 17 00:00:00 2001 From: Justin O'Dwyer Date: Sun, 14 Dec 2025 15:03:48 -0500 Subject: [PATCH 3/4] Refactor for maintainability. --- datafusion/spark/src/function/math/modulus.rs | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/datafusion/spark/src/function/math/modulus.rs b/datafusion/spark/src/function/math/modulus.rs index ecf35982422f..21d5676beb75 100644 --- a/datafusion/spark/src/function/math/modulus.rs +++ b/datafusion/spark/src/function/math/modulus.rs @@ -89,11 +89,7 @@ impl ScalarUDFImpl for SparkMod { } fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { - // The mod function output is nullable only in the case that the input is nullable - // (notably, a mod 0 returns an error, not null). Thus this check is sufficient. - let any_nullable = args.arg_fields.iter().any(|f| f.is_nullable()); - let data_type = args.arg_fields[0].data_type().clone(); - Ok(Arc::new(Field::new(self.name(), data_type, any_nullable))) + return_field_for_binary_op(self.name(), args) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { @@ -139,11 +135,7 @@ impl ScalarUDFImpl for SparkPmod { } fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { - // The mod function output is nullable only in the case that the input is nullable - // (notably, a mod 0 returns an error, not null). Thus this check is sufficient. - let any_nullable = args.arg_fields.iter().any(|f| f.is_nullable()); - let data_type = args.arg_fields[0].data_type().clone(); - Ok(Arc::new(Field::new(self.name(), data_type, any_nullable))) + return_field_for_binary_op(self.name(), args) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { @@ -151,6 +143,14 @@ impl ScalarUDFImpl for SparkPmod { } } +fn return_field_for_binary_op(name: &str, args: ReturnFieldArgs) -> Result { + // The mod function output is nullable only in the case that the input is nullable + // (notably, a mod 0 returns an error, not null). Thus this check is sufficient. + let any_nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + let data_type = args.arg_fields[0].data_type().clone(); + Ok(Arc::new(Field::new(name, data_type, any_nullable))) +} + #[cfg(test)] mod test { use std::sync::Arc; From 16959e51dadcfae3aa454eaac41f7eacbd5d9b4d Mon Sep 17 00:00:00 2001 From: Justin O'Dwyer Date: Wed, 24 Dec 2025 18:25:37 -0500 Subject: [PATCH 4/4] Make mod and pmod always nullable. --- datafusion/spark/src/function/math/modulus.rs | 48 +++++++++++-------- 1 file changed, 29 insertions(+), 19 deletions(-) diff --git a/datafusion/spark/src/function/math/modulus.rs b/datafusion/spark/src/function/math/modulus.rs index 21d5676beb75..f8a30938bd19 100644 --- a/datafusion/spark/src/function/math/modulus.rs +++ b/datafusion/spark/src/function/math/modulus.rs @@ -89,7 +89,8 @@ impl ScalarUDFImpl for SparkMod { } fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { - return_field_for_binary_op(self.name(), args) + let data_type = args.arg_fields[0].data_type().clone(); + Ok(Arc::new(Field::new(self.name(), data_type, true))) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { @@ -97,7 +98,7 @@ impl ScalarUDFImpl for SparkMod { } } -/// SparkMod implements the Spark-compatible modulo function +/// SparkPMod implements the Spark-compatible modulo function #[derive(Debug, PartialEq, Eq, Hash)] pub struct SparkPmod { signature: Signature, @@ -135,7 +136,8 @@ impl ScalarUDFImpl for SparkPmod { } fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { - return_field_for_binary_op(self.name(), args) + let data_type = args.arg_fields[0].data_type().clone(); + Ok(Arc::new(Field::new(self.name(), data_type, true))) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { @@ -143,14 +145,6 @@ impl ScalarUDFImpl for SparkPmod { } } -fn return_field_for_binary_op(name: &str, args: ReturnFieldArgs) -> Result { - // The mod function output is nullable only in the case that the input is nullable - // (notably, a mod 0 returns an error, not null). Thus this check is sufficient. - let any_nullable = args.arg_fields.iter().any(|f| f.is_nullable()); - let data_type = args.arg_fields[0].data_type().clone(); - Ok(Arc::new(Field::new(name, data_type, any_nullable))) -} - #[cfg(test)] mod test { use std::sync::Arc; @@ -620,7 +614,6 @@ mod test { fn test_mod_return_field_nullability() { let mod_func = SparkMod::new(); - // Non-nullable inputs -> non-nullable output. let args = ReturnFieldArgs { arg_fields: &[ Arc::new(Field::new("a", DataType::Int32, false)), @@ -629,13 +622,22 @@ mod test { scalar_arguments: &[], }; let field = mod_func.return_field_from_args(args).unwrap(); - assert!(!field.is_nullable()); + assert!(field.is_nullable()); + + let args = ReturnFieldArgs { + arg_fields: &[ + Arc::new(Field::new("a", DataType::Int32, false)), + Arc::new(Field::new("b", DataType::Int32, true)), + ], + scalar_arguments: &[], + }; + let field = mod_func.return_field_from_args(args).unwrap(); + assert!(field.is_nullable()); - // Nullable input -> nullable output. let args = ReturnFieldArgs { arg_fields: &[ Arc::new(Field::new("a", DataType::Int32, true)), - Arc::new(Field::new("b", DataType::Int32, false)), + Arc::new(Field::new("b", DataType::Int32, true)), ], scalar_arguments: &[], }; @@ -654,7 +656,6 @@ mod test { fn test_pmod_return_field_nullability() { let pmod_func = SparkPmod::new(); - // Non-nullable inputs -> non-nullable output. let args = ReturnFieldArgs { arg_fields: &[ Arc::new(Field::new("a", DataType::Int32, false)), @@ -663,13 +664,22 @@ mod test { scalar_arguments: &[], }; let field = pmod_func.return_field_from_args(args).unwrap(); - assert!(!field.is_nullable()); + assert!(field.is_nullable()); + + let args = ReturnFieldArgs { + arg_fields: &[ + Arc::new(Field::new("a", DataType::Int32, false)), + Arc::new(Field::new("b", DataType::Int32, true)), + ], + scalar_arguments: &[], + }; + let field = pmod_func.return_field_from_args(args).unwrap(); + assert!(field.is_nullable()); - // Nullable input -> nullable output. let args = ReturnFieldArgs { arg_fields: &[ Arc::new(Field::new("a", DataType::Int32, true)), - Arc::new(Field::new("b", DataType::Int32, false)), + Arc::new(Field::new("b", DataType::Int32, true)), ], scalar_arguments: &[], };