diff --git a/examples/async-gpt2-api/examples/async-gpt2-api.rs b/examples/async-gpt2-api/examples/async-gpt2-api.rs index 7aaf2869..a5de60ef 100644 --- a/examples/async-gpt2-api/examples/async-gpt2-api.rs +++ b/examples/async-gpt2-api/examples/async-gpt2-api.rs @@ -12,7 +12,7 @@ use axum::{ use futures::Stream; use ort::{ execution_providers::CUDAExecutionProvider, - session::{Session, builder::GraphOptimizationLevel}, + session::{RunOptions, Session, builder::GraphOptimizationLevel}, value::TensorRef }; use rand::Rng; @@ -74,7 +74,8 @@ fn generate_stream( let input = TensorRef::from_array_view((vec![1, 1, tokens.len() as i64], tokens.as_slice()))?; let probabilities = { let mut session = session.lock().await; - let outputs = session.run_async(ort::inputs![input])?.await?; + let options = RunOptions::new()?; + let outputs = session.run_async(ort::inputs![input], &options)?.await?; let (dim, probabilities) = outputs["output1"].try_extract_raw_tensor()?; // Collect and sort logits diff --git a/src/adapter.rs b/src/adapter.rs index db6fc421..b2bc4b28 100644 --- a/src/adapter.rs +++ b/src/adapter.rs @@ -164,11 +164,11 @@ mod tests { }; #[test] + #[cfg(feature = "std")] fn test_lora() -> crate::Result<()> { let model = std::fs::read("tests/data/lora_model.onnx").expect(""); let mut session = Session::builder()?.commit_from_memory(&model)?; - let lora = std::fs::read("tests/data/adapter.orl").expect(""); - let lora = Adapter::from_memory(&lora, None)?; + let lora = Adapter::from_file("tests/data/adapter.orl", None)?; let mut run_options = RunOptions::new()?; run_options.add_adapter(&lora)?; diff --git a/src/error.rs b/src/error.rs index 939a7682..8eb2ab4c 100644 --- a/src/error.rs +++ b/src/error.rs @@ -168,12 +168,6 @@ impl From for ort_sys::OrtErrorCode { } } -pub(crate) fn assert_non_null_pointer(ptr: *const T, name: &'static str) -> Result<()> { - (!ptr.is_null()) - .then_some(()) - .ok_or_else(|| Error::new(format!("Expected pointer `{name}` to not be null"))) -} - /// Converts an [`ort_sys::OrtStatus`] to a [`Result`]. /// /// Note that this frees `status`! diff --git a/src/lib.rs b/src/lib.rs index 3a66d5a2..a04544fc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -38,7 +38,7 @@ pub mod tensor; #[cfg(feature = "training")] #[cfg_attr(docsrs, doc(cfg(feature = "training")))] pub mod training; -pub(crate) mod util; +pub mod util; pub mod value; #[cfg(feature = "load-dynamic")] @@ -218,17 +218,39 @@ macro_rules! ortsys { (unsafe $method:ident($($n:expr),+ $(,)?).expect($e:expr)) => { unsafe { $crate::error::status_to_result(($crate::api().$method)($($n),+)) }.expect($e) }; - (unsafe $method:ident($($n:expr),+ $(,)?); nonNull($($check:expr),+ $(,)?)$(;)?) => {{ + (unsafe $method:ident($($n:expr),+ $(,)?).expect($e:expr); nonNull($($check:ident),+ $(,)?)$(;)?) => {{ + unsafe { $crate::error::status_to_result(($crate::api().$method)($($n),+)) }.expect($e); + $( + // TODO: #[cfg(debug_assertions)]? + if ($check).is_null() { + $crate::util::cold(); + panic!(concat!("expected `", stringify!($check), "` to not be null")); + } + )+ + }}; + (unsafe $method:ident($($n:expr),+ $(,)?); nonNull($($check:ident),+ $(,)?)$(;)?) => {{ let _x = unsafe { ($crate::api().$method)($($n),+) }; - $($crate::error::assert_non_null_pointer($check, stringify!($method)).unwrap();)+ + $( + // TODO: #[cfg(debug_assertions)]? + if ($check).is_null() { + $crate::util::cold(); + panic!(concat!("expected `", stringify!($check), "` to not be null")); + } + )+ _x }}; (unsafe $method:ident($($n:expr),+ $(,)?)?) => { unsafe { $crate::error::status_to_result(($crate::api().$method)($($n),+)) }?; }; - (unsafe $method:ident($($n:expr),+ $(,)?)?; nonNull($($check:expr),+ $(,)?)$(;)?) => {{ + (unsafe $method:ident($($n:expr),+ $(,)?)?; nonNull($($check:ident),+ $(,)?)$(;)?) => {{ unsafe { $crate::error::status_to_result(($crate::api().$method)($($n),+)) }?; - $($crate::error::assert_non_null_pointer($check, stringify!($method))?;)+ + $( + // TODO: #[cfg(debug_assertions)]? + if ($check).is_null() { + $crate::util::cold(); + return Err($crate::Error::new(concat!("expected `", stringify!($check), "` to not be null"))); + } + )+ }}; } diff --git a/src/memory.rs b/src/memory.rs index d4b6aba9..c42b5131 100644 --- a/src/memory.rs +++ b/src/memory.rs @@ -512,6 +512,12 @@ impl MemoryInfo { } } +impl Default for MemoryInfo { + fn default() -> Self { + MemoryInfo::new(AllocationDevice::CPU, 0, AllocatorType::Device, MemoryType::Default).expect("failed to create default memory info") + } +} + impl Clone for MemoryInfo { fn clone(&self) -> Self { MemoryInfo::new(self.allocation_device(), self.device_id(), self.allocator_type(), self.memory_type()).expect("failed to clone memory info") diff --git a/src/session/async.rs b/src/session/async.rs index 97443800..6a14f1c9 100644 --- a/src/session/async.rs +++ b/src/session/async.rs @@ -4,7 +4,6 @@ use core::{ ffi::{c_char, c_void}, future::Future, marker::PhantomData, - ops::Deref, pin::Pin, ptr::NonNull, task::{Context, Poll, Waker} @@ -13,8 +12,8 @@ use std::sync::Mutex; use crate::{ error::Result, - session::{RunOptions, SelectedOutputMarker, SessionInputValue, SessionOutputs, SharedSessionInner}, - value::Value + session::{SessionOutputs, SharedSessionInner, run_options::UntypedRunOptions}, + value::{Value, ValueInner} }; #[derive(Debug)] @@ -53,43 +52,17 @@ impl<'r, 's> InferenceFutInner<'r, 's> { unsafe impl Send for InferenceFutInner<'_, '_> {} unsafe impl Sync for InferenceFutInner<'_, '_> {} -pub enum RunOptionsRef<'r, O: SelectedOutputMarker> { - Arc(Arc>), - Ref(&'r RunOptions) -} - -impl From<&Arc>> for RunOptionsRef<'_, O> { - fn from(value: &Arc>) -> Self { - Self::Arc(Arc::clone(value)) - } -} - -impl<'r, O: SelectedOutputMarker> From<&'r RunOptions> for RunOptionsRef<'r, O> { - fn from(value: &'r RunOptions) -> Self { - Self::Ref(value) - } -} - -impl Deref for RunOptionsRef<'_, O> { - type Target = RunOptions; - - fn deref(&self) -> &Self::Target { - match self { - Self::Arc(r) => r, - Self::Ref(r) => r - } - } -} - -pub struct InferenceFut<'s, 'r, 'v, O: SelectedOutputMarker> { +pub struct InferenceFut<'s, 'r, 'v> { inner: Arc>, - run_options: RunOptionsRef<'r, O>, + run_options: &'r UntypedRunOptions, did_receive: bool, _inputs: PhantomData<&'v ()> } -impl<'s, 'r, O: SelectedOutputMarker> InferenceFut<'s, 'r, '_, O> { - pub(crate) fn new(inner: Arc>, run_options: RunOptionsRef<'r, O>) -> Self { +unsafe impl Send for InferenceFut<'_, '_, '_> {} + +impl<'s, 'r> InferenceFut<'s, 'r, '_> { + pub(crate) fn new(inner: Arc>, run_options: &'r UntypedRunOptions) -> Self { Self { inner, run_options, @@ -99,7 +72,7 @@ impl<'s, 'r, O: SelectedOutputMarker> InferenceFut<'s, 'r, '_, O> { } } -impl<'s, 'r, O: SelectedOutputMarker> Future for InferenceFut<'s, 'r, '_, O> { +impl<'s, 'r> Future for InferenceFut<'s, 'r, '_> { type Output = Result>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { @@ -115,7 +88,7 @@ impl<'s, 'r, O: SelectedOutputMarker> Future for InferenceFut<'s, 'r, '_, O> { } } -impl Drop for InferenceFut<'_, '_, '_, O> { +impl Drop for InferenceFut<'_, '_, '_> { fn drop(&mut self) { if !self.did_receive { let _ = self.run_options.terminate(); @@ -124,19 +97,19 @@ impl Drop for InferenceFut<'_, '_, '_, O> { } } -pub(crate) struct AsyncInferenceContext<'r, 's, 'v> { +pub(crate) struct AsyncInferenceContext<'r, 's> { pub(crate) inner: Arc>, - pub(crate) _input_values: Vec>, pub(crate) input_ort_values: Vec<*const ort_sys::OrtValue>, + pub(crate) _input_inner_holders: Vec>, pub(crate) input_name_ptrs: Vec<*const c_char>, pub(crate) output_name_ptrs: Vec<*const c_char>, pub(crate) session_inner: &'s Arc, - pub(crate) output_names: Vec<&'s str>, + pub(crate) output_names: Vec<&'r str>, pub(crate) output_value_ptrs: Vec<*mut ort_sys::OrtValue> } pub(crate) extern "system" fn async_callback(user_data: *mut c_void, _: *mut *mut ort_sys::OrtValue, _: usize, status: ort_sys::OrtStatusPtr) { - let ctx = unsafe { Box::from_raw(user_data.cast::>()) }; + let ctx = unsafe { Box::from_raw(user_data.cast::>()) }; // Reconvert name ptrs to CString so drop impl is called and memory is freed for p in ctx.input_name_ptrs { diff --git a/src/session/builder/impl_commit.rs b/src/session/builder/impl_commit.rs index 139ea2ea..91a1976c 100644 --- a/src/session/builder/impl_commit.rs +++ b/src/session/builder/impl_commit.rs @@ -9,6 +9,8 @@ use core::{ }; #[cfg(feature = "std")] use std::path::Path; +#[cfg(feature = "fetch-models")] +use std::path::PathBuf; use super::SessionBuilder; #[cfg(feature = "std")] @@ -28,6 +30,12 @@ impl SessionBuilder { #[cfg(all(feature = "fetch-models", feature = "std"))] #[cfg_attr(docsrs, doc(cfg(all(feature = "fetch-models", feature = "std"))))] pub fn commit_from_url(self, model_url: impl AsRef) -> Result { + let downloaded_path = SessionBuilder::download(model_url.as_ref())?; + self.commit_from_file(downloaded_path) + } + + #[cfg(all(feature = "fetch-models", feature = "std"))] + fn download(url: &str) -> Result { let mut download_dir = ort_sys::internal::dirs::cache_dir() .expect("could not determine cache directory") .join("models"); @@ -35,15 +43,14 @@ impl SessionBuilder { download_dir = std::env::current_dir().expect("Failed to obtain current working directory"); } - let url = model_url.as_ref(); let model_filename = ::digest(url).into_iter().fold(String::new(), |mut s, b| { let _ = write!(&mut s, "{:02x}", b); s }); let model_filepath = download_dir.join(&model_filename); - let downloaded_path = if model_filepath.exists() { + if model_filepath.exists() { crate::info!(model_filepath = format!("{}", model_filepath.display()).as_str(), "Model already exists, skipping download"); - model_filepath + Ok(model_filepath) } else { crate::info!(model_filepath = format!("{}", model_filepath.display()).as_str(), url = format!("{url:?}").as_str(), "Downloading model"); @@ -71,29 +78,31 @@ impl SessionBuilder { drop(writer); match std::fs::rename(&temp_filepath, &model_filepath) { - Ok(()) => model_filepath, + Ok(()) => Ok(model_filepath), Err(e) => { if model_filepath.exists() { let _ = std::fs::remove_file(temp_filepath); - model_filepath + Ok(model_filepath) } else { - return Err(Error::new(format!("Failed to download model: {e}"))); + Err(Error::new(format!("Failed to download model: {e}"))) } } } - }; - - self.commit_from_file(downloaded_path) + } } /// Loads an ONNX model from a file and builds the session. #[cfg(feature = "std")] #[cfg_attr(docsrs, doc(cfg(feature = "std")))] - pub fn commit_from_file

(mut self, model_filepath_ref: P) -> Result + pub fn commit_from_file

(self, model_filepath: P) -> Result where P: AsRef { - let model_filepath = model_filepath_ref.as_ref(); + self.commit_from_file_inner(model_filepath.as_ref()) + } + + #[cfg(feature = "std")] + fn commit_from_file_inner(mut self, model_filepath: &Path) -> Result { if !model_filepath.exists() { return Err(Error::new_with_code(ErrorCode::NoSuchFile, format!("File at `{}` does not exist", model_filepath.display()))); } @@ -166,7 +175,6 @@ impl SessionBuilder { self.add_config_entry("session.use_ort_model_bytes_for_initializers", "1")?; let session = self.commit_from_memory(model_bytes)?; - Ok(InMemorySession { session, phantom: PhantomData }) } diff --git a/src/session/builder/mod.rs b/src/session/builder/mod.rs index d56aa490..da5da40c 100644 --- a/src/session/builder/mod.rs +++ b/src/session/builder/mod.rs @@ -4,14 +4,7 @@ use core::{ ptr::{self, NonNull} }; -use crate::{ - AsPointer, - error::{Result, assert_non_null_pointer}, - memory::MemoryInfo, - operator::OperatorDomain, - ortsys, - value::DynValue -}; +use crate::{AsPointer, error::Result, memory::MemoryInfo, operator::OperatorDomain, ortsys, value::DynValue}; mod impl_commit; mod impl_config_keys; @@ -51,8 +44,11 @@ pub struct SessionBuilder { impl Clone for SessionBuilder { fn clone(&self) -> Self { let mut session_options_ptr = ptr::null_mut(); - ortsys![unsafe CloneSessionOptions(self.ptr(), ptr::addr_of_mut!(session_options_ptr)).expect("error cloning session options")]; - assert_non_null_pointer(session_options_ptr, "OrtSessionOptions").expect("Cloned session option pointer is null"); + ortsys![ + unsafe CloneSessionOptions(self.ptr(), ptr::addr_of_mut!(session_options_ptr)) + .expect("error cloning session options"); + nonNull(session_options_ptr) + ]; Self { session_options_ptr: unsafe { NonNull::new_unchecked(session_options_ptr) }, memory_info: self.memory_info.clone(), diff --git a/src/session/input.rs b/src/session/input.rs index 0ff723e3..87cffac5 100644 --- a/src/session/input.rs +++ b/src/session/input.rs @@ -1,7 +1,7 @@ -use alloc::{borrow::Cow, vec::Vec}; +use alloc::{borrow::Cow, sync::Arc, vec::Vec}; use core::ops::Deref; -use crate::value::{DynValueTypeMarker, Value, ValueRef, ValueRefMut, ValueTypeMarker}; +use crate::value::{DynValueTypeMarker, Value, ValueInner, ValueRef, ValueRefMut, ValueTypeMarker}; pub enum SessionInputValue<'v> { ViewMut(ValueRefMut<'v, DynValueTypeMarker>), @@ -9,6 +9,16 @@ pub enum SessionInputValue<'v> { Owned(Value) } +impl SessionInputValue<'_> { + pub(crate) fn inner(&self) -> &Arc { + match self { + Self::ViewMut(v) => v.inner(), + Self::View(v) => v.inner(), + Self::Owned(v) => v.inner() + } + } +} + impl Deref for SessionInputValue<'_> { type Target = Value; diff --git a/src/session/mod.rs b/src/session/mod.rs index 3a78dd7b..99737813 100644 --- a/src/session/mod.rs +++ b/src/session/mod.rs @@ -23,7 +23,7 @@ use core::{ use crate::{ AsPointer, char_p_to_string, - error::{Error, ErrorCode, Result, assert_non_null_pointer, status_to_result}, + error::{Error, ErrorCode, Result, status_to_result}, io_binding::IoBinding, memory::Allocator, metadata::ModelMetadata, @@ -40,8 +40,8 @@ pub mod run_options; #[cfg(feature = "std")] pub use self::r#async::InferenceFut; #[cfg(feature = "std")] -use self::r#async::{AsyncInferenceContext, InferenceFutInner, RunOptionsRef}; -use self::builder::SessionBuilder; +use self::r#async::{AsyncInferenceContext, InferenceFutInner}; +use self::{builder::SessionBuilder, run_options::UntypedRunOptions}; pub use self::{ input::{SessionInputValue, SessionInputs}, output::SessionOutputs, @@ -202,17 +202,11 @@ impl Session { /// ``` pub fn run<'s, 'i, 'v: 'i, const N: usize>(&'s mut self, input_values: impl Into>) -> Result> { match input_values.into() { - SessionInputs::ValueSlice(input_values) => { - self.run_inner::(&self.inputs.iter().map(|input| input.name.as_str()).collect::>(), input_values.iter(), None) - } - SessionInputs::ValueArray(input_values) => { - self.run_inner::(&self.inputs.iter().map(|input| input.name.as_str()).collect::>(), input_values.iter(), None) + SessionInputs::ValueSlice(input_values) => self.run_inner(&mut self.inputs.iter().map(|input| input.name.as_str()), &mut input_values.iter(), None), + SessionInputs::ValueArray(input_values) => self.run_inner(&mut self.inputs.iter().map(|input| input.name.as_str()), &mut input_values.iter(), None), + SessionInputs::ValueMap(input_values) => { + self.run_inner(&mut input_values.iter().map(|(k, _)| k.as_ref()), &mut input_values.iter().map(|(_, v)| v), None) } - SessionInputs::ValueMap(input_values) => self.run_inner::( - &input_values.iter().map(|(k, _)| k.as_ref()).collect::>(), - input_values.iter().map(|(_, v)| v), - None - ) } } @@ -249,28 +243,40 @@ impl Session { ) -> Result> { match input_values.into() { SessionInputs::ValueSlice(input_values) => { - self.run_inner(&self.inputs.iter().map(|input| input.name.as_str()).collect::>(), input_values.iter(), Some(run_options)) + self.run_inner(&mut self.inputs.iter().map(|input| input.name.as_str()), &mut input_values.iter(), Some(&run_options.inner)) } SessionInputs::ValueArray(input_values) => { - self.run_inner(&self.inputs.iter().map(|input| input.name.as_str()).collect::>(), input_values.iter(), Some(run_options)) + self.run_inner(&mut self.inputs.iter().map(|input| input.name.as_str()), &mut input_values.iter(), Some(&run_options.inner)) } SessionInputs::ValueMap(input_values) => { - self.run_inner(&input_values.iter().map(|(k, _)| k.as_ref()).collect::>(), input_values.iter().map(|(_, v)| v), Some(run_options)) + self.run_inner(&mut input_values.iter().map(|(k, _)| k.as_ref()), &mut input_values.iter().map(|(_, v)| v), Some(&run_options.inner)) } } } - fn run_inner<'i, 'r, 's: 'r, 'v: 'i, O: SelectedOutputMarker>( + fn run_inner<'i, 'r, 's: 'r, 'v: 'i>( &'s self, - input_names: &[&str], - input_values: impl Iterator>, - run_options: Option<&'r RunOptions> + input_names: &mut dyn ExactIterator<&str>, + input_values: &mut dyn ExactIterator<&'i SessionInputValue<'v>>, + run_options: Option<&'r UntypedRunOptions> ) -> Result> { - let input_names_ptr: Vec<*const c_char> = input_names - .iter() - .map(|n| CString::new(n.as_bytes()).unwrap_or_else(|_| unreachable!())) - .map(|n| n.into_raw().cast_const()) - .collect(); + if input_values.len() > input_names.len() { + // If we provide more inputs than the model expects with `ort::inputs![a, b, c]`, then we get an `input_names` shorter + // than `inputs`. ONNX Runtime will attempt to look up the name of all inputs before doing any checks, thus going out of + // bounds of `input_names` and triggering a segfault, so we check that condition here. This will never trip for + // `ValueMap` inputs since the number of names & values are always equal as its a vec of tuples. + return Err(Error::new_with_code( + ErrorCode::InvalidArgument, + format!("{} inputs were provided, but the model only accepts {}.", input_values.len(), input_names.len()) + )); + } + + let mut input_name_ptrs = Vec::with_capacity(input_names.len()); + #[allow(clippy::while_let_on_iterator)] + while let Some(name) = input_names.next() { + let name = CString::new(name.as_bytes())?; + input_name_ptrs.push(name.into_raw().cast_const()); + } let (output_names, mut output_tensors) = match run_options { Some(r) => r.outputs.resolve_outputs(&self.outputs), @@ -289,26 +295,19 @@ impl Session { }) .collect(); - // The C API expects pointers for the arrays (pointers to C-arrays) - let input_ort_values: Vec<*const ort_sys::OrtValue> = input_values.map(|input_array_ort| input_array_ort.ptr()).collect(); - if input_ort_values.len() > input_names.len() { - // If we provide more inputs than the model expects with `ort::inputs![a, b, c]`, then we get an `input_names` shorter - // than `inputs`. ONNX Runtime will attempt to look up the name of all inputs before doing any checks, thus going out of - // bounds of `input_names` and triggering a segfault, so we check that condition here. This will never trip for - // `ValueMap` inputs since the number of names & values are always equal as its a vec of tuples. - return Err(Error::new_with_code( - ErrorCode::InvalidArgument, - format!("{} inputs were provided, but the model only accepts {}.", input_ort_values.len(), input_names.len()) - )); + let mut input_ort_values = Vec::with_capacity(input_values.len()); + #[allow(clippy::while_let_on_iterator)] + while let Some(input) = input_values.next() { + input_ort_values.push(input.ptr()); } - let run_options_ptr = if let Some(run_options) = &run_options { run_options.ptr() } else { ptr::null() }; + let run_options_ptr = if let Some(run_options) = &run_options { run_options.ptr.as_ptr() } else { ptr::null() }; ortsys![ unsafe Run( self.inner.session_ptr.as_ptr(), run_options_ptr, - input_names_ptr.as_ptr(), + input_name_ptrs.as_ptr(), input_ort_values.as_ptr(), input_ort_values.len(), output_names_ptr.as_ptr(), @@ -332,7 +331,7 @@ impl Session { .collect(); // Reconvert name ptrs to CString so drop impl is called and memory is freed - for p in input_names_ptr.into_iter().chain(output_names_ptr.into_iter()) { + for p in input_name_ptrs.into_iter().chain(output_names_ptr.into_iter()) { drop(unsafe { CString::from_raw(p.cast_mut().cast()) }); } @@ -359,8 +358,6 @@ impl Session { let run_options_ptr = if let Some(run_options) = run_options { run_options.ptr() } else { ptr::null() }; ortsys![unsafe RunWithBinding(self.inner.ptr().cast_mut(), run_options_ptr, binding.ptr())?]; - // let owned_ptrs: HashMap<*mut ort_sys::OrtValue, &Value> = self.output_values.values().map(|c| (c.ptr().cast_mut(), - // c)).collect(); let mut count = binding.output_values.len(); if count > 0 { let mut output_values_ptr: *mut *mut ort_sys::OrtValue = ptr::null_mut(); @@ -406,94 +403,75 @@ impl Session { /// # fn main() -> ort::Result<()> { tokio_test::block_on(async { /// let mut session = Session::builder()?.with_intra_threads(2)?.commit_from_file("tests/data/upsample.onnx")?; /// let input = ndarray::Array4::::zeros((1, 64, 64, 3)); - /// let outputs = session.run_async(ort::inputs![TensorRef::from_array_view(&input)?])?.await?; + /// let options = RunOptions::new()?; + /// let outputs = session.run_async(ort::inputs![TensorRef::from_array_view(&input)?], &options)?.await?; /// # Ok(()) /// # }) } /// ``` #[cfg(feature = "std")] #[cfg_attr(docsrs, doc(cfg(feature = "std")))] - pub fn run_async<'s, 'i, 'v: 'i + 's, const N: usize>( - &'s mut self, - input_values: impl Into> - ) -> Result> { - match input_values.into() { - SessionInputs::ValueSlice(_) => unimplemented!("slices cannot be used in `run_async`"), - SessionInputs::ValueArray(input_values) => { - self.run_inner_async(&self.inputs.iter().map(|input| input.name.to_string()).collect::>(), input_values.into_iter(), None) - } - SessionInputs::ValueMap(input_values) => { - self.run_inner_async(&input_values.iter().map(|(k, _)| k.to_string()).collect::>(), input_values.into_iter().map(|(_, v)| v), None) - } - } - } - - /// Asynchronously run input data through the ONNX graph, performing inference, with the given [`RunOptions`]. - /// See [`Session::run_with_options`] and [`Session::run_async`] for more details. - #[cfg(feature = "std")] - #[cfg_attr(docsrs, doc(cfg(feature = "std")))] - pub fn run_async_with_options<'s, 'i, 'v: 'i + 's, 'r, O: SelectedOutputMarker, const N: usize>( + pub fn run_async<'r, 's: 'r, 'i, 'v: 'i + 's, O: SelectedOutputMarker, const N: usize>( &'s mut self, input_values: impl Into>, run_options: &'r RunOptions - ) -> Result> { + ) -> Result> { match input_values.into() { SessionInputs::ValueSlice(_) => unimplemented!("slices cannot be used in `run_async`"), SessionInputs::ValueArray(input_values) => { - self.run_inner_async(&self.inputs.iter().map(|input| input.name.to_string()).collect::>(), input_values.into_iter(), Some(run_options)) + self.run_inner_async(&mut self.inputs.iter().map(|input| input.name.as_str()), &mut input_values.iter(), &run_options.inner) + } + SessionInputs::ValueMap(input_values) => { + self.run_inner_async(&mut input_values.iter().map(|(k, _)| k.as_ref()), &mut input_values.iter().map(|(_, v)| v), &run_options.inner) } - SessionInputs::ValueMap(input_values) => self.run_inner_async( - &input_values.iter().map(|(k, _)| k.to_string()).collect::>(), - input_values.into_iter().map(|(_, v)| v), - Some(run_options) - ) } } #[cfg(feature = "std")] - fn run_inner_async<'s, 'v: 's, 'r, O: SelectedOutputMarker>( + fn run_inner_async<'r, 's: 'r, 'v: 's>( &'s self, - input_names: &[String], - input_values: impl Iterator>, - run_options: Option<&'r RunOptions> - ) -> Result> { - let run_options = match run_options { - Some(r) => RunOptionsRef::Ref(r), - // create a `RunOptions` to pass to the future so that when it drops, it terminates inference - crucial - // (performance-wise) for routines involving `tokio::select!` or timeouts - None => RunOptionsRef::Arc(Arc::new(unsafe { - // SAFETY: transmuting from `RunOptions` to `RunOptions`; safe because its just a marker - core::mem::transmute::, RunOptions>(RunOptions::new()?) - })) - }; + input_names: &mut dyn ExactIterator<&str>, + input_values: &mut dyn ExactIterator<&SessionInputValue<'v>>, + run_options: &'r UntypedRunOptions + ) -> Result> { + let mut input_name_ptrs = Vec::with_capacity(input_names.len()); + #[allow(clippy::while_let_on_iterator)] + while let Some(name) = input_names.next() { + let name = CString::new(name.as_bytes())?; + input_name_ptrs.push(name.into_raw().cast_const()); + } + let mut input_inner_holders = Vec::with_capacity(input_values.len()); + let mut input_ort_values = Vec::with_capacity(input_values.len()); + #[allow(clippy::while_let_on_iterator)] + while let Some(input) = input_values.next() { + input_ort_values.push(input.ptr()); + input_inner_holders.push(Arc::clone(input.inner())); + } - let input_name_ptrs: Vec<*const c_char> = input_names + let (output_names, mut output_tensors) = run_options.outputs.resolve_outputs(&self.outputs); + let output_name_ptrs: Vec<*const c_char> = output_names .iter() - .map(|n| CString::new(n.as_bytes()).unwrap_or_else(|_| unreachable!())) + .map(|n| CString::new(*n).unwrap_or_else(|_| unreachable!())) .map(|n| n.into_raw().cast_const()) .collect(); - let output_name_ptrs: Vec<*const c_char> = self - .outputs - .iter() - .map(|output| CString::new(output.name.as_str()).unwrap_or_else(|_| unreachable!())) - .map(|n| n.into_raw().cast_const()) + let output_tensor_ptrs: Vec<*mut ort_sys::OrtValue> = output_tensors + .iter_mut() + .map(|c| match c { + Some(v) => v.ptr_mut(), + None => ptr::null_mut() + }) .collect(); - let output_tensor_ptrs: Vec<*mut ort_sys::OrtValue> = vec![ptr::null_mut(); self.outputs.len()]; - - let input_values: Vec<_> = input_values.collect(); - let input_ort_values: Vec<*const ort_sys::OrtValue> = input_values.iter().map(|input_array_ort| input_array_ort.ptr()).collect(); - let async_inner = Arc::new(InferenceFutInner::new()); let ctx = Box::leak(Box::new(AsyncInferenceContext { inner: Arc::clone(&async_inner), - _input_values: input_values, // everything allocated within `run_inner_async` needs to be kept alive until we are certain inference has completed and ONNX Runtime no longer // needs the data - i.e. when `async_callback` is called. `async_callback` will free all of this data just like we do in `run_inner` input_ort_values, + _input_inner_holders: input_inner_holders, input_name_ptrs, output_name_ptrs, - output_names: self.outputs.iter().map(|o| o.name.as_str()).collect::>(), + output_names, output_value_ptrs: output_tensor_ptrs, session_inner: &self.inner })); @@ -501,7 +479,7 @@ impl Session { ortsys![ unsafe RunAsync( self.inner.session_ptr.as_ptr(), - run_options.ptr(), + run_options.ptr.as_ptr(), ctx.input_name_ptrs.as_ptr(), ctx.input_ort_values.as_ptr(), ctx.input_ort_values.len(), @@ -536,8 +514,8 @@ impl Session { pub fn end_profiling(&mut self) -> Result { let mut profiling_name: *mut c_char = ptr::null_mut(); - ortsys![unsafe SessionEndProfiling(self.inner.session_ptr.as_ptr(), self.inner.allocator.ptr().cast_mut(), &mut profiling_name)?]; - assert_non_null_pointer(profiling_name, "ProfilingName")?; + ortsys![unsafe SessionEndProfiling(self.inner.session_ptr.as_ptr(), self.inner.allocator.ptr().cast_mut(), &mut profiling_name)?; nonNull(profiling_name)]; + dangerous::raw_pointer_to_string(&self.inner.allocator, profiling_name) } @@ -570,6 +548,10 @@ impl Session { } } +trait ExactIterator: Iterator + ExactSizeIterator {} + +impl ExactIterator for I where I: Iterator + ExactSizeIterator {} + /// Workload type, used to signal to execution providers whether to prioritize performance or efficiency. /// /// See [`Session::set_workload_type`]. @@ -663,13 +645,16 @@ mod dangerous { allocator: &Allocator, i: usize ) -> Result { - let mut name_bytes: *mut c_char = ptr::null_mut(); + let mut name_ptr: *mut c_char = ptr::null_mut(); - let status = unsafe { f(session_ptr.as_ptr(), i, allocator.ptr().cast_mut(), &mut name_bytes) }; + let status = unsafe { f(session_ptr.as_ptr(), i, allocator.ptr().cast_mut(), &mut name_ptr) }; unsafe { status_to_result(status) }?; - assert_non_null_pointer(name_bytes, "InputName")?; + if name_ptr.is_null() { + crate::util::cold(); + return Err(crate::Error::new(concat!("expected `name_ptr` to not be null"))); + } - raw_pointer_to_string(allocator, name_bytes) + raw_pointer_to_string(allocator, name_ptr) } pub(super) fn extract_input(session_ptr: NonNull, allocator: &Allocator, i: usize) -> Result { @@ -695,7 +680,10 @@ mod dangerous { let status = unsafe { f(session_ptr.as_ptr(), i, &mut typeinfo_ptr) }; unsafe { status_to_result(status) }?; - assert_non_null_pointer(typeinfo_ptr, "TypeInfo")?; + if typeinfo_ptr.is_null() { + crate::util::cold(); + return Err(crate::Error::new(concat!("expected `typeinfo_ptr` to not be null"))); + } Ok(ValueType::from_type_info(typeinfo_ptr)) } diff --git a/src/session/run_options.rs b/src/session/run_options.rs index ca722487..34ae2b46 100644 --- a/src/session/run_options.rs +++ b/src/session/run_options.rs @@ -137,6 +137,23 @@ impl SelectedOutputMarker for NoSelectedOutputs {} pub struct HasSelectedOutputs; impl SelectedOutputMarker for HasSelectedOutputs {} +#[derive(Debug)] +pub(crate) struct UntypedRunOptions { + pub(crate) ptr: NonNull, + pub(crate) outputs: OutputSelector, + adapters: Vec> +} + +impl UntypedRunOptions { + pub fn terminate(&self) -> Result<()> { + ortsys![unsafe RunOptionsSetTerminate(self.ptr.as_ptr())?]; + Ok(()) + } +} + +// https://onnxruntime.ai/docs/api/c/struct_ort_api.html#ac2a08cac0a657604bd5899e0d1a13675 +unsafe impl Send for UntypedRunOptions {} + /// Allows for finer control over session inference. /// /// [`RunOptions`] provides three main features: @@ -163,14 +180,10 @@ impl SelectedOutputMarker for HasSelectedOutputs {} /// [`IoBinding::run_with_options`]: crate::io_binding::IoBinding::run_with_options #[derive(Debug)] pub struct RunOptions { - run_options_ptr: NonNull, - pub(crate) outputs: OutputSelector, - adapters: Vec>, + pub(crate) inner: UntypedRunOptions, _marker: PhantomData } -// https://onnxruntime.ai/docs/api/c/struct_ort_api.html#ac2a08cac0a657604bd5899e0d1a13675 -unsafe impl Send for RunOptions {} // Only allow `Sync` if we don't have (potentially pre-allocated) outputs selected. // Allowing `Sync` here would mean a single pre-allocated `Value` could be mutated simultaneously in different threads - // a brazen crime against crabkind. @@ -182,9 +195,11 @@ impl RunOptions { let mut run_options_ptr: *mut ort_sys::OrtRunOptions = ptr::null_mut(); ortsys![unsafe CreateRunOptions(&mut run_options_ptr)?; nonNull(run_options_ptr)]; Ok(RunOptions { - run_options_ptr: unsafe { NonNull::new_unchecked(run_options_ptr) }, - outputs: OutputSelector::default(), - adapters: Vec::new(), + inner: UntypedRunOptions { + ptr: unsafe { NonNull::new_unchecked(run_options_ptr) }, + outputs: OutputSelector::default(), + adapters: Vec::new() + }, _marker: PhantomData }) } @@ -218,7 +233,7 @@ impl RunOptions { /// # } /// ``` pub fn with_outputs(mut self, outputs: OutputSelector) -> RunOptions { - self.outputs = outputs; + self.inner.outputs = outputs; unsafe { mem::transmute(self) } } @@ -230,13 +245,13 @@ impl RunOptions { /// Sets a tag to identify this run in logs. pub fn set_tag(&mut self, tag: impl AsRef) -> Result<()> { let tag = CString::new(tag.as_ref())?; - ortsys![unsafe RunOptionsSetRunTag(self.run_options_ptr.as_ptr(), tag.as_ptr())?]; + ortsys![unsafe RunOptionsSetRunTag(self.inner.ptr.as_ptr(), tag.as_ptr())?]; Ok(()) } pub fn tag(&self) -> Result { let mut tag_ptr: *const c_char = ptr::null(); - ortsys![unsafe RunOptionsGetRunTag(self.run_options_ptr.as_ptr(), &mut tag_ptr)?]; + ortsys![unsafe RunOptionsGetRunTag(self.inner.ptr.as_ptr(), &mut tag_ptr)?]; if tag_ptr.is_null() { Ok(String::default()) } else { @@ -273,8 +288,7 @@ impl RunOptions { /// # } /// ``` pub fn terminate(&self) -> Result<()> { - ortsys![unsafe RunOptionsSetTerminate(self.run_options_ptr.as_ptr())?]; - Ok(()) + self.inner.terminate() } /// Resets the termination flag for the runs associated with [`RunOptions`]. @@ -300,7 +314,7 @@ impl RunOptions { /// # } /// ``` pub fn unterminate(&self) -> Result<()> { - ortsys![unsafe RunOptionsUnsetTerminate(self.run_options_ptr.as_ptr())?]; + ortsys![unsafe RunOptionsUnsetTerminate(self.inner.ptr.as_ptr())?]; Ok(()) } @@ -320,13 +334,13 @@ impl RunOptions { pub fn add_config_entry(&mut self, key: impl AsRef, value: impl AsRef) -> Result<()> { let key = CString::new(key.as_ref())?; let value = CString::new(value.as_ref())?; - ortsys![unsafe AddRunConfigEntry(self.run_options_ptr.as_ptr(), key.as_ptr(), value.as_ptr())?]; + ortsys![unsafe AddRunConfigEntry(self.inner.ptr.as_ptr(), key.as_ptr(), value.as_ptr())?]; Ok(()) } pub fn add_adapter(&mut self, adapter: &Adapter) -> Result<()> { - ortsys![unsafe RunOptionsAddActiveLoraAdapter(self.run_options_ptr.as_ptr(), adapter.ptr())?]; - self.adapters.push(Arc::clone(&adapter.inner)); + ortsys![unsafe RunOptionsAddActiveLoraAdapter(self.inner.ptr.as_ptr(), adapter.ptr())?]; + self.inner.adapters.push(Arc::clone(&adapter.inner)); Ok(()) } } @@ -335,12 +349,12 @@ impl AsPointer for RunOptions { type Sys = ort_sys::OrtRunOptions; fn ptr(&self) -> *const Self::Sys { - self.run_options_ptr.as_ptr() + self.inner.ptr.as_ptr() } } impl Drop for RunOptions { fn drop(&mut self) { - ortsys![unsafe ReleaseRunOptions(self.run_options_ptr.as_ptr())]; + ortsys![unsafe ReleaseRunOptions(self.inner.ptr.as_ptr())]; } } diff --git a/src/tensor/mod.rs b/src/tensor/mod.rs index fcf2a0e9..e179f51a 100644 --- a/src/tensor/mod.rs +++ b/src/tensor/mod.rs @@ -7,5 +7,3 @@ mod types; #[cfg(feature = "ndarray")] pub use self::ndarray::ArrayExtensions; pub use self::types::{IntoTensorElementType, PrimitiveTensorElementType, TensorElementType, Utf8Data}; -#[cfg(feature = "ndarray")] -pub(crate) use self::types::{extract_primitive_array, extract_primitive_array_mut}; diff --git a/src/tensor/types.rs b/src/tensor/types.rs index 73b8e6bc..f562a105 100644 --- a/src/tensor/types.rs +++ b/src/tensor/types.rs @@ -1,10 +1,5 @@ use alloc::string::String; use core::fmt; -#[cfg(feature = "ndarray")] -use core::{ffi::c_void, ptr}; - -#[cfg(feature = "ndarray")] -use crate::{error::Result, ortsys}; /// Enum mapping ONNX Runtime's supported tensor data types. #[derive(Debug, PartialEq, Eq, Clone, Copy)] @@ -239,35 +234,3 @@ impl Utf8Data for &str { self.as_bytes() } } - -/// Construct an [`ndarray::ArrayView`] for an ORT tensor. -/// -/// Only to be used on types whose Rust in-memory representation matches ONNX Runtime's (e.g. primitive numeric types -/// like u32) -#[cfg(feature = "ndarray")] -pub(crate) fn extract_primitive_array<'t, T>(shape: ndarray::IxDyn, tensor: *const ort_sys::OrtValue) -> Result> { - // Get pointer to output tensor values - let mut output_array_ptr: *mut T = ptr::null_mut(); - let output_array_ptr_ptr: *mut *mut T = &mut output_array_ptr; - let output_array_ptr_ptr_void: *mut *mut c_void = output_array_ptr_ptr.cast(); - ortsys![unsafe GetTensorMutableData(tensor.cast_mut(), output_array_ptr_ptr_void)?; nonNull(output_array_ptr)]; - - let array_view = unsafe { ndarray::ArrayView::from_shape_ptr(shape, output_array_ptr) }; - Ok(array_view) -} - -/// Construct an [`ndarray::ArrayViewMut`] for an ORT tensor. -/// -/// Only to be used on types whose Rust in-memory representation matches ONNX Runtime's (e.g. primitive numeric types -/// like u32) -#[cfg(feature = "ndarray")] -pub(crate) fn extract_primitive_array_mut<'t, T>(shape: ndarray::IxDyn, tensor: *mut ort_sys::OrtValue) -> Result> { - // Get pointer to output tensor values - let mut output_array_ptr: *mut T = ptr::null_mut(); - let output_array_ptr_ptr: *mut *mut T = &mut output_array_ptr; - let output_array_ptr_ptr_void: *mut *mut c_void = output_array_ptr_ptr.cast(); - ortsys![unsafe GetTensorMutableData(tensor, output_array_ptr_ptr_void)?; nonNull(output_array_ptr)]; - - let array_view = unsafe { ndarray::ArrayViewMut::from_shape_ptr(shape, output_array_ptr) }; - Ok(array_view) -} diff --git a/src/training/mod.rs b/src/training/mod.rs index 4eccca5f..5505e09b 100644 --- a/src/training/mod.rs +++ b/src/training/mod.rs @@ -74,7 +74,13 @@ macro_rules! trainsys { }; (unsafe $method:ident($($n:expr),+ $(,)?); nonNull($($check:expr),+ $(,)?)$(;)?) => {{ let _x = unsafe { ($crate::training::training_api().unwrap().$method)($($n),+) }; - $($crate::error::assert_non_null_pointer($check, stringify!($method)).unwrap();)+ + $( + // TODO: #[cfg(debug_assertions)]? + if ($check).is_null() { + $crate::util::cold(); + panic!(concat!("expected `", stringify!($check), "` to not be null")); + } + )+ _x }}; (unsafe $method:ident($($n:expr),+ $(,)?)?) => { @@ -82,7 +88,13 @@ macro_rules! trainsys { }; (unsafe $method:ident($($n:expr),+ $(,)?)?; nonNull($($check:expr),+ $(,)?)$(;)?) => {{ unsafe { $crate::error::status_to_result(($crate::training::training_api()?.$method)($($n),+)) }?; - $($crate::error::assert_non_null_pointer($check, stringify!($method))?;)+ + $( + // TODO: #[cfg(debug_assertions)]? + if ($check).is_null() { + $crate::util::cold(); + return Err($crate::Error::new(concat!("expected `", stringify!($check), "` to not be null"))); + } + )+ }}; } pub(crate) use trainsys; diff --git a/src/training/trainer.rs b/src/training/trainer.rs index 2ee37ba1..06e425b6 100644 --- a/src/training/trainer.rs +++ b/src/training/trainer.rs @@ -10,7 +10,7 @@ use ort_sys::c_char; use super::{Checkpoint, Optimizer, trainsys}; use crate::{ AsPointer, char_p_to_string, - error::{Result, assert_non_null_pointer, status_to_result}, + error::{Result, status_to_result}, memory::Allocator, session::{RunOptions, SessionInputValue, SessionInputs, SessionOutputs, builder::SessionBuilder}, tensor::IntoTensorElementType, @@ -263,11 +263,9 @@ impl Trainer { drop( output_names_ptr .into_iter() - .map(|p| { - assert_non_null_pointer(p, "c_char for CString")?; - unsafe { Ok(CString::from_raw(p.cast_mut().cast())) } - }) - .collect::>>()? + // SAFETY: `str` will never have a null pointer + .map(|p| unsafe { CString::from_raw(p as *mut _) }) + .collect::>() ); unsafe { status_to_result(res) }?; diff --git a/src/util.rs b/src/util.rs index a4da59c4..cfcad522 100644 --- a/src/util.rs +++ b/src/util.rs @@ -15,7 +15,7 @@ type OsCharArray = Vec; type OsCharArray = Vec; #[cfg(feature = "std")] -pub fn path_to_os_char(path: impl AsRef) -> OsCharArray { +pub(crate) fn path_to_os_char(path: impl AsRef) -> OsCharArray { #[cfg(not(target_family = "windows"))] use core::ffi::c_char; #[cfg(unix)] @@ -38,7 +38,7 @@ pub fn path_to_os_char(path: impl AsRef) -> OsCharArray { // generally as performant or faster than HashMap for <50 items. good enough for #[no_std] #[derive(Clone, PartialEq, Eq)] -pub struct MiniMap { +pub(crate) struct MiniMap { values: Vec<(K, V)> } @@ -100,7 +100,7 @@ impl fmt::Debug for MiniMap { } } -pub struct OnceLock { +pub(crate) struct OnceLock { data: UnsafeCell>, #[cfg(not(feature = "std"))] status: core::sync::atomic::AtomicU8, @@ -275,3 +275,19 @@ impl Drop for OnceLock { } } } + +#[cold] +#[inline] +#[doc(hidden)] +pub fn cold() {} + +pub fn element_count(shape: &[i64]) -> usize { + let mut size = 1usize; + for dim in shape { + if *dim < 0 { + return 0; + } + size *= *dim as usize; + } + size +} diff --git a/src/value/impl_map.rs b/src/value/impl_map.rs index e0a3d8eb..1dccab2e 100644 --- a/src/value/impl_map.rs +++ b/src/value/impl_map.rs @@ -1,14 +1,7 @@ -use alloc::{ - boxed::Box, - format, - string::{String, ToString}, - sync::Arc, - vec, - vec::Vec -}; +use alloc::{boxed::Box, format, string::String, sync::Arc, vec, vec::Vec}; use core::{ ffi::c_void, - fmt::Debug, + fmt::{self, Debug}, hash::Hash, marker::PhantomData, mem, @@ -20,14 +13,15 @@ use std::collections::HashMap; use super::{ DowncastableTarget, DynValue, Value, ValueInner, ValueRef, ValueRefMut, ValueType, ValueTypeMarker, - impl_tensor::{DynTensor, Tensor, calculate_tensor_size} + impl_tensor::{DynTensor, Tensor} }; use crate::{ AsPointer, ErrorCode, error::{Error, Result}, memory::Allocator, ortsys, - tensor::{IntoTensorElementType, PrimitiveTensorElementType, TensorElementType} + tensor::{IntoTensorElementType, PrimitiveTensorElementType, TensorElementType}, + util::element_count }; pub trait MapValueTypeMarker: ValueTypeMarker { @@ -37,8 +31,8 @@ pub trait MapValueTypeMarker: ValueTypeMarker { #[derive(Debug)] pub struct DynMapValueType; impl ValueTypeMarker for DynMapValueType { - fn format() -> String { - "DynMap".to_string() + fn fmt(f: &mut fmt::Formatter) -> fmt::Result { + f.write_str("DynMap") } private_impl!(); @@ -58,8 +52,12 @@ impl DowncastableTarget for DynMapValueType { #[derive(Debug)] pub struct MapValueType(PhantomData<(K, V)>); impl ValueTypeMarker for MapValueType { - fn format() -> String { - format!("Map<{}, {}>", K::into_tensor_element_type(), V::into_tensor_element_type()) + fn fmt(f: &mut fmt::Formatter) -> fmt::Result { + f.write_str("Map<")?; + ::fmt(&K::into_tensor_element_type(), f)?; + f.write_str(", ")?; + ::fmt(&V::into_tensor_element_type(), f)?; + f.write_str(">") } private_impl!(); @@ -126,7 +124,7 @@ impl Value { let output_array_ptr_ptr_void: *mut *mut c_void = output_array_ptr_ptr.cast(); ortsys![unsafe GetTensorMutableData(key_tensor_ptr, output_array_ptr_ptr_void)?; nonNull(output_array_ptr)]; - let len = calculate_tensor_size(dimensions); + let len = element_count(dimensions); (dimensions, unsafe { slice::from_raw_parts(output_array_ptr, len) }) } else { return Err(Error::new_with_code( @@ -299,7 +297,7 @@ impl`] to a type-erased [`DynMap`]. #[inline] pub fn upcast(self) -> DynMap { - unsafe { mem::transmute(self) } + unsafe { self.transmute_type() } } /// Converts from a strongly-typed [`Map`] to a reference to a type-erased [`DynMap`]. diff --git a/src/value/impl_sequence.rs b/src/value/impl_sequence.rs index badd1a0c..03a7651a 100644 --- a/src/value/impl_sequence.rs +++ b/src/value/impl_sequence.rs @@ -1,18 +1,11 @@ -use alloc::{ - boxed::Box, - format, - string::{String, ToString}, - sync::Arc, - vec::Vec -}; +use alloc::{boxed::Box, format, sync::Arc, vec::Vec}; use core::{ - fmt::Debug, + fmt::{self, Debug, Display}, marker::PhantomData, - mem, ptr::{self, NonNull} }; -use super::{DowncastableTarget, Value, ValueInner, ValueRef, ValueRefMut, ValueType, ValueTypeMarker}; +use super::{DowncastableTarget, Value, ValueInner, ValueRef, ValueRefMut, ValueType, ValueTypeMarker, format_value_type}; use crate::{ AsPointer, ErrorCode, error::{Error, Result}, @@ -27,8 +20,8 @@ pub trait SequenceValueTypeMarker: ValueTypeMarker { #[derive(Debug)] pub struct DynSequenceValueType; impl ValueTypeMarker for DynSequenceValueType { - fn format() -> String { - "DynSequence".to_string() + fn fmt(f: &mut fmt::Formatter) -> fmt::Result { + f.write_str("DynSequence") } private_impl!(); @@ -48,8 +41,10 @@ impl DowncastableTarget for DynSequenceValueType { #[derive(Debug)] pub struct SequenceValueType(PhantomData); impl ValueTypeMarker for SequenceValueType { - fn format() -> String { - format!("Sequence<{}>", T::format()) + fn fmt(f: &mut fmt::Formatter) -> fmt::Result { + f.write_str("Sequence<")?; + format_value_type::().fmt(f)?; + f.write_str(">") } private_impl!(); @@ -99,7 +94,7 @@ impl Value { if !OtherType::can_downcast(value.dtype()) { return Err(Error::new_with_code( ErrorCode::InvalidArgument, - format!("Cannot extract Sequence<{}> from {value_type:?}", OtherType::format()) + format!("Cannot extract Sequence<{}> from {value_type:?}", format_value_type::()) )); } @@ -107,7 +102,7 @@ impl Value { } Ok(vec) } - t => Err(Error::new(format!("Cannot extract Sequence<{}> from {t}", OtherType::format()))) + t => Err(Error::new(format!("Cannot extract Sequence<{}> from {t}", format_value_type::()))) } } } @@ -165,7 +160,7 @@ impl Value`] to a type-erased [`DynSequence`]. #[inline] pub fn upcast(self) -> DynSequence { - unsafe { mem::transmute(self) } + unsafe { self.transmute_type() } } /// Converts from a strongly-typed [`Sequence`] to a reference to a type-erased [`DynSequence`]. diff --git a/src/value/impl_tensor/create.rs b/src/value/impl_tensor/create.rs index 3af448e7..67c4b146 100644 --- a/src/value/impl_tensor/create.rs +++ b/src/value/impl_tensor/create.rs @@ -11,13 +11,14 @@ use core::{ #[cfg(feature = "ndarray")] use ndarray::{ArcArray, Array, ArrayView, ArrayViewMut, CowArray, Dimension}; -use super::{DynTensor, Tensor, TensorRef, TensorRefMut, calculate_tensor_size}; +use super::{DynTensor, Tensor, TensorRef, TensorRefMut}; use crate::{ AsPointer, - error::{Error, ErrorCode, Result, assert_non_null_pointer}, - memory::{AllocationDevice, Allocator, AllocatorType, MemoryInfo, MemoryType}, + error::{Error, ErrorCode, Result}, + memory::{Allocator, MemoryInfo}, ortsys, tensor::{PrimitiveTensorElementType, TensorElementType, Utf8Data}, + util::element_count, value::{Value, ValueInner, ValueType} }; @@ -110,7 +111,7 @@ impl Tensor { /// ``` pub fn new(allocator: &Allocator, shape: impl ToDimensions) -> Result> { let tensor = DynTensor::new(allocator, T::into_tensor_element_type(), shape)?; - Ok(unsafe { core::mem::transmute::>(tensor) }) + Ok(unsafe { tensor.transmute_type() }) } /// Construct an owned tensor from an array of data. @@ -140,48 +141,52 @@ impl Tensor { /// /// Creating string tensors requires a separate method; see [`Tensor::from_string_array`]. pub fn from_array(input: impl OwnedTensorArrayData) -> Result> { - let memory_info = MemoryInfo::new(AllocationDevice::CPU, 0, AllocatorType::Arena, MemoryType::CPUInput)?; - - let mut value_ptr: *mut ort_sys::OrtValue = ptr::null_mut(); - - // f16 and bf16 are repr(transparent) to u16, so memory layout should be identical to onnxruntime let TensorArrayDataParts { shape, ptr, num_elements, guard } = input.into_parts()?; - let shape_ptr: *const i64 = shape.as_ptr(); - let shape_len = shape.len(); - - let tensor_values_ptr: *mut c_void = ptr.cast(); - assert_non_null_pointer(tensor_values_ptr, "TensorValues")?; - - ortsys![ - unsafe CreateTensorWithDataAsOrtValue( - memory_info.ptr(), - tensor_values_ptr, - num_elements * size_of::(), - shape_ptr, - shape_len, - T::into_tensor_element_type().into(), - &mut value_ptr - )?; - nonNull(value_ptr) - ]; - - Ok(Value { - inner: Arc::new(ValueInner { - ptr: unsafe { NonNull::new_unchecked(value_ptr) }, - dtype: ValueType::Tensor { - ty: T::into_tensor_element_type(), - dimensions: shape, - dimension_symbols: vec![None; shape_len] - }, - drop: true, - memory_info: Some(memory_info), - _backing: Some(guard) - }), - _markers: PhantomData - }) + tensor_from_array(MemoryInfo::default(), shape, ptr.as_ptr().cast(), num_elements, size_of::(), T::into_tensor_element_type(), guard) + .map(|tensor| unsafe { tensor.transmute_type() }) } } +fn tensor_from_array( + memory_info: MemoryInfo, + shape: Vec, + data: *mut c_void, + num_elements: usize, + element_size: usize, + element_type: TensorElementType, + guard: Option> +) -> Result { + let mut value_ptr: *mut ort_sys::OrtValue = ptr::null_mut(); + + ortsys![ + unsafe CreateTensorWithDataAsOrtValue( + memory_info.ptr(), + data, + num_elements * element_size, + shape.as_ptr(), + shape.len(), + element_type.into(), + &mut value_ptr + )?; + nonNull(value_ptr) + ]; + + Ok(DynTensor { + inner: Arc::new(ValueInner { + ptr: unsafe { NonNull::new_unchecked(value_ptr) }, + dtype: ValueType::Tensor { + ty: element_type, + dimension_symbols: vec![None; shape.len()], + dimensions: shape + }, + drop: true, + memory_info: Some(memory_info), + _backing: guard + }), + _markers: PhantomData + }) +} + impl<'a, T: PrimitiveTensorElementType + Debug> TensorRef<'a, T> { /// Construct a tensor from borrowed data. /// @@ -213,48 +218,16 @@ impl<'a, T: PrimitiveTensorElementType + Debug> TensorRef<'a, T> { /// When passing an [`ndarray`] type, the data **must** have a contiguous memory layout, or else an error will be /// returned. See [`ndarray::ArrayBase::as_standard_layout`] to convert an array to a contiguous layout. pub fn from_array_view(input: impl TensorArrayData + 'a) -> Result> { - let memory_info = MemoryInfo::new(AllocationDevice::CPU, 0, AllocatorType::Arena, MemoryType::CPUInput)?; - - let mut value_ptr: *mut ort_sys::OrtValue = ptr::null_mut(); - - // f16 and bf16 are repr(transparent) to u16, so memory layout should be identical to onnxruntime let (shape, data, guard) = input.ref_parts()?; - let num_elements = calculate_tensor_size(&shape); - let shape_ptr: *const i64 = shape.as_ptr(); - let shape_len = shape.len(); - - let tensor_values_ptr: *mut c_void = data.as_ptr() as *mut _; - assert_non_null_pointer(tensor_values_ptr, "TensorValues")?; - - ortsys![ - unsafe CreateTensorWithDataAsOrtValue( - memory_info.ptr(), - tensor_values_ptr, - num_elements * size_of::(), - shape_ptr, - shape_len, - T::into_tensor_element_type().into(), - &mut value_ptr - )?; - nonNull(value_ptr) - ]; + let num_elements = element_count(&shape); - let mut tensor = TensorRef::new(Tensor { - inner: Arc::new(ValueInner { - ptr: unsafe { NonNull::new_unchecked(value_ptr) }, - dtype: ValueType::Tensor { - ty: T::into_tensor_element_type(), - dimensions: shape, - dimension_symbols: vec![None; shape_len] - }, - drop: true, - memory_info: Some(memory_info), - _backing: guard - }), - _markers: PhantomData - }); - tensor.upgradable = false; - Ok(tensor) + tensor_from_array(MemoryInfo::default(), shape, data.as_ptr() as *mut _, num_elements, size_of::(), T::into_tensor_element_type(), guard).map( + |tensor| { + let mut tensor: TensorRef<'_, T> = TensorRef::new(unsafe { tensor.transmute_type() }); + tensor.upgradable = false; + tensor + } + ) } } @@ -288,48 +261,16 @@ impl<'a, T: PrimitiveTensorElementType + Debug> TensorRefMut<'a, T> { /// When passing an [`ndarray`] type, the data **must** have a contiguous memory layout, or else an error will be /// returned. See [`ndarray::ArrayBase::as_standard_layout`] to convert an array to a contiguous layout. pub fn from_array_view_mut(mut input: impl TensorArrayDataMut) -> Result> { - let memory_info = MemoryInfo::new(AllocationDevice::CPU, 0, AllocatorType::Arena, MemoryType::CPUInput)?; - - let mut value_ptr: *mut ort_sys::OrtValue = ptr::null_mut(); - - // f16 and bf16 are repr(transparent) to u16, so memory layout should be identical to onnxruntime let (shape, data, guard) = input.ref_parts_mut()?; - let num_elements = calculate_tensor_size(&shape); - let shape_ptr: *const i64 = shape.as_ptr(); - let shape_len = shape.len(); - - let tensor_values_ptr: *mut c_void = data.as_ptr() as *mut _; - assert_non_null_pointer(tensor_values_ptr, "TensorValues")?; + let num_elements = element_count(&shape); - ortsys![ - unsafe CreateTensorWithDataAsOrtValue( - memory_info.ptr(), - tensor_values_ptr, - num_elements * size_of::(), - shape_ptr, - shape_len, - T::into_tensor_element_type().into(), - &mut value_ptr - )?; - nonNull(value_ptr) - ]; - - let mut tensor = TensorRefMut::new(Tensor { - inner: Arc::new(ValueInner { - ptr: unsafe { NonNull::new_unchecked(value_ptr) }, - dtype: ValueType::Tensor { - ty: T::into_tensor_element_type(), - dimensions: shape, - dimension_symbols: vec![None; shape_len] - }, - drop: true, - memory_info: Some(memory_info), - _backing: guard - }), - _markers: PhantomData - }); - tensor.upgradable = false; - Ok(tensor) + tensor_from_array(MemoryInfo::default(), shape, data.as_ptr() as *mut _, num_elements, size_of::(), T::into_tensor_element_type(), guard).map( + |tensor| { + let mut tensor: TensorRefMut<'_, T> = TensorRefMut::new(unsafe { tensor.transmute_type() }); + tensor.upgradable = false; + tensor + } + ) } /// Create a mutable tensor view from a raw pointer and shape. @@ -356,43 +297,12 @@ impl<'a, T: PrimitiveTensorElementType + Debug> TensorRefMut<'a, T> { /// - The pointer must be valid for the device description provided by `MemoryInfo`. /// - The returned tensor must outlive the data described by the data pointer. pub unsafe fn from_raw(info: MemoryInfo, data: *mut ort_sys::c_void, shape: Vec) -> Result> { - let mut value_ptr: *mut ort_sys::OrtValue = ptr::null_mut(); - - // f16 and bf16 are repr(transparent) to u16, so memory layout should be identical to onnxruntime - let shape_ptr: *const i64 = shape.as_ptr(); - let shape_len = shape.len(); - - let data_len = calculate_tensor_size(&shape) * size_of::(); - - ortsys![ - unsafe CreateTensorWithDataAsOrtValue( - info.ptr(), - data, - data_len, - shape_ptr, - shape_len, - T::into_tensor_element_type().into(), - &mut value_ptr - )?; - nonNull(value_ptr) - ]; - - let mut tensor = TensorRefMut::new(Value { - inner: Arc::new(ValueInner { - ptr: unsafe { NonNull::new_unchecked(value_ptr) }, - dtype: ValueType::Tensor { - ty: T::into_tensor_element_type(), - dimensions: shape, - dimension_symbols: vec![None; shape_len] - }, - drop: true, - memory_info: Some(info), - _backing: None - }), - _markers: PhantomData - }); - tensor.upgradable = false; - Ok(tensor) + let num_elements = element_count(&shape); + tensor_from_array(info, shape, data, num_elements, size_of::(), T::into_tensor_element_type(), None).map(|tensor| { + let mut tensor: TensorRefMut<'_, T> = TensorRefMut::new(unsafe { tensor.transmute_type() }); + tensor.upgradable = false; + tensor + }) } } @@ -418,9 +328,9 @@ pub trait OwnedTensorArrayData { pub struct TensorArrayDataParts { pub shape: Vec, - pub ptr: *mut I, + pub ptr: NonNull, pub num_elements: usize, - pub guard: Box + pub guard: Option> } pub trait ToDimensions { @@ -439,17 +349,17 @@ macro_rules! impl_to_dimensions { } else { Err(Error::new_with_code( ErrorCode::InvalidArgument, - format!("Invalid dimension at {}; all dimensions must be >= 1 when creating a tensor from raw data", i) + format!("Invalid dimension #{}; all dimensions must be >= 1 when creating a tensor from raw data", i + 1) )) } }) .collect::>()?; - let sum = calculate_tensor_size(&v); + let sum = element_count(&v); if let Some(expected_size) = expected_size { if sum != expected_size { Err(Error::new_with_code( ErrorCode::InvalidArgument, - format!("Cannot create a tensor from raw data; shape {:?} ({}) is larger than the length of the data provided ({})", v, sum, expected_size) + format!("Cannot create a tensor from raw data; shape {:?} ({} elements) is larger than the length of the data provided ({} elements)", v, sum, expected_size) )) } else { Ok(v) @@ -544,19 +454,30 @@ impl OwnedTensorArrayData for Arr fn into_parts(self) -> Result> { if self.is_standard_layout() { // We can avoid the copy here and use the data as is - let mut guard = Box::new(self); - let shape: Vec = guard.shape().iter().map(|d| *d as i64).collect(); - let ptr = guard.as_mut_ptr(); - let num_elements = guard.len(); - Ok(TensorArrayDataParts { shape, ptr, num_elements, guard }) + let mut this = Box::new(self); + let shape: Vec = this.shape().iter().map(|d| *d as i64).collect(); + // SAFETY: ndarrays internally store their pointer as NonNull + let ptr = unsafe { NonNull::new_unchecked(this.as_mut_ptr()) }; + let num_elements = this.len(); + Ok(TensorArrayDataParts { + shape, + ptr, + num_elements, + guard: Some(this) + }) } else { // Need to do a copy here to get data in to standard layout let mut contiguous_array = self.as_standard_layout().into_owned(); let shape: Vec = contiguous_array.shape().iter().map(|d| *d as i64).collect(); - let ptr = contiguous_array.as_mut_ptr(); + // SAFETY: ndarrays internally store their pointer as NonNull + let ptr = unsafe { NonNull::new_unchecked(contiguous_array.as_mut_ptr()) }; let num_elements: usize = contiguous_array.len(); - let guard = Box::new(contiguous_array); - Ok(TensorArrayDataParts { shape, ptr, num_elements, guard }) + Ok(TensorArrayDataParts { + shape, + ptr, + num_elements, + guard: Some(Box::new(contiguous_array)) + }) } } @@ -649,13 +570,14 @@ impl TensorArrayDataMut for (D, &mut [T] impl OwnedTensorArrayData for (D, Vec) { fn into_parts(mut self) -> Result> { let shape = self.0.to_dimensions(Some(self.1.len()))?; - let ptr = self.1.as_mut_ptr(); + // SAFETY: A `Vec` always has a non-null pointer. + let ptr = unsafe { NonNull::new_unchecked(self.1.as_mut_ptr()) }; let num_elements: usize = self.1.len(); Ok(TensorArrayDataParts { shape, ptr, num_elements, - guard: Box::new(self.1) + guard: Some(Box::new(self.1)) }) } @@ -665,13 +587,14 @@ impl OwnedTensorArrayData for (D, Vec impl OwnedTensorArrayData for (D, Box<[T]>) { fn into_parts(mut self) -> Result> { let shape = self.0.to_dimensions(Some(self.1.len()))?; - let ptr = self.1.as_mut_ptr(); + // SAFETY: A `Box` always has a non-null pointer. + let ptr = unsafe { NonNull::new_unchecked(self.1.as_mut_ptr()) }; let num_elements: usize = self.1.len(); Ok(TensorArrayDataParts { shape, ptr, num_elements, - guard: Box::new(self.1) + guard: Some(Box::new(self.1)) }) } diff --git a/src/value/impl_tensor/extract.rs b/src/value/impl_tensor/extract.rs index 3b14ee4b..b08b8b55 100644 --- a/src/value/impl_tensor/extract.rs +++ b/src/value/impl_tensor/extract.rs @@ -9,14 +9,14 @@ use core::{ffi::c_void, fmt::Debug, ptr, slice}; #[cfg(feature = "ndarray")] use ndarray::IxDyn; -use super::{Tensor, TensorValueTypeMarker, calculate_tensor_size}; -#[cfg(feature = "ndarray")] -use crate::tensor::{extract_primitive_array, extract_primitive_array_mut}; +use super::{Tensor, TensorValueTypeMarker}; use crate::{ AsPointer, error::{Error, ErrorCode, Result}, + memory::MemoryInfo, ortsys, tensor::{PrimitiveTensorElementType, TensorElementType}, + util::element_count, value::{Value, ValueType} }; @@ -52,24 +52,10 @@ impl Value { #[cfg(feature = "ndarray")] #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] pub fn try_extract_tensor(&self) -> Result> { - match self.dtype() { - ValueType::Tensor { ty, dimensions, .. } => { - let mem = self.memory_info(); - if !mem.is_cpu_accessible() { - return Err(Error::new(format!("Cannot extract from value on device `{}`, which is not CPU accessible", mem.allocation_device().as_str()))); - } - - if *ty == T::into_tensor_element_type() { - Ok(extract_primitive_array(IxDyn(&dimensions.iter().map(|&n| n as usize).collect::>()), self.ptr())?) - } else { - Err(Error::new_with_code( - ErrorCode::InvalidArgument, - format!("Cannot extract Tensor<{}> from Tensor<{}>", T::into_tensor_element_type(), ty) - )) - } - } - t => Err(Error::new_with_code(ErrorCode::InvalidArgument, format!("Cannot extract a Tensor<{}> from {t}", T::into_tensor_element_type()))) - } + extract_tensor(self.ptr().cast_mut(), self.dtype(), self.memory_info(), T::into_tensor_element_type()).and_then(|(ptr, dimensions)| { + let shape = IxDyn(&dimensions.iter().map(|&n| n as usize).collect::>()); + Ok(unsafe { ndarray::ArrayView::from_shape_ptr(shape, data_ptr(ptr)?.cast::()) }) + }) } /// Attempt to extract the scalar from a tensor of type `T`. @@ -96,36 +82,16 @@ impl Value { /// /// [`DynValue`]: crate::value::DynValue pub fn try_extract_scalar(&self) -> Result { - match self.dtype() { - ValueType::Tensor { ty, dimensions, .. } => { - let mem = self.memory_info(); - if !mem.is_cpu_accessible() { - return Err(Error::new(format!("Cannot extract from value on device `{}`, which is not CPU accessible", mem.allocation_device().as_str()))); - } - - if *ty == T::into_tensor_element_type() { - if !dimensions.is_empty() { - return Err(Error::new_with_code( - ErrorCode::InvalidArgument, - format!("Cannot extract scalar {} from a tensor of dimensionality {}", T::into_tensor_element_type(), dimensions.len()) - )); - } - - let mut output_array_ptr: *mut T = ptr::null_mut(); - let output_array_ptr_ptr: *mut *mut T = &mut output_array_ptr; - let output_array_ptr_ptr_void: *mut *mut c_void = output_array_ptr_ptr.cast(); - ortsys![unsafe GetTensorMutableData(self.ptr().cast_mut(), output_array_ptr_ptr_void)?; nonNull(output_array_ptr)]; - - Ok(unsafe { *output_array_ptr }) - } else { - Err(Error::new_with_code( - ErrorCode::InvalidArgument, - format!("Cannot extract scalar {} from Tensor<{}>", T::into_tensor_element_type(), ty) - )) - } + extract_tensor(self.ptr().cast_mut(), self.dtype(), self.memory_info(), T::into_tensor_element_type()).and_then(|(ptr, dimensions)| { + if !dimensions.is_empty() { + return Err(Error::new_with_code( + ErrorCode::InvalidArgument, + format!("Cannot extract scalar {} from a tensor of dimensionality {}", T::into_tensor_element_type(), dimensions.len()) + )); } - t => Err(Error::new_with_code(ErrorCode::InvalidArgument, format!("Cannot extract Tensor<{}> from {t}", T::into_tensor_element_type()))) - } + + Ok(unsafe { *data_ptr(ptr)?.cast::() }) + }) } /// Attempt to extract the underlying data of type `T` into a mutable read-only [`ndarray::ArrayViewMut`]. @@ -158,24 +124,10 @@ impl Value { #[cfg(feature = "ndarray")] #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] pub fn try_extract_tensor_mut(&mut self) -> Result> { - match self.dtype() { - ValueType::Tensor { ty, dimensions, .. } => { - let mem = self.memory_info(); - if !mem.is_cpu_accessible() { - return Err(Error::new(format!("Cannot extract from value on device `{}`, which is not CPU accessible", mem.allocation_device().as_str()))); - } - - if *ty == T::into_tensor_element_type() { - Ok(extract_primitive_array_mut(IxDyn(&dimensions.iter().map(|&n| n as usize).collect::>()), self.ptr_mut())?) - } else { - Err(Error::new_with_code( - ErrorCode::InvalidArgument, - format!("Cannot extract Tensor<{}> from Tensor<{}>", T::into_tensor_element_type(), ty) - )) - } - } - t => Err(Error::new_with_code(ErrorCode::InvalidArgument, format!("Cannot extract Tensor<{}> from {t}", T::into_tensor_element_type()))) - } + extract_tensor(self.ptr_mut(), self.dtype(), self.memory_info(), T::into_tensor_element_type()).and_then(|(ptr, dimensions)| { + let shape = IxDyn(&dimensions.iter().map(|&n| n as usize).collect::>()); + Ok(unsafe { ndarray::ArrayViewMut::from_shape_ptr(shape, data_ptr(ptr)?.cast::()) }) + }) } /// Attempt to extract the underlying data into a "raw" view tuple, consisting of the tensor's dimensions and an @@ -207,30 +159,8 @@ impl Value { /// /// [`DynValue`]: crate::value::DynValue pub fn try_extract_raw_tensor(&self) -> Result<(&[i64], &[T])> { - match self.dtype() { - ValueType::Tensor { ty, dimensions, .. } => { - let mem = self.memory_info(); - if !mem.is_cpu_accessible() { - return Err(Error::new(format!("Cannot extract from value on device `{}`, which is not CPU accessible", mem.allocation_device().as_str()))); - } - - if *ty == T::into_tensor_element_type() { - let mut output_array_ptr: *mut T = ptr::null_mut(); - let output_array_ptr_ptr: *mut *mut T = &mut output_array_ptr; - let output_array_ptr_ptr_void: *mut *mut c_void = output_array_ptr_ptr.cast(); - ortsys![unsafe GetTensorMutableData(self.ptr().cast_mut(), output_array_ptr_ptr_void)?; nonNull(output_array_ptr)]; - - let len = calculate_tensor_size(dimensions); - Ok((dimensions, unsafe { slice::from_raw_parts(output_array_ptr, len) })) - } else { - Err(Error::new_with_code( - ErrorCode::InvalidArgument, - format!("Cannot extract Tensor<{}> from Tensor<{}>", T::into_tensor_element_type(), ty) - )) - } - } - t => Err(Error::new_with_code(ErrorCode::InvalidArgument, format!("Cannot extract Tensor<{}> from {t}", T::into_tensor_element_type()))) - } + extract_tensor(self.ptr().cast_mut(), self.dtype(), self.memory_info(), T::into_tensor_element_type()) + .and_then(|(ptr, dimensions)| Ok((dimensions.as_slice(), unsafe { slice::from_raw_parts(data_ptr(ptr)?.cast::(), element_count(dimensions)) }))) } /// Attempt to extract the underlying data into a "raw" view tuple, consisting of the tensor's dimensions and a @@ -259,31 +189,9 @@ impl Value { /// /// [`DynValue`]: crate::value::DynValue pub fn try_extract_raw_tensor_mut(&mut self) -> Result<(&[i64], &mut [T])> { - let dtype = self.dtype(); - match dtype { - ValueType::Tensor { ty, dimensions, .. } => { - let mem = self.memory_info(); - if !mem.is_cpu_accessible() { - return Err(Error::new(format!("Cannot extract from value on device `{}`, which is not CPU accessible", mem.allocation_device().as_str()))); - } - - if *ty == T::into_tensor_element_type() { - let mut output_array_ptr: *mut T = ptr::null_mut(); - let output_array_ptr_ptr: *mut *mut T = &mut output_array_ptr; - let output_array_ptr_ptr_void: *mut *mut c_void = output_array_ptr_ptr.cast(); - ortsys![unsafe GetTensorMutableData(self.ptr().cast_mut(), output_array_ptr_ptr_void)?; nonNull(output_array_ptr)]; - - let len = calculate_tensor_size(dimensions); - Ok((dimensions, unsafe { slice::from_raw_parts_mut(output_array_ptr, len) })) - } else { - Err(Error::new_with_code( - ErrorCode::InvalidArgument, - format!("Cannot extract Tensor<{}> from Tensor<{}>", T::into_tensor_element_type(), ty) - )) - } - } - t => Err(Error::new_with_code(ErrorCode::InvalidArgument, format!("Cannot extract Tensor<{}> from {t:?}", T::into_tensor_element_type()))) - } + extract_tensor(self.ptr_mut(), self.dtype(), self.memory_info(), T::into_tensor_element_type()).and_then(|(ptr, dimensions)| { + Ok((dimensions.as_slice(), unsafe { slice::from_raw_parts_mut(data_ptr(ptr)?.cast::(), element_count(dimensions)) })) + }) } /// Attempt to extract the underlying data into a Rust `ndarray`. @@ -302,54 +210,11 @@ impl Value { #[cfg(feature = "ndarray")] #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] pub fn try_extract_string_tensor(&self) -> Result> { - match self.dtype() { - ValueType::Tensor { ty, dimensions, .. } => { - let mem = self.memory_info(); - if !mem.is_cpu_accessible() { - return Err(Error::new(format!("Cannot extract from value on device `{}`, which is not CPU accessible", mem.allocation_device().as_str()))); - } - - if *ty == TensorElementType::String { - let len = calculate_tensor_size(dimensions); - - // Total length of string data, not including \0 suffix - let mut total_length = 0; - ortsys![unsafe GetStringTensorDataLength(self.ptr(), &mut total_length)?]; - - // In the JNI impl of this, tensor_element_len was included in addition to total_length, - // but that seems contrary to the docs of GetStringTensorDataLength, and those extra bytes - // don't seem to be written to in practice either. - // If the string data actually did go farther, it would panic below when using the offset - // data to get slices for each string. - let mut string_contents = vec![0u8; total_length]; - // one extra slot so that the total length can go in the last one, making all per-string - // length calculations easy - let mut offsets = vec![0; len + 1]; - - ortsys![unsafe GetStringTensorContent(self.ptr(), string_contents.as_mut_ptr().cast(), total_length, offsets.as_mut_ptr(), len)?]; - - // final offset = overall length so that per-string length calculations work for the last string - debug_assert_eq!(0, offsets[len]); - offsets[len] = total_length; - - let strings = offsets - // offsets has 1 extra offset past the end so that all windows work - .windows(2) - .map(|w| { - let slice = &string_contents[w[0]..w[1]]; - String::from_utf8(slice.into()) - }) - .collect::, FromUtf8Error>>() - .map_err(Error::wrap)?; - - Ok(ndarray::Array::from_shape_vec(IxDyn(&dimensions.iter().map(|&n| n as usize).collect::>()), strings) - .expect("Shape extracted from tensor didn't match tensor contents")) - } else { - Err(Error::new_with_code(ErrorCode::InvalidArgument, format!("Cannot extract Tensor from Tensor<{ty}>"))) - } - } - t => Err(Error::new_with_code(ErrorCode::InvalidArgument, format!("Cannot extract Tensor from {t}"))) - } + extract_tensor(self.ptr().cast_mut(), self.dtype(), self.memory_info(), TensorElementType::String).and_then(|(ptr, dimensions)| { + let strings = extract_strings(ptr, dimensions)?; + Ok(ndarray::Array::from_shape_vec(IxDyn(&dimensions.iter().map(|&n| n as usize).collect::>()), strings) + .expect("Shape extracted from tensor didn't match tensor contents")) + }) } /// Attempt to extract the underlying string data into a "raw" data tuple, consisting of the tensor's dimensions and @@ -368,53 +233,10 @@ impl Value { /// # } /// ``` pub fn try_extract_raw_string_tensor(&self) -> Result<(&[i64], Vec)> { - match self.dtype() { - ValueType::Tensor { ty, dimensions, .. } => { - let mem = self.memory_info(); - if !mem.is_cpu_accessible() { - return Err(Error::new(format!("Cannot extract from value on device `{}`, which is not CPU accessible", mem.allocation_device().as_str()))); - } - - if *ty == TensorElementType::String { - let len = calculate_tensor_size(dimensions); - - // Total length of string data, not including \0 suffix - let mut total_length = 0; - ortsys![unsafe GetStringTensorDataLength(self.ptr(), &mut total_length)?]; - - // In the JNI impl of this, tensor_element_len was included in addition to total_length, - // but that seems contrary to the docs of GetStringTensorDataLength, and those extra bytes - // don't seem to be written to in practice either. - // If the string data actually did go farther, it would panic below when using the offset - // data to get slices for each string. - let mut string_contents = vec![0u8; total_length]; - // one extra slot so that the total length can go in the last one, making all per-string - // length calculations easy - let mut offsets = vec![0; len + 1]; - - ortsys![unsafe GetStringTensorContent(self.ptr(), string_contents.as_mut_ptr().cast(), total_length, offsets.as_mut_ptr(), len)?]; - - // final offset = overall length so that per-string length calculations work for the last string - debug_assert_eq!(0, offsets[len]); - offsets[len] = total_length; - - let strings = offsets - // offsets has 1 extra offset past the end so that all windows work - .windows(2) - .map(|w| { - let slice = &string_contents[w[0]..w[1]]; - String::from_utf8(slice.into()) - }) - .collect::, FromUtf8Error>>() - .map_err(Error::wrap)?; - - Ok((dimensions, strings)) - } else { - Err(Error::new_with_code(ErrorCode::InvalidArgument, format!("Cannot extract Tensor from Tensor<{ty}>"))) - } - } - t => Err(Error::new_with_code(ErrorCode::InvalidArgument, format!("Cannot extract Tensor from {t}"))) - } + extract_tensor(self.ptr().cast_mut(), self.dtype(), self.memory_info(), TensorElementType::String).and_then(|(ptr, dimensions)| { + let strings = extract_strings(ptr, dimensions)?; + Ok((dimensions.as_slice(), strings)) + }) } /// Returns the shape of the tensor. @@ -447,6 +269,72 @@ impl Value { } } +fn extract_tensor<'t>( + ptr: *mut ort_sys::OrtValue, + dtype: &'t ValueType, + memory_info: &MemoryInfo, + expected_ty: TensorElementType +) -> Result<(*mut ort_sys::OrtValue, &'t Vec)> { + match dtype { + ValueType::Tensor { ty, dimensions, .. } => { + if !memory_info.is_cpu_accessible() { + return Err(Error::new(format!( + "Cannot extract from value on device `{}`, which is not CPU accessible", + memory_info.allocation_device().as_str() + ))); + } + + if *ty == expected_ty { + Ok((ptr, dimensions)) + } else { + Err(Error::new_with_code(ErrorCode::InvalidArgument, format!("Cannot extract Tensor<{}> from Tensor<{}>", expected_ty, ty))) + } + } + t => Err(Error::new_with_code(ErrorCode::InvalidArgument, format!("Cannot extract a Tensor<{}> from {t}", expected_ty))) + } +} + +unsafe fn data_ptr(ptr: *mut ort_sys::OrtValue) -> Result<*mut c_void> { + let mut output_array_ptr: *mut c_void = ptr::null_mut(); + ortsys![unsafe GetTensorMutableData(ptr, &mut output_array_ptr)?; nonNull(output_array_ptr)]; + Ok(output_array_ptr) +} + +fn extract_strings(ptr: *mut ort_sys::OrtValue, dimensions: &[i64]) -> Result> { + let len = element_count(dimensions); + + // Total length of string data, not including \0 suffix + let mut total_length = 0; + ortsys![unsafe GetStringTensorDataLength(ptr, &mut total_length)?]; + + // In the JNI impl of this, tensor_element_len was included in addition to total_length, + // but that seems contrary to the docs of GetStringTensorDataLength, and those extra bytes + // don't seem to be written to in practice either. + // If the string data actually did go farther, it would panic below when using the offset + // data to get slices for each string. + let mut string_contents = vec![0u8; total_length]; + // one extra slot so that the total length can go in the last one, making all per-string + // length calculations easy + let mut offsets = vec![0; len + 1]; + + ortsys![unsafe GetStringTensorContent(ptr, string_contents.as_mut_ptr().cast(), total_length, offsets.as_mut_ptr(), len)?]; + + // final offset = overall length so that per-string length calculations work for the last string + debug_assert_eq!(0, offsets[len]); + offsets[len] = total_length; + + let strings = offsets + // offsets has 1 extra offset past the end so that all windows work + .windows(2) + .map(|w| { + let slice = &string_contents[w[0]..w[1]]; + String::from_utf8(slice.into()) + }) + .collect::, FromUtf8Error>>() + .map_err(Error::wrap)?; + Ok(strings) +} + impl Tensor { /// Extracts the underlying data into a read-only [`ndarray::ArrayView`]. /// diff --git a/src/value/impl_tensor/mod.rs b/src/value/impl_tensor/mod.rs index 12f4a5cb..8d81bfe7 100644 --- a/src/value/impl_tensor/mod.rs +++ b/src/value/impl_tensor/mod.rs @@ -1,16 +1,10 @@ mod create; mod extract; -use alloc::{ - format, - string::{String, ToString}, - sync::Arc, - vec -}; +use alloc::{sync::Arc, vec}; use core::{ - fmt::Debug, + fmt::{self, Debug}, marker::PhantomData, - mem, ops::{Index, IndexMut}, ptr::{self, NonNull} }; @@ -22,7 +16,8 @@ use crate::{ error::Result, memory::{Allocator, MemoryInfo}, ortsys, - tensor::{IntoTensorElementType, TensorElementType} + tensor::{IntoTensorElementType, TensorElementType}, + util::element_count }; pub trait TensorValueTypeMarker: ValueTypeMarker { @@ -32,8 +27,8 @@ pub trait TensorValueTypeMarker: ValueTypeMarker { #[derive(Debug)] pub struct DynTensorValueType; impl ValueTypeMarker for DynTensorValueType { - fn format() -> String { - "DynTensor".to_string() + fn fmt(f: &mut fmt::Formatter) -> fmt::Result { + f.write_str("DynTensor") } private_impl!(); @@ -45,8 +40,10 @@ impl TensorValueTypeMarker for DynTensorValueType { #[derive(Debug)] pub struct TensorValueType(PhantomData); impl ValueTypeMarker for TensorValueType { - fn format() -> String { - format!("Tensor<{}>", T::into_tensor_element_type()) + fn fmt(f: &mut fmt::Formatter) -> fmt::Result { + f.write_str("Tensor<")?; + ::fmt(&T::into_tensor_element_type(), f)?; + f.write_str(">") } private_impl!(); @@ -121,7 +118,7 @@ impl DynTensor { let mut buffer_ptr: *mut ort_sys::c_void = ptr::null_mut(); ortsys![unsafe GetTensorMutableData(value_ptr, &mut buffer_ptr)?; nonNull(buffer_ptr)]; - unsafe { buffer_ptr.write_bytes(0, data_type.byte_size(calculate_tensor_size(&shape))) }; + unsafe { buffer_ptr.write_bytes(0, data_type.byte_size(element_count(&shape))) }; } Ok(Value { @@ -234,7 +231,7 @@ impl Tensor { /// ``` #[inline] pub fn upcast(self) -> DynTensor { - unsafe { mem::transmute(self) } + unsafe { self.transmute_type() } } /// Creates a type-erased [`DynTensorRef`] from a strongly-typed [`Tensor`]. @@ -331,17 +328,6 @@ impl IndexMut<[i64; N] } } -pub(crate) fn calculate_tensor_size(shape: &[i64]) -> usize { - let mut size = 1usize; - for dim in shape { - if *dim < 0 { - return 0; - } - size *= *dim as usize; - } - size -} - #[cfg(test)] mod tests { use alloc::sync::Arc; diff --git a/src/value/mod.rs b/src/value/mod.rs index 0cfe6995..25428a3e 100644 --- a/src/value/mod.rs +++ b/src/value/mod.rs @@ -16,15 +16,10 @@ //! //! ONNX Runtime also supports [`Sequence`]s and [`Map`]s, though they are less commonly used. -use alloc::{ - boxed::Box, - format, - string::{String, ToString}, - sync::Arc -}; +use alloc::{boxed::Box, format, sync::Arc}; use core::{ any::Any, - fmt::Debug, + fmt::{self, Debug}, marker::PhantomData, mem::transmute, ops::{Deref, DerefMut}, @@ -101,6 +96,10 @@ impl<'v, Type: ValueTypeMarker + ?Sized> ValueRef<'v, Type> { } } + pub(crate) fn inner(&self) -> &Arc { + &self.inner.inner + } + /// Attempts to downcast a temporary dynamic value (like [`DynValue`] or [`DynTensor`]) to a more strongly typed /// variant, like [`TensorRef`]. #[inline] @@ -109,7 +108,7 @@ impl<'v, Type: ValueTypeMarker + ?Sized> ValueRef<'v, Type> { if OtherType::can_downcast(dt) { Ok(unsafe { transmute::, ValueRef<'v, OtherType>>(self) }) } else { - Err(Error::new_with_code(ErrorCode::InvalidArgument, format!("Cannot downcast &{dt} to &{}", OtherType::format()))) + Err(Error::new_with_code(ErrorCode::InvalidArgument, format!("Cannot downcast &{dt} to &{}", format_value_type::()))) } } @@ -154,6 +153,10 @@ impl<'v, Type: ValueTypeMarker + ?Sized> ValueRefMut<'v, Type> { } } + pub(crate) fn inner(&self) -> &Arc { + &self.inner.inner + } + /// Attempts to downcast a temporary mutable dynamic value (like [`DynValue`] or [`DynTensor`]) to a more /// strongly typed variant, like [`TensorRefMut`]. #[inline] @@ -162,7 +165,7 @@ impl<'v, Type: ValueTypeMarker + ?Sized> ValueRefMut<'v, Type> { if OtherType::can_downcast(dt) { Ok(unsafe { transmute::, ValueRefMut<'v, OtherType>>(self) }) } else { - Err(Error::new_with_code(ErrorCode::InvalidArgument, format!("Cannot downcast &mut {dt} to &mut {}", OtherType::format()))) + Err(Error::new_with_code(ErrorCode::InvalidArgument, format!("Cannot downcast &mut {dt} to &mut {}", format_value_type::()))) } } @@ -251,11 +254,24 @@ pub type DynValue = Value; /// inherits this trait), i.e. [`Tensor`]s, [`DynTensor`]s, and [`DynValue`]s. pub trait ValueTypeMarker { #[doc(hidden)] - fn format() -> String; + fn fmt(f: &mut fmt::Formatter) -> fmt::Result; private_trait!(); } +pub(crate) struct ValueTypeFormatter(PhantomData); + +#[inline] +pub(crate) fn format_value_type() -> ValueTypeFormatter { + ValueTypeFormatter(PhantomData) +} + +impl fmt::Display for ValueTypeFormatter { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + ::fmt(f) + } +} + /// Represents a type that a [`DynValue`] can be downcast to. pub trait DowncastableTarget: ValueTypeMarker { fn can_downcast(dtype: &ValueType) -> bool; @@ -276,8 +292,8 @@ impl DowncastableTarget for DynValueTypeMarker { #[derive(Debug)] pub struct DynValueTypeMarker; impl ValueTypeMarker for DynValueTypeMarker { - fn format() -> String { - "DynValue".to_string() + fn fmt(f: &mut fmt::Formatter) -> fmt::Result { + f.write_str("DynValue") } private_impl!(); @@ -296,6 +312,10 @@ unsafe impl Send for Value {} unsafe impl Sync for Value {} impl Value { + pub(crate) fn inner(&self) -> &Arc { + &self.inner + } + /// Returns the data type of this [`Value`]. pub fn dtype(&self) -> &ValueType { &self.inner.dtype @@ -358,7 +378,12 @@ impl Value { /// Converts this value into a type-erased [`DynValue`]. pub fn into_dyn(self) -> DynValue { - unsafe { transmute(self) } + unsafe { self.transmute_type() } + } + + #[inline(always)] + pub(crate) unsafe fn transmute_type(self) -> Value { + unsafe { transmute::, Value>(self) } } pub(crate) fn clone_of(value: &Self) -> Self { @@ -395,7 +420,7 @@ impl Value { if OtherType::can_downcast(dt) { Ok(unsafe { transmute::, Value>(self) }) } else { - Err(Error::new_with_code(ErrorCode::InvalidArgument, format!("Cannot downcast {dt} to {}", OtherType::format()))) + Err(Error::new_with_code(ErrorCode::InvalidArgument, format!("Cannot downcast {dt} to {}", format_value_type::()))) } } @@ -407,7 +432,7 @@ impl Value { if OtherType::can_downcast(dt) { Ok(ValueRef::new(unsafe { transmute::>(Value::clone_of(self)) })) } else { - Err(Error::new_with_code(ErrorCode::InvalidArgument, format!("Cannot downcast &{dt} to &{}", OtherType::format()))) + Err(Error::new_with_code(ErrorCode::InvalidArgument, format!("Cannot downcast &{dt} to &{}", format_value_type::()))) } } @@ -419,7 +444,7 @@ impl Value { if OtherType::can_downcast(dt) { Ok(ValueRefMut::new(unsafe { transmute::>(Value::clone_of(self)) })) } else { - Err(Error::new_with_code(ErrorCode::InvalidArgument, format!("Cannot downcast &mut {dt} to &mut {}", OtherType::format()))) + Err(Error::new_with_code(ErrorCode::InvalidArgument, format!("Cannot downcast &mut {dt} to &mut {}", format_value_type::()))) } } }