From 87577ef3965f41409f86247859b594b656639b29 Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Mon, 18 Nov 2024 19:38:59 -0600 Subject: [PATCH] feat: many a feature I did that thing again! Features in this commit: - `ThreadManager` allows you to define custom thread creation functions for environments & sessions. - Sessions can now opt-out of using the environment's global thread pool. - Implemented the safe `ShapeInferenceContext` wrapper for custom operators. - Prepacked weights allow the CPU execution provider to share one allocation for identical weights between sessions. - Customize workload type to prioritize efficiency; useful for background tasks. - Configurable per-session log identifiers - Dynamic dimension overrides Breaking changes: - `EnvironmentGlobalThreadPoolOptions` is now `GlobalThreadPoolOptions` and uses the builder pattern instead of exposed struct fields. --- examples/custom-ops/examples/custom-ops.rs | 8 + src/environment.rs | 165 +++++++++++++++++---- src/operator/bound.rs | 9 +- src/operator/mod.rs | 49 +++++- src/session/builder/impl_commit.rs | 41 +++-- src/session/builder/impl_options.rs | 79 ++++++++++ src/session/builder/mod.rs | 18 ++- src/session/mod.rs | 40 +++++ src/session/run_options.rs | 18 ++- src/value/mod.rs | 2 +- src/value/type.rs | 51 ++++++- tests/thread_manager.rs | 110 ++++++++++++++ 12 files changed, 536 insertions(+), 54 deletions(-) create mode 100644 tests/thread_manager.rs diff --git a/examples/custom-ops/examples/custom-ops.rs b/examples/custom-ops/examples/custom-ops.rs index 0b8c986a..f74035da 100644 --- a/examples/custom-ops/examples/custom-ops.rs +++ b/examples/custom-ops/examples/custom-ops.rs @@ -30,6 +30,14 @@ impl Operator for CustomOpOne { fn outputs() -> Vec { vec![OperatorOutput::required(TensorElementType::Float32)] } + + fn get_infer_shape_function() -> Option> { + Some(Box::new(|ctx| { + let inputs = ctx.inputs(); + ctx.set_output(0, &inputs[0])?; + Ok(()) + })) + } } impl Kernel for CustomOpOneKernel { diff --git a/src/environment.rs b/src/environment.rs index ae86aa58..2611f87c 100644 --- a/src/environment.rs +++ b/src/environment.rs @@ -12,7 +12,9 @@ //! ``` use std::{ + any::Any, ffi::{self, CStr, CString}, + os::raw::c_void, ptr::{self, NonNull}, sync::{Arc, RwLock} }; @@ -47,7 +49,8 @@ static G_ENV: EnvironmentSingleton = EnvironmentSingleton { lock: RwLock::new(No pub struct Environment { pub(crate) execution_providers: Vec, ptr: NonNull, - pub(crate) has_global_threadpool: bool + pub(crate) has_global_threadpool: bool, + _thread_manager: Option> } unsafe impl Send for Environment {} @@ -83,12 +86,126 @@ pub fn get_environment() -> Result> { } } -#[derive(Debug, Default, Clone)] -pub struct EnvironmentGlobalThreadPoolOptions { - pub inter_op_parallelism: Option, - pub intra_op_parallelism: Option, - pub spin_control: Option, - pub intra_op_thread_affinity: Option +#[derive(Debug)] +pub struct GlobalThreadPoolOptions { + ptr: *mut ort_sys::OrtThreadingOptions, + thread_manager: Option> +} + +impl Default for GlobalThreadPoolOptions { + fn default() -> Self { + let mut ptr = ptr::null_mut(); + ortsys![unsafe CreateThreadingOptions(&mut ptr)]; + Self { ptr, thread_manager: None } + } +} + +impl GlobalThreadPoolOptions { + pub fn with_inter_threads(mut self, num_threads: usize) -> Result { + ortsys![unsafe SetGlobalInterOpNumThreads(self.ptr_mut(), num_threads as _)?]; + Ok(self) + } + + pub fn with_intra_threads(mut self, num_threads: usize) -> Result { + ortsys![unsafe SetGlobalIntraOpNumThreads(self.ptr_mut(), num_threads as _)?]; + Ok(self) + } + + pub fn with_spin_control(mut self, spin_control: bool) -> Result { + ortsys![unsafe SetGlobalSpinControl(self.ptr_mut(), if spin_control { 1 } else { 0 })?]; + Ok(self) + } + + pub fn with_intra_affinity(mut self, affinity: impl AsRef) -> Result { + let affinity = CString::new(affinity.as_ref())?; + ortsys![unsafe SetGlobalIntraOpThreadAffinity(self.ptr_mut(), affinity.as_ptr())?]; + Ok(self) + } + + pub fn with_flush_to_zero(mut self) -> Result { + ortsys![unsafe SetGlobalDenormalAsZero(self.ptr_mut())?]; + Ok(self) + } + + pub fn with_thread_manager(mut self, manager: T) -> Result { + let mut manager = Box::new(manager); + ortsys![unsafe SetGlobalCustomThreadCreationOptions(self.ptr_mut(), (&mut *manager as *mut T).cast())?]; + ortsys![unsafe SetGlobalCustomCreateThreadFn(self.ptr_mut(), Some(thread_create::))?]; + ortsys![unsafe SetGlobalCustomJoinThreadFn(self.ptr_mut(), Some(thread_join::))?]; + self.thread_manager = Some(manager as Box); + Ok(self) + } +} + +impl AsPointer for GlobalThreadPoolOptions { + type Sys = ort_sys::OrtThreadingOptions; + + fn ptr(&self) -> *const Self::Sys { + self.ptr + } +} + +impl Drop for GlobalThreadPoolOptions { + fn drop(&mut self) { + ortsys![unsafe ReleaseThreadingOptions(self.ptr)]; + } +} + +pub struct ThreadWorker { + data: *mut c_void, + worker: ort_sys::OrtThreadWorkerFn +} + +unsafe impl Send for ThreadWorker {} + +impl ThreadWorker { + pub fn work(self) { + unsafe { self.worker.unwrap_unchecked()(self.data) } + } +} + +pub trait ThreadManager { + type Thread; + + fn create(&mut self, worker: ThreadWorker) -> crate::Result; + + fn join(thread: Self::Thread) -> crate::Result<()>; +} + +pub(crate) unsafe extern "C" fn thread_create( + ort_custom_thread_creation_options: *mut c_void, + ort_thread_worker_fn: ort_sys::OrtThreadWorkerFn, + ort_worker_fn_param: *mut c_void +) -> ort_sys::OrtCustomThreadHandle { + let thread_worker = ThreadWorker { + data: ort_worker_fn_param, + worker: ort_thread_worker_fn + }; + + let res = std::panic::catch_unwind(|| { + let manager = unsafe { &mut *ort_custom_thread_creation_options.cast::() }; + ::create(manager, thread_worker) + }); + match res { + Ok(Ok(thread)) => (Box::leak(Box::new(thread)) as *mut ::Thread) + .cast_const() + .cast::(), + Ok(Err(e)) => { + tracing::error!("Failed to create thread using manager: {e}"); + ptr::null() + } + Err(e) => { + tracing::error!("Thread manager panicked: {e:?}"); + ptr::null() + } + } +} + +pub(crate) unsafe extern "C" fn thread_join(ort_custom_thread_handle: ort_sys::OrtCustomThreadHandle) { + let handle = Box::from_raw(ort_custom_thread_handle.cast_mut().cast::<::Thread>()); + if let Err(e) = ::join(*handle) { + tracing::error!("Failed to join thread using manager: {e}"); + } } /// Struct used to build an [`Environment`]; see [`crate::init`]. @@ -96,7 +213,7 @@ pub struct EnvironmentBuilder { name: String, telemetry: bool, execution_providers: Vec, - global_thread_pool_options: Option + global_thread_pool_options: Option } impl EnvironmentBuilder { @@ -153,48 +270,33 @@ impl EnvironmentBuilder { /// Enables the global thread pool for this environment. #[must_use = "commit() must be called in order for the environment to take effect"] - pub fn with_global_thread_pool(mut self, options: EnvironmentGlobalThreadPoolOptions) -> Self { + pub fn with_global_thread_pool(mut self, options: GlobalThreadPoolOptions) -> Self { self.global_thread_pool_options = Some(options); self } /// Commit the environment configuration and set the global environment. pub fn commit(self) -> Result> { - let (env_ptr, has_global_threadpool) = if let Some(global_thread_pool) = self.global_thread_pool_options { + let (env_ptr, thread_manager, has_global_threadpool) = if let Some(mut thread_pool_options) = self.global_thread_pool_options { let mut env_ptr: *mut ort_sys::OrtEnv = std::ptr::null_mut(); let logging_function: ort_sys::OrtLoggingFunction = Some(custom_logger); let logger_param: *mut std::ffi::c_void = std::ptr::null_mut(); let cname = CString::new(self.name.clone()).unwrap_or_else(|_| unreachable!()); - let mut thread_options: *mut ort_sys::OrtThreadingOptions = std::ptr::null_mut(); - ortsys![unsafe CreateThreadingOptions(&mut thread_options)?; nonNull(thread_options)]; - if let Some(inter_op_parallelism) = global_thread_pool.inter_op_parallelism { - ortsys![unsafe SetGlobalInterOpNumThreads(thread_options, inter_op_parallelism)?]; - } - if let Some(intra_op_parallelism) = global_thread_pool.intra_op_parallelism { - ortsys![unsafe SetGlobalIntraOpNumThreads(thread_options, intra_op_parallelism)?]; - } - if let Some(spin_control) = global_thread_pool.spin_control { - ortsys![unsafe SetGlobalSpinControl(thread_options, i32::from(spin_control))?]; - } - if let Some(intra_op_thread_affinity) = global_thread_pool.intra_op_thread_affinity { - let cstr = CString::new(intra_op_thread_affinity).unwrap_or_else(|_| unreachable!()); - ortsys![unsafe SetGlobalIntraOpThreadAffinity(thread_options, cstr.as_ptr())?]; - } - ortsys![ unsafe CreateEnvWithCustomLoggerAndGlobalThreadPools( logging_function, logger_param, ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, cname.as_ptr(), - thread_options, + thread_pool_options.ptr(), &mut env_ptr )?; nonNull(env_ptr) ]; - ortsys![unsafe ReleaseThreadingOptions(thread_options)]; - (env_ptr, true) + + let thread_manager = thread_pool_options.thread_manager.take(); + (env_ptr, thread_manager, true) } else { let mut env_ptr: *mut ort_sys::OrtEnv = std::ptr::null_mut(); let logging_function: ort_sys::OrtLoggingFunction = Some(custom_logger); @@ -211,7 +313,7 @@ impl EnvironmentBuilder { )?; nonNull(env_ptr) ]; - (env_ptr, false) + (env_ptr, None, false) }; debug!(env_ptr = format!("{env_ptr:?}").as_str(), "Environment created"); @@ -230,7 +332,8 @@ impl EnvironmentBuilder { execution_providers: self.execution_providers, // we already asserted the env pointer is non-null in the `CreateEnvWithCustomLogger` call ptr: unsafe { NonNull::new_unchecked(env_ptr) }, - has_global_threadpool + has_global_threadpool, + _thread_manager: thread_manager }); env_lock.replace(Arc::clone(&env)); diff --git a/src/operator/bound.rs b/src/operator/bound.rs index c128f7c5..9f6c7375 100644 --- a/src/operator/bound.rs +++ b/src/operator/bound.rs @@ -5,7 +5,7 @@ use std::{ }; use super::{ - DummyOperator, Operator, + DummyOperator, Operator, ShapeInferenceContext, io::InputOutputCharacteristic, kernel::{Kernel, KernelAttributes, KernelContext} }; @@ -203,8 +203,11 @@ impl BoundOperator { } extern_system_fn! { - pub(crate) unsafe fn InferOutputShapeFn(_: *const ort_sys::OrtCustomOp, arg1: *mut ort_sys::OrtShapeInferContext) -> *mut ort_sys::OrtStatus { - O::get_infer_shape_function().expect("missing infer shape function")(arg1).into_status() + pub(crate) unsafe fn InferOutputShapeFn(_: *const ort_sys::OrtCustomOp, ctx: *mut ort_sys::OrtShapeInferContext) -> *mut ort_sys::OrtStatus { + let mut ctx = ShapeInferenceContext { + ptr: ctx + }; + O::get_infer_shape_function().expect("missing infer shape function")(&mut ctx).into_status() } } } diff --git a/src/operator/mod.rs b/src/operator/mod.rs index c24865f9..7f58ddd5 100644 --- a/src/operator/mod.rs +++ b/src/operator/mod.rs @@ -16,9 +16,12 @@ use self::{ io::{OperatorInput, OperatorOutput}, kernel::{DummyKernel, Kernel, KernelAttributes} }; -use crate::{AsPointer, error::Result, ortsys}; - -pub type InferShapeFn = dyn FnMut(*mut ort_sys::OrtShapeInferContext) -> crate::Result<()>; +use crate::{ + AsPointer, Error, + error::Result, + ortsys, + value::{ValueType, r#type::extract_data_type_from_tensor_info} +}; /// A custom operator descriptor, which describes the expected inputs & outputs of a graph operator. /// @@ -84,6 +87,46 @@ impl Operator for DummyOperator { } } +pub type InferShapeFn = dyn FnMut(&mut ShapeInferenceContext) -> crate::Result<()> + 'static; + +pub struct ShapeInferenceContext { + ptr: *mut ort_sys::OrtShapeInferContext +} + +impl ShapeInferenceContext { + pub fn inputs(&self) -> Vec { + let mut count = 0; + ortsys![unsafe ShapeInferContext_GetInputCount(self.ptr(), &mut count).expect("failed to get input count")]; + + let mut tys = Vec::with_capacity(count); + for i in 0..count { + let mut ty_info = ptr::null_mut(); + ortsys![unsafe ShapeInferContext_GetInputTypeShape(self.ptr(), i, &mut ty_info).expect("failed to get info type")]; + tys.push(unsafe { extract_data_type_from_tensor_info(ty_info) }); + } + tys + } + + pub fn set_output(&mut self, idx: usize, ty: &ValueType) -> Result<()> { + match ty.to_tensor_type_info() { + Some(ty_ptr) => { + ortsys![unsafe ShapeInferContext_SetOutputTypeShape(self.ptr(), idx, ty_ptr)?]; + ortsys![unsafe ReleaseTensorTypeAndShapeInfo(ty_ptr)]; + Ok(()) + } + None => Err(Error::new("only tensors are supported")) + } + } +} + +impl AsPointer for ShapeInferenceContext { + type Sys = ort_sys::OrtShapeInferContext; + + fn ptr(&self) -> *const Self::Sys { + self.ptr + } +} + pub struct OperatorDomain { ptr: NonNull, _name: CString, diff --git a/src/session/builder/impl_commit.rs b/src/session/builder/impl_commit.rs index aa21873a..f586c169 100644 --- a/src/session/builder/impl_commit.rs +++ b/src/session/builder/impl_commit.rs @@ -76,12 +76,16 @@ impl SessionBuilder { let env = get_environment()?; apply_execution_providers(&mut self, env.execution_providers.iter().cloned())?; - if env.has_global_threadpool { + if env.has_global_threadpool && !self.no_global_thread_pool { ortsys![unsafe DisablePerSessionThreads(self.ptr_mut())?]; } let mut session_ptr: *mut ort_sys::OrtSession = std::ptr::null_mut(); - ortsys![unsafe CreateSession(env.ptr(), model_path.as_ptr(), self.ptr(), &mut session_ptr)?; nonNull(session_ptr)]; + if let Some(prepacked_weights) = self.prepacked_weights.as_ref() { + ortsys![unsafe CreateSessionWithPrepackedWeightsContainer(env.ptr(), model_path.as_ptr(), self.ptr(), prepacked_weights.ptr().cast_mut(), &mut session_ptr)?; nonNull(session_ptr)]; + } else { + ortsys![unsafe CreateSession(env.ptr(), model_path.as_ptr(), self.ptr(), &mut session_ptr)?; nonNull(session_ptr)]; + } let session_ptr = unsafe { NonNull::new_unchecked(session_ptr) }; @@ -104,7 +108,13 @@ impl SessionBuilder { .map(|i| dangerous::extract_output(session_ptr, &allocator, i)) .collect::>>()?; - let extras = self.operator_domains.drain(..).map(|d| Box::new(d) as Box).collect(); + let mut extras: Vec> = self.operator_domains.drain(..).map(|d| Box::new(d) as Box).collect(); + if let Some(prepacked_weights) = self.prepacked_weights.take() { + extras.push(Box::new(prepacked_weights) as Box); + } + if let Some(thread_manager) = self.thread_manager.take() { + extras.push(Box::new(thread_manager) as Box); + } Ok(Session { inner: Arc::new(SharedSessionInner { @@ -141,16 +151,23 @@ impl SessionBuilder { let env = get_environment()?; apply_execution_providers(&mut self, env.execution_providers.iter().cloned())?; - if env.has_global_threadpool { + if env.has_global_threadpool && !self.no_global_thread_pool { ortsys![unsafe DisablePerSessionThreads(self.ptr_mut())?]; } let model_data = model_bytes.as_ptr().cast::(); let model_data_length = model_bytes.len(); - ortsys![ - unsafe CreateSessionFromArray(env.ptr(), model_data, model_data_length, self.ptr(), &mut session_ptr)?; - nonNull(session_ptr) - ]; + if let Some(prepacked_weights) = self.prepacked_weights.as_ref() { + ortsys![ + unsafe CreateSessionFromArrayWithPrepackedWeightsContainer(env.ptr(), model_data, model_data_length, self.ptr(), prepacked_weights.ptr().cast_mut(), &mut session_ptr)?; + nonNull(session_ptr) + ]; + } else { + ortsys![ + unsafe CreateSessionFromArray(env.ptr(), model_data, model_data_length, self.ptr(), &mut session_ptr)?; + nonNull(session_ptr) + ]; + } let session_ptr = unsafe { NonNull::new_unchecked(session_ptr) }; @@ -173,7 +190,13 @@ impl SessionBuilder { .map(|i| dangerous::extract_output(session_ptr, &allocator, i)) .collect::>>()?; - let extras = self.operator_domains.drain(..).map(|d| Box::new(d) as Box).collect(); + let mut extras: Vec> = self.operator_domains.drain(..).map(|d| Box::new(d) as Box).collect(); + if let Some(prepacked_weights) = self.prepacked_weights.take() { + extras.push(Box::new(prepacked_weights) as Box); + } + if let Some(thread_manager) = self.thread_manager.take() { + extras.push(Box::new(thread_manager) as Box); + } let session = Session { inner: Arc::new(SharedSessionInner { diff --git a/src/session/builder/impl_options.rs b/src/session/builder/impl_options.rs index 771ad2f5..3f356b9b 100644 --- a/src/session/builder/impl_options.rs +++ b/src/session/builder/impl_options.rs @@ -1,7 +1,9 @@ use std::{ + any::Any, borrow::Cow, ffi::{CString, c_char}, path::Path, + ptr, rc::Rc, sync::Arc }; @@ -9,6 +11,7 @@ use std::{ use super::SessionBuilder; use crate::{ AsPointer, + environment::{self, ThreadManager}, error::Result, execution_providers::{ExecutionProviderDispatch, apply_execution_providers}, memory::MemoryInfo, @@ -168,6 +171,45 @@ impl SessionBuilder { self.external_initializer_buffers.push(buffer); Ok(self) } + + pub fn with_log_id(mut self, id: impl AsRef) -> Result { + let id = CString::new(id.as_ref())?; + ortsys![unsafe SetSessionLogId(self.ptr_mut(), id.as_ptr())?]; + Ok(self) + } + + pub fn with_dimension_override(mut self, name: impl AsRef, size: i64) -> Result { + let name = CString::new(name.as_ref())?; + ortsys![unsafe AddFreeDimensionOverrideByName(self.ptr_mut(), name.as_ptr(), size)?]; + Ok(self) + } + + pub fn with_dimension_override_by_denotation(mut self, denotation: impl AsRef, size: i64) -> Result { + let denotation = CString::new(denotation.as_ref())?; + ortsys![unsafe AddFreeDimensionOverride(self.ptr_mut(), denotation.as_ptr(), size)?]; + Ok(self) + } + + pub fn with_prepacked_weights(mut self, weights: &PrepackedWeights) -> Result { + self.prepacked_weights = Some(weights.clone()); + Ok(self) + } + + /// Configures this environment to use its own thread pool instead of defaulting to the + /// [`Environment`](crate::environment::Environment)'s global thread pool if one was defined. + pub fn with_independent_thread_pool(mut self) -> Result { + self.no_global_thread_pool = true; + Ok(self) + } + + pub fn with_thread_manager(mut self, manager: T) -> Result { + let manager = Rc::new(manager); + ortsys![unsafe SessionOptionsSetCustomThreadCreationOptions(self.ptr_mut(), (&*manager as *const T) as *mut std::ffi::c_void)?]; + ortsys![unsafe SessionOptionsSetCustomCreateThreadFn(self.ptr_mut(), Some(environment::thread_create::))?]; + ortsys![unsafe SessionOptionsSetCustomJoinThreadFn(self.ptr_mut(), Some(environment::thread_join::))?]; + self.thread_manager = Some(manager as Rc); + Ok(self) + } } /// ONNX Runtime provides various graph optimizations to improve performance. Graph optimizations are essentially @@ -260,3 +302,40 @@ impl From for ort_sys::GraphOptimizationLevel { } } } + +#[derive(Debug)] +struct PrepackedWeightsInner(*mut ort_sys::OrtPrepackedWeightsContainer); + +impl Drop for PrepackedWeightsInner { + fn drop(&mut self) { + ortsys![unsafe ReleasePrepackedWeightsContainer(self.0)]; + } +} + +#[derive(Debug, Clone)] +pub struct PrepackedWeights { + inner: Arc +} + +impl PrepackedWeights { + #[allow(clippy::new_without_default)] + pub fn new() -> Self { + let mut ptr: *mut ort_sys::OrtPrepackedWeightsContainer = ptr::null_mut(); + ortsys![unsafe CreatePrepackedWeightsContainer(&mut ptr).expect("")]; + Self { + inner: Arc::new(PrepackedWeightsInner(ptr)) + } + } +} + +impl AsPointer for PrepackedWeights { + type Sys = ort_sys::OrtPrepackedWeightsContainer; + + fn ptr(&self) -> *const Self::Sys { + self.inner.0 + } + + fn ptr_mut(&mut self) -> *mut Self::Sys { + self.inner.0 + } +} diff --git a/src/session/builder/mod.rs b/src/session/builder/mod.rs index eacfcd56..3ec58b2f 100644 --- a/src/session/builder/mod.rs +++ b/src/session/builder/mod.rs @@ -1,4 +1,5 @@ use std::{ + any::Any, borrow::Cow, ffi::CString, ptr::{self, NonNull}, @@ -19,7 +20,7 @@ mod impl_commit; mod impl_config_keys; mod impl_options; -pub use self::impl_options::GraphOptimizationLevel; +pub use self::impl_options::{GraphOptimizationLevel, PrepackedWeights}; /// Creates a session using the builder pattern. /// @@ -44,7 +45,10 @@ pub struct SessionBuilder { memory_info: Option>, operator_domains: Vec>, external_initializers: Vec>, - external_initializer_buffers: Vec> + external_initializer_buffers: Vec>, + prepacked_weights: Option, + thread_manager: Option>, + no_global_thread_pool: bool } impl Clone for SessionBuilder { @@ -57,7 +61,10 @@ impl Clone for SessionBuilder { memory_info: self.memory_info.clone(), operator_domains: self.operator_domains.clone(), external_initializers: self.external_initializers.clone(), - external_initializer_buffers: self.external_initializer_buffers.clone() + external_initializer_buffers: self.external_initializer_buffers.clone(), + prepacked_weights: self.prepacked_weights.clone(), + thread_manager: self.thread_manager.clone(), + no_global_thread_pool: self.no_global_thread_pool } } } @@ -90,7 +97,10 @@ impl SessionBuilder { memory_info: None, operator_domains: Vec::new(), external_initializers: Vec::new(), - external_initializer_buffers: Vec::new() + external_initializer_buffers: Vec::new(), + prepacked_weights: None, + thread_manager: None, + no_global_thread_pool: false }) } diff --git a/src/session/mod.rs b/src/session/mod.rs index de58b358..f6e9551e 100644 --- a/src/session/mod.rs +++ b/src/session/mod.rs @@ -472,6 +472,46 @@ impl Session { assert_non_null_pointer(profiling_name, "ProfilingName")?; dangerous::raw_pointer_to_string(&self.inner.allocator, profiling_name) } + + /// Sets this session's [workload type][`WorkloadType`] to instruct execution providers to prioritize performance or + /// efficiency. + /// + /// ``` + /// # use std::sync::Arc; + /// # use ort::{session::{run_options::RunOptions, Session, WorkloadType}, tensor::TensorElementType, value::{Value, ValueType}}; + /// # fn main() -> ort::Result<()> { + /// let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; + /// session.set_workload_type(WorkloadType::Efficient)?; + /// + /// let input = ndarray::Array4::::zeros((1, 64, 64, 3)); + /// let outputs = session.run(ort::inputs![input]?)?; + /// # Ok(()) + /// # } + /// ``` + pub fn set_workload_type(&self, workload_type: WorkloadType) -> Result<()> { + static KEY: &[u8] = b"ep.dynamic.workload_type\0"; + match workload_type { + WorkloadType::Default => self.set_dynamic_option(KEY.as_ptr().cast(), b"Default\0".as_ptr().cast()), + WorkloadType::Efficient => self.set_dynamic_option(KEY.as_ptr().cast(), b"Efficient\0".as_ptr().cast()) + } + } + + pub(crate) fn set_dynamic_option(&self, key: *const c_char, value: *const c_char) -> Result<()> { + ortsys![unsafe SetEpDynamicOptions(self.inner.session_ptr.as_ptr(), &key, &value, 1)?]; + Ok(()) + } +} + +/// Workload type, used to signal to execution providers whether to prioritize performance or efficiency. +/// +/// See [`Session::set_workload_type`]. +#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)] +pub enum WorkloadType { + /// Prioritize performance. This is the default behavior when the workload type is not overridden. + #[default] + Default, + /// Prioritize efficiency, by i.e. reducing scheduling priority and/or offloading to efficiency cores. + Efficient } // https://github.com/microsoft/onnxruntime/issues/114 diff --git a/src/session/run_options.rs b/src/session/run_options.rs index 0c176b1b..bdd5b02d 100644 --- a/src/session/run_options.rs +++ b/src/session/run_options.rs @@ -1,4 +1,10 @@ -use std::{collections::HashMap, ffi::CString, marker::PhantomData, ptr::NonNull, sync::Arc}; +use std::{ + collections::HashMap, + ffi::{CStr, CString, c_char}, + marker::PhantomData, + ptr::{self, NonNull}, + sync::Arc +}; use crate::{ AsPointer, @@ -227,6 +233,16 @@ impl RunOptions { Ok(()) } + pub fn tag(&self) -> Result { + let mut tag_ptr: *const c_char = ptr::null(); + ortsys![unsafe RunOptionsGetRunTag(self.run_options_ptr.as_ptr(), &mut tag_ptr)?]; + if tag_ptr.is_null() { + Ok(String::default()) + } else { + Ok(unsafe { CStr::from_ptr(tag_ptr) }.to_string_lossy().into()) + } + } + /// Sets the termination flag for the runs associated with this [`RunOptions`]. /// /// This function returns immediately (it does not wait for the session run to terminate). The run will terminate as diff --git a/src/value/mod.rs b/src/value/mod.rs index e10ab591..8616592b 100644 --- a/src/value/mod.rs +++ b/src/value/mod.rs @@ -29,7 +29,7 @@ use std::{ mod impl_map; mod impl_sequence; mod impl_tensor; -mod r#type; +pub(crate) mod r#type; pub use self::{ impl_map::{DynMap, DynMapRef, DynMapRefMut, DynMapValueType, Map, MapRef, MapRefMut, MapValueType, MapValueTypeMarker}, diff --git a/src/value/type.rs b/src/value/type.rs index a25d4b38..27983d68 100644 --- a/src/value/type.rs +++ b/src/value/type.rs @@ -1,5 +1,5 @@ use std::{ - ffi::{CStr, c_char}, + ffi::{CStr, CString, c_char}, fmt, ptr }; @@ -132,6 +132,32 @@ impl ValueType { ortsys![unsafe ReleaseTypeInfo(typeinfo_ptr)]; io_type } + + pub(crate) fn to_tensor_type_info(&self) -> Option<*mut ort_sys::OrtTensorTypeAndShapeInfo> { + match self { + Self::Tensor { ty, dimensions, dimension_symbols } => { + let mut info_ptr = ptr::null_mut(); + ortsys![unsafe CreateTensorTypeAndShapeInfo(&mut info_ptr)]; + ortsys![unsafe SetTensorElementType(info_ptr, (*ty).into())]; + ortsys![unsafe SetDimensions(info_ptr, dimensions.as_ptr(), dimensions.len())]; + let dimension_symbols: Vec<*const c_char> = dimension_symbols + .iter() + .cloned() + .map(|s| CString::new(s.unwrap_or_default())) + .map(|s| s.map_or(ptr::null(), |s| s.into_raw().cast_const())) + .collect(); + ortsys![unsafe SetSymbolicDimensions(info_ptr, dimension_symbols.as_ptr().cast_mut(), dimension_symbols.len())]; + for p in dimension_symbols { + if !p.is_null() { + drop(unsafe { CString::from_raw(p.cast_mut().cast()) }); + } + } + Some(info_ptr) + } + _ => None + } + } + /// Returns the dimensions of this value type if it is a tensor, or `None` if it is a sequence or map. /// /// ``` @@ -216,7 +242,7 @@ impl fmt::Display for ValueType { } } -unsafe fn extract_data_type_from_tensor_info(info_ptr: *const ort_sys::OrtTensorTypeAndShapeInfo) -> ValueType { +pub(crate) unsafe fn extract_data_type_from_tensor_info(info_ptr: *const ort_sys::OrtTensorTypeAndShapeInfo) -> ValueType { let mut type_sys = ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; ortsys![GetTensorElementType(info_ptr, &mut type_sys)]; assert_ne!(type_sys, ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); @@ -260,3 +286,24 @@ unsafe fn extract_data_type_from_map_info(info_ptr: *const ort_sys::OrtMapTypeIn value: value_type_sys.into() } } + +#[cfg(test)] +mod tests { + use super::ValueType; + use crate::{ortsys, tensor::TensorElementType}; + + #[test] + fn test_to_from_tensor_info() -> crate::Result<()> { + let ty = ValueType::Tensor { + ty: TensorElementType::Float32, + dimensions: vec![-1, 32, 4, 32], + dimension_symbols: vec![Some("d1".to_string()), None, None, None] + }; + let ty_ptr = ty.to_tensor_type_info().expect(""); + let ty_d = unsafe { super::extract_data_type_from_tensor_info(ty_ptr) }; + ortsys![unsafe ReleaseTensorTypeAndShapeInfo(ty_ptr)]; + assert_eq!(ty, ty_d); + + Ok(()) + } +} diff --git a/tests/thread_manager.rs b/tests/thread_manager.rs new file mode 100644 index 00000000..4e15774d --- /dev/null +++ b/tests/thread_manager.rs @@ -0,0 +1,110 @@ +use std::{ + path::Path, + sync::{ + Arc, + atomic::{AtomicUsize, Ordering} + }, + thread::{self, JoinHandle} +}; + +use image::{ImageBuffer, Luma, Pixel, imageops::FilterType}; +use ort::{ + environment::{GlobalThreadPoolOptions, ThreadManager, ThreadWorker}, + inputs, + session::{Session, builder::GraphOptimizationLevel} +}; +use test_log::test; + +struct ThreadStats { + active_threads: AtomicUsize +} + +struct StdThread { + stats: Arc, + join_handle: JoinHandle<()> +} + +impl StdThread { + pub fn spawn(worker: ThreadWorker, stats: &Arc) -> Self { + let join_handle = thread::spawn(move || worker.work()); + stats.active_threads.fetch_add(1, Ordering::AcqRel); + Self { + stats: Arc::clone(stats), + join_handle + } + } + + pub fn join(self) { + let _ = self.join_handle.join(); + self.stats.active_threads.fetch_sub(1, Ordering::AcqRel); + } +} + +struct StdThreadManager { + stats: Arc +} + +impl ThreadManager for StdThreadManager { + type Thread = StdThread; + + fn create(&mut self, worker: ThreadWorker) -> ort::Result { + Ok(StdThread::spawn(worker, &self.stats)) + } + + fn join(thread: Self::Thread) -> ort::Result<()> { + thread.join(); + Ok(()) + } +} + +#[test] +fn global_thread_manager() -> ort::Result<()> { + let stats = Arc::new(ThreadStats { active_threads: AtomicUsize::new(0) }); + + ort::init() + .with_name("integration_test") + .with_global_thread_pool( + GlobalThreadPoolOptions::default() + .with_inter_threads(4)? + .with_intra_threads(2)? + .with_thread_manager(StdThreadManager { stats: Arc::clone(&stats) })? + ) + .commit()?; + + assert_eq!(stats.active_threads.load(Ordering::Acquire), 4); + + Ok(()) +} + +#[test] +fn session_thread_manager() -> ort::Result<()> { + const IMAGE_TO_LOAD: &str = "mnist_5.jpg"; + + let stats = Arc::new(ThreadStats { active_threads: AtomicUsize::new(0) }); + + let session = Session::builder()? + .with_optimization_level(GraphOptimizationLevel::Level1)? + .with_inter_threads(2)? + .with_intra_threads(2)? + .with_thread_manager(StdThreadManager { stats: Arc::clone(&stats) })? + .commit_from_url("https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_models/mnist.onnx") + .expect("Could not download model from file"); + + assert_eq!(stats.active_threads.load(Ordering::Acquire), 1); + + let image_buffer: ImageBuffer, Vec> = image::open(Path::new(env!("CARGO_MANIFEST_DIR")).join("tests").join("data").join(IMAGE_TO_LOAD)) + .unwrap() + .resize(28, 28, FilterType::Nearest) + .to_luma8(); + let array = ndarray::Array::from_shape_fn((1, 1, 28, 28), |(_, c, j, i)| { + let pixel = image_buffer.get_pixel(i as u32, j as u32); + let channels = pixel.channels(); + (channels[c] as f32) / 255.0 + }); + + let _ = session.run(inputs![array]?)?; + + assert_eq!(stats.active_threads.load(Ordering::Acquire), 1); + + Ok(()) +}