diff --git a/Cargo.lock b/Cargo.lock index 5a124fedd65..6910acc7165 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8814,6 +8814,7 @@ dependencies = [ "arrow-array", "arrow-buffer", "async-compat", + "async-fs", "bindgen", "bitvec", "cbindgen", @@ -8831,8 +8832,6 @@ dependencies = [ "reqwest", "rstest", "tempfile", - "tokio", - "tokio-stream", "url", "vortex", "vortex-utils", @@ -9066,6 +9065,7 @@ version = "0.1.0" dependencies = [ "anyhow", "async-compat", + "async-fs", "async-stream", "async-trait", "bytes", diff --git a/Cargo.toml b/Cargo.toml index 3411739247c..10e55b861de 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -74,6 +74,7 @@ arrow-schema = "56" arrow-select = "56" arrow-string = "56" async-compat = "0.2.5" +async-fs = "2.2.0" async-stream = "0.3.6" async-trait = "0.1.89" bindgen = "0.72.0" diff --git a/vortex-duckdb/Cargo.toml b/vortex-duckdb/Cargo.toml index 112cff32ec6..25cfc5965bc 100644 --- a/vortex-duckdb/Cargo.toml +++ b/vortex-duckdb/Cargo.toml @@ -24,6 +24,7 @@ anyhow = { workspace = true } arrow-array = { workspace = true } arrow-buffer = { workspace = true } async-compat = { workspace = true } +async-fs = { workspace = true } bitvec = { workspace = true } futures = { workspace = true } glob = { workspace = true } @@ -34,8 +35,6 @@ object_store = { workspace = true, features = ["aws"] } parking_lot = { workspace = true } paste = { workspace = true } tempfile = { workspace = true } -tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] } -tokio-stream = { workspace = true } url = { workspace = true } vortex = { workspace = true, features = ["files", "tokio", "object_store"] } vortex-utils = { workspace = true, features = ["dashmap"] } diff --git a/vortex-duckdb/src/copy.rs b/vortex-duckdb/src/copy.rs index 0083688225a..fee0f1e35af 100644 --- a/vortex-duckdb/src/copy.rs +++ b/vortex-duckdb/src/copy.rs @@ -3,21 +3,21 @@ use std::fmt::Debug; use std::iter; -use std::sync::LazyLock; - -use tokio::fs::File; -use tokio::runtime::{self, Runtime}; -use tokio::sync::mpsc; -use tokio::sync::mpsc::Sender; -use tokio::task::JoinHandle; -use tokio_stream::wrappers::ReceiverStream; + +use futures::channel::mpsc; +use futures::channel::mpsc::Sender; +use futures::{SinkExt, TryStreamExt}; +use parking_lot::Mutex; use vortex::ArrayRef; use vortex::dtype::Nullability::{NonNullable, Nullable}; use vortex::dtype::{DType, StructFields}; use vortex::error::{VortexExpect, VortexResult, vortex_err}; use vortex::file::{VortexWriteOptions, WriteSummary}; +use vortex::io::runtime::current::CurrentThreadWorkerPool; +use vortex::io::runtime::{BlockingRuntime, Task}; use vortex::stream::ArrayStreamAdapter; +use crate::RUNTIME; use crate::convert::{data_chunk_to_arrow, from_duckdb_table}; use crate::duckdb::{CopyFunction, DataChunk, LogicalType}; @@ -29,21 +29,19 @@ pub struct BindData { fields: StructFields, } -static COPY_RUNTIME: LazyLock = LazyLock::new(|| { - runtime::Builder::new_current_thread() - .enable_all() - .build() - .vortex_expect("Cannot start runtime") -}); - /// Write to a file has two phases, writing data chunks and then closing the file. /// We use a spawned tokio task to actually compress arrays are write it to disk. /// Each chunk is pushed into the sink and read from the task. /// Once finished we can close all sinks and then the task can be awaited and the file /// flushed to disk. pub struct GlobalState { - write_task: Option>>, + write_task: Mutex>>>, sink: Option>>, + // Pool of background workers helping to drive the write task. + // Note that this is optional and without it, we would only drive the task when DuckDB calls + // into us, and we call `RUNTIME.block_on`. + #[allow(dead_code)] + worker_pool: CurrentThreadWorkerPool, } impl CopyFunction for VortexCopyFunction { @@ -76,10 +74,10 @@ impl CopyFunction for VortexCopyFunction { chunk: &mut DataChunk, ) -> VortexResult<()> { let chunk = data_chunk_to_arrow(bind_data.fields.names(), chunk); - COPY_RUNTIME.block_on(async { + RUNTIME.block_on(|_h| async { init_global .sink - .as_ref() + .as_mut() .vortex_expect("sink closed early") .send(chunk) .await @@ -93,15 +91,16 @@ impl CopyFunction for VortexCopyFunction { _bind_data: &Self::BindData, init_global: &mut Self::GlobalState, ) -> VortexResult<()> { - COPY_RUNTIME.block_on(async { + RUNTIME.block_on(|_h| async { if let Some(sink) = init_global.sink.take() { drop(sink) } - init_global + let task = init_global .write_task + .lock() .take() - .vortex_expect("no file to close") - .await??; + .vortex_expect("no file to close"); + task.await?; Ok(()) }) } @@ -112,18 +111,22 @@ impl CopyFunction for VortexCopyFunction { ) -> VortexResult { // The channel size 32 was chosen arbitrarily. let (sink, rx) = mpsc::channel(32); - let array_stream = - ArrayStreamAdapter::new(bind_data.dtype.clone(), ReceiverStream::new(rx)); + let array_stream = ArrayStreamAdapter::new(bind_data.dtype.clone(), rx.into_stream()); - let writer = COPY_RUNTIME.spawn(async move { - let mut file = File::create(file_path).await?; + let writer = RUNTIME.handle().spawn_nested(|h| async move { + let mut file = async_fs::File::create(file_path).await?; VortexWriteOptions::default() + .with_handle(h) .write(&mut file, array_stream) .await }); + let worker_pool = RUNTIME.new_pool(); + worker_pool.set_workers_to_available_parallelism(); + Ok(GlobalState { - write_task: Some(writer), + worker_pool, + write_task: Mutex::new(Some(writer)), sink: Some(sink), }) } diff --git a/vortex-duckdb/src/e2e_test/vortex_scan_test.rs b/vortex-duckdb/src/e2e_test/vortex_scan_test.rs index 8c4ba2eb2a0..52b7c229c0a 100644 --- a/vortex-duckdb/src/e2e_test/vortex_scan_test.rs +++ b/vortex-duckdb/src/e2e_test/vortex_scan_test.rs @@ -20,12 +20,13 @@ use vortex::arrays::{ }; use vortex::buffer::buffer; use vortex::file::VortexWriteOptions; +use vortex::io::runtime::{BlockingRuntime, Handle}; use vortex::scalar::Scalar; use vortex::validity::Validity; -use crate::cpp; use crate::cpp::{duckdb_string_t, duckdb_timestamp}; use crate::duckdb::{Connection, Database}; +use crate::{RUNTIME, cpp}; fn database_connection() -> Connection { let db = Database::open_in_memory().unwrap(); @@ -38,18 +39,24 @@ fn create_temp_file() -> NamedTempFile { NamedTempFile::new().unwrap() } -async fn write_single_column_vortex_file(field_name: &str, array: impl IntoArray) -> NamedTempFile { - write_vortex_file([(field_name, array)].into_iter()).await +async fn write_single_column_vortex_file( + handle: Handle, + field_name: &str, + array: impl IntoArray, +) -> NamedTempFile { + write_vortex_file(handle, [(field_name, array)].into_iter()).await } async fn write_vortex_file( + handle: Handle, iter: impl Iterator, impl IntoArray)>, ) -> NamedTempFile { let temp_file_path = create_temp_file(); let struct_array = StructArray::try_from_iter(iter).unwrap(); - let mut file = tokio::fs::File::create(&temp_file_path).await.unwrap(); + let mut file = async_fs::File::create(&temp_file_path).await.unwrap(); VortexWriteOptions::default() + .with_handle(handle) .write(&mut file, struct_array.to_array_stream()) .await .unwrap(); @@ -131,6 +138,7 @@ fn scan_vortex_file>( } async fn write_vortex_file_to_dir( + handle: Handle, dir: &Path, field_name: &str, array: impl IntoArray, @@ -141,8 +149,9 @@ async fn write_vortex_file_to_dir( .tempfile_in(dir) .unwrap(); - let mut file = tokio::fs::File::create(&temp_file_path).await.unwrap(); + let mut file = async_fs::File::create(&temp_file_path).await.unwrap(); VortexWriteOptions::default() + .with_handle(handle) .write(&mut file, struct_array.to_array_stream()) .await .unwrap(); @@ -167,10 +176,9 @@ fn test_scan_function_registration() { #[test] fn test_vortex_scan_strings() { - let runtime = tokio::runtime::Runtime::new().unwrap(); - let file = runtime.block_on(async { + let file = RUNTIME.block_on(|h| async { let strings = VarBinArray::from(vec!["Hello", "Hi", "Hey"]); - write_single_column_vortex_file("strings", strings).await + write_single_column_vortex_file(h, "strings", strings).await }); let result: String = scan_vortex_file_single_row( @@ -184,10 +192,9 @@ fn test_vortex_scan_strings() { #[test] fn test_vortex_scan_strings_contains() { - let runtime = tokio::runtime::Runtime::new().unwrap(); - let file = runtime.block_on(async { + let file = RUNTIME.block_on(|h| async { let strings = VarBinArray::from(vec!["Hello", "Hi", "Hey"]); - write_single_column_vortex_file("strings", strings).await + write_single_column_vortex_file(h, "strings", strings).await }); let result: String = scan_vortex_file_single_row( file, @@ -200,10 +207,9 @@ fn test_vortex_scan_strings_contains() { #[test] fn test_vortex_scan_integers() { - let runtime = tokio::runtime::Runtime::new().unwrap(); - let file = runtime.block_on(async { + let file = RUNTIME.block_on(|h| async { let numbers = buffer![1i32, 42, 100, -5, 0]; - write_single_column_vortex_file("number", numbers).await + write_single_column_vortex_file(h, "number", numbers).await }); let sum: i64 = scan_vortex_file_single_row::(file, "SELECT SUM(number) FROM vortex_scan(?)", 0); @@ -212,10 +218,9 @@ fn test_vortex_scan_integers() { #[test] fn test_vortex_scan_integers_in_list() { - let runtime = tokio::runtime::Runtime::new().unwrap(); - let file = runtime.block_on(async { + let file = RUNTIME.block_on(|h| async { let numbers = buffer![1i32, 42, 100, -5, 0]; - write_single_column_vortex_file("number", numbers).await + write_single_column_vortex_file(h, "number", numbers).await }); let sum: i64 = scan_vortex_file_single_row::( file, @@ -227,10 +232,9 @@ fn test_vortex_scan_integers_in_list() { #[test] fn test_vortex_scan_integers_between() { - let runtime = tokio::runtime::Runtime::new().unwrap(); - let file = runtime.block_on(async { + let file = RUNTIME.block_on(|h| async { let numbers = buffer![1i32, 42, 100, -5, 0]; - write_single_column_vortex_file("number", numbers).await + write_single_column_vortex_file(h, "number", numbers).await }); let sum: i64 = scan_vortex_file_single_row::( file, @@ -242,10 +246,9 @@ fn test_vortex_scan_integers_between() { #[test] fn test_vortex_scan_floats() { - let runtime = tokio::runtime::Runtime::new().unwrap(); - let file = runtime.block_on(async { + let file = RUNTIME.block_on(|h| async { let values = buffer![1.5f64, -2.5, 0.0, 42.42]; - write_single_column_vortex_file("value", values).await + write_single_column_vortex_file(h, "value", values).await }); let count: i64 = scan_vortex_file_single_row::( file, @@ -257,10 +260,9 @@ fn test_vortex_scan_floats() { #[test] fn test_vortex_scan_constant() { - let runtime = tokio::runtime::Runtime::new().unwrap(); - let file = runtime.block_on(async { + let file = RUNTIME.block_on(|h| async { let constant = ConstantArray::new(Scalar::from(42i32), 100); - write_single_column_vortex_file("constant", constant).await + write_single_column_vortex_file(h, "constant", constant).await }); let value: i32 = scan_vortex_file_single_row::( file, @@ -272,11 +274,10 @@ fn test_vortex_scan_constant() { #[test] fn test_vortex_scan_booleans() { - let runtime = tokio::runtime::Runtime::new().unwrap(); - let file = runtime.block_on(async { + let file = RUNTIME.block_on(|h| async { let flags = vec![true, false, true, true, false]; let flags_array = BoolArray::from_bit_buffer(flags.into(), Validity::NonNullable); - write_single_column_vortex_file("flag", flags_array).await + write_single_column_vortex_file(h, "flag", flags_array).await }); let true_count: i64 = scan_vortex_file_single_row::( file, @@ -288,8 +289,7 @@ fn test_vortex_scan_booleans() { #[test] fn test_vortex_multi_column() { - let runtime = tokio::runtime::Runtime::new().unwrap(); - let file = runtime.block_on(async { + let file = RUNTIME.block_on(|h| async { let f1 = BoolArray::from_bit_buffer( vec![true, false, true, true, false].into(), Validity::NonNullable, @@ -297,7 +297,7 @@ fn test_vortex_multi_column() { .to_array(); let f2 = (0..5).collect::().to_array(); let f3 = (100..105).collect::().to_array(); - write_vortex_file([("f1", f1), ("f2", f2), ("f3", f3)].into_iter()).await + write_vortex_file(h, [("f1", f1), ("f2", f2), ("f3", f3)].into_iter()).await }); let result: Vec = scan_vortex_file::( @@ -312,13 +312,15 @@ fn test_vortex_multi_column() { #[test] fn test_vortex_scan_multiple_files() { - let runtime = tokio::runtime::Runtime::new().unwrap(); - let (tempdir, _file1, _file2) = runtime.block_on(async { + let (tempdir, _file1, _file2) = RUNTIME.block_on(|h| async { let tempdir = tempfile::tempdir().unwrap(); - let file1 = write_vortex_file_to_dir(tempdir.path(), "numbers", buffer![1i32, 2, 3]).await; + let file1 = + write_vortex_file_to_dir(h.clone(), tempdir.path(), "numbers", buffer![1i32, 2, 3]) + .await; - let file2 = write_vortex_file_to_dir(tempdir.path(), "numbers", buffer![4i32, 5, 6]).await; + let file2 = + write_vortex_file_to_dir(h, tempdir.path(), "numbers", buffer![4i32, 5, 6]).await; (tempdir, file1, file2) }); @@ -394,8 +396,7 @@ fn test_write_timestamps() { fn test_vortex_scan_fixed_size_list_utf8() { // Test a simple FixedSizeList of Utf8 strings to ensure proper materialization. - let runtime = tokio::runtime::Runtime::new().unwrap(); - let file = runtime.block_on(async { + let file = RUNTIME.block_on(|h| async { // Create a large number of strings to stress test. let strings: Vec<&str> = (0..24) .map(|i| match i % 6 { @@ -418,7 +419,7 @@ fn test_vortex_scan_fixed_size_list_utf8() { 6, // 6 lists total ); - write_single_column_vortex_file("string_lists", fsl).await + write_single_column_vortex_file(h, "string_lists", fsl).await }); let conn = database_connection(); @@ -446,9 +447,7 @@ fn test_vortex_scan_nested_fixed_size_list_utf8() { // when running with `FixedSizeList` instead of `List`. // Test FixedSizeList of FixedSizeList of Utf8 to ensure proper materialization. - - let runtime = tokio::runtime::Runtime::new().unwrap(); - let file = runtime.block_on(async { + let file = RUNTIME.block_on(|h| async { // Create a large number of strings to stress test. let strings: Vec<&str> = (0..24) .map(|i| match i % 6 { @@ -479,7 +478,7 @@ fn test_vortex_scan_nested_fixed_size_list_utf8() { 2, // 2 outer lists ); - write_single_column_vortex_file("nested_string_lists", outer_fsl).await + write_single_column_vortex_file(h, "nested_string_lists", outer_fsl).await }); let conn = database_connection(); @@ -504,8 +503,8 @@ fn test_vortex_scan_nested_fixed_size_list_utf8() { #[test] fn test_vortex_scan_list_of_ints() { // Test a simple List of integers. - let runtime = tokio::runtime::Runtime::new().unwrap(); - let file = runtime.block_on(async { + + let file = RUNTIME.block_on(|h| async { // Create integers that will be grouped into lists. let integers = PrimitiveArray::from_iter([ 10i32, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, @@ -525,7 +524,7 @@ fn test_vortex_scan_list_of_ints() { ) .unwrap(); - write_single_column_vortex_file("int_list", list_array).await + write_single_column_vortex_file(h, "int_list", list_array).await }); let conn = database_connection(); @@ -556,8 +555,8 @@ fn test_vortex_scan_list_of_ints() { #[test] fn test_vortex_scan_list_of_utf8() { // Test a simple List of UTF8 strings. - let runtime = tokio::runtime::Runtime::new().unwrap(); - let file = runtime.block_on(async { + + let file = RUNTIME.block_on(|h| async { // Create UTF8 strings that will be grouped into lists. let strings = VarBinViewArray::from_iter_str(vec![ "apple", @@ -587,7 +586,7 @@ fn test_vortex_scan_list_of_utf8() { ) .unwrap(); - write_single_column_vortex_file("string_list", list_array).await + write_single_column_vortex_file(h, "string_list", list_array).await }); let conn = database_connection(); @@ -622,8 +621,7 @@ fn test_vortex_scan_ultra_deep_nesting() { // Test ultra-deep nesting: Multiple levels of FSL and List combinations with UTF8. // FSL[List[FSL[List[FSL[UTF8]]]]] - let runtime = tokio::runtime::Runtime::new().unwrap(); - let file = runtime.block_on(async { + let file = RUNTIME.block_on(|h| async { // Level 1: Create base UTF8 strings - need a lot for deep nesting. let strings = VarBinViewArray::from_iter_str( (0..360) @@ -684,7 +682,7 @@ fn test_vortex_scan_ultra_deep_nesting() { 1, // 1 outermost FSL ); - write_single_column_vortex_file("ultra_deep", outermost_fsl).await + write_single_column_vortex_file(h, "ultra_deep", outermost_fsl).await }); let conn = database_connection(); diff --git a/vortex-duckdb/src/lib.rs b/vortex-duckdb/src/lib.rs index 3e7d350cbfc..b85930163eb 100644 --- a/vortex-duckdb/src/lib.rs +++ b/vortex-duckdb/src/lib.rs @@ -2,10 +2,12 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors #![allow(clippy::missing_safety_doc)] +// **WARNING end use std::ffi::{CStr, c_char}; +use std::sync::LazyLock; -// **WARNING end use vortex::error::{VortexExpect, VortexResult}; +use vortex::io::runtime::current::CurrentThreadRuntime; use crate::copy::VortexCopyFunction; use crate::duckdb::Config; @@ -28,6 +30,9 @@ mod copy; #[cfg(test)] mod e2e_test; +// A global runtime for Vortex operations within DuckDB. +static RUNTIME: LazyLock = LazyLock::new(CurrentThreadRuntime::new); + /// Register Vortex extension configuration options with DuckDB. /// This must be called before `register_table_functions` to take effect. pub fn register_extension_options(config: &Config) { diff --git a/vortex-duckdb/src/scan.rs b/vortex-duckdb/src/scan.rs index 7e29a3284e5..10b7583f6d4 100644 --- a/vortex-duckdb/src/scan.rs +++ b/vortex-duckdb/src/scan.rs @@ -21,9 +21,10 @@ use vortex::error::{VortexExpect, VortexResult, vortex_bail, vortex_err}; use vortex::expr::{ExprRef, and, and_collect, col, lit, root, select}; use vortex::file::{VortexFile, VortexOpenOptions}; use vortex::io::runtime::BlockingRuntime; -use vortex::io::runtime::current::{CurrentThreadRuntime, ThreadSafeIterator}; +use vortex::io::runtime::current::ThreadSafeIterator; use vortex::{ArrayRef, ToCanonical}; +use crate::RUNTIME; use crate::convert::{try_from_bound_expression, try_from_table_filter}; use crate::duckdb::footer_cache::FooterCache; use crate::duckdb::{ @@ -40,7 +41,6 @@ pub struct VortexBindData { file_urls: Vec, column_names: Vec, column_types: Vec, - runtime: CurrentThreadRuntime, max_threads: u64, } @@ -54,7 +54,6 @@ impl Clone for VortexBindData { file_urls: self.file_urls.clone(), column_names: self.column_names.clone(), column_types: self.column_types.clone(), - runtime: self.runtime.clone(), max_threads: self.max_threads, } } @@ -224,8 +223,6 @@ impl TableFunction for VortexTableFunction { input: &BindInput, result: &mut BindResult, ) -> VortexResult { - let runtime = CurrentThreadRuntime::new(); - let file_glob_string = input .get_parameter(0) .ok_or_else(|| vortex_err!("Missing file glob parameter"))?; @@ -248,7 +245,7 @@ impl TableFunction for VortexTableFunction { log::trace!("running scan with max_threads {max_threads}"); - let (file_urls, _metadata) = runtime + let (file_urls, _metadata) = RUNTIME .block_on(|_h| Compat::new(expand_glob(file_glob_string.as_ref().as_string())))?; // The first file is skipped in `create_file_paths_queue`. @@ -258,7 +255,7 @@ impl TableFunction for VortexTableFunction { let footer_cache = FooterCache::new(ctx.object_cache()); let entry = footer_cache.entry(first_file_url.as_ref()); - let first_file = runtime.block_on(|h| async move { + let first_file = RUNTIME.block_on(|h| async move { let options = entry.apply_to_file(VortexOpenOptions::new().with_handle(h)); let file = open_file(first_file_url.clone(), options).await?; entry.put_if_absent(|| file.footer().clone()); @@ -278,7 +275,6 @@ impl TableFunction for VortexTableFunction { filter_exprs: vec![], column_names, column_types, - runtime, max_threads: max_threads as u64, }) } @@ -347,7 +343,7 @@ impl TableFunction for VortexTableFunction { let client_context = init_input.client_context()?; let object_cache = client_context.object_cache(); - let handle = bind_data.runtime.handle(); + let handle = RUNTIME.handle(); let first_file = bind_data.first_file.clone(); let scan_streams = stream::iter(bind_data.file_urls.clone()) .enumerate() @@ -400,14 +396,12 @@ impl TableFunction for VortexTableFunction { .filter_map(|result| async move { result.transpose() }); Ok(VortexGlobalData { - iterator: bind_data - .runtime - .block_on_stream_thread_safe(move |_| MultiScan { - streams: scan_streams.boxed(), - streams_finished: false, - select_all: Default::default(), - max_concurrency: num_workers * 2, - }), + iterator: RUNTIME.block_on_stream_thread_safe(move |_| MultiScan { + streams: scan_streams.boxed(), + streams_finished: false, + select_all: Default::default(), + max_concurrency: num_workers * 2, + }), batch_id: AtomicU64::new(0), }) } diff --git a/vortex-io/Cargo.toml b/vortex-io/Cargo.toml index 2b237339dab..9faf001b5c0 100644 --- a/vortex-io/Cargo.toml +++ b/vortex-io/Cargo.toml @@ -18,6 +18,7 @@ all-features = true [dependencies] async-compat = { workspace = true } +async-fs = { workspace = true } async-stream = { workspace = true } async-trait = { workspace = true } bytes = { workspace = true } diff --git a/vortex-io/src/runtime/current.rs b/vortex-io/src/runtime/current.rs index cb7807738c0..7c8b14c7582 100644 --- a/vortex-io/src/runtime/current.rs +++ b/vortex-io/src/runtime/current.rs @@ -7,51 +7,44 @@ use futures::stream::BoxStream; use futures::{Stream, StreamExt}; use smol::block_on; +pub use crate::runtime::pool::CurrentThreadWorkerPool; use crate::runtime::{BlockingRuntime, Executor, Handle}; -/// A current thread runtime allows users to explicitly drive Vortex futures from multiple worker -/// threads that they manage. This is useful in environments where the user already has a thread -/// pool and wants to integrate Vortex into that pool, for example query engines. +/// A current thread runtime allows callers to much more explicitly drive Vortex futures than with +/// a Tokio runtime. +/// +/// The current thread runtime will do no work unless `block_on` is called. In other words, the +/// default behavior is single-threaded with code running on the thread that called `block_on`. +/// +/// It's also possible to clone the runtime onto other threads, each of which can call `block_on` +/// to drive work on that thread. Each thread shares the same underlying executor with the same +/// set of tasks, allowing work to be driven in parallel. +/// +/// For automatic driving of work, a [`CurrentThreadWorkerPool`] can be created from the runtime +/// by calling [`new_pool`](CurrentThreadRuntime::new_pool). The returned pool can be configured +/// with the desired number of worker threads that will drive work on behalf of the runtime. #[derive(Clone, Default)] pub struct CurrentThreadRuntime { executor: Arc>, } -impl BlockingRuntime for CurrentThreadRuntime { - type BlockingIterator<'a, R: 'a> = CurrentThreadIterator<'a, R>; - - fn handle(&self) -> Handle { - let executor: Arc = self.executor.clone(); - Handle::new(Arc::downgrade(&executor)) - } - - fn block_on(&self, f: F) -> R - where - F: FnOnce(Handle) -> Fut, - Fut: Future, - { - block_on(self.executor.run(f(self.handle()))) - } - - fn block_on_stream<'a, F, S, R>(&self, f: F) -> Self::BlockingIterator<'a, R> - where - F: FnOnce(Handle) -> S, - S: Stream + Send + 'a, - R: Send + 'a, - { - CurrentThreadIterator { - executor: self.executor.clone(), - stream: f(self.handle()).boxed(), - } - } -} - impl CurrentThreadRuntime { /// Create a new current thread runtime. pub fn new() -> Self { Self::default() } + /// Create a new worker pool for driving the runtime in the background. + /// + /// This pool can be used to offload work from the current thread to a set of worker threads + /// that will drive the runtime's executor. + /// + /// By default, the pool has no worker threads; the caller must set the desired number of + /// worker threads using the `set_workers` method on the returned pool. + pub fn new_pool(&self) -> CurrentThreadWorkerPool { + CurrentThreadWorkerPool::new(self.executor.clone()) + } + /// Returns an iterator wrapper around a stream, blocking the current thread for each item. /// /// ## Multi-threaded Usage @@ -91,6 +84,35 @@ impl CurrentThreadRuntime { } } +impl BlockingRuntime for CurrentThreadRuntime { + type BlockingIterator<'a, R: 'a> = CurrentThreadIterator<'a, R>; + + fn handle(&self) -> Handle { + let executor: Arc = self.executor.clone(); + Handle::new(Arc::downgrade(&executor)) + } + + fn block_on(&self, f: F) -> R + where + F: FnOnce(Handle) -> Fut, + Fut: Future, + { + block_on(self.executor.run(f(self.handle()))) + } + + fn block_on_stream<'a, F, S, R>(&self, f: F) -> Self::BlockingIterator<'a, R> + where + F: FnOnce(Handle) -> S, + S: Stream + Send + 'a, + R: Send + 'a, + { + CurrentThreadIterator { + executor: self.executor.clone(), + stream: f(self.handle()).boxed(), + } + } +} + /// An iterator that wraps up a stream to drive it using the current thread execution. pub struct CurrentThreadIterator<'a, T> { executor: Arc>, @@ -142,6 +164,38 @@ mod tests { use super::*; + #[test] + fn test_worker_thread() { + let runtime = CurrentThreadRuntime::new(); + + // We spawn a future that sets a value on a separate thread. + let value = Arc::new(AtomicUsize::new(0)); + let value2 = value.clone(); + runtime + .handle() + .spawn(async move { + value2.store(42, Ordering::SeqCst); + }) + .detach(); + + // By default, nothing has driven the executor, so the value should still be 0. + assert_eq!(value.load(Ordering::SeqCst), 0); + + // An empty pool still does nothing. + let pool = runtime.new_pool(); + assert_eq!(value.load(Ordering::SeqCst), 0); + + // Adding a worker thread should drive the executor. + pool.set_workers(1); + for _ in 0..10 { + if value.load(Ordering::SeqCst) == 42 { + break; + } + thread::sleep(Duration::from_millis(10)); + } + assert_eq!(value.load(Ordering::SeqCst), 42); + } + #[test] fn test_block_on_stream_single_thread() { let mut iter = CurrentThreadRuntime::new() diff --git a/vortex-io/src/runtime/mod.rs b/vortex-io/src/runtime/mod.rs index f30806e7e45..df291dc7960 100644 --- a/vortex-io/src/runtime/mod.rs +++ b/vortex-io/src/runtime/mod.rs @@ -26,6 +26,8 @@ pub use handle::*; #[cfg(not(target_arch = "wasm32"))] pub mod current; #[cfg(not(target_arch = "wasm32"))] +mod pool; +#[cfg(not(target_arch = "wasm32"))] pub mod single; #[cfg(not(target_arch = "wasm32"))] mod smol; diff --git a/vortex-io/src/runtime/pool.rs b/vortex-io/src/runtime/pool.rs new file mode 100644 index 00000000000..85e082835dc --- /dev/null +++ b/vortex-io/src/runtime/pool.rs @@ -0,0 +1,99 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::time::Duration; + +use parking_lot::Mutex; +use smol::block_on; +use vortex_error::VortexExpect; + +#[derive(Clone)] +pub struct CurrentThreadWorkerPool { + executor: Arc>, + state: Arc>, +} + +impl CurrentThreadWorkerPool { + pub(super) fn new(executor: Arc>) -> Self { + Self { + executor, + state: Arc::new(Mutex::new(PoolState::default())), + } + } + + /// Set the number of worker threads to the available system parallelism as reported by + /// `std::thread::available_parallelism()` minus 1, to leave a slot open for the calling thread. + pub fn set_workers_to_available_parallelism(&self) { + let n = std::thread::available_parallelism() + .map(|n| n.get().saturating_sub(1).max(1)) + .unwrap_or(1); + self.set_workers(n); + } + + /// Set the number of worker threads + /// - If n > current: spawns additional workers + /// - If n < current: signals extra workers to shut down + pub fn set_workers(&self, n: usize) { + let mut state = self.state.lock(); + let current = state.workers.len(); + + if n > current { + // Spawn new workers + for _ in current..n { + let shutdown = Arc::new(AtomicBool::new(false)); + let executor = self.executor.clone(); + let shutdown_clone = shutdown.clone(); + + std::thread::Builder::new() + .name("vortex-current-thread-worker".to_string()) + .spawn(move || { + // Run the executor with a sleeping future that checks for shutdown + block_on(executor.run(async move { + while !shutdown_clone.load(Ordering::Relaxed) { + smol::Timer::after(Duration::from_millis(100)).await; + } + })) + }) + .vortex_expect("Failed to spawn current thread worker"); + + state.workers.push(WorkerHandle { shutdown }); + } + } else if n < current { + // Signal extra workers to shutdown and remove them + while state.workers.len() > n { + if let Some(worker) = state.workers.pop() { + worker.shutdown.store(true, Ordering::Relaxed); + } + } + } + } + + /// Get the current number of worker threads + pub fn worker_count(&self) -> usize { + self.state.lock().workers.len() + } +} + +#[derive(Default)] +struct PoolState { + /// The set of worker handles for the background threads. + workers: Vec, +} + +struct WorkerHandle { + /// The shutdown flag indicating that the worker should stop. + shutdown: Arc, +} + +impl Drop for CurrentThreadWorkerPool { + fn drop(&mut self) { + let mut state = self.state.lock(); + + // Signal all workers to shut down + for worker in state.workers.drain(..) { + worker.shutdown.store(true, Ordering::Relaxed); + } + } +} diff --git a/vortex-io/src/write.rs b/vortex-io/src/write.rs index e63009c7491..71b2c5fbc59 100644 --- a/vortex-io/src/write.rs +++ b/vortex-io/src/write.rs @@ -91,6 +91,21 @@ impl VortexWrite for &mut W { } } +impl VortexWrite for async_fs::File { + async fn write_all(&mut self, buffer: B) -> io::Result { + AsyncWriteExt::write_all(self, buffer.as_slice()).await?; + Ok(buffer) + } + + fn flush(&mut self) -> impl Future> { + AsyncWriteExt::flush(self) + } + + fn shutdown(&mut self) -> impl Future> { + AsyncWriteExt::close(self) + } +} + /// An adapter to use an `AsyncWrite` as a `VortexWrite`. pub struct AsyncWriteAdapter(pub W); impl VortexWrite for AsyncWriteAdapter {