diff --git a/Cargo.lock b/Cargo.lock index 594e63cd68efc..866c4a34c0959 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3054,6 +3054,7 @@ version = "0.1.0" dependencies = [ "abi_stable", "datafusion", + "datafusion-ffi", "ffi_module_interface", "tokio", ] diff --git a/datafusion-examples/examples/ffi/ffi_example_table_provider/src/lib.rs b/datafusion-examples/examples/ffi/ffi_example_table_provider/src/lib.rs index a83f15926f054..eb217ef9e4832 100644 --- a/datafusion-examples/examples/ffi/ffi_example_table_provider/src/lib.rs +++ b/datafusion-examples/examples/ffi/ffi_example_table_provider/src/lib.rs @@ -21,6 +21,7 @@ use abi_stable::{export_root_module, prefix_type::PrefixTypeTrait}; use arrow::array::RecordBatch; use arrow::datatypes::{DataType, Field, Schema}; use datafusion::{common::record_batch, datasource::MemTable}; +use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec; use datafusion_ffi::table_provider::FFI_TableProvider; use ffi_module_interface::{TableProviderModule, TableProviderModuleRef}; @@ -34,7 +35,9 @@ fn create_record_batch(start_value: i32, num_values: usize) -> RecordBatch { /// Here we only wish to create a simple table provider as an example. /// We create an in-memory table and convert it to it's FFI counterpart. -extern "C" fn construct_simple_table_provider() -> FFI_TableProvider { +extern "C" fn construct_simple_table_provider( + codec: FFI_LogicalExtensionCodec, +) -> FFI_TableProvider { let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Int32, true), Field::new("b", DataType::Float64, true), @@ -50,7 +53,7 @@ extern "C" fn construct_simple_table_provider() -> FFI_TableProvider { let table_provider = MemTable::try_new(schema, vec![batches]).unwrap(); - FFI_TableProvider::new(Arc::new(table_provider), true, None) + FFI_TableProvider::new_with_ffi_codec(Arc::new(table_provider), true, None, codec) } #[export_root_module] diff --git a/datafusion-examples/examples/ffi/ffi_module_interface/src/lib.rs b/datafusion-examples/examples/ffi/ffi_module_interface/src/lib.rs index 1166eeb707a35..3b2b9e1871dae 100644 --- a/datafusion-examples/examples/ffi/ffi_module_interface/src/lib.rs +++ b/datafusion-examples/examples/ffi/ffi_module_interface/src/lib.rs @@ -21,6 +21,7 @@ use abi_stable::{ package_version_strings, sabi_types::VersionStrings, }; +use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec; use datafusion_ffi::table_provider::FFI_TableProvider; #[repr(C)] @@ -33,7 +34,8 @@ use datafusion_ffi::table_provider::FFI_TableProvider; /// how a user may wish to separate these concerns. pub struct TableProviderModule { /// Constructs the table provider - pub create_table: extern "C" fn() -> FFI_TableProvider, + pub create_table: + extern "C" fn(codec: FFI_LogicalExtensionCodec) -> FFI_TableProvider, } impl RootModule for TableProviderModuleRef { diff --git a/datafusion-examples/examples/ffi/ffi_module_loader/Cargo.toml b/datafusion-examples/examples/ffi/ffi_module_loader/Cargo.toml index c60cd7c0294c3..823c9afddee2a 100644 --- a/datafusion-examples/examples/ffi/ffi_module_loader/Cargo.toml +++ b/datafusion-examples/examples/ffi/ffi_module_loader/Cargo.toml @@ -24,5 +24,6 @@ publish = false [dependencies] abi_stable = "0.11.3" datafusion = { workspace = true } +datafusion-ffi = { workspace = true } ffi_module_interface = { path = "../ffi_module_interface" } tokio = { workspace = true, features = ["rt-multi-thread", "parking_lot"] } diff --git a/datafusion-examples/examples/ffi/ffi_module_loader/src/main.rs b/datafusion-examples/examples/ffi/ffi_module_loader/src/main.rs index d76af96058721..8ce5b156df3b1 100644 --- a/datafusion-examples/examples/ffi/ffi_module_loader/src/main.rs +++ b/datafusion-examples/examples/ffi/ffi_module_loader/src/main.rs @@ -24,6 +24,8 @@ use datafusion::{ use abi_stable::library::{RootModule, development_utils::compute_library_path}; use datafusion::datasource::TableProvider; +use datafusion::execution::TaskContextProvider; +use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec; use ffi_module_interface::TableProviderModuleRef; #[tokio::main] @@ -39,6 +41,11 @@ async fn main() -> Result<()> { TableProviderModuleRef::load_from_directory(&library_path) .map_err(|e| DataFusionError::External(Box::new(e)))?; + let ctx = Arc::new(SessionContext::new()); + let codec = FFI_LogicalExtensionCodec::new_default( + &(Arc::clone(&ctx) as Arc), + ); + // By calling the code below, the table provided will be created within // the module's code. let ffi_table_provider = @@ -46,14 +53,12 @@ async fn main() -> Result<()> { .create_table() .ok_or(DataFusionError::NotImplemented( "External table provider failed to implement create_table".to_string(), - ))?(); + ))?(codec); // In order to access the table provider within this executable, we need to // turn it into a `TableProvider`. let foreign_table_provider: Arc = (&ffi_table_provider).into(); - let ctx = SessionContext::new(); - // Display the data to show the full cycle works. ctx.register_table("external_table", foreign_table_provider)?; let df = ctx.table("external_table").await?; diff --git a/datafusion/ffi/src/catalog_provider.rs b/datafusion/ffi/src/catalog_provider.rs index fe3cce10652ad..25e398c4ddee7 100644 --- a/datafusion/ffi/src/catalog_provider.rs +++ b/datafusion/ffi/src/catalog_provider.rs @@ -15,22 +15,24 @@ // specific language governing permissions and limitations // under the License. -use std::{any::Any, ffi::c_void, sync::Arc}; - -use abi_stable::{ - StableAbi, - std_types::{ROption, RResult, RString, RVec}, +use std::any::Any; +use std::ffi::c_void; +use std::sync::Arc; + +use abi_stable::StableAbi; +use abi_stable::std_types::{ROption, RResult, RString, RVec}; +use datafusion_catalog::{CatalogProvider, SchemaProvider}; +use datafusion_common::error::Result; +use datafusion_proto::logical_plan::{ + DefaultLogicalExtensionCodec, LogicalExtensionCodec, }; -use datafusion::catalog::{CatalogProvider, SchemaProvider}; use tokio::runtime::Handle; -use crate::{ - df_result, rresult_return, - schema_provider::{FFI_SchemaProvider, ForeignSchemaProvider}, -}; - +use crate::execution::FFI_TaskContextProvider; +use crate::proto::logical_extension_codec::FFI_LogicalExtensionCodec; +use crate::schema_provider::{FFI_SchemaProvider, ForeignSchemaProvider}; use crate::util::FFIResult; -use datafusion::error::Result; +use crate::{df_result, rresult_return}; /// A stable struct for sharing [`CatalogProvider`] across FFI boundaries. #[repr(C)] @@ -58,6 +60,8 @@ pub struct FFI_CatalogProvider { ) -> FFIResult>, + pub logical_codec: FFI_LogicalExtensionCodec, + /// Used to create a clone on the provider of the execution plan. This should /// only need to be called by the receiver of the plan. pub clone: unsafe extern "C" fn(plan: &Self) -> Self, @@ -118,7 +122,13 @@ unsafe extern "C" fn schema_fn_wrapper( unsafe { let maybe_schema = provider.inner().schema(name.as_str()); maybe_schema - .map(|schema| FFI_SchemaProvider::new(schema, provider.runtime())) + .map(|schema| { + FFI_SchemaProvider::new_with_ffi_codec( + schema, + provider.runtime(), + provider.logical_codec.clone(), + ) + }) .into() } } @@ -130,12 +140,18 @@ unsafe extern "C" fn register_schema_fn_wrapper( ) -> FFIResult> { unsafe { let runtime = provider.runtime(); - let provider = provider.inner(); + let inner_provider = provider.inner(); let schema: Arc = schema.into(); let returned_schema = - rresult_return!(provider.register_schema(name.as_str(), schema)) - .map(|schema| FFI_SchemaProvider::new(schema, runtime)) + rresult_return!(inner_provider.register_schema(name.as_str(), schema)) + .map(|schema| { + FFI_SchemaProvider::new_with_ffi_codec( + schema, + runtime, + provider.logical_codec.clone(), + ) + }) .into(); RResult::ROk(returned_schema) @@ -149,14 +165,20 @@ unsafe extern "C" fn deregister_schema_fn_wrapper( ) -> FFIResult> { unsafe { let runtime = provider.runtime(); - let provider = provider.inner(); + let inner_provider = provider.inner(); let maybe_schema = - rresult_return!(provider.deregister_schema(name.as_str(), cascade)); + rresult_return!(inner_provider.deregister_schema(name.as_str(), cascade)); RResult::ROk( maybe_schema - .map(|schema| FFI_SchemaProvider::new(schema, runtime)) + .map(|schema| { + FFI_SchemaProvider::new_with_ffi_codec( + schema, + runtime, + provider.logical_codec.clone(), + ) + }) .into(), ) } @@ -189,6 +211,7 @@ unsafe extern "C" fn clone_fn_wrapper( schema: schema_fn_wrapper, register_schema: register_schema_fn_wrapper, deregister_schema: deregister_schema_fn_wrapper, + logical_codec: provider.logical_codec.clone(), clone: clone_fn_wrapper, release: release_fn_wrapper, version: super::version, @@ -209,6 +232,24 @@ impl FFI_CatalogProvider { pub fn new( provider: Arc, runtime: Option, + task_ctx_provider: impl Into, + logical_codec: Option>, + ) -> Self { + let task_ctx_provider = task_ctx_provider.into(); + let logical_codec = + logical_codec.unwrap_or_else(|| Arc::new(DefaultLogicalExtensionCodec {})); + let logical_codec = FFI_LogicalExtensionCodec::new( + logical_codec, + runtime.clone(), + task_ctx_provider.clone(), + ); + Self::new_with_ffi_codec(provider, runtime, logical_codec) + } + + pub fn new_with_ffi_codec( + provider: Arc, + runtime: Option, + logical_codec: FFI_LogicalExtensionCodec, ) -> Self { let private_data = Box::new(ProviderPrivateData { provider, runtime }); @@ -217,6 +258,7 @@ impl FFI_CatalogProvider { schema: schema_fn_wrapper, register_schema: register_schema_fn_wrapper, deregister_schema: deregister_schema_fn_wrapper, + logical_codec, clone: clone_fn_wrapper, release: release_fn_wrapper, version: super::version, @@ -286,7 +328,11 @@ impl CatalogProvider for ForeignCatalogProvider { unsafe { let schema = match schema.as_any().downcast_ref::() { Some(s) => &s.0, - None => &FFI_SchemaProvider::new(schema, None), + None => &FFI_SchemaProvider::new_with_ffi_codec( + schema, + None, + self.0.logical_codec.clone(), + ), }; let returned_schema: Option = df_result!((self.0.register_schema)(&self.0, name.into(), schema))? @@ -331,8 +377,10 @@ mod tests { .unwrap() .is_none() ); + let (_ctx, task_ctx_provider) = crate::util::tests::test_session_and_ctx(); - let mut ffi_catalog = FFI_CatalogProvider::new(catalog, None); + let mut ffi_catalog = + FFI_CatalogProvider::new(catalog, None, task_ctx_provider, None); ffi_catalog.library_marker_id = crate::mock_foreign_marker_id; let foreign_catalog: Arc = (&ffi_catalog).into(); @@ -375,7 +423,9 @@ mod tests { fn test_ffi_catalog_provider_local_bypass() { let catalog = Arc::new(MemoryCatalogProvider::new()); - let mut ffi_catalog = FFI_CatalogProvider::new(catalog, None); + let (_ctx, task_ctx_provider) = crate::util::tests::test_session_and_ctx(); + let mut ffi_catalog = + FFI_CatalogProvider::new(catalog, None, task_ctx_provider, None); // Verify local libraries can be downcast to their original let foreign_catalog: Arc = (&ffi_catalog).into(); diff --git a/datafusion/ffi/src/catalog_provider_list.rs b/datafusion/ffi/src/catalog_provider_list.rs index a119dd9485509..c6ff5f7da7834 100644 --- a/datafusion/ffi/src/catalog_provider_list.rs +++ b/datafusion/ffi/src/catalog_provider_list.rs @@ -15,16 +15,21 @@ // specific language governing permissions and limitations // under the License. -use std::{any::Any, ffi::c_void, sync::Arc}; - -use abi_stable::{ - StableAbi, - std_types::{ROption, RString, RVec}, +use std::any::Any; +use std::ffi::c_void; +use std::sync::Arc; + +use abi_stable::StableAbi; +use abi_stable::std_types::{ROption, RString, RVec}; +use datafusion_catalog::{CatalogProvider, CatalogProviderList}; +use datafusion_proto::logical_plan::{ + DefaultLogicalExtensionCodec, LogicalExtensionCodec, }; -use datafusion::catalog::{CatalogProvider, CatalogProviderList}; use tokio::runtime::Handle; use crate::catalog_provider::{FFI_CatalogProvider, ForeignCatalogProvider}; +use crate::execution::FFI_TaskContextProvider; +use crate::proto::logical_extension_codec::FFI_LogicalExtensionCodec; /// A stable struct for sharing [`CatalogProviderList`] across FFI boundaries. #[repr(C)] @@ -45,8 +50,10 @@ pub struct FFI_CatalogProviderList { pub catalog: unsafe extern "C" fn(&Self, name: RString) -> ROption, - /// Used to create a clone on the provider. This should only need to be called - /// by the receiver of the plan. + pub logical_codec: FFI_LogicalExtensionCodec, + + /// Used to create a clone on the provider of the execution plan. This should + /// only need to be called by the receiver of the plan. pub clone: unsafe extern "C" fn(plan: &Self) -> Self, /// Release the memory of the private data when it is no longer being used. @@ -105,12 +112,18 @@ unsafe extern "C" fn register_catalog_fn_wrapper( ) -> ROption { unsafe { let runtime = provider.runtime(); - let provider = provider.inner(); + let inner_provider = provider.inner(); let catalog: Arc = catalog.into(); - provider + inner_provider .register_catalog(name.into(), catalog) - .map(|catalog| FFI_CatalogProvider::new(catalog, runtime)) + .map(|catalog| { + FFI_CatalogProvider::new_with_ffi_codec( + catalog, + runtime, + provider.logical_codec.clone(), + ) + }) .into() } } @@ -121,10 +134,16 @@ unsafe extern "C" fn catalog_fn_wrapper( ) -> ROption { unsafe { let runtime = provider.runtime(); - let provider = provider.inner(); - provider + let inner_provider = provider.inner(); + inner_provider .catalog(name.as_str()) - .map(|catalog| FFI_CatalogProvider::new(catalog, runtime)) + .map(|catalog| { + FFI_CatalogProvider::new_with_ffi_codec( + catalog, + runtime, + provider.logical_codec.clone(), + ) + }) .into() } } @@ -155,6 +174,7 @@ unsafe extern "C" fn clone_fn_wrapper( register_catalog: register_catalog_fn_wrapper, catalog_names: catalog_names_fn_wrapper, catalog: catalog_fn_wrapper, + logical_codec: provider.logical_codec.clone(), clone: clone_fn_wrapper, release: release_fn_wrapper, version: super::version, @@ -175,6 +195,23 @@ impl FFI_CatalogProviderList { pub fn new( provider: Arc, runtime: Option, + task_ctx_provider: impl Into, + logical_codec: Option>, + ) -> Self { + let task_ctx_provider = task_ctx_provider.into(); + let logical_codec = + logical_codec.unwrap_or_else(|| Arc::new(DefaultLogicalExtensionCodec {})); + let logical_codec = FFI_LogicalExtensionCodec::new( + logical_codec, + runtime.clone(), + task_ctx_provider.clone(), + ); + Self::new_with_ffi_codec(provider, runtime, logical_codec) + } + pub fn new_with_ffi_codec( + provider: Arc, + runtime: Option, + logical_codec: FFI_LogicalExtensionCodec, ) -> Self { let private_data = Box::new(ProviderPrivateData { provider, runtime }); @@ -182,6 +219,7 @@ impl FFI_CatalogProviderList { register_catalog: register_catalog_fn_wrapper, catalog_names: catalog_names_fn_wrapper, catalog: catalog_fn_wrapper, + logical_codec, clone: clone_fn_wrapper, release: release_fn_wrapper, version: super::version, @@ -232,7 +270,11 @@ impl CatalogProviderList for ForeignCatalogProviderList { let catalog = match catalog.as_any().downcast_ref::() { Some(s) => &s.0, - None => &FFI_CatalogProvider::new(catalog, None), + None => &FFI_CatalogProvider::new_with_ffi_codec( + catalog, + None, + self.0.logical_codec.clone(), + ), }; (self.0.register_catalog)(&self.0, name.into(), catalog) @@ -279,7 +321,9 @@ mod tests { .is_none() ); - let mut ffi_catalog_list = FFI_CatalogProviderList::new(catalog_list, None); + let (_ctx, task_ctx_provider) = crate::util::tests::test_session_and_ctx(); + let mut ffi_catalog_list = + FFI_CatalogProviderList::new(catalog_list, None, task_ctx_provider, None); ffi_catalog_list.library_marker_id = crate::mock_foreign_marker_id; let foreign_catalog_list: Arc = @@ -318,7 +362,9 @@ mod tests { fn test_ffi_catalog_provider_list_local_bypass() { let catalog_list = Arc::new(MemoryCatalogProviderList::new()); - let mut ffi_catalog_list = FFI_CatalogProviderList::new(catalog_list, None); + let (_ctx, task_ctx_provider) = crate::util::tests::test_session_and_ctx(); + let mut ffi_catalog_list = + FFI_CatalogProviderList::new(catalog_list, None, task_ctx_provider, None); // Verify local libraries can be downcast to their original let foreign_catalog_list: Arc = diff --git a/datafusion/ffi/src/execution_plan.rs b/datafusion/ffi/src/execution_plan.rs index d8ef1c98272db..fdd3d5649e80e 100644 --- a/datafusion/ffi/src/execution_plan.rs +++ b/datafusion/ffi/src/execution_plan.rs @@ -29,6 +29,7 @@ use datafusion::{ use datafusion::{error::Result, physical_plan::DisplayFormatType}; use tokio::runtime::Handle; +use crate::execution::FFI_TaskContext; use crate::util::FFIResult; use crate::{ df_result, plan_properties::FFI_PlanProperties, @@ -54,6 +55,7 @@ pub struct FFI_ExecutionPlan { pub execute: unsafe extern "C" fn( plan: &Self, partition: usize, + context: FFI_TaskContext, ) -> FFIResult, /// Used to create a clone on the provider of the execution plan. This should @@ -78,7 +80,6 @@ unsafe impl Sync for FFI_ExecutionPlan {} pub struct ExecutionPlanPrivateData { pub plan: Arc, - pub context: Arc, pub runtime: Option, } @@ -101,19 +102,12 @@ unsafe extern "C" fn children_fn_wrapper( unsafe { let private_data = plan.private_data as *const ExecutionPlanPrivateData; let plan = &(*private_data).plan; - let ctx = &(*private_data).context; let runtime = &(*private_data).runtime; let children: Vec<_> = plan .children() .into_iter() - .map(|child| { - FFI_ExecutionPlan::new( - Arc::clone(child), - Arc::clone(ctx), - runtime.clone(), - ) - }) + .map(|child| FFI_ExecutionPlan::new(Arc::clone(child), runtime.clone())) .collect(); children.into() @@ -123,15 +117,16 @@ unsafe extern "C" fn children_fn_wrapper( unsafe extern "C" fn execute_fn_wrapper( plan: &FFI_ExecutionPlan, partition: usize, + context: FFI_TaskContext, ) -> FFIResult { unsafe { + let ctx = context.into(); let private_data = plan.private_data as *const ExecutionPlanPrivateData; let plan = &(*private_data).plan; - let ctx = &(*private_data).context; let runtime = (*private_data).runtime.clone(); rresult!( - plan.execute(partition, Arc::clone(ctx)) + plan.execute(partition, ctx) .map(|rbs| FFI_RecordBatchStream::new(rbs, runtime)) ) } @@ -156,11 +151,7 @@ unsafe extern "C" fn clone_fn_wrapper(plan: &FFI_ExecutionPlan) -> FFI_Execution let private_data = plan.private_data as *const ExecutionPlanPrivateData; let plan_data = &(*private_data); - FFI_ExecutionPlan::new( - Arc::clone(&plan_data.plan), - Arc::clone(&plan_data.context), - plan_data.runtime.clone(), - ) + FFI_ExecutionPlan::new(Arc::clone(&plan_data.plan), plan_data.runtime.clone()) } } @@ -172,16 +163,8 @@ impl Clone for FFI_ExecutionPlan { impl FFI_ExecutionPlan { /// This function is called on the provider's side. - pub fn new( - plan: Arc, - context: Arc, - runtime: Option, - ) -> Self { - let private_data = Box::new(ExecutionPlanPrivateData { - plan, - context, - runtime, - }); + pub fn new(plan: Arc, runtime: Option) -> Self { + let private_data = Box::new(ExecutionPlanPrivateData { plan, runtime }); Self { properties: properties_fn_wrapper, @@ -305,10 +288,11 @@ impl ExecutionPlan for ForeignExecutionPlan { fn execute( &self, partition: usize, - _context: Arc, + context: Arc, ) -> Result { + let context = FFI_TaskContext::from(context); unsafe { - df_result!((self.plan.execute)(&self.plan, partition)) + df_result!((self.plan.execute)(&self.plan, partition, context)) .map(|stream| Pin::new(Box::new(stream)) as SendableRecordBatchStream) } } @@ -318,12 +302,9 @@ impl ExecutionPlan for ForeignExecutionPlan { pub(crate) mod tests { use super::*; use arrow::datatypes::{DataType, Field, Schema}; - use datafusion::{ - physical_plan::{ - Partitioning, - execution_plan::{Boundedness, EmissionType}, - }, - prelude::SessionContext, + use datafusion::physical_plan::{ + Partitioning, + execution_plan::{Boundedness, EmissionType}, }; #[derive(Debug)] @@ -400,17 +381,16 @@ pub(crate) mod tests { fn test_round_trip_ffi_execution_plan() -> Result<()> { let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, false)])); - let ctx = SessionContext::new(); let original_plan = Arc::new(EmptyExec::new(schema)); let original_name = original_plan.name().to_string(); - let mut local_plan = FFI_ExecutionPlan::new(original_plan, ctx.task_ctx(), None); + let mut local_plan = FFI_ExecutionPlan::new(original_plan, None); local_plan.library_marker_id = crate::mock_foreign_marker_id; let foreign_plan: Arc = (&local_plan).try_into()?; - assert!(original_name == foreign_plan.name()); + assert_eq!(original_name, foreign_plan.name()); let display = datafusion::physical_plan::display::DisplayableExecutionPlan::new( foreign_plan.as_ref(), @@ -429,16 +409,15 @@ pub(crate) mod tests { fn test_ffi_execution_plan_children() -> Result<()> { let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, false)])); - let ctx = SessionContext::new(); // Version 1: Adding child to the foreign plan let child_plan = Arc::new(EmptyExec::new(Arc::clone(&schema))); - let mut child_local = FFI_ExecutionPlan::new(child_plan, ctx.task_ctx(), None); + let mut child_local = FFI_ExecutionPlan::new(child_plan, None); child_local.library_marker_id = crate::mock_foreign_marker_id; let child_foreign = >::try_from(&child_local)?; let parent_plan = Arc::new(EmptyExec::new(Arc::clone(&schema))); - let mut parent_local = FFI_ExecutionPlan::new(parent_plan, ctx.task_ctx(), None); + let mut parent_local = FFI_ExecutionPlan::new(parent_plan, None); parent_local.library_marker_id = crate::mock_foreign_marker_id; let parent_foreign = >::try_from(&parent_local)?; @@ -450,13 +429,13 @@ pub(crate) mod tests { // Version 2: Adding child to the local plan let child_plan = Arc::new(EmptyExec::new(Arc::clone(&schema))); - let mut child_local = FFI_ExecutionPlan::new(child_plan, ctx.task_ctx(), None); + let mut child_local = FFI_ExecutionPlan::new(child_plan, None); child_local.library_marker_id = crate::mock_foreign_marker_id; let child_foreign = >::try_from(&child_local)?; let parent_plan = Arc::new(EmptyExec::new(Arc::clone(&schema))); let parent_plan = parent_plan.with_new_children(vec![child_foreign])?; - let mut parent_local = FFI_ExecutionPlan::new(parent_plan, ctx.task_ctx(), None); + let mut parent_local = FFI_ExecutionPlan::new(parent_plan, None); parent_local.library_marker_id = crate::mock_foreign_marker_id; let parent_foreign = >::try_from(&parent_local)?; @@ -469,11 +448,10 @@ pub(crate) mod tests { fn test_ffi_execution_plan_local_bypass() { let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, false)])); - let ctx = SessionContext::new(); let plan = Arc::new(EmptyExec::new(schema)); - let mut ffi_plan = FFI_ExecutionPlan::new(plan, ctx.task_ctx(), None); + let mut ffi_plan = FFI_ExecutionPlan::new(plan, None); // Verify local libraries can be downcast to their original let foreign_plan: Arc = (&ffi_plan).try_into().unwrap(); diff --git a/datafusion/ffi/src/plan_properties.rs b/datafusion/ffi/src/plan_properties.rs index 430b34c984534..ac388a3653e12 100644 --- a/datafusion/ffi/src/plan_properties.rs +++ b/datafusion/ffi/src/plan_properties.rs @@ -15,43 +15,29 @@ // specific language governing permissions and limitations // under the License. -use std::{ffi::c_void, sync::Arc}; +use std::ffi::c_void; +use std::sync::Arc; -use abi_stable::{ - StableAbi, - std_types::{RResult::ROk, RVec}, -}; +use abi_stable::StableAbi; +use abi_stable::std_types::{ROption, RVec}; use arrow::datatypes::SchemaRef; -use datafusion::{ - error::{DataFusionError, Result}, - physical_expr::EquivalenceProperties, - physical_plan::{ - PlanProperties, - execution_plan::{Boundedness, EmissionType}, - }, - prelude::SessionContext, -}; -use datafusion_proto::{ - physical_plan::{ - DefaultPhysicalExtensionCodec, - from_proto::{parse_physical_sort_exprs, parse_protobuf_partitioning}, - to_proto::{serialize_partitioning, serialize_physical_sort_exprs}, - }, - protobuf::{Partitioning, PhysicalSortExprNodeCollection}, -}; -use prost::Message; - -use crate::util::FFIResult; -use crate::{arrow_wrappers::WrappedSchema, df_result, rresult_return}; +use datafusion_common::error::{DataFusionError, Result}; +use datafusion_physical_expr::EquivalenceProperties; +use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; +use datafusion_physical_plan::PlanProperties; +use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; + +use crate::arrow_wrappers::WrappedSchema; +use crate::physical_expr::partitioning::FFI_Partitioning; +use crate::physical_expr::sort::FFI_PhysicalSortExpr; /// A stable struct for sharing [`PlanProperties`] across FFI boundaries. #[repr(C)] #[derive(Debug, StableAbi)] #[allow(non_camel_case_types)] pub struct FFI_PlanProperties { - /// The output partitioning is a [`Partitioning`] protobuf message serialized - /// into bytes to pass across the FFI boundary. - pub output_partitioning: unsafe extern "C" fn(plan: &Self) -> FFIResult>, + /// The output partitioning of the plan. + pub output_partitioning: unsafe extern "C" fn(plan: &Self) -> FFI_Partitioning, /// Return the emission type of the plan. pub emission_type: unsafe extern "C" fn(plan: &Self) -> FFI_EmissionType, @@ -59,9 +45,9 @@ pub struct FFI_PlanProperties { /// Indicate boundedness of the plan and its memory requirements. pub boundedness: unsafe extern "C" fn(plan: &Self) -> FFI_Boundedness, - /// The output ordering is a [`PhysicalSortExprNodeCollection`] protobuf message - /// serialized into bytes to pass across the FFI boundary. - pub output_ordering: unsafe extern "C" fn(plan: &Self) -> FFIResult>, + /// The output ordering of the plan. + pub output_ordering: + unsafe extern "C" fn(plan: &Self) -> ROption>, /// Return the schema of the plan. pub schema: unsafe extern "C" fn(plan: &Self) -> WrappedSchema, @@ -92,15 +78,8 @@ impl FFI_PlanProperties { unsafe extern "C" fn output_partitioning_fn_wrapper( properties: &FFI_PlanProperties, -) -> FFIResult> { - let codec = DefaultPhysicalExtensionCodec {}; - let partitioning_data = rresult_return!(serialize_partitioning( - properties.inner().output_partitioning(), - &codec - )); - let output_partitioning = partitioning_data.encode_to_vec(); - - ROk(output_partitioning.into()) +) -> FFI_Partitioning { + properties.inner().output_partitioning().into() } unsafe extern "C" fn emission_type_fn_wrapper( @@ -117,22 +96,17 @@ unsafe extern "C" fn boundedness_fn_wrapper( unsafe extern "C" fn output_ordering_fn_wrapper( properties: &FFI_PlanProperties, -) -> FFIResult> { - let codec = DefaultPhysicalExtensionCodec {}; - let output_ordering = match properties.inner().output_ordering() { - Some(ordering) => { - let physical_sort_expr_nodes = rresult_return!( - serialize_physical_sort_exprs(ordering.to_owned(), &codec) - ); - let ordering_data = PhysicalSortExprNodeCollection { - physical_sort_expr_nodes, - }; - - ordering_data.encode_to_vec() - } - None => Vec::default(), - }; - ROk(output_ordering.into()) +) -> ROption> { + let ordering: Option> = + properties.inner().output_ordering().map(|lex_ordering| { + let vec_ordering: Vec = lex_ordering.clone().into(); + vec_ordering + .iter() + .map(FFI_PhysicalSortExpr::from) + .collect() + }); + + ordering.into() } unsafe extern "C" fn schema_fn_wrapper(properties: &FFI_PlanProperties) -> WrappedSchema { @@ -186,38 +160,18 @@ impl TryFrom for PlanProperties { let ffi_schema = unsafe { (ffi_props.schema)(&ffi_props) }; let schema = (&ffi_schema.0).try_into()?; - // TODO Extend FFI to get the registry and codex - let default_ctx = SessionContext::new(); - let task_context = default_ctx.task_ctx(); - let codex = DefaultPhysicalExtensionCodec {}; - - let ffi_orderings = unsafe { (ffi_props.output_ordering)(&ffi_props) }; - - let proto_output_ordering = - PhysicalSortExprNodeCollection::decode(df_result!(ffi_orderings)?.as_ref()) - .map_err(|e| DataFusionError::External(Box::new(e)))?; - let sort_exprs = parse_physical_sort_exprs( - &proto_output_ordering.physical_sort_expr_nodes, - &task_context, - &schema, - &codex, - )?; - - let partitioning_vec = - unsafe { df_result!((ffi_props.output_partitioning)(&ffi_props))? }; - let proto_output_partitioning = - Partitioning::decode(partitioning_vec.as_ref()) - .map_err(|e| DataFusionError::External(Box::new(e)))?; - let partitioning = parse_protobuf_partitioning( - Some(&proto_output_partitioning), - &task_context, - &schema, - &codex, - )? - .ok_or(DataFusionError::Plan( - "Unable to deserialize partitioning protobuf in FFI_PlanProperties" - .to_string(), - ))?; + let ffi_orderings: Option> = + unsafe { (ffi_props.output_ordering)(&ffi_props) }.into(); + let sort_exprs = ffi_orderings + .map(|ordering_vec| { + ordering_vec + .iter() + .map(PhysicalSortExpr::from) + .collect::>() + }) + .unwrap_or_default(); + + let partitioning = unsafe { (ffi_props.output_partitioning)(&ffi_props) }; let eq_properties = if sort_exprs.is_empty() { EquivalenceProperties::new(Arc::new(schema)) @@ -233,7 +187,7 @@ impl TryFrom for PlanProperties { Ok(PlanProperties::new( eq_properties, - partitioning, + (&partitioning).into(), emission_type, boundedness, )) @@ -307,7 +261,8 @@ impl From for EmissionType { #[cfg(test)] mod tests { - use datafusion::{physical_expr::PhysicalSortExpr, physical_plan::Partitioning}; + use datafusion::physical_expr::PhysicalSortExpr; + use datafusion::physical_plan::Partitioning; use super::*; @@ -331,6 +286,7 @@ mod tests { #[test] fn test_round_trip_ffi_plan_properties() -> Result<()> { let original_props = create_test_props()?; + let mut local_props_ptr = FFI_PlanProperties::from(&original_props); local_props_ptr.library_marker_id = crate::mock_foreign_marker_id; @@ -342,7 +298,7 @@ mod tests { } #[test] - fn test_ffi_execution_plan_local_bypass() -> Result<()> { + fn test_ffi_plan_properties_local_bypass() -> Result<()> { let props = create_test_props()?; let ffi_plan = FFI_PlanProperties::from(&props); diff --git a/datafusion/ffi/src/proto/logical_extension_codec.rs b/datafusion/ffi/src/proto/logical_extension_codec.rs index 3f79cfc73248c..3cb62f4a53889 100644 --- a/datafusion/ffi/src/proto/logical_extension_codec.rs +++ b/datafusion/ffi/src/proto/logical_extension_codec.rs @@ -25,12 +25,14 @@ use datafusion_catalog::TableProvider; use datafusion_common::error::Result; use datafusion_common::{TableReference, not_impl_err}; use datafusion_datasource::file_format::FileFormatFactory; -use datafusion_execution::TaskContext; +use datafusion_execution::{TaskContext, TaskContextProvider}; use datafusion_expr::{ AggregateUDF, AggregateUDFImpl, Extension, LogicalPlan, ScalarUDF, ScalarUDFImpl, WindowUDF, WindowUDFImpl, }; -use datafusion_proto::logical_plan::LogicalExtensionCodec; +use datafusion_proto::logical_plan::{ + DefaultLogicalExtensionCodec, LogicalExtensionCodec, +}; use tokio::runtime::Handle; use crate::arrow_wrappers::WrappedSchema; @@ -95,7 +97,7 @@ pub struct FFI_LogicalExtensionCodec { try_encode_udwf: unsafe extern "C" fn(&Self, node: FFI_WindowUDF) -> FFIResult>, - task_ctx_provider: FFI_TaskContextProvider, + pub task_ctx_provider: FFI_TaskContextProvider, /// Used to create a clone on the provider of the execution plan. This should /// only need to be called by the receiver of the plan. @@ -120,14 +122,14 @@ unsafe impl Send for FFI_LogicalExtensionCodec {} unsafe impl Sync for FFI_LogicalExtensionCodec {} struct LogicalExtensionCodecPrivateData { - provider: Arc, + codec: Arc, runtime: Option, } impl FFI_LogicalExtensionCodec { fn inner(&self) -> &Arc { let private_data = self.private_data as *const LogicalExtensionCodecPrivateData; - unsafe { &(*private_data).provider } + unsafe { &(*private_data).codec } } fn runtime(&self) -> &Option { @@ -148,18 +150,23 @@ unsafe extern "C" fn try_decode_table_provider_fn_wrapper( ) -> FFIResult { let ctx = rresult_return!(codec.task_ctx()); let runtime = codec.runtime().clone(); - let codec = codec.inner(); + let codec_inner = codec.inner(); let table_ref = TableReference::from(table_ref.as_str()); let schema: SchemaRef = schema.into(); - let table_provider = rresult_return!(codec.try_decode_table_provider( + let table_provider = rresult_return!(codec_inner.try_decode_table_provider( buf.as_ref(), &table_ref, schema, ctx.as_ref() )); - RResult::ROk(FFI_TableProvider::new(table_provider, true, runtime)) + RResult::ROk(FFI_TableProvider::new_with_ffi_codec( + table_provider, + true, + runtime, + codec.clone(), + )) } unsafe extern "C" fn try_encode_table_provider_fn_wrapper( @@ -286,13 +293,12 @@ impl Drop for FFI_LogicalExtensionCodec { impl FFI_LogicalExtensionCodec { /// Creates a new [`FFI_LogicalExtensionCodec`]. pub fn new( - provider: Arc, + codec: Arc, runtime: Option, task_ctx_provider: impl Into, ) -> Self { let task_ctx_provider = task_ctx_provider.into(); - let private_data = - Box::new(LogicalExtensionCodecPrivateData { provider, runtime }); + let private_data = Box::new(LogicalExtensionCodecPrivateData { codec, runtime }); Self { try_decode_table_provider: try_decode_table_provider_fn_wrapper, @@ -312,6 +318,13 @@ impl FFI_LogicalExtensionCodec { library_marker_id: crate::get_library_marker_id, } } + + pub fn new_default(task_ctx_provider: &Arc) -> Self { + let task_ctx_provider = FFI_TaskContextProvider::from(task_ctx_provider); + let codec = Arc::new(DefaultLogicalExtensionCodec {}); + + Self::new(codec, None, task_ctx_provider) + } } /// This wrapper struct exists on the receiver side of the FFI interface, so it has @@ -383,7 +396,8 @@ impl LogicalExtensionCodec for ForeignLogicalExtensionCodec { buf: &mut Vec, ) -> Result<()> { let table_ref = table_ref.to_string(); - let node = FFI_TableProvider::new(node, true, None); + let node = + FFI_TableProvider::new_with_ffi_codec(node, true, None, self.0.clone()); let bytes = df_result!(unsafe { (self.0.try_encode_table_provider)(&self.0, table_ref.as_str().into(), node) diff --git a/datafusion/ffi/src/proto/physical_extension_codec.rs b/datafusion/ffi/src/proto/physical_extension_codec.rs index 89a9a2cead007..5025c962d3989 100644 --- a/datafusion/ffi/src/proto/physical_extension_codec.rs +++ b/datafusion/ffi/src/proto/physical_extension_codec.rs @@ -145,7 +145,7 @@ unsafe extern "C" fn try_decode_fn_wrapper( let plan = rresult_return!(codec.try_decode(buf.as_ref(), &inputs, task_ctx.as_ref())); - RResult::ROk(FFI_ExecutionPlan::new(plan, task_ctx, None)) + RResult::ROk(FFI_ExecutionPlan::new(plan, None)) } unsafe extern "C" fn try_encode_fn_wrapper( @@ -329,12 +329,9 @@ impl PhysicalExtensionCodec for ForeignPhysicalExtensionCodec { inputs: &[Arc], _ctx: &TaskContext, ) -> Result> { - let task_ctx = (&self.0.task_ctx_provider).try_into()?; let inputs = inputs .iter() - .map(|plan| { - FFI_ExecutionPlan::new(Arc::clone(plan), Arc::clone(&task_ctx), None) - }) + .map(|plan| FFI_ExecutionPlan::new(Arc::clone(plan), None)) .collect(); let plan = @@ -345,8 +342,7 @@ impl PhysicalExtensionCodec for ForeignPhysicalExtensionCodec { } fn try_encode(&self, node: Arc, buf: &mut Vec) -> Result<()> { - let task_ctx = (&self.0.task_ctx_provider).try_into()?; - let plan = FFI_ExecutionPlan::new(node, task_ctx, None); + let plan = FFI_ExecutionPlan::new(node, None); let bytes = df_result!(unsafe { (self.0.try_encode)(&self.0, plan) })?; buf.extend(bytes); diff --git a/datafusion/ffi/src/schema_provider.rs b/datafusion/ffi/src/schema_provider.rs index cea41ae30fcc9..84cff6e8b3c2e 100644 --- a/datafusion/ffi/src/schema_provider.rs +++ b/datafusion/ffi/src/schema_provider.rs @@ -15,27 +15,26 @@ // specific language governing permissions and limitations // under the License. -use std::{any::Any, ffi::c_void, sync::Arc}; +use std::any::Any; +use std::ffi::c_void; +use std::sync::Arc; -use abi_stable::{ - StableAbi, - std_types::{ROption, RResult, RString, RVec}, -}; +use abi_stable::StableAbi; +use abi_stable::std_types::{ROption, RResult, RString, RVec}; use async_ffi::{FfiFuture, FutureExt}; use async_trait::async_trait; -use datafusion::{ - catalog::{SchemaProvider, TableProvider}, - error::DataFusionError, +use datafusion_catalog::{SchemaProvider, TableProvider}; +use datafusion_common::error::{DataFusionError, Result}; +use datafusion_proto::logical_plan::{ + DefaultLogicalExtensionCodec, LogicalExtensionCodec, }; use tokio::runtime::Handle; -use crate::{ - df_result, rresult_return, - table_provider::{FFI_TableProvider, ForeignTableProvider}, -}; - +use crate::execution::FFI_TaskContextProvider; +use crate::proto::logical_extension_codec::FFI_LogicalExtensionCodec; +use crate::table_provider::{FFI_TableProvider, ForeignTableProvider}; use crate::util::FFIResult; -use datafusion::error::Result; +use crate::{df_result, rresult_return}; /// A stable struct for sharing [`SchemaProvider`] across FFI boundaries. #[repr(C)] @@ -67,6 +66,8 @@ pub struct FFI_SchemaProvider { pub table_exist: unsafe extern "C" fn(provider: &Self, name: RString) -> bool, + pub logical_codec: FFI_LogicalExtensionCodec, + /// Used to create a clone on the provider of the execution plan. This should /// only need to be called by the receiver of the plan. pub clone: unsafe extern "C" fn(plan: &Self) -> Self, @@ -128,11 +129,14 @@ unsafe extern "C" fn table_fn_wrapper( ) -> FfiFuture>> { unsafe { let runtime = provider.runtime(); + let logical_codec = provider.logical_codec.clone(); let provider = Arc::clone(provider.inner()); async move { let table = rresult_return!(provider.table(name.as_str()).await) - .map(|t| FFI_TableProvider::new(t, true, runtime)) + .map(|t| { + FFI_TableProvider::new_with_ffi_codec(t, true, runtime, logical_codec) + }) .into(); RResult::ROk(table) @@ -148,12 +152,15 @@ unsafe extern "C" fn register_table_fn_wrapper( ) -> FFIResult> { unsafe { let runtime = provider.runtime(); + let logical_codec = provider.logical_codec.clone(); let provider = provider.inner(); let table = Arc::new(ForeignTableProvider(table)); let returned_table = rresult_return!(provider.register_table(name.into(), table)) - .map(|t| FFI_TableProvider::new(t, true, runtime)); + .map(|t| { + FFI_TableProvider::new_with_ffi_codec(t, true, runtime, logical_codec) + }); RResult::ROk(returned_table.into()) } @@ -165,10 +172,13 @@ unsafe extern "C" fn deregister_table_fn_wrapper( ) -> FFIResult> { unsafe { let runtime = provider.runtime(); + let logical_codec = provider.logical_codec.clone(); let provider = provider.inner(); let returned_table = rresult_return!(provider.deregister_table(name.as_str())) - .map(|t| FFI_TableProvider::new(t, true, runtime)); + .map(|t| { + FFI_TableProvider::new_with_ffi_codec(t, true, runtime, logical_codec) + }); RResult::ROk(returned_table.into()) } @@ -206,14 +216,15 @@ unsafe extern "C" fn clone_fn_wrapper( FFI_SchemaProvider { owner_name: provider.owner_name.clone(), table_names: table_names_fn_wrapper, - clone: clone_fn_wrapper, - release: release_fn_wrapper, - version: super::version, - private_data, table: table_fn_wrapper, register_table: register_table_fn_wrapper, deregister_table: deregister_table_fn_wrapper, table_exist: table_exist_fn_wrapper, + logical_codec: provider.logical_codec.clone(), + clone: clone_fn_wrapper, + release: release_fn_wrapper, + version: super::version, + private_data, library_marker_id: crate::get_library_marker_id, } } @@ -230,6 +241,24 @@ impl FFI_SchemaProvider { pub fn new( provider: Arc, runtime: Option, + task_ctx_provider: impl Into, + logical_codec: Option>, + ) -> Self { + let task_ctx_provider = task_ctx_provider.into(); + let logical_codec = + logical_codec.unwrap_or_else(|| Arc::new(DefaultLogicalExtensionCodec {})); + let logical_codec = FFI_LogicalExtensionCodec::new( + logical_codec, + runtime.clone(), + task_ctx_provider.clone(), + ); + Self::new_with_ffi_codec(provider, runtime, logical_codec) + } + + pub fn new_with_ffi_codec( + provider: Arc, + runtime: Option, + logical_codec: FFI_LogicalExtensionCodec, ) -> Self { let owner_name = provider.owner_name().map(|s| s.into()).into(); let private_data = Box::new(ProviderPrivateData { provider, runtime }); @@ -237,14 +266,15 @@ impl FFI_SchemaProvider { Self { owner_name, table_names: table_names_fn_wrapper, - clone: clone_fn_wrapper, - release: release_fn_wrapper, - version: super::version, - private_data: Box::into_raw(private_data) as *mut c_void, table: table_fn_wrapper, register_table: register_table_fn_wrapper, deregister_table: deregister_table_fn_wrapper, table_exist: table_exist_fn_wrapper, + logical_codec, + clone: clone_fn_wrapper, + release: release_fn_wrapper, + version: super::version, + private_data: Box::into_raw(private_data) as *mut c_void, library_marker_id: crate::get_library_marker_id, } } @@ -319,7 +349,12 @@ impl SchemaProvider for ForeignSchemaProvider { unsafe { let ffi_table = match table.as_any().downcast_ref::() { Some(t) => t.0.clone(), - None => FFI_TableProvider::new(table, true, None), + None => FFI_TableProvider::new_with_ffi_codec( + table, + true, + None, + self.0.logical_codec.clone(), + ), }; let returned_provider: Option = @@ -348,9 +383,11 @@ impl SchemaProvider for ForeignSchemaProvider { #[cfg(test)] mod tests { - use super::*; use arrow::datatypes::Schema; - use datafusion::{catalog::MemorySchemaProvider, datasource::empty::EmptyTable}; + use datafusion::catalog::MemorySchemaProvider; + use datafusion::datasource::empty::EmptyTable; + + use super::*; fn empty_table() -> Arc { Arc::new(EmptyTable::new(Arc::new(Schema::empty()))) @@ -367,7 +404,10 @@ mod tests { .is_none() ); - let mut ffi_schema_provider = FFI_SchemaProvider::new(schema_provider, None); + let (_ctx, task_ctx_provider) = crate::util::tests::test_session_and_ctx(); + + let mut ffi_schema_provider = + FFI_SchemaProvider::new(schema_provider, None, task_ctx_provider, None); ffi_schema_provider.library_marker_id = crate::mock_foreign_marker_id; let foreign_schema_provider: Arc = @@ -418,7 +458,9 @@ mod tests { fn test_ffi_schema_provider_local_bypass() { let schema_provider = Arc::new(MemorySchemaProvider::new()); - let mut ffi_schema = FFI_SchemaProvider::new(schema_provider, None); + let (_ctx, task_ctx_provider) = crate::util::tests::test_session_and_ctx(); + let mut ffi_schema = + FFI_SchemaProvider::new(schema_provider, None, task_ctx_provider, None); // Verify local libraries can be downcast to their original let foreign_schema: Arc = (&ffi_schema).into(); diff --git a/datafusion/ffi/src/session/mod.rs b/datafusion/ffi/src/session/mod.rs index 1694f46dc7447..4e1ef96d301e1 100644 --- a/datafusion/ffi/src/session/mod.rs +++ b/datafusion/ffi/src/session/mod.rs @@ -76,7 +76,7 @@ pub mod config; #[repr(C)] #[derive(Debug, StableAbi)] #[allow(non_camel_case_types)] -pub struct FFI_SessionRef { +pub(crate) struct FFI_SessionRef { session_id: unsafe extern "C" fn(&Self) -> RStr, config: unsafe extern "C" fn(&Self) -> FFI_SessionConfig, @@ -177,9 +177,7 @@ unsafe extern "C" fn create_physical_plan_fn_wrapper( let physical_plan = session.create_physical_plan(&logical_plan).await; - rresult!( - physical_plan.map(|plan| FFI_ExecutionPlan::new(plan, task_ctx, runtime)) - ) + rresult!(physical_plan.map(|plan| FFI_ExecutionPlan::new(plan, runtime))) } .into_ffi() } diff --git a/datafusion/ffi/src/table_provider.rs b/datafusion/ffi/src/table_provider.rs index 8929d02b0ed3e..a5f9940128c59 100644 --- a/datafusion/ffi/src/table_provider.rs +++ b/datafusion/ffi/src/table_provider.rs @@ -15,46 +15,39 @@ // specific language governing permissions and limitations // under the License. -use std::{any::Any, ffi::c_void, sync::Arc}; +use std::any::Any; +use std::ffi::c_void; +use std::sync::Arc; -use abi_stable::{ - StableAbi, - std_types::{ROption, RResult, RVec}, -}; +use abi_stable::StableAbi; +use abi_stable::std_types::{ROption, RResult, RVec}; use arrow::datatypes::SchemaRef; use async_ffi::{FfiFuture, FutureExt}; use async_trait::async_trait; -use datafusion::{ - catalog::{Session, TableProvider}, - datasource::TableType, - error::DataFusionError, - execution::{TaskContext, session_state::SessionStateBuilder}, - logical_expr::{TableProviderFilterPushDown, logical_plan::dml::InsertOp}, - physical_plan::ExecutionPlan, - prelude::{Expr, SessionContext}, -}; -use datafusion_proto::{ - logical_plan::{ - DefaultLogicalExtensionCodec, from_proto::parse_exprs, to_proto::serialize_exprs, - }, - protobuf::LogicalExprList, +use datafusion_catalog::{Session, TableProvider}; +use datafusion_common::error::{DataFusionError, Result}; +use datafusion_execution::TaskContext; +use datafusion_expr::dml::InsertOp; +use datafusion_expr::{Expr, TableProviderFilterPushDown, TableType}; +use datafusion_physical_plan::ExecutionPlan; +use datafusion_proto::logical_plan::from_proto::parse_exprs; +use datafusion_proto::logical_plan::to_proto::serialize_exprs; +use datafusion_proto::logical_plan::{ + DefaultLogicalExtensionCodec, LogicalExtensionCodec, }; +use datafusion_proto::protobuf::LogicalExprList; use prost::Message; use tokio::runtime::Handle; -use crate::{ - arrow_wrappers::WrappedSchema, - df_result, rresult_return, - table_source::{FFI_TableProviderFilterPushDown, FFI_TableType}, -}; - -use super::{ - execution_plan::FFI_ExecutionPlan, insert_op::FFI_InsertOp, - session::config::FFI_SessionConfig, -}; +use super::execution_plan::FFI_ExecutionPlan; +use super::insert_op::FFI_InsertOp; +use crate::arrow_wrappers::WrappedSchema; +use crate::execution::FFI_TaskContextProvider; +use crate::proto::logical_extension_codec::FFI_LogicalExtensionCodec; +use crate::session::{FFI_SessionRef, ForeignSession}; +use crate::table_source::{FFI_TableProviderFilterPushDown, FFI_TableType}; use crate::util::FFIResult; -use datafusion::error::Result; -use datafusion_execution::config::SessionConfig; +use crate::{df_result, rresult_return}; /// A stable struct for sharing [`TableProvider`] across FFI boundaries. /// @@ -100,60 +93,62 @@ use datafusion_execution::config::SessionConfig; #[allow(non_camel_case_types)] pub struct FFI_TableProvider { /// Return the table schema - pub schema: unsafe extern "C" fn(provider: &Self) -> WrappedSchema, + schema: unsafe extern "C" fn(provider: &Self) -> WrappedSchema, /// Perform a scan on the table. See [`TableProvider`] for detailed usage information. /// /// # Arguments /// /// * `provider` - the table provider - /// * `session_config` - session configuration + /// * `session` - session /// * `projections` - if specified, only a subset of the columns are returned /// * `filters_serialized` - filters to apply to the scan, which are a /// [`LogicalExprList`] protobuf message serialized into bytes to pass /// across the FFI boundary. /// * `limit` - if specified, limit the number of rows returned - pub scan: unsafe extern "C" fn( + scan: unsafe extern "C" fn( provider: &Self, - session_config: &FFI_SessionConfig, + session: FFI_SessionRef, projections: RVec, filters_serialized: RVec, limit: ROption, ) -> FfiFuture>, /// Return the type of table. See [`TableType`] for options. - pub table_type: unsafe extern "C" fn(provider: &Self) -> FFI_TableType, + table_type: unsafe extern "C" fn(provider: &Self) -> FFI_TableType, /// Based upon the input filters, identify which are supported. The filters /// are a [`LogicalExprList`] protobuf message serialized into bytes to pass /// across the FFI boundary. - pub supports_filters_pushdown: Option< + supports_filters_pushdown: Option< unsafe extern "C" fn( provider: &FFI_TableProvider, filters_serialized: RVec, ) -> FFIResult>, >, - pub insert_into: unsafe extern "C" fn( + insert_into: unsafe extern "C" fn( provider: &Self, - session_config: &FFI_SessionConfig, + session: FFI_SessionRef, input: &FFI_ExecutionPlan, insert_op: FFI_InsertOp, ) -> FfiFuture>, + pub logical_codec: FFI_LogicalExtensionCodec, + /// Used to create a clone on the provider of the execution plan. This should /// only need to be called by the receiver of the plan. - pub clone: unsafe extern "C" fn(plan: &Self) -> Self, + clone: unsafe extern "C" fn(plan: &Self) -> Self, /// Release the memory of the private data when it is no longer being used. - pub release: unsafe extern "C" fn(arg: &mut Self), + release: unsafe extern "C" fn(arg: &mut Self), /// Return the major DataFusion version number of this provider. pub version: unsafe extern "C" fn() -> u64, /// Internal data. This is only to be accessed by the provider of the plan. /// A [`ForeignTableProvider`] should never attempt to access this data. - pub private_data: *mut c_void, + private_data: *mut c_void, /// Utility to identify when FFI objects are accessed locally through /// the foreign interface. See [`crate::get_library_marker_id`] and @@ -194,17 +189,16 @@ unsafe extern "C" fn table_type_fn_wrapper( fn supports_filters_pushdown_internal( provider: &Arc, filters_serialized: &[u8], + task_ctx: &Arc, + codec: &dyn LogicalExtensionCodec, ) -> Result> { - let default_ctx = SessionContext::new(); - let codec = DefaultLogicalExtensionCodec {}; - let filters = match filters_serialized.is_empty() { true => vec![], false => { let proto_filters = LogicalExprList::decode(filters_serialized) .map_err(|e| DataFusionError::Plan(e.to_string()))?; - parse_exprs(proto_filters.expr.iter(), &default_ctx, &codec)? + parse_exprs(proto_filters.expr.iter(), task_ctx.as_ref(), codec)? } }; let filters_borrowed: Vec<&Expr> = filters.iter().collect(); @@ -222,43 +216,56 @@ unsafe extern "C" fn supports_filters_pushdown_fn_wrapper( provider: &FFI_TableProvider, filters_serialized: RVec, ) -> FFIResult> { - supports_filters_pushdown_internal(provider.inner(), &filters_serialized) - .map_err(|e| e.to_string().into()) - .into() + let logical_codec: Arc = (&provider.logical_codec).into(); + let task_ctx = rresult_return!(>::try_from( + &provider.logical_codec.task_ctx_provider + )); + supports_filters_pushdown_internal( + provider.inner(), + &filters_serialized, + &task_ctx, + logical_codec.as_ref(), + ) + .map_err(|e| e.to_string().into()) + .into() } unsafe extern "C" fn scan_fn_wrapper( provider: &FFI_TableProvider, - session_config: &FFI_SessionConfig, + session: FFI_SessionRef, projections: RVec, filters_serialized: RVec, limit: ROption, ) -> FfiFuture> { + let task_ctx: Result, DataFusionError> = + (&provider.logical_codec.task_ctx_provider).try_into(); let runtime = provider.runtime().clone(); + let logical_codec: Arc = (&provider.logical_codec).into(); let internal_provider = Arc::clone(provider.inner()); - let session_config = session_config.clone(); async move { - let config = rresult_return!(SessionConfig::try_from(&session_config)); - let session = SessionStateBuilder::new() - .with_default_features() - .with_config(config) - .build(); - let ctx = SessionContext::new_with_state(session); + let mut foreign_session = None; + let session = rresult_return!( + session + .as_local() + .map(Ok::<&(dyn Session + Send + Sync), DataFusionError>) + .unwrap_or_else(|| { + foreign_session = Some(ForeignSession::try_from(&session)?); + Ok(foreign_session.as_ref().unwrap()) + }) + ); + let task_ctx = rresult_return!(task_ctx); let filters = match filters_serialized.is_empty() { true => vec![], false => { - let default_ctx = SessionContext::new(); - let codec = DefaultLogicalExtensionCodec {}; - let proto_filters = rresult_return!(LogicalExprList::decode(filters_serialized.as_ref())); rresult_return!(parse_exprs( proto_filters.expr.iter(), - &default_ctx, - &codec + task_ctx.as_ref(), + logical_codec.as_ref(), )) } }; @@ -267,37 +274,36 @@ unsafe extern "C" fn scan_fn_wrapper( let plan = rresult_return!( internal_provider - .scan(&ctx.state(), Some(&projections), &filters, limit.into()) + .scan(session, Some(&projections), &filters, limit.into()) .await ); - RResult::ROk(FFI_ExecutionPlan::new( - plan, - ctx.task_ctx(), - runtime.clone(), - )) + RResult::ROk(FFI_ExecutionPlan::new(plan, runtime.clone())) } .into_ffi() } unsafe extern "C" fn insert_into_fn_wrapper( provider: &FFI_TableProvider, - session_config: &FFI_SessionConfig, + session: FFI_SessionRef, input: &FFI_ExecutionPlan, insert_op: FFI_InsertOp, ) -> FfiFuture> { let runtime = provider.runtime().clone(); let internal_provider = Arc::clone(provider.inner()); - let session_config = session_config.clone(); let input = input.clone(); async move { - let config = rresult_return!(SessionConfig::try_from(&session_config)); - let session = SessionStateBuilder::new() - .with_default_features() - .with_config(config) - .build(); - let ctx = SessionContext::new_with_state(session); + let mut foreign_session = None; + let session = rresult_return!( + session + .as_local() + .map(Ok::<&(dyn Session + Send + Sync), DataFusionError>) + .unwrap_or_else(|| { + foreign_session = Some(ForeignSession::try_from(&session)?); + Ok(foreign_session.as_ref().unwrap()) + }) + ); let input = rresult_return!(>::try_from(&input)); @@ -305,15 +311,11 @@ unsafe extern "C" fn insert_into_fn_wrapper( let plan = rresult_return!( internal_provider - .insert_into(&ctx.state(), input, insert_op) + .insert_into(session, input, insert_op) .await ); - RResult::ROk(FFI_ExecutionPlan::new( - plan, - ctx.task_ctx(), - runtime.clone(), - )) + RResult::ROk(FFI_ExecutionPlan::new(plan, runtime.clone())) } .into_ffi() } @@ -343,6 +345,7 @@ unsafe extern "C" fn clone_fn_wrapper(provider: &FFI_TableProvider) -> FFI_Table table_type: table_type_fn_wrapper, supports_filters_pushdown: provider.supports_filters_pushdown, insert_into: provider.insert_into, + logical_codec: provider.logical_codec.clone(), clone: clone_fn_wrapper, release: release_fn_wrapper, version: super::version, @@ -363,6 +366,30 @@ impl FFI_TableProvider { provider: Arc, can_support_pushdown_filters: bool, runtime: Option, + task_ctx_provider: impl Into, + logical_codec: Option>, + ) -> Self { + let task_ctx_provider = task_ctx_provider.into(); + let logical_codec = + logical_codec.unwrap_or_else(|| Arc::new(DefaultLogicalExtensionCodec {})); + let logical_codec = FFI_LogicalExtensionCodec::new( + logical_codec, + runtime.clone(), + task_ctx_provider.clone(), + ); + Self::new_with_ffi_codec( + provider, + can_support_pushdown_filters, + runtime, + logical_codec, + ) + } + + pub fn new_with_ffi_codec( + provider: Arc, + can_support_pushdown_filters: bool, + runtime: Option, + logical_codec: FFI_LogicalExtensionCodec, ) -> Self { let private_data = Box::new(ProviderPrivateData { provider, runtime }); @@ -375,6 +402,7 @@ impl FFI_TableProvider { false => None, }, insert_into: insert_into_fn_wrapper, + logical_codec, clone: clone_fn_wrapper, release: release_fn_wrapper, version: super::version, @@ -432,21 +460,21 @@ impl TableProvider for ForeignTableProvider { filters: &[Expr], limit: Option, ) -> Result> { - let session_config: FFI_SessionConfig = session.config().into(); + let session = FFI_SessionRef::new(session, None, self.0.logical_codec.clone()); let projections: Option> = projection.map(|p| p.iter().map(|v| v.to_owned()).collect()); - let codec = DefaultLogicalExtensionCodec {}; + let codec: Arc = (&self.0.logical_codec).into(); let filter_list = LogicalExprList { - expr: serialize_exprs(filters, &codec)?, + expr: serialize_exprs(filters, codec.as_ref())?, }; let filters_serialized = filter_list.encode_to_vec().into(); let plan = unsafe { let maybe_plan = (self.0.scan)( &self.0, - &session_config, + session, projections.unwrap_or_default(), filters_serialized, limit.into(), @@ -476,10 +504,13 @@ impl TableProvider for ForeignTableProvider { } }; - let codec = DefaultLogicalExtensionCodec {}; + let codec: Arc = (&self.0.logical_codec).into(); let expr_list = LogicalExprList { - expr: serialize_exprs(filters.iter().map(|f| f.to_owned()), &codec)?, + expr: serialize_exprs( + filters.iter().map(|f| f.to_owned()), + codec.as_ref(), + )?, }; let serialized_filters = expr_list.encode_to_vec(); @@ -495,16 +526,15 @@ impl TableProvider for ForeignTableProvider { input: Arc, insert_op: InsertOp, ) -> Result> { - let session_config: FFI_SessionConfig = session.config().into(); + let session = FFI_SessionRef::new(session, None, self.0.logical_codec.clone()); let rc = Handle::try_current().ok(); - let input = - FFI_ExecutionPlan::new(input, Arc::new(TaskContext::from(session)), rc); + let input = FFI_ExecutionPlan::new(input, rc); let insert_op: FFI_InsertOp = insert_op.into(); let plan = unsafe { let maybe_plan = - (self.0.insert_into)(&self.0, &session_config, &input, insert_op).await; + (self.0.insert_into)(&self.0, session, &input, insert_op).await; >::try_from(&df_result!(maybe_plan)?)? }; @@ -515,15 +545,17 @@ impl TableProvider for ForeignTableProvider { #[cfg(test)] mod tests { - use super::*; use arrow::datatypes::Schema; - use datafusion::prelude::{col, lit}; + use datafusion::prelude::{SessionContext, col, lit}; + use datafusion_execution::TaskContextProvider; + + use super::*; fn create_test_table_provider() -> Result> { use arrow::datatypes::Field; - use datafusion::arrow::{ - array::Float32Array, datatypes::DataType, record_batch::RecordBatch, - }; + use datafusion::arrow::array::Float32Array; + use datafusion::arrow::datatypes::DataType; + use datafusion::arrow::record_batch::RecordBatch; use datafusion::datasource::MemTable; let schema = @@ -548,9 +580,12 @@ mod tests { #[tokio::test] async fn test_round_trip_ffi_table_provider_scan() -> Result<()> { let provider = create_test_table_provider()?; - let ctx = SessionContext::new(); + let ctx = Arc::new(SessionContext::new()); + let task_ctx_provider = Arc::clone(&ctx) as Arc; + let task_ctx_provider = FFI_TaskContextProvider::from(&task_ctx_provider); - let mut ffi_provider = FFI_TableProvider::new(provider, true, None); + let mut ffi_provider = + FFI_TableProvider::new(provider, true, None, task_ctx_provider, None); ffi_provider.library_marker_id = crate::mock_foreign_marker_id; let foreign_table_provider: Arc = (&ffi_provider).into(); @@ -570,9 +605,12 @@ mod tests { #[tokio::test] async fn test_round_trip_ffi_table_provider_insert_into() -> Result<()> { let provider = create_test_table_provider()?; - let ctx = SessionContext::new(); + let ctx = Arc::new(SessionContext::new()); + let task_ctx_provider = Arc::clone(&ctx) as Arc; + let task_ctx_provider = FFI_TaskContextProvider::from(&task_ctx_provider); - let mut ffi_provider = FFI_TableProvider::new(provider, true, None); + let mut ffi_provider = + FFI_TableProvider::new(provider, true, None, task_ctx_provider, None); ffi_provider.library_marker_id = crate::mock_foreign_marker_id; let foreign_table_provider: Arc = (&ffi_provider).into(); @@ -600,9 +638,9 @@ mod tests { #[tokio::test] async fn test_aggregation() -> Result<()> { use arrow::datatypes::Field; - use datafusion::arrow::{ - array::Float32Array, datatypes::DataType, record_batch::RecordBatch, - }; + use datafusion::arrow::array::Float32Array; + use datafusion::arrow::datatypes::DataType; + use datafusion::arrow::record_batch::RecordBatch; use datafusion::common::assert_batches_eq; use datafusion::datasource::MemTable; @@ -615,11 +653,14 @@ mod tests { vec![Arc::new(Float32Array::from(vec![2.0, 4.0, 8.0]))], )?; - let ctx = SessionContext::new(); + let ctx = Arc::new(SessionContext::new()); + let task_ctx_provider = Arc::clone(&ctx) as Arc; + let task_ctx_provider = FFI_TaskContextProvider::from(&task_ctx_provider); let provider = Arc::new(MemTable::try_new(schema, vec![vec![batch1]])?); - let ffi_provider = FFI_TableProvider::new(provider, true, None); + let ffi_provider = + FFI_TableProvider::new(provider, true, None, task_ctx_provider, None); let foreign_table_provider: Arc = (&ffi_provider).into(); @@ -646,7 +687,10 @@ mod tests { fn test_ffi_table_provider_local_bypass() -> Result<()> { let table_provider = create_test_table_provider()?; - let mut ffi_table = FFI_TableProvider::new(table_provider, false, None); + let ctx = Arc::new(SessionContext::new()) as Arc; + let task_ctx_provider = FFI_TaskContextProvider::from(&ctx); + let mut ffi_table = + FFI_TableProvider::new(table_provider, false, None, task_ctx_provider, None); // Verify local libraries can be downcast to their original let foreign_table: Arc = (&ffi_table).into(); diff --git a/datafusion/ffi/src/tests/async_provider.rs b/datafusion/ffi/src/tests/async_provider.rs index 67421f58805a0..5e43a3e8ad7b4 100644 --- a/datafusion/ffi/src/tests/async_provider.rs +++ b/datafusion/ffi/src/tests/async_provider.rs @@ -27,6 +27,8 @@ use std::{any::Any, fmt::Debug, sync::Arc}; +use super::create_record_batch; +use crate::proto::logical_extension_codec::FFI_LogicalExtensionCodec; use crate::table_provider::FFI_TableProvider; use arrow::array::RecordBatch; use arrow::datatypes::Schema; @@ -46,8 +48,6 @@ use tokio::{ sync::{broadcast, mpsc}, }; -use super::create_record_batch; - #[derive(Debug)] pub struct AsyncTableProvider { batch_request: mpsc::Sender, @@ -277,7 +277,14 @@ impl Stream for AsyncTestRecordBatchStream { } } -pub(crate) fn create_async_table_provider() -> FFI_TableProvider { +pub(crate) fn create_async_table_provider( + codec: FFI_LogicalExtensionCodec, +) -> FFI_TableProvider { let (table_provider, tokio_rt) = start_async_provider(); - FFI_TableProvider::new(Arc::new(table_provider), true, Some(tokio_rt)) + FFI_TableProvider::new_with_ffi_codec( + Arc::new(table_provider), + true, + Some(tokio_rt), + codec, + ) } diff --git a/datafusion/ffi/src/tests/catalog.rs b/datafusion/ffi/src/tests/catalog.rs index de012659f2beb..122971eb0ca85 100644 --- a/datafusion/ffi/src/tests/catalog.rs +++ b/datafusion/ffi/src/tests/catalog.rs @@ -29,6 +29,7 @@ use std::{any::Any, fmt::Debug, sync::Arc}; use crate::catalog_provider::FFI_CatalogProvider; use crate::catalog_provider_list::FFI_CatalogProviderList; +use crate::proto::logical_extension_codec::FFI_LogicalExtensionCodec; use arrow::datatypes::Schema; use async_trait::async_trait; use datafusion::{ @@ -180,9 +181,11 @@ impl CatalogProvider for FixedCatalogProvider { } } -pub(crate) extern "C" fn create_catalog_provider() -> FFI_CatalogProvider { +pub(crate) extern "C" fn create_catalog_provider( + codec: FFI_LogicalExtensionCodec, +) -> FFI_CatalogProvider { let catalog_provider = Arc::new(FixedCatalogProvider::default()); - FFI_CatalogProvider::new(catalog_provider, None) + FFI_CatalogProvider::new_with_ffi_codec(catalog_provider, None, codec) } /// This catalog provider list is intended only for unit tests. It prepopulates with one @@ -234,7 +237,9 @@ impl CatalogProviderList for FixedCatalogProviderList { } } -pub(crate) extern "C" fn create_catalog_provider_list() -> FFI_CatalogProviderList { +pub(crate) extern "C" fn create_catalog_provider_list( + codec: FFI_LogicalExtensionCodec, +) -> FFI_CatalogProviderList { let catalog_provider_list = Arc::new(FixedCatalogProviderList::default()); - FFI_CatalogProviderList::new(catalog_provider_list, None) + FFI_CatalogProviderList::new_with_ffi_codec(catalog_provider_list, None, codec) } diff --git a/datafusion/ffi/src/tests/mod.rs b/datafusion/ffi/src/tests/mod.rs index dfb47561077f9..e87d465db0d10 100644 --- a/datafusion/ffi/src/tests/mod.rs +++ b/datafusion/ffi/src/tests/mod.rs @@ -34,6 +34,7 @@ use crate::udwf::FFI_WindowUDF; use super::{table_provider::FFI_TableProvider, udf::FFI_ScalarUDF}; use crate::catalog_provider_list::FFI_CatalogProviderList; +use crate::proto::logical_extension_codec::FFI_LogicalExtensionCodec; use crate::tests::catalog::create_catalog_provider_list; use arrow::array::RecordBatch; use async_provider::create_async_table_provider; @@ -61,20 +62,26 @@ pub mod utils; /// module. pub struct ForeignLibraryModule { /// Construct an opinionated catalog provider - pub create_catalog: extern "C" fn() -> FFI_CatalogProvider, + pub create_catalog: + extern "C" fn(codec: FFI_LogicalExtensionCodec) -> FFI_CatalogProvider, /// Construct an opinionated catalog provider list - pub create_catalog_list: extern "C" fn() -> FFI_CatalogProviderList, + pub create_catalog_list: + extern "C" fn(codec: FFI_LogicalExtensionCodec) -> FFI_CatalogProviderList, /// Constructs the table provider - pub create_table: extern "C" fn(synchronous: bool) -> FFI_TableProvider, + pub create_table: extern "C" fn( + synchronous: bool, + codec: FFI_LogicalExtensionCodec, + ) -> FFI_TableProvider, /// Create a scalar UDF pub create_scalar_udf: extern "C" fn() -> FFI_ScalarUDF, pub create_nullary_udf: extern "C" fn() -> FFI_ScalarUDF, - pub create_table_function: extern "C" fn() -> FFI_TableFunction, + pub create_table_function: + extern "C" fn(FFI_LogicalExtensionCodec) -> FFI_TableFunction, /// Create an aggregate UDAF using sum pub create_sum_udaf: extern "C" fn() -> FFI_AggregateUDF, @@ -115,10 +122,13 @@ pub fn create_record_batch(start_value: i32, num_values: usize) -> RecordBatch { /// Here we only wish to create a simple table provider as an example. /// We create an in-memory table and convert it to it's FFI counterpart. -extern "C" fn construct_table_provider(synchronous: bool) -> FFI_TableProvider { +extern "C" fn construct_table_provider( + synchronous: bool, + codec: FFI_LogicalExtensionCodec, +) -> FFI_TableProvider { match synchronous { - true => create_sync_table_provider(), - false => create_async_table_provider(), + true => create_sync_table_provider(codec), + false => create_async_table_provider(codec), } } diff --git a/datafusion/ffi/src/tests/sync_provider.rs b/datafusion/ffi/src/tests/sync_provider.rs index ff85e0b15b395..cbad8343dd000 100644 --- a/datafusion/ffi/src/tests/sync_provider.rs +++ b/datafusion/ffi/src/tests/sync_provider.rs @@ -17,12 +17,14 @@ use std::sync::Arc; +use super::{create_record_batch, create_test_schema}; +use crate::proto::logical_extension_codec::FFI_LogicalExtensionCodec; use crate::table_provider::FFI_TableProvider; use datafusion::datasource::MemTable; -use super::{create_record_batch, create_test_schema}; - -pub(crate) fn create_sync_table_provider() -> FFI_TableProvider { +pub(crate) fn create_sync_table_provider( + codec: FFI_LogicalExtensionCodec, +) -> FFI_TableProvider { let schema = create_test_schema(); // It is useful to create these as multiple record batches @@ -35,5 +37,5 @@ pub(crate) fn create_sync_table_provider() -> FFI_TableProvider { let table_provider = MemTable::try_new(schema, vec![batches]).unwrap(); - FFI_TableProvider::new(Arc::new(table_provider), true, None) + FFI_TableProvider::new_with_ffi_codec(Arc::new(table_provider), true, None, codec) } diff --git a/datafusion/ffi/src/tests/udf_udaf_udwf.rs b/datafusion/ffi/src/tests/udf_udaf_udwf.rs index 55e31ef3ab770..5eee18f00aa05 100644 --- a/datafusion/ffi/src/tests/udf_udaf_udwf.rs +++ b/datafusion/ffi/src/tests/udf_udaf_udwf.rs @@ -27,25 +27,61 @@ use datafusion::{ functions_window::rank::Rank, logical_expr::{AggregateUDF, ScalarUDF, WindowUDF}, }; +use std::any::Any; +use crate::proto::logical_extension_codec::FFI_LogicalExtensionCodec; +use arrow_schema::DataType; +use datafusion::logical_expr::{ColumnarValue, Signature}; +use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl}; use std::sync::Arc; pub(crate) extern "C" fn create_ffi_abs_func() -> FFI_ScalarUDF { - let udf: Arc = Arc::new(AbsFunc::new().into()); + let inner = WrappedAbs(Arc::new(AbsFunc::new().into())); + let udf: Arc = Arc::new(inner.into()); udf.into() } +#[derive(Debug, Hash, Eq, PartialEq)] +struct WrappedAbs(Arc); + +impl ScalarUDFImpl for WrappedAbs { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "ffi_abs" + } + + fn signature(&self) -> &Signature { + self.0.signature() + } + + fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result { + self.0.return_type(arg_types) + } + + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + self.0.invoke_with_args(args) + } +} + pub(crate) extern "C" fn create_ffi_random_func() -> FFI_ScalarUDF { let udf: Arc = Arc::new(RandomFunc::new().into()); udf.into() } -pub(crate) extern "C" fn create_ffi_table_func() -> FFI_TableFunction { +pub(crate) extern "C" fn create_ffi_table_func( + codec: FFI_LogicalExtensionCodec, +) -> FFI_TableFunction { let udtf: Arc = Arc::new(RangeFunc {}); - FFI_TableFunction::new(udtf, None) + FFI_TableFunction::new_with_ffi_codec(udtf, None, codec) } pub(crate) extern "C" fn create_ffi_sum_func() -> FFI_AggregateUDF { diff --git a/datafusion/ffi/src/udaf/accumulator_args.rs b/datafusion/ffi/src/udaf/accumulator_args.rs index a5cc37043621d..c5b7b7110e028 100644 --- a/datafusion/ffi/src/udaf/accumulator_args.rs +++ b/datafusion/ffi/src/udaf/accumulator_args.rs @@ -17,29 +17,21 @@ use std::sync::Arc; -use crate::arrow_wrappers::WrappedSchema; use abi_stable::{ StableAbi, std_types::{RString, RVec}, }; use arrow::{datatypes::Schema, ffi::FFI_ArrowSchema}; use arrow_schema::FieldRef; -use datafusion::{ - error::DataFusionError, - logical_expr::function::AccumulatorArgs, - physical_expr::{PhysicalExpr, PhysicalSortExpr}, - prelude::SessionContext, +use datafusion_common::error::DataFusionError; +use datafusion_expr::function::AccumulatorArgs; +use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; + +use crate::{ + arrow_wrappers::WrappedSchema, + physical_expr::{FFI_PhysicalExpr, sort::FFI_PhysicalSortExpr}, + util::{rvec_wrapped_to_vec_fieldref, vec_fieldref_to_rvec_wrapped}, }; -use datafusion_common::ffi_datafusion_err; -use datafusion_proto::{ - physical_plan::{ - DefaultPhysicalExtensionCodec, - from_proto::{parse_physical_exprs, parse_physical_sort_exprs}, - to_proto::{serialize_physical_exprs, serialize_physical_sort_exprs}, - }, - protobuf::PhysicalAggregateExprNode, -}; -use prost::Message; /// A stable struct for sharing [`AccumulatorArgs`] across FFI boundaries. /// For an explanation of each field, see the corresponding field @@ -50,42 +42,47 @@ use prost::Message; pub struct FFI_AccumulatorArgs { return_field: WrappedSchema, schema: WrappedSchema, + ignore_nulls: bool, + order_bys: RVec, is_reversed: bool, name: RString, - physical_expr_def: RVec, + is_distinct: bool, + exprs: RVec, + expr_fields: RVec, } impl TryFrom> for FFI_AccumulatorArgs { type Error = DataFusionError; - - fn try_from(args: AccumulatorArgs) -> Result { + fn try_from(args: AccumulatorArgs) -> Result { let return_field = WrappedSchema(FFI_ArrowSchema::try_from(args.return_field.as_ref())?); let schema = WrappedSchema(FFI_ArrowSchema::try_from(args.schema)?); - let codec = DefaultPhysicalExtensionCodec {}; - let ordering_req = - serialize_physical_sort_exprs(args.order_bys.to_owned(), &codec)?; + let order_bys: RVec<_> = args + .order_bys + .iter() + .map(FFI_PhysicalSortExpr::from) + .collect(); - let expr = serialize_physical_exprs(args.exprs, &codec)?; + let exprs = args + .exprs + .iter() + .map(Arc::clone) + .map(FFI_PhysicalExpr::from) + .collect(); - let physical_expr_def = PhysicalAggregateExprNode { - expr, - ordering_req, - distinct: args.is_distinct, - ignore_nulls: args.ignore_nulls, - fun_definition: None, - aggregate_function: None, - human_display: args.name.to_string(), - }; - let physical_expr_def = physical_expr_def.encode_to_vec().into(); + let expr_fields = vec_fieldref_to_rvec_wrapped(args.expr_fields)?; Ok(Self { return_field, schema, + ignore_nulls: args.ignore_nulls, + order_bys, is_reversed: args.is_reversed, name: args.name.into(), - physical_expr_def, + is_distinct: args.is_distinct, + exprs, + expr_fields, }) } } @@ -110,43 +107,28 @@ impl TryFrom for ForeignAccumulatorArgs { type Error = DataFusionError; fn try_from(value: FFI_AccumulatorArgs) -> Result { - let proto_def = PhysicalAggregateExprNode::decode( - value.physical_expr_def.as_ref(), - ) - .map_err(|e| { - ffi_datafusion_err!("Failed to decode PhysicalAggregateExprNode: {e}") - })?; - let return_field = Arc::new((&value.return_field.0).try_into()?); let schema = Schema::try_from(&value.schema.0)?; - let default_ctx = SessionContext::new(); - let task_ctx = default_ctx.task_ctx(); - let codex = DefaultPhysicalExtensionCodec {}; - - let order_bys = parse_physical_sort_exprs( - &proto_def.ordering_req, - &task_ctx, - &schema, - &codex, - )?; - - let exprs = parse_physical_exprs(&proto_def.expr, &task_ctx, &schema, &codex)?; + let order_bys = value.order_bys.iter().map(PhysicalSortExpr::from).collect(); - let expr_fields = exprs + let exprs = value + .exprs .iter() - .map(|e| e.return_field(&schema)) - .collect::, _>>()?; + .map(>::from) + .collect(); + + let expr_fields = rvec_wrapped_to_vec_fieldref(&value.expr_fields)?; Ok(Self { return_field, schema, expr_fields, - ignore_nulls: proto_def.ignore_nulls, + ignore_nulls: value.ignore_nulls, order_bys, is_reversed: value.is_reversed, name: value.name.to_string(), - is_distinct: proto_def.distinct, + is_distinct: value.is_distinct, exprs, }) } @@ -170,13 +152,14 @@ impl<'a> From<&'a ForeignAccumulatorArgs> for AccumulatorArgs<'a> { #[cfg(test)] mod tests { - use super::{FFI_AccumulatorArgs, ForeignAccumulatorArgs}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion::{ error::Result, logical_expr::function::AccumulatorArgs, physical_expr::PhysicalSortExpr, physical_plan::expressions::col, }; + use super::{FFI_AccumulatorArgs, ForeignAccumulatorArgs}; + #[test] fn test_round_trip_accumulator_args() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); @@ -193,7 +176,7 @@ mod tests { }; let orig_str = format!("{orig_args:?}"); - let ffi_args: FFI_AccumulatorArgs = orig_args.try_into()?; + let ffi_args = FFI_AccumulatorArgs::try_from(orig_args)?; let foreign_args: ForeignAccumulatorArgs = ffi_args.try_into()?; let round_trip_args: AccumulatorArgs = (&foreign_args).into(); diff --git a/datafusion/ffi/src/udtf.rs b/datafusion/ffi/src/udtf.rs index ccecee2217dac..bd35839d965ed 100644 --- a/datafusion/ffi/src/udtf.rs +++ b/datafusion/ffi/src/udtf.rs @@ -15,29 +15,29 @@ // specific language governing permissions and limitations // under the License. -use std::{ffi::c_void, sync::Arc}; - -use abi_stable::{ - StableAbi, - std_types::{RResult, RVec}, -}; - -use datafusion::error::Result; -use datafusion::{ - catalog::{TableFunctionImpl, TableProvider}, - prelude::{Expr, SessionContext}, -}; -use datafusion_proto::{ - logical_plan::{ - DefaultLogicalExtensionCodec, from_proto::parse_exprs, to_proto::serialize_exprs, - }, - protobuf::LogicalExprList, +use std::ffi::c_void; +use std::sync::Arc; + +use abi_stable::StableAbi; +use abi_stable::std_types::{RResult, RVec}; +use datafusion_catalog::{TableFunctionImpl, TableProvider}; +use datafusion_common::error::Result; +use datafusion_execution::TaskContext; +use datafusion_expr::Expr; +use datafusion_proto::logical_plan::from_proto::parse_exprs; +use datafusion_proto::logical_plan::to_proto::serialize_exprs; +use datafusion_proto::logical_plan::{ + DefaultLogicalExtensionCodec, LogicalExtensionCodec, }; +use datafusion_proto::protobuf::LogicalExprList; use prost::Message; use tokio::runtime::Handle; +use crate::execution::FFI_TaskContextProvider; +use crate::proto::logical_extension_codec::FFI_LogicalExtensionCodec; +use crate::table_provider::FFI_TableProvider; use crate::util::FFIResult; -use crate::{df_result, rresult_return, table_provider::FFI_TableProvider}; +use crate::{df_result, rresult_return}; /// A stable struct for sharing a [`TableFunctionImpl`] across FFI boundaries. #[repr(C)] @@ -49,6 +49,8 @@ pub struct FFI_TableFunction { pub call: unsafe extern "C" fn(udtf: &Self, args: RVec) -> FFIResult, + pub logical_codec: FFI_LogicalExtensionCodec, + /// Used to create a clone on the provider of the udtf. This should /// only need to be called by the receiver of the udtf. pub clone: unsafe extern "C" fn(udtf: &Self) -> Self, @@ -91,18 +93,27 @@ unsafe extern "C" fn call_fn_wrapper( args: RVec, ) -> FFIResult { let runtime = udtf.runtime(); - let udtf = udtf.inner(); + let udtf_inner = udtf.inner(); - let default_ctx = SessionContext::new(); - let codec = DefaultLogicalExtensionCodec {}; + let ctx: Arc = + rresult_return!((&udtf.logical_codec.task_ctx_provider).try_into()); + let codec: Arc = (&udtf.logical_codec).into(); let proto_filters = rresult_return!(LogicalExprList::decode(args.as_ref())); - let args = - rresult_return!(parse_exprs(proto_filters.expr.iter(), &default_ctx, &codec)); - - let table_provider = rresult_return!(udtf.call(&args)); - RResult::ROk(FFI_TableProvider::new(table_provider, false, runtime)) + let args = rresult_return!(parse_exprs( + proto_filters.expr.iter(), + ctx.as_ref(), + codec.as_ref() + )); + + let table_provider = rresult_return!(udtf_inner.call(&args)); + RResult::ROk(FFI_TableProvider::new_with_ffi_codec( + table_provider, + false, + runtime, + udtf.logical_codec.clone(), + )) } unsafe extern "C" fn release_fn_wrapper(udtf: &mut FFI_TableFunction) { @@ -117,9 +128,13 @@ unsafe extern "C" fn release_fn_wrapper(udtf: &mut FFI_TableFunction) { unsafe extern "C" fn clone_fn_wrapper(udtf: &FFI_TableFunction) -> FFI_TableFunction { let runtime = udtf.runtime(); - let udtf = udtf.inner(); + let udtf_inner = udtf.inner(); - FFI_TableFunction::new(Arc::clone(udtf), runtime) + FFI_TableFunction::new_with_ffi_codec( + Arc::clone(udtf_inner), + runtime, + udtf.logical_codec.clone(), + ) } impl Clone for FFI_TableFunction { @@ -129,28 +144,34 @@ impl Clone for FFI_TableFunction { } impl FFI_TableFunction { - pub fn new(udtf: Arc, runtime: Option) -> Self { - let private_data = Box::new(TableFunctionPrivateData { udtf, runtime }); - - Self { - call: call_fn_wrapper, - clone: clone_fn_wrapper, - release: release_fn_wrapper, - private_data: Box::into_raw(private_data) as *mut c_void, - library_marker_id: crate::get_library_marker_id, - } + pub fn new( + udtf: Arc, + runtime: Option, + task_ctx_provider: impl Into, + logical_codec: Option>, + ) -> Self { + let task_ctx_provider = task_ctx_provider.into(); + let logical_codec = + logical_codec.unwrap_or_else(|| Arc::new(DefaultLogicalExtensionCodec {})); + let logical_codec = FFI_LogicalExtensionCodec::new( + logical_codec, + runtime.clone(), + task_ctx_provider.clone(), + ); + + Self::new_with_ffi_codec(udtf, runtime, logical_codec) } -} -impl From> for FFI_TableFunction { - fn from(udtf: Arc) -> Self { - let private_data = Box::new(TableFunctionPrivateData { - udtf, - runtime: None, - }); + pub fn new_with_ffi_codec( + udtf: Arc, + runtime: Option, + logical_codec: FFI_LogicalExtensionCodec, + ) -> Self { + let private_data = Box::new(TableFunctionPrivateData { udtf, runtime }); Self { call: call_fn_wrapper, + logical_codec, clone: clone_fn_wrapper, release: release_fn_wrapper, private_data: Box::into_raw(private_data) as *mut c_void, @@ -189,9 +210,9 @@ impl From for Arc { impl TableFunctionImpl for ForeignTableFunction { fn call(&self, args: &[Expr]) -> Result> { - let codec = DefaultLogicalExtensionCodec {}; + let codec: Arc = (&self.0.logical_codec).into(); let expr_list = LogicalExprList { - expr: serialize_exprs(args, &codec)?, + expr: serialize_exprs(args, codec.as_ref())?, }; let filters_serialized = expr_list.encode_to_vec().into(); @@ -206,17 +227,18 @@ impl TableFunctionImpl for ForeignTableFunction { #[cfg(test)] mod tests { - use super::*; - use arrow::{ - array::{ - ArrayRef, Float64Array, RecordBatch, StringArray, UInt64Array, record_batch, - }, - datatypes::{DataType, Field, Schema}, + use arrow::array::{ + ArrayRef, Float64Array, RecordBatch, StringArray, UInt64Array, record_batch, }; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion::catalog::MemTable; + use datafusion::common::exec_err; use datafusion::logical_expr::ptr_eq::arc_ptr_eq; - use datafusion::{ - catalog::MemTable, common::exec_err, prelude::lit, scalar::ScalarValue, - }; + use datafusion::prelude::{SessionContext, lit}; + use datafusion::scalar::ScalarValue; + use datafusion_execution::TaskContextProvider; + + use super::*; #[derive(Debug)] struct TestUDTF {} @@ -299,16 +321,22 @@ mod tests { #[tokio::test] async fn test_round_trip_udtf() -> Result<()> { let original_udtf = Arc::new(TestUDTF {}) as Arc; - - let mut local_udtf: FFI_TableFunction = - FFI_TableFunction::new(Arc::clone(&original_udtf), None); + let ctx = Arc::new(SessionContext::default()); + let task_ctx_provider = Arc::clone(&ctx) as Arc; + let task_ctx_provider = FFI_TaskContextProvider::from(&task_ctx_provider); + + let mut local_udtf: FFI_TableFunction = FFI_TableFunction::new( + Arc::clone(&original_udtf), + None, + task_ctx_provider, + None, + ); local_udtf.library_marker_id = crate::mock_foreign_marker_id; let foreign_udf: Arc = local_udtf.into(); let table = foreign_udf.call(&[lit(6_u64), lit("one"), lit(2.0), lit(3_u64)])?; - let ctx = SessionContext::default(); let _ = ctx.register_table("test-table", table)?; let returned_batches = ctx.table("test-table").await?.collect().await?; @@ -335,7 +363,14 @@ mod tests { fn test_ffi_udtf_local_bypass() -> Result<()> { let original_udtf = Arc::new(TestUDTF {}) as Arc; - let mut ffi_udtf = FFI_TableFunction::from(Arc::clone(&original_udtf)); + let ctx = Arc::new(SessionContext::default()) as Arc; + let task_ctx_provider = FFI_TaskContextProvider::from(&ctx); + let mut ffi_udtf = FFI_TableFunction::new( + Arc::clone(&original_udtf), + None, + task_ctx_provider, + None, + ); // Verify local libraries can be downcast to their original let foreign_udtf: Arc = ffi_udtf.clone().into(); diff --git a/datafusion/ffi/src/udwf/partition_evaluator_args.rs b/datafusion/ffi/src/udwf/partition_evaluator_args.rs index e272b1416606e..20238fdb8b7f0 100644 --- a/datafusion/ffi/src/udwf/partition_evaluator_args.rs +++ b/datafusion/ffi/src/udwf/partition_evaluator_args.rs @@ -15,31 +15,19 @@ // specific language governing permissions and limitations // under the License. -use std::{collections::HashMap, sync::Arc}; +use std::sync::Arc; -use crate::arrow_wrappers::WrappedSchema; use abi_stable::{StableAbi, std_types::RVec}; -use arrow::{ - datatypes::{DataType, Field, Schema, SchemaRef}, - error::ArrowError, - ffi::FFI_ArrowSchema, -}; +use arrow::{error::ArrowError, ffi::FFI_ArrowSchema}; use arrow_schema::FieldRef; -use datafusion::{ - error::{DataFusionError, Result}, - logical_expr::function::PartitionEvaluatorArgs, - physical_plan::{PhysicalExpr, expressions::Column}, - prelude::SessionContext, -}; -use datafusion_common::ffi_datafusion_err; -use datafusion_proto::{ - physical_plan::{ - DefaultPhysicalExtensionCodec, from_proto::parse_physical_expr, - to_proto::serialize_physical_exprs, - }, - protobuf::PhysicalExprNode, +use datafusion_common::{DataFusionError, Result}; +use datafusion_expr::function::PartitionEvaluatorArgs; +use datafusion_physical_plan::PhysicalExpr; + +use crate::{ + arrow_wrappers::WrappedSchema, physical_expr::FFI_PhysicalExpr, + util::rvec_wrapped_to_vec_fieldref, }; -use prost::Message; /// A stable struct for sharing [`PartitionEvaluatorArgs`] across FFI boundaries. /// For an explanation of each field, see the corresponding function @@ -48,58 +36,21 @@ use prost::Message; #[derive(Debug, StableAbi)] #[allow(non_camel_case_types)] pub struct FFI_PartitionEvaluatorArgs { - input_exprs: RVec>, + input_exprs: RVec, input_fields: RVec, is_reversed: bool, ignore_nulls: bool, - schema: WrappedSchema, } impl TryFrom> for FFI_PartitionEvaluatorArgs { type Error = DataFusionError; + fn try_from(args: PartitionEvaluatorArgs) -> Result { - // This is a bit of a hack. Since PartitionEvaluatorArgs does not carry a schema - // around, and instead passes the data types directly we are unable to decode the - // protobuf PhysicalExpr correctly. In evaluating the code the only place these - // appear to be really used are the Column data types. So here we will find all - // of the required columns and create a schema that has empty fields except for - // the ones we require. Ideally we would enhance PartitionEvaluatorArgs to just - // pass along the schema, but that is a larger breaking change. - let required_columns: HashMap = args + let input_exprs = args .input_exprs() .iter() - .zip(args.input_fields()) - .filter_map(|(expr, field)| { - expr.as_any() - .downcast_ref::() - .map(|column| (column.index(), (column.name(), field.data_type()))) - }) - .collect(); - - let max_column = required_columns.keys().max(); - let fields: Vec<_> = max_column - .map(|max_column| { - (0..(max_column + 1)) - .map(|idx| match required_columns.get(&idx) { - Some((name, data_type)) => { - Field::new(*name, (*data_type).clone(), true) - } - None => Field::new( - format!("ffi_partition_evaluator_col_{idx}"), - DataType::Null, - true, - ), - }) - .collect() - }) - .unwrap_or_default(); - - let schema = Arc::new(Schema::new(fields)); - - let codec = DefaultPhysicalExtensionCodec {}; - let input_exprs = serialize_physical_exprs(args.input_exprs(), &codec)? - .into_iter() - .map(|expr_node| expr_node.encode_to_vec().into()) + .map(Arc::clone) + .map(FFI_PhysicalExpr::from) .collect(); let input_fields = args @@ -109,12 +60,9 @@ impl TryFrom> for FFI_PartitionEvaluatorArgs { .collect::, ArrowError>>()? .into(); - let schema: WrappedSchema = schema.into(); - Ok(Self { input_exprs, input_fields, - schema, is_reversed: args.is_reversed(), ignore_nulls: args.ignore_nulls(), }) @@ -136,27 +84,9 @@ impl TryFrom for ForeignPartitionEvaluatorArgs { type Error = DataFusionError; fn try_from(value: FFI_PartitionEvaluatorArgs) -> Result { - let default_ctx = SessionContext::new(); - let codec = DefaultPhysicalExtensionCodec {}; - - let schema: SchemaRef = value.schema.into(); + let input_exprs = value.input_exprs.iter().map(Into::into).collect(); - let input_exprs = value - .input_exprs - .into_iter() - .map(|input_expr_bytes| PhysicalExprNode::decode(input_expr_bytes.as_ref())) - .collect::, prost::DecodeError>>() - .map_err(|e| ffi_datafusion_err!("Failed to decode PhysicalExprNode: {e}"))? - .iter() - .map(|expr_node| { - parse_physical_expr(expr_node, &default_ctx.task_ctx(), &schema, &codec) - }) - .collect::>>()?; - - let input_fields = input_exprs - .iter() - .map(|expr| expr.return_field(&schema)) - .collect::>>()?; + let input_fields = rvec_wrapped_to_vec_fieldref(&value.input_fields)?; Ok(Self { input_exprs, diff --git a/datafusion/ffi/tests/ffi_catalog.rs b/datafusion/ffi/tests/ffi_catalog.rs index c45a62ee5093b..a464d3df5a0bf 100644 --- a/datafusion/ffi/tests/ffi_catalog.rs +++ b/datafusion/ffi/tests/ffi_catalog.rs @@ -15,12 +15,13 @@ // specific language governing permissions and limitations // under the License. +mod utils; + /// Add an additional module here for convenience to scope this to only /// when the feature integration-tests is built #[cfg(feature = "integration-tests")] mod tests { use datafusion::catalog::{CatalogProvider, CatalogProviderList}; - use datafusion::prelude::SessionContext; use datafusion_common::DataFusionError; use datafusion_ffi::tests::utils::get_module; use std::sync::Arc; @@ -28,6 +29,7 @@ mod tests { #[tokio::test] async fn test_catalog() -> datafusion_common::Result<()> { let module = get_module()?; + let (ctx, codec) = super::utils::ctx_and_codec(); let ffi_catalog = module @@ -35,10 +37,9 @@ mod tests { .ok_or(DataFusionError::NotImplemented( "External catalog provider failed to implement create_catalog" .to_string(), - ))?(); + ))?(codec); let foreign_catalog: Arc = (&ffi_catalog).into(); - let ctx = SessionContext::default(); let _ = ctx.register_catalog("fruit", foreign_catalog); let df = ctx.table("fruit.apple.purchases").await?; @@ -55,6 +56,7 @@ mod tests { #[tokio::test] async fn test_catalog_list() -> datafusion_common::Result<()> { let module = get_module()?; + let (ctx, codec) = super::utils::ctx_and_codec(); let ffi_catalog_list = module @@ -62,11 +64,10 @@ mod tests { .ok_or(DataFusionError::NotImplemented( "External catalog provider failed to implement create_catalog_list" .to_string(), - ))?(); + ))?(codec); let foreign_catalog_list: Arc = (&ffi_catalog_list).into(); - let ctx = SessionContext::default(); ctx.register_catalog_list(foreign_catalog_list); let df = ctx.table("blue.apple.purchases").await?; diff --git a/datafusion/ffi/tests/ffi_integration.rs b/datafusion/ffi/tests/ffi_integration.rs index 216d6576a8216..78650abf2f264 100644 --- a/datafusion/ffi/tests/ffi_integration.rs +++ b/datafusion/ffi/tests/ffi_integration.rs @@ -15,13 +15,14 @@ // specific language governing permissions and limitations // under the License. +mod utils; + /// Add an additional module here for convenience to scope this to only /// when the feature integration-tests is built #[cfg(feature = "integration-tests")] mod tests { use datafusion::catalog::TableProvider; use datafusion::error::{DataFusionError, Result}; - use datafusion::prelude::SessionContext; use datafusion_ffi::tests::create_record_batch; use datafusion_ffi::tests::utils::get_module; use std::sync::Arc; @@ -31,6 +32,7 @@ mod tests { /// testing it via a different executable. async fn test_table_provider(synchronous: bool) -> Result<()> { let table_provider_module = get_module()?; + let (ctx, codec) = super::utils::ctx_and_codec(); // By calling the code below, the table provided will be created within // the module's code. @@ -38,14 +40,12 @@ mod tests { DataFusionError::NotImplemented( "External table provider failed to implement create_table".to_string(), ), - )?(synchronous); + )?(synchronous, codec); // In order to access the table provider within this executable, we need to // turn it into a `TableProvider`. let foreign_table_provider: Arc = (&ffi_table_provider).into(); - let ctx = SessionContext::new(); - // Display the data to show the full cycle works. ctx.register_table("external_table", foreign_table_provider)?; let df = ctx.table("external_table").await?; diff --git a/datafusion/ffi/tests/ffi_udaf.rs b/datafusion/ffi/tests/ffi_udaf.rs index 6aac88d494d4a..f219979a85062 100644 --- a/datafusion/ffi/tests/ffi_udaf.rs +++ b/datafusion/ffi/tests/ffi_udaf.rs @@ -19,13 +19,15 @@ /// when the feature integration-tests is built #[cfg(feature = "integration-tests")] mod tests { + use std::sync::Arc; + use arrow::array::Float64Array; use datafusion::common::record_batch; use datafusion::error::{DataFusionError, Result}; use datafusion::logical_expr::{AggregateUDF, AggregateUDFImpl}; use datafusion::prelude::{SessionContext, col}; - use std::sync::Arc; - + use datafusion_catalog::MemTable; + use datafusion_expr::{ScalarUDF, ScalarUDFImpl}; use datafusion_ffi::tests::utils::get_module; #[tokio::test] @@ -126,4 +128,69 @@ mod tests { Ok(()) } + + /// This test FFI UDFs can be used as inputs to FFI Aggregate UDFs. + /// Really this is a test of the Protobuf serialization and deserialization + /// using the TaskContextProvider. It can be demonstrated through the + /// UDAF accumulator arguments as an end-to-end test. + #[tokio::test] + async fn udf_as_input_to_udf() -> Result<()> { + let module = get_module()?; + + let ffi_abs_func = + module + .create_scalar_udf() + .ok_or(DataFusionError::NotImplemented( + "External table provider failed to implement create_scalar_udf" + .to_string(), + ))?(); + let foreign_abs_func: Arc = (&ffi_abs_func).into(); + let abs_udf = ScalarUDF::new_from_shared_impl(foreign_abs_func); + + let ctx = SessionContext::new(); + ctx.deregister_udf("abs"); + + let ffi_sum_func = + module + .create_sum_udaf() + .ok_or(DataFusionError::NotImplemented( + "External table provider failed to implement create_udaf".to_string(), + ))?(); + let foreign_sum_func: Arc = (&ffi_sum_func).into(); + + let udaf = AggregateUDF::new_from_shared_impl(foreign_sum_func); + + // We need at least 2 record batches so we get an accumulator + let ctx = SessionContext::default(); + let rb1 = record_batch!( + ("a", Int32, vec![1, 2, 2, 4, 4, 4, 4]), + ("b", Float64, vec![-1.0, 2.0, -2.0, 4.0, -4.0, -4.0, -4.0]) + ) + .unwrap(); + let rb2 = rb1.clone(); + + let table = Arc::new(MemTable::try_new(rb1.schema(), vec![vec![rb1, rb2]])?); + + let df = ctx.read_table(table)?; + + let df = df + .aggregate( + vec![col("a")], + vec![udaf.call(vec![abs_udf.call(vec![col("b")])]).alias("sum_b")], + )? + .sort_by(vec![col("a")])?; + + df.clone().show().await?; + + let result = df.collect().await?; + + let expected = record_batch!( + ("a", Int32, vec![1, 2, 4]), + ("sum_b", Float64, vec![2.0, 8.0, 32.0]) + )?; + + assert_eq!(result[0], expected); + + Ok(()) + } } diff --git a/datafusion/ffi/tests/ffi_udtf.rs b/datafusion/ffi/tests/ffi_udtf.rs index a0135a903f1ea..097ed75f35361 100644 --- a/datafusion/ffi/tests/ffi_udtf.rs +++ b/datafusion/ffi/tests/ffi_udtf.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +mod utils; + /// Add an additional module here for convenience to scope this to only /// when the feature integration-tests is built #[cfg(feature = "integration-tests")] @@ -25,7 +27,6 @@ mod tests { use arrow::array::{ArrayRef, create_array}; use datafusion::catalog::TableFunctionImpl; use datafusion::error::{DataFusionError, Result}; - use datafusion::prelude::SessionContext; use datafusion_ffi::tests::utils::get_module; @@ -35,16 +36,16 @@ mod tests { #[tokio::test] async fn test_user_defined_table_function() -> Result<()> { let module = get_module()?; + let (ctx, codec) = super::utils::ctx_and_codec(); let ffi_table_func = module .create_table_function() .ok_or(DataFusionError::NotImplemented( "External table function provider failed to implement create_table_function" .to_string(), - ))?(); + ))?(codec); let foreign_table_func: Arc = ffi_table_func.into(); - let ctx = SessionContext::default(); ctx.register_udtf("my_range", foreign_table_func); let result = ctx diff --git a/datafusion/ffi/tests/utils/mod.rs b/datafusion/ffi/tests/utils/mod.rs new file mode 100644 index 0000000000000..25d1464811f3f --- /dev/null +++ b/datafusion/ffi/tests/utils/mod.rs @@ -0,0 +1,36 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::prelude::SessionContext; +use datafusion_execution::TaskContextProvider; +use datafusion_ffi::execution::FFI_TaskContextProvider; +use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec; +use datafusion_proto::logical_plan::DefaultLogicalExtensionCodec; +use std::sync::Arc; + +pub fn ctx_and_codec() -> (Arc, FFI_LogicalExtensionCodec) { + let ctx = Arc::new(SessionContext::default()); + let task_ctx_provider = Arc::clone(&ctx) as Arc; + let task_ctx_provider = FFI_TaskContextProvider::from(&task_ctx_provider); + let codec = FFI_LogicalExtensionCodec::new( + Arc::new(DefaultLogicalExtensionCodec {}), + None, + task_ctx_provider, + ); + + (ctx, codec) +} diff --git a/docs/source/library-user-guide/upgrading.md b/docs/source/library-user-guide/upgrading.md index c12768c95aaf3..39d52bd5903a4 100644 --- a/docs/source/library-user-guide/upgrading.md +++ b/docs/source/library-user-guide/upgrading.md @@ -396,6 +396,47 @@ Instead this should now be: let foreign_udf = ScalarUDF::new_from_shared_impl(foreign_udf); ``` +When creating any of the following structs, we now require the user to +provide a `TaskContextProvider` and optionally a `LogicalExtensionCodec`: + +- `FFI_CatalogListProvider` +- `FFI_CatalogProvider` +- `FFI_SchemaProvider` +- `FFI_TableProvider` +- `FFI_TableFunction` + +Each of these structs has a `new()` and a `new_with_ffi_codec()` method for +instantiation. For example, when you previously would write + +```rust,ignore + let table = Arc::new(MyTableProvider::new()); + let ffi_table = FFI_TableProvider::new(table, None); +``` + +Now you will need to provide a `TaskContextProvider`. The most common +implementation of this trait is `SessionContext`. + +```rust,ignore + let ctx = Arc::new(SessionContext::default()); + let table = Arc::new(MyTableProvider::new()); + let ffi_table = FFI_TableProvider::new(table, None, ctx, None); +``` + +The alternative function to create these structures may be more convenient +if you are doing many of these operations. A `FFI_LogicalExtensionCodec` will +store the `TaskContextProvider` as well. + +```rust,ignore + let codec = Arc::new(DefaultLogicalExtensionCodec {}); + let ctx = Arc::new(SessionContext::default()); + let ffi_codec = FFI_LogicalExtensionCodec::new(codec, None, ctx); + let table = Arc::new(MyTableProvider::new()); + let ffi_table = FFI_TableProvider::new_with_ffi_codec(table, None, ffi_codec); +``` + +Additional information about the usage of the `TaskContextProvider` can be +found in the crate README. + Additionally, the FFI structure for Scalar UDF's no longer contains a `return_type` call. This code was not used since the `ForeignScalarUDF` struct implements the `return_field_from_args` instead.