diff --git a/ballista/rust/core/Cargo.toml b/ballista/rust/core/Cargo.toml index 617003dae00c..f5a323361865 100644 --- a/ballista/rust/core/Cargo.toml +++ b/ballista/rust/core/Cargo.toml @@ -42,7 +42,10 @@ datafusion = { path = "../../../datafusion/core", version = "7.0.0" } datafusion-proto = { path = "../../../datafusion/proto", version = "7.0.0" } futures = "0.3" hashbrown = "0.12" + +libloading = "0.7.3" log = "0.4" +once_cell = "1.9.0" parking_lot = "0.12" parse_arg = "0.1.3" @@ -53,9 +56,11 @@ sqlparser = "0.15" tokio = "1.0" tonic = "0.6" uuid = { version = "0.8", features = ["v4"] } +walkdir = "2.3.2" [dev-dependencies] tempfile = "3" [build-dependencies] +rustc_version = "0.4.0" tonic-build = { version = "0.6" } diff --git a/ballista/rust/core/build.rs b/ballista/rust/core/build.rs index b5110f8f574b..c2acde108a2b 100644 --- a/ballista/rust/core/build.rs +++ b/ballista/rust/core/build.rs @@ -20,6 +20,8 @@ fn main() -> Result<(), String> { println!("cargo:rerun-if-env-changed=FORCE_REBUILD"); println!("cargo:rerun-if-changed=proto/ballista.proto"); + let version = rustc_version::version().unwrap(); + println!("cargo:rustc-env=RUSTC_VERSION={}", version); println!("cargo:rerun-if-changed=proto/datafusion.proto"); tonic_build::configure() .extern_path(".datafusion", "::datafusion_proto::protobuf") diff --git a/ballista/rust/core/src/config.rs b/ballista/rust/core/src/config.rs index fffe0ead3d75..1ff02c16d069 100644 --- a/ballista/rust/core/src/config.rs +++ b/ballista/rust/core/src/config.rs @@ -34,6 +34,8 @@ pub const BALLISTA_REPARTITION_AGGREGATIONS: &str = "ballista.repartition.aggreg pub const BALLISTA_REPARTITION_WINDOWS: &str = "ballista.repartition.windows"; pub const BALLISTA_PARQUET_PRUNING: &str = "ballista.parquet.pruning"; pub const BALLISTA_WITH_INFORMATION_SCHEMA: &str = "ballista.with_information_schema"; +/// give a plugin files dir, and then the dynamic library files in this dir will be load when scheduler state init. +pub const BALLISTA_PLUGIN_DIR: &str = "ballista.plugin_dir"; pub type ParseResult = result::Result; @@ -139,6 +141,9 @@ impl BallistaConfig { .parse::() .map_err(|e| format!("{:?}", e))?; } + DataType::Utf8 => { + val.to_string(); + } _ => { return Err(format!("not support data type: {}", data_type)); } @@ -171,6 +176,9 @@ impl BallistaConfig { ConfigEntry::new(BALLISTA_WITH_INFORMATION_SCHEMA.to_string(), "Sets whether enable information_schema".to_string(), DataType::Boolean,Some("false".to_string())), + ConfigEntry::new(BALLISTA_PLUGIN_DIR.to_string(), + "Sets the plugin dir".to_string(), + DataType::Utf8,Some("".to_string())), ]; entries .iter() @@ -186,6 +194,10 @@ impl BallistaConfig { self.get_usize_setting(BALLISTA_DEFAULT_SHUFFLE_PARTITIONS) } + pub fn default_plugin_dir(&self) -> String { + self.get_string_setting(BALLISTA_PLUGIN_DIR) + } + pub fn default_batch_size(&self) -> usize { self.get_usize_setting(BALLISTA_DEFAULT_BATCH_SIZE) } @@ -233,6 +245,17 @@ impl BallistaConfig { v.parse::().unwrap() } } + fn get_string_setting(&self, key: &str) -> String { + if let Some(v) = self.settings.get(key) { + // infallible because we validate all configs in the constructor + v.to_string() + } else { + let entries = Self::valid_entries(); + // infallible because we validate all configs in the constructor + let v = entries.get(key).unwrap().default_value.as_ref().unwrap(); + v.to_string() + } + } } // an enum used to configure the scheduler policy @@ -266,6 +289,7 @@ mod tests { let config = BallistaConfig::new()?; assert_eq!(2, config.default_shuffle_partitions()); assert!(!config.default_with_information_schema()); + assert_eq!("", config.default_plugin_dir().as_str()); Ok(()) } @@ -284,6 +308,7 @@ mod tests { fn custom_config_invalid() -> Result<()> { let config = BallistaConfig::builder() .set(BALLISTA_DEFAULT_SHUFFLE_PARTITIONS, "true") + .set(BALLISTA_PLUGIN_DIR, "test_dir") .build(); assert!(config.is_err()); assert_eq!("General(\"Failed to parse user-supplied value 'ballista.shuffle.partitions' for configuration setting 'true': ParseIntError { kind: InvalidDigit }\")", format!("{:?}", config.unwrap_err())); @@ -293,7 +318,6 @@ mod tests { .build(); assert!(config.is_err()); assert_eq!("General(\"Failed to parse user-supplied value 'ballista.with_information_schema' for configuration setting '123': ParseBoolError\")", format!("{:?}", config.unwrap_err())); - Ok(()) } } diff --git a/ballista/rust/core/src/lib.rs b/ballista/rust/core/src/lib.rs index c452a45b1087..34f4699e115a 100644 --- a/ballista/rust/core/src/lib.rs +++ b/ballista/rust/core/src/lib.rs @@ -27,6 +27,8 @@ pub mod config; pub mod error; pub mod event_loop; pub mod execution_plans; +/// some plugins +pub mod plugin; pub mod utils; #[macro_use] diff --git a/ballista/rust/core/src/plugin/mod.rs b/ballista/rust/core/src/plugin/mod.rs new file mode 100644 index 000000000000..a7012af479bc --- /dev/null +++ b/ballista/rust/core/src/plugin/mod.rs @@ -0,0 +1,127 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::error::Result; +use crate::plugin::udf::UDFPluginManager; +use libloading::Library; +use std::any::Any; +use std::env; +use std::sync::Arc; + +/// plugin manager +pub mod plugin_manager; +/// udf plugin +pub mod udf; + +/// CARGO_PKG_VERSION +pub static CORE_VERSION: &str = env!("CARGO_PKG_VERSION"); +/// RUSTC_VERSION +pub static RUSTC_VERSION: &str = env!("RUSTC_VERSION"); + +/// Top plugin trait +pub trait Plugin { + /// Returns the plugin as [`Any`](std::any::Any) so that it can be + /// downcast to a specific implementation. + fn as_any(&self) -> &dyn Any; +} + +/// The enum of Plugin +#[derive(PartialEq, std::cmp::Eq, std::hash::Hash, Copy, Clone)] +pub enum PluginEnum { + /// UDF/UDAF plugin + UDF, +} + +impl PluginEnum { + /// new a struct which impl the PluginRegistrar trait + pub fn init_plugin_manager(&self) -> Box { + match self { + PluginEnum::UDF => Box::new(UDFPluginManager::default()), + } + } +} + +/// Every plugin need a PluginDeclaration +#[derive(Copy, Clone)] +pub struct PluginDeclaration { + /// Rust doesn’t have a stable ABI, meaning different compiler versions can generate incompatible code. + /// For these reasons, the UDF plug-in must be compiled using the same version of rustc as datafusion. + pub rustc_version: &'static str, + + /// core version of the plugin. The plugin's core_version need same as plugin manager. + pub core_version: &'static str, + + /// One of PluginEnum + pub plugin_type: unsafe extern "C" fn() -> PluginEnum, +} + +/// Plugin Registrar , Every plugin need implement this trait +pub trait PluginRegistrar: Send + Sync + 'static { + /// # Safety + /// load plugin from library + unsafe fn load(&mut self, library: Arc) -> Result<()>; + + /// Returns the plugin as [`Any`](std::any::Any) so that it can be + /// downcast to a specific implementation. + fn as_any(&self) -> &dyn Any; +} + +/// Declare a plugin's PluginDeclaration. +/// +/// # Notes +/// +/// This works by automatically generating an `extern "C"` function named `get_plugin_type` with a +/// pre-defined signature and symbol name. And then generating a PluginDeclaration. +/// Therefore you will only be able to declare one plugin per library. +#[macro_export] +macro_rules! declare_plugin { + ($plugin_type:expr) => { + #[no_mangle] + pub extern "C" fn get_plugin_type() -> $crate::plugin::PluginEnum { + $plugin_type + } + + #[no_mangle] + pub static plugin_declaration: $crate::plugin::PluginDeclaration = + $crate::plugin::PluginDeclaration { + rustc_version: $crate::plugin::RUSTC_VERSION, + core_version: $crate::plugin::CORE_VERSION, + plugin_type: get_plugin_type, + }; + }; +} + +/// get the plugin dir +pub fn plugin_dir() -> String { + let current_exe_dir = match env::current_exe() { + Ok(exe_path) => exe_path.display().to_string(), + Err(_e) => "".to_string(), + }; + + // If current_exe_dir contain `deps` the root dir is the parent dir + // eg: /Users/xxx/workspace/rust/rust_plugin_sty/target/debug/deps/plugins_app-067452b3ff2af70e + // the plugin dir is /Users/xxx/workspace/rust/rust_plugin_sty/target/debug/deps + // else eg: /Users/xxx/workspace/rust/rust_plugin_sty/target/debug/plugins_app + // the plugin dir is /Users/xxx/workspace/rust/rust_plugin_sty/target/debug/ + if current_exe_dir.contains("/deps/") { + let i = current_exe_dir.find("/deps/").unwrap(); + String::from(¤t_exe_dir.as_str()[..i + 6]) + } else { + let i = current_exe_dir.rfind('/').unwrap(); + String::from(¤t_exe_dir.as_str()[..i]) + } +} diff --git a/ballista/rust/core/src/plugin/plugin_manager.rs b/ballista/rust/core/src/plugin/plugin_manager.rs new file mode 100644 index 000000000000..e238383b4620 --- /dev/null +++ b/ballista/rust/core/src/plugin/plugin_manager.rs @@ -0,0 +1,150 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +use crate::error::{BallistaError, Result}; +use libloading::Library; +use log::info; +use std::collections::HashMap; +use std::io; +use std::sync::{Arc, Mutex}; +use walkdir::{DirEntry, WalkDir}; + +use crate::plugin::{ + PluginDeclaration, PluginEnum, PluginRegistrar, CORE_VERSION, RUSTC_VERSION, +}; +use once_cell::sync::OnceCell; + +/// To prevent the library from being loaded multiple times, we use once_cell defines a Arc> +/// Because datafusion is a library, not a service, users may not need to load all plug-ins in the process. +/// So fn global_plugin_manager return Arc>. In this way, users can load the required library through the load method of GlobalPluginManager when needed +static INSTANCE: OnceCell>> = OnceCell::new(); + +/// global_plugin_manager +pub fn global_plugin_manager( + plugin_path: &str, +) -> &'static Arc> { + INSTANCE.get_or_init(move || unsafe { + let mut gpm = GlobalPluginManager::default(); + gpm.load(plugin_path).unwrap(); + Arc::new(Mutex::new(gpm)) + }) +} + +#[derive(Default)] +/// manager all plugin_type's plugin_manager +pub struct GlobalPluginManager { + /// every plugin need a plugin registrar + pub plugin_managers: HashMap>, + + /// loaded plugin files + pub plugin_files: Vec, +} + +impl GlobalPluginManager { + /// # Safety + /// find plugin file from `plugin_path` and load it . + unsafe fn load(&mut self, plugin_path: &str) -> Result<()> { + if "".eq(plugin_path) { + return Ok(()); + } + // find library file from udaf_plugin_path + info!("load plugin from dir:{}", plugin_path); + + let plugin_files = self.get_all_plugin_files(plugin_path)?; + + for plugin_file in plugin_files { + let library = Library::new(plugin_file.path()).map_err(|e| { + BallistaError::IoError(io::Error::new( + io::ErrorKind::Other, + format!("load library error: {}", e), + )) + })?; + + let library = Arc::new(library); + + let dec = library.get::<*mut PluginDeclaration>(b"plugin_declaration\0"); + if dec.is_err() { + info!( + "not found plugin_declaration in the library: {}", + plugin_file.path().to_str().unwrap() + ); + continue; + } + + let dec = dec.unwrap().read(); + + // ersion checks to prevent accidental ABI incompatibilities + if dec.rustc_version != RUSTC_VERSION || dec.core_version != CORE_VERSION { + return Err(BallistaError::IoError(io::Error::new( + io::ErrorKind::Other, + "Version mismatch", + ))); + } + + let plugin_enum = (dec.plugin_type)(); + let curr_plugin_manager = match self.plugin_managers.get_mut(&plugin_enum) { + None => { + let plugin_manager = plugin_enum.init_plugin_manager(); + self.plugin_managers.insert(plugin_enum, plugin_manager); + self.plugin_managers.get_mut(&plugin_enum).unwrap() + } + Some(manager) => manager, + }; + curr_plugin_manager.load(library)?; + self.plugin_files + .push(plugin_file.path().to_str().unwrap().to_string()); + } + + Ok(()) + } + + /// get all plugin file in the dir + fn get_all_plugin_files(&self, plugin_path: &str) -> io::Result> { + let mut plugin_files = Vec::new(); + for entry in WalkDir::new(plugin_path).into_iter().filter_map(|e| { + let item = e.unwrap(); + // every file only load once + if self + .plugin_files + .contains(&item.path().to_str().unwrap().to_string()) + { + return None; + } + + let file_type = item.file_type(); + if !file_type.is_file() { + return None; + } + + if let Some(path) = item.path().extension() { + if let Some(suffix) = path.to_str() { + if suffix == "dylib" || suffix == "so" || suffix == "dll" { + info!( + "load plugin from library file:{}", + item.path().to_str().unwrap() + ); + return Some(item); + } + } + } + + None + }) { + plugin_files.push(entry); + } + Ok(plugin_files) + } +} diff --git a/ballista/rust/core/src/plugin/udf.rs b/ballista/rust/core/src/plugin/udf.rs new file mode 100644 index 000000000000..ea82742fb868 --- /dev/null +++ b/ballista/rust/core/src/plugin/udf.rs @@ -0,0 +1,152 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +use crate::error::{BallistaError, Result}; +use crate::plugin::plugin_manager::global_plugin_manager; +use crate::plugin::{Plugin, PluginEnum, PluginRegistrar}; +use datafusion::physical_plan::udaf::AggregateUDF; +use datafusion::physical_plan::udf::ScalarUDF; +use libloading::{Library, Symbol}; +use std::any::Any; +use std::collections::HashMap; +use std::io; +use std::sync::Arc; + +/// UDF plugin trait +pub trait UDFPlugin: Plugin { + /// get a ScalarUDF by name + fn get_scalar_udf_by_name(&self, fun_name: &str) -> Result; + + /// return all udf names in the plugin + fn udf_names(&self) -> Result>; + + /// get a aggregate udf by name + fn get_aggregate_udf_by_name(&self, fun_name: &str) -> Result; + + /// return all udaf names + fn udaf_names(&self) -> Result>; +} + +/// UDFPluginManager +#[derive(Default, Clone)] +pub struct UDFPluginManager { + /// scalar udfs + pub scalar_udfs: HashMap>, + + /// aggregate udfs + pub aggregate_udfs: HashMap>, + + /// All libraries load from the plugin dir. + pub libraries: Vec>, +} + +impl PluginRegistrar for UDFPluginManager { + unsafe fn load(&mut self, library: Arc) -> Result<()> { + type PluginRegister = unsafe fn() -> Box; + let register_fun: Symbol = + library.get(b"registrar_udf_plugin\0").map_err(|e| { + BallistaError::IoError(io::Error::new( + io::ErrorKind::Other, + format!("not found fn registrar_udf_plugin in the library: {}", e), + )) + })?; + + let udf_plugin: Box = register_fun(); + udf_plugin + .udf_names() + .unwrap() + .iter() + .try_for_each(|udf_name| { + if self.scalar_udfs.contains_key(udf_name) { + Err(BallistaError::IoError(io::Error::new( + io::ErrorKind::Other, + format!("udf name: {} already exists", udf_name), + ))) + } else { + let scalar_udf = udf_plugin.get_scalar_udf_by_name(udf_name)?; + self.scalar_udfs + .insert(udf_name.to_string(), Arc::new(scalar_udf)); + Ok(()) + } + })?; + + udf_plugin + .udaf_names() + .unwrap() + .iter() + .try_for_each(|udaf_name| { + if self.aggregate_udfs.contains_key(udaf_name) { + Err(BallistaError::IoError(io::Error::new( + io::ErrorKind::Other, + format!("udaf name: {} already exists", udaf_name), + ))) + } else { + let aggregate_udf = + udf_plugin.get_aggregate_udf_by_name(udaf_name)?; + self.aggregate_udfs + .insert(udaf_name.to_string(), Arc::new(aggregate_udf)); + Ok(()) + } + })?; + self.libraries.push(library); + Ok(()) + } + + fn as_any(&self) -> &dyn Any { + self + } +} + +/// Declare a udf plugin registrar callback +/// +/// # Notes +/// +/// This works by automatically generating an `extern "C"` function named `registrar_udf_plugin` with a +/// pre-defined signature and symbol name. +/// Therefore you will only be able to declare one plugin per library. +#[macro_export] +macro_rules! declare_udf_plugin { + ($curr_plugin_type:ty, $constructor:path) => { + #[no_mangle] + pub extern "C" fn registrar_udf_plugin() -> Box { + // make sure the constructor is the correct type. + let constructor: fn() -> $curr_plugin_type = $constructor; + let object = constructor(); + Box::new(object) + } + + $crate::declare_plugin!($crate::plugin::PluginEnum::UDF); + }; +} + +/// get a Option of Immutable UDFPluginManager +pub fn get_udf_plugin_manager(path: &str) -> Option { + let udf_plugin_manager_opt = { + let gpm = global_plugin_manager(path).lock().unwrap(); + let plugin_registrar_opt = gpm.plugin_managers.get(&PluginEnum::UDF); + if let Some(plugin_registrar) = plugin_registrar_opt { + if let Some(udf_plugin_manager) = + plugin_registrar.as_any().downcast_ref::() + { + return Some(udf_plugin_manager.clone()); + } else { + return None; + } + } + None + }; + udf_plugin_manager_opt +} diff --git a/ballista/rust/executor/executor_config_spec.toml b/ballista/rust/executor/executor_config_spec.toml index 167ec20d2e4a..86e712bda284 100644 --- a/ballista/rust/executor/executor_config_spec.toml +++ b/ballista/rust/executor/executor_config_spec.toml @@ -96,3 +96,9 @@ name = "executor_cleanup_ttl" type = "u64" doc = "The number of seconds to retain job directories on each worker 604800 (7 days, 7 * 24 * 3600), In other words, after job done, how long the resulting data is retained" default = "604800" + +[[param]] +name = "plugin_dir" +type = "String" +doc = "plugin dir" +default = "std::string::String::from(\"\")" \ No newline at end of file diff --git a/ballista/rust/scheduler/scheduler_config_spec.toml b/ballista/rust/scheduler/scheduler_config_spec.toml index 000d74e7d32d..cca96edfa954 100644 --- a/ballista/rust/scheduler/scheduler_config_spec.toml +++ b/ballista/rust/scheduler/scheduler_config_spec.toml @@ -64,4 +64,10 @@ abbr = "s" name = "scheduler_policy" type = "ballista_core::config::TaskSchedulingPolicy" doc = "The scheduing policy for the scheduler, see TaskSchedulingPolicy::variants() for options. Default: PullStaged" -default = "ballista_core::config::TaskSchedulingPolicy::PullStaged" \ No newline at end of file +default = "ballista_core::config::TaskSchedulingPolicy::PullStaged" + +[[param]] +name = "plugin_dir" +type = "String" +doc = "plugin dir" +default = "std::string::String::from(\"\")" \ No newline at end of file