From 1dbad54248cae53e7d4ded8c769c798aaff41c58 Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Fri, 15 Nov 2024 18:08:42 -0600 Subject: [PATCH] refactor!: precompute value dtype/memory info Breaking because `extract_tensor_*` now returns `&[i64]` for dimensions, and `dtype()` and `memory_info()` also return references. Each tensor extract call not only had multiple FFI calls to determine the `ValueType`, but also had to determine `MemoryInfo` to ensure the data was CPU-accessible. Since neither the data type or memory location can *change* for a given value, it doesn't make sense to compute this on each extract call; it's better to compute it once, when we create the `Value` (and we often already have the types created by this time, so little FFI is actually required). This should make `extract_tensor_raw` zero-alloc, most benefitting usages of `IoBinding`/`OutputSelector`. This does mean usages of `Value` without ever extracting said value (like HF Transformers hidden state outputs which go ignored) incur slightly more overhead, but the tradeoff of having less overhead at extraction time seems worth it. --- examples/custom-ops/examples/custom-ops.rs | 8 +- examples/model-info/examples/model-info.rs | 44 +-- src/io_binding.rs | 10 +- src/memory.rs | 13 + src/operator/kernel.rs | 8 +- src/operator/tests.rs | 8 +- src/session/output.rs | 9 +- src/session/run_options.rs | 10 +- src/value/impl_map.rs | 21 +- src/value/impl_sequence.rs | 12 +- src/value/impl_tensor/create.rs | 48 ++- src/value/impl_tensor/extract.rs | 66 ++-- src/value/impl_tensor/mod.rs | 14 +- src/value/mod.rs | 382 ++++----------------- src/value/type.rs | 254 ++++++++++++++ 15 files changed, 447 insertions(+), 460 deletions(-) create mode 100644 src/value/type.rs diff --git a/examples/custom-ops/examples/custom-ops.rs b/examples/custom-ops/examples/custom-ops.rs index 1afdc73d..0b8c986a 100644 --- a/examples/custom-ops/examples/custom-ops.rs +++ b/examples/custom-ops/examples/custom-ops.rs @@ -39,9 +39,9 @@ impl Kernel for CustomOpOneKernel { let (x_shape, x) = x.try_extract_raw_tensor::()?; let (y_shape, y) = y.try_extract_raw_tensor::()?; - let mut z = ctx.output(0, x_shape)?.unwrap(); + let mut z = ctx.output(0, x_shape.to_vec())?.unwrap(); let (_, z_ref) = z.try_extract_raw_tensor_mut::()?; - for i in 0..y_shape.into_iter().reduce(|acc, e| acc * e).unwrap() as usize { + for i in 0..y_shape.iter().copied().reduce(|acc, e| acc * e).unwrap() as usize { if i % 2 == 0 { z_ref[i] = x[i]; } else { @@ -79,9 +79,9 @@ impl Kernel for CustomOpTwoKernel { fn compute(&mut self, ctx: &KernelContext) -> ort::Result<()> { let x = ctx.input(0)?.unwrap(); let (x_shape, x) = x.try_extract_raw_tensor::()?; - let mut z = ctx.output(0, x_shape.clone())?.unwrap(); + let mut z = ctx.output(0, x_shape.to_vec())?.unwrap(); let (_, z_ref) = z.try_extract_raw_tensor_mut::()?; - for i in 0..x_shape.into_iter().reduce(|acc, e| acc * e).unwrap() as usize { + for i in 0..x_shape.iter().copied().reduce(|acc, e| acc * e).unwrap() as usize { z_ref[i] = (x[i] * i as f32) as i32; } Ok(()) diff --git a/examples/model-info/examples/model-info.rs b/examples/model-info/examples/model-info.rs index 08c068db..8bceac49 100644 --- a/examples/model-info/examples/model-info.rs +++ b/examples/model-info/examples/model-info.rs @@ -1,44 +1,6 @@ use std::{env, process}; -use ort::{session::Session, tensor::TensorElementType, value::ValueType}; - -fn display_element_type(t: TensorElementType) -> &'static str { - match t { - TensorElementType::Bfloat16 => "bf16", - TensorElementType::Bool => "bool", - TensorElementType::Float16 => "f16", - TensorElementType::Float32 => "f32", - TensorElementType::Float64 => "f64", - TensorElementType::Int16 => "i16", - TensorElementType::Int32 => "i32", - TensorElementType::Int64 => "i64", - TensorElementType::Int8 => "i8", - TensorElementType::String => "str", - TensorElementType::Uint16 => "u16", - TensorElementType::Uint32 => "u32", - TensorElementType::Uint64 => "u64", - TensorElementType::Uint8 => "u8" - } -} - -fn display_value_type(value: &ValueType) -> String { - match value { - ValueType::Tensor { ty, dimensions } => { - format!( - "Tensor<{}>({})", - display_element_type(*ty), - dimensions - .iter() - .map(|c| if *c == -1 { "dyn".to_string() } else { c.to_string() }) - .collect::>() - .join(", ") - ) - } - ValueType::Map { key, value } => format!("Map<{}, {}>", display_element_type(*key), display_element_type(*value)), - ValueType::Sequence(inner) => format!("Sequence<{}>", display_value_type(inner)), - ValueType::Optional(inner) => format!("Option<{}>", display_value_type(inner)) - } -} +use ort::session::Session; fn main() -> ort::Result<()> { let Some(path) = env::args().nth(1) else { @@ -61,11 +23,11 @@ fn main() -> ort::Result<()> { println!("Inputs:"); for (i, input) in session.inputs.iter().enumerate() { - println!(" {i} {}: {}", input.name, display_value_type(&input.input_type)); + println!(" {i} {}: {}", input.name, input.input_type); } println!("Outputs:"); for (i, output) in session.outputs.iter().enumerate() { - println!(" {i} {}: {}", output.name, display_value_type(&output.output_type)); + println!(" {i} {}: {}", output.name, output.output_type); } Ok(()) diff --git a/src/io_binding.rs b/src/io_binding.rs index 96995cec..c4a627f1 100644 --- a/src/io_binding.rs +++ b/src/io_binding.rs @@ -4,7 +4,6 @@ use std::{ collections::HashMap, ffi::CString, fmt::Debug, - marker::PhantomData, ptr::{self, NonNull}, sync::Arc }; @@ -214,7 +213,7 @@ impl IoBinding { let run_options_ptr = if let Some(run_options) = run_options { run_options.ptr() } else { std::ptr::null() }; ortsys![unsafe RunWithBinding(self.session.ptr().cast_mut(), run_options_ptr, self.ptr())?]; - let owned_ptrs: HashMap<*mut ort_sys::OrtValue, &Arc> = self.output_values.values().map(|c| (c.ptr().cast_mut(), &c.inner)).collect(); + let owned_ptrs: HashMap<*mut ort_sys::OrtValue, &Value> = self.output_values.values().map(|c| (c.ptr().cast_mut(), c)).collect(); let mut count = self.output_names.len(); if count > 0 { let mut output_values_ptr: *mut *mut ort_sys::OrtValue = ptr::null_mut(); @@ -223,11 +222,8 @@ impl IoBinding { let output_values = unsafe { std::slice::from_raw_parts(output_values_ptr, count).to_vec() } .into_iter() .map(|v| unsafe { - if let Some(inner) = owned_ptrs.get(&v) { - DynValue { - inner: Arc::clone(*inner), - _markers: PhantomData - } + if let Some(value) = owned_ptrs.get(&v) { + DynValue::clone_of(value) } else { DynValue::from_ptr( NonNull::new(v).expect("OrtValue ptrs returned by GetBoundOutputValues should not be null"), diff --git a/src/memory.rs b/src/memory.rs index 809a3e27..984abcb3 100644 --- a/src/memory.rs +++ b/src/memory.rs @@ -400,6 +400,19 @@ impl MemoryInfo { }) } + pub(crate) fn from_value(value_ptr: *mut ort_sys::OrtValue) -> Option { + let mut is_tensor = 0; + ortsys![unsafe IsTensor(value_ptr, &mut is_tensor)]; // infallible + if is_tensor != 0 { + let mut memory_info_ptr: *const ort_sys::OrtMemoryInfo = std::ptr::null_mut(); + // infallible, and `memory_info_ptr` will never be null + ortsys![unsafe GetTensorMemoryInfo(value_ptr, &mut memory_info_ptr)]; + Some(Self::from_raw(unsafe { NonNull::new_unchecked(memory_info_ptr.cast_mut()) }, false)) + } else { + None + } + } + pub(crate) fn from_raw(ptr: NonNull, should_release: bool) -> Self { MemoryInfo { ptr, should_release } } diff --git a/src/operator/kernel.rs b/src/operator/kernel.rs index 8b5a665e..c2cbdd70 100644 --- a/src/operator/kernel.rs +++ b/src/operator/kernel.rs @@ -7,7 +7,7 @@ use std::{ use crate::{ AsPointer, error::{Error, Result, status_to_result}, - memory::{Allocator, MemoryInfo}, + memory::{Allocator, MemoryInfo, MemoryType}, ortsys, session::{Input, Output}, value::{DowncastableTarget, DynValue, Value, ValueRef, ValueRefMut, ValueType} @@ -89,6 +89,12 @@ impl KernelAttributes { ortsys![unsafe KernelInfo_GetNodeName(self.0.as_ptr(), name.as_mut_ptr().cast::(), &mut name_len)?]; CString::from_vec_with_nul(name).map_err(Error::wrap)?.into_string().map_err(Error::wrap) } + + pub fn allocator(&self, mem_type: MemoryType) -> Result { + let mut ptr: *mut ort_sys::OrtAllocator = ptr::null_mut(); + ortsys![unsafe KernelInfoGetAllocator(self.0.as_ptr(), mem_type.into(), &mut ptr)?]; + Ok(unsafe { Allocator::from_raw_unchecked(ptr) }) + } } impl AsPointer for KernelAttributes { diff --git a/src/operator/tests.rs b/src/operator/tests.rs index 23a29258..4c10cd73 100644 --- a/src/operator/tests.rs +++ b/src/operator/tests.rs @@ -41,9 +41,9 @@ impl Kernel for CustomOpOneKernel { let (x_shape, x) = x.try_extract_raw_tensor::()?; let (y_shape, y) = y.try_extract_raw_tensor::()?; - let mut z = ctx.output(0, x_shape)?.ok_or_else(|| crate::Error::new("missing input"))?; + let mut z = ctx.output(0, x_shape.to_vec())?.ok_or_else(|| crate::Error::new("missing input"))?; let (_, z_ref) = z.try_extract_raw_tensor_mut::()?; - for i in 0..y_shape.into_iter().reduce(|acc, e| acc * e).unwrap_or(0) as usize { + for i in 0..y_shape.iter().copied().reduce(|acc, e| acc * e).unwrap_or(0) as usize { if i % 2 == 0 { z_ref[i] = x[i]; } else { @@ -81,9 +81,9 @@ impl Kernel for CustomOpTwoKernel { fn compute(&mut self, ctx: &KernelContext) -> crate::Result<()> { let x = ctx.input(0)?.ok_or_else(|| crate::Error::new("missing input"))?; let (x_shape, x) = x.try_extract_raw_tensor::()?; - let mut z = ctx.output(0, x_shape.clone())?.ok_or_else(|| crate::Error::new("missing input"))?; + let mut z = ctx.output(0, x_shape.to_vec())?.ok_or_else(|| crate::Error::new("missing input"))?; let (_, z_ref) = z.try_extract_raw_tensor_mut::()?; - for i in 0..x_shape.into_iter().reduce(|acc, e| acc * e).unwrap_or(0) as usize { + for i in 0..x_shape.iter().copied().reduce(|acc, e| acc * e).unwrap_or(0) as usize { z_ref[i] = (x[i] * i as f32) as i32; } Ok(()) diff --git a/src/session/output.rs b/src/session/output.rs index 41714585..6dc9784b 100644 --- a/src/session/output.rs +++ b/src/session/output.rs @@ -1,11 +1,9 @@ use std::{ ffi::c_void, iter::FusedIterator, - marker::PhantomData, mem::ManuallyDrop, ops::{Index, IndexMut}, - ptr, - sync::Arc + ptr }; use crate::{ @@ -113,10 +111,7 @@ impl<'r, 's> SessionOutputs<'r, 's> { if &key == k { *k = ""; self.effective_len -= 1; - return Some(DynValue { - inner: Arc::clone(&self.values[i].inner), - _markers: PhantomData - }); + return Some(DynValue::clone_of(&self.values[i])); } } None diff --git a/src/session/run_options.rs b/src/session/run_options.rs index 93569aab..0c176b1b 100644 --- a/src/session/run_options.rs +++ b/src/session/run_options.rs @@ -116,15 +116,7 @@ impl OutputSelector { .map(|o| &o.name) .filter(|n| !self.default_blocklist.contains(n)) .chain(self.allowlist.iter()) - .map(|n| { - ( - n.as_str(), - self.preallocated_outputs.get(n).map(|v| DynValue { - inner: Arc::clone(&v.inner), - _markers: PhantomData - }) - ) - }) + .map(|n| (n.as_str(), self.preallocated_outputs.get(n).map(DynValue::clone_of))) .unzip() } } diff --git a/src/value/impl_map.rs b/src/value/impl_map.rs index 82f3d506..5507430a 100644 --- a/src/value/impl_map.rs +++ b/src/value/impl_map.rs @@ -81,11 +81,11 @@ impl Value { match self.dtype() { ValueType::Map { key, value } => { let k_type = K::into_tensor_element_type(); - if k_type != key { + if k_type != *key { return Err(Error::new_with_code(ErrorCode::InvalidArgument, format!("Cannot extract Map<{:?}, _> (value has K type {:?})", k_type, key))); } let v_type = V::into_tensor_element_type(); - if v_type != value { + if v_type != *value { return Err(Error::new_with_code( ErrorCode::InvalidArgument, format!("Cannot extract Map<{}, {}> from Map<{}, {}>", K::into_tensor_element_type(), V::into_tensor_element_type(), k_type, v_type) @@ -100,7 +100,7 @@ impl Value { if K::into_tensor_element_type() != TensorElementType::String { let dtype = key_value.dtype(); let (key_tensor_shape, key_tensor) = match dtype { - ValueType::Tensor { ty, dimensions } => { + ValueType::Tensor { ty, dimensions, .. } => { let mem = key_value.memory_info(); if !mem.is_cpu_accessible() { return Err(Error::new(format!( @@ -109,13 +109,13 @@ impl Value { ))); } - if ty == K::into_tensor_element_type() { + if *ty == K::into_tensor_element_type() { let mut output_array_ptr: *mut K = ptr::null_mut(); let output_array_ptr_ptr: *mut *mut K = &mut output_array_ptr; let output_array_ptr_ptr_void: *mut *mut std::ffi::c_void = output_array_ptr_ptr.cast(); ortsys![unsafe GetTensorMutableData(key_tensor_ptr, output_array_ptr_ptr_void)?; nonNull(output_array_ptr)]; - let len = calculate_tensor_size(&dimensions); + let len = calculate_tensor_size(dimensions); (dimensions, unsafe { std::slice::from_raw_parts(output_array_ptr, len) }) } else { return Err(Error::new_with_code( @@ -251,10 +251,15 @@ impl Value { let value = unsafe { Value::from_ptr(NonNull::new_unchecked(value_ptr), None) }; let value_type = value.dtype(); - if !OtherType::can_downcast(&value.dtype()) { + if !OtherType::can_downcast(value.dtype()) { return Err(Error::new_with_code( ErrorCode::InvalidArgument, format!("Cannot extract Sequence<{}> from {value_type:?}", OtherType::format()) @@ -134,10 +134,14 @@ impl Value { @@ -76,10 +76,16 @@ impl Tensor { ortsys![unsafe FillStringTensor(value_ptr, string_pointers.as_ptr(), string_pointers.len())?]; Ok(Value { - inner: Arc::new(ValueInner::RustOwned { + inner: Arc::new(ValueInner { ptr: unsafe { NonNull::new_unchecked(value_ptr) }, - _array: Box::new(()), - _memory_info: None + dtype: ValueType::Tensor { + ty: TensorElementType::String, + dimensions: shape, + dimension_symbols: vec![None; shape_len] + }, + memory_info: MemoryInfo::from_value(value_ptr), + drop: true, + _backing: None }), _markers: PhantomData }) @@ -124,10 +130,16 @@ impl Tensor { ]; Ok(Value { - inner: Arc::new(ValueInner::RustOwned { + inner: Arc::new(ValueInner { ptr: unsafe { NonNull::new_unchecked(value_ptr) }, - _array: Box::new(()), - _memory_info: None + dtype: ValueType::Tensor { + ty: T::into_tensor_element_type(), + dimensions: shape, + dimension_symbols: vec![None; shape_len] + }, + drop: true, + memory_info: MemoryInfo::from_value(value_ptr), + _backing: None }), _markers: PhantomData }) @@ -195,10 +207,16 @@ impl Tensor { ]; Ok(Value { - inner: Arc::new(ValueInner::RustOwned { + inner: Arc::new(ValueInner { ptr: unsafe { NonNull::new_unchecked(value_ptr) }, - _array: guard, - _memory_info: Some(memory_info) + dtype: ValueType::Tensor { + ty: T::into_tensor_element_type(), + dimensions: shape, + dimension_symbols: vec![None; shape_len] + }, + drop: true, + memory_info: Some(memory_info), + _backing: Some(guard) }), _markers: PhantomData }) @@ -252,10 +270,16 @@ impl<'a, T: PrimitiveTensorElementType + Debug> TensorRefMut<'a, T> { ]; Ok(TensorRefMut::new(Value { - inner: Arc::new(ValueInner::CppOwned { + inner: Arc::new(ValueInner { ptr: unsafe { NonNull::new_unchecked(value_ptr) }, + dtype: ValueType::Tensor { + ty: T::into_tensor_element_type(), + dimensions: shape, + dimension_symbols: vec![None; shape_len] + }, drop: true, - _session: None + memory_info: Some(info), + _backing: None }), _markers: PhantomData })) diff --git a/src/value/impl_tensor/extract.rs b/src/value/impl_tensor/extract.rs index 02106a39..3026f2e6 100644 --- a/src/value/impl_tensor/extract.rs +++ b/src/value/impl_tensor/extract.rs @@ -48,15 +48,14 @@ impl Value { pub fn try_extract_tensor(&self) -> Result> { use crate::AsPointer; - let dtype = self.dtype(); - match dtype { - ValueType::Tensor { ty, dimensions } => { + match self.dtype() { + ValueType::Tensor { ty, dimensions, .. } => { let mem = self.memory_info(); if !mem.is_cpu_accessible() { return Err(Error::new(format!("Cannot extract from value on device `{}`, which is not CPU accessible", mem.allocation_device().as_str()))); } - if ty == T::into_tensor_element_type() { + if *ty == T::into_tensor_element_type() { Ok(extract_primitive_array(IxDyn(&dimensions.iter().map(|&n| n as usize).collect::>()), self.ptr())?) } else { Err(Error::new_with_code( @@ -93,15 +92,14 @@ impl Value { /// /// [`DynValue`]: crate::value::DynValue pub fn try_extract_scalar(&self) -> Result { - let dtype = self.dtype(); - match dtype { - ValueType::Tensor { ty, dimensions } => { + match self.dtype() { + ValueType::Tensor { ty, dimensions, .. } => { let mem = self.memory_info(); if !mem.is_cpu_accessible() { return Err(Error::new(format!("Cannot extract from value on device `{}`, which is not CPU accessible", mem.allocation_device().as_str()))); } - if ty == T::into_tensor_element_type() { + if *ty == T::into_tensor_element_type() { if !dimensions.is_empty() { return Err(Error::new_with_code( ErrorCode::InvalidArgument, @@ -158,15 +156,14 @@ impl Value { #[cfg(feature = "ndarray")] #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] pub fn try_extract_tensor_mut(&mut self) -> Result> { - let dtype = self.dtype(); - match dtype { - ValueType::Tensor { ty, dimensions } => { + match self.dtype() { + ValueType::Tensor { ty, dimensions, .. } => { let mem = self.memory_info(); if !mem.is_cpu_accessible() { return Err(Error::new(format!("Cannot extract from value on device `{}`, which is not CPU accessible", mem.allocation_device().as_str()))); } - if ty == T::into_tensor_element_type() { + if *ty == T::into_tensor_element_type() { Ok(extract_primitive_array_mut(IxDyn(&dimensions.iter().map(|&n| n as usize).collect::>()), self.ptr_mut())?) } else { Err(Error::new_with_code( @@ -207,22 +204,21 @@ impl Value { /// - The provided type `T` does not match the tensor's element type. /// /// [`DynValue`]: crate::value::DynValue - pub fn try_extract_raw_tensor(&self) -> Result<(Vec, &[T])> { - let dtype = self.dtype(); - match dtype { - ValueType::Tensor { ty, dimensions } => { + pub fn try_extract_raw_tensor(&self) -> Result<(&[i64], &[T])> { + match self.dtype() { + ValueType::Tensor { ty, dimensions, .. } => { let mem = self.memory_info(); if !mem.is_cpu_accessible() { return Err(Error::new(format!("Cannot extract from value on device `{}`, which is not CPU accessible", mem.allocation_device().as_str()))); } - if ty == T::into_tensor_element_type() { + if *ty == T::into_tensor_element_type() { let mut output_array_ptr: *mut T = ptr::null_mut(); let output_array_ptr_ptr: *mut *mut T = &mut output_array_ptr; let output_array_ptr_ptr_void: *mut *mut std::ffi::c_void = output_array_ptr_ptr.cast(); ortsys![unsafe GetTensorMutableData(self.ptr().cast_mut(), output_array_ptr_ptr_void)?; nonNull(output_array_ptr)]; - let len = calculate_tensor_size(&dimensions); + let len = calculate_tensor_size(dimensions); Ok((dimensions, unsafe { std::slice::from_raw_parts(output_array_ptr, len) })) } else { Err(Error::new_with_code( @@ -260,22 +256,22 @@ impl Value { /// - The provided type `T` does not match the tensor's element type. /// /// [`DynValue`]: crate::value::DynValue - pub fn try_extract_raw_tensor_mut(&mut self) -> Result<(Vec, &mut [T])> { + pub fn try_extract_raw_tensor_mut(&mut self) -> Result<(&[i64], &mut [T])> { let dtype = self.dtype(); match dtype { - ValueType::Tensor { ty, dimensions } => { + ValueType::Tensor { ty, dimensions, .. } => { let mem = self.memory_info(); if !mem.is_cpu_accessible() { return Err(Error::new(format!("Cannot extract from value on device `{}`, which is not CPU accessible", mem.allocation_device().as_str()))); } - if ty == T::into_tensor_element_type() { + if *ty == T::into_tensor_element_type() { let mut output_array_ptr: *mut T = ptr::null_mut(); let output_array_ptr_ptr: *mut *mut T = &mut output_array_ptr; let output_array_ptr_ptr_void: *mut *mut std::ffi::c_void = output_array_ptr_ptr.cast(); - ortsys![unsafe GetTensorMutableData(self.ptr_mut(), output_array_ptr_ptr_void)?; nonNull(output_array_ptr)]; + ortsys![unsafe GetTensorMutableData(self.ptr().cast_mut(), output_array_ptr_ptr_void)?; nonNull(output_array_ptr)]; - let len = calculate_tensor_size(&dimensions); + let len = calculate_tensor_size(dimensions); Ok((dimensions, unsafe { std::slice::from_raw_parts_mut(output_array_ptr, len) })) } else { Err(Error::new_with_code( @@ -304,16 +300,15 @@ impl Value { #[cfg(feature = "ndarray")] #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] pub fn try_extract_string_tensor(&self) -> Result> { - let dtype = self.dtype(); - match dtype { - ValueType::Tensor { ty, dimensions } => { + match self.dtype() { + ValueType::Tensor { ty, dimensions, .. } => { let mem = self.memory_info(); if !mem.is_cpu_accessible() { return Err(Error::new(format!("Cannot extract from value on device `{}`, which is not CPU accessible", mem.allocation_device().as_str()))); } - if ty == TensorElementType::String { - let len = calculate_tensor_size(&dimensions); + if *ty == TensorElementType::String { + let len = calculate_tensor_size(dimensions); // Total length of string data, not including \0 suffix let mut total_length = 0; @@ -370,17 +365,16 @@ impl Value { /// # Ok(()) /// # } /// ``` - pub fn try_extract_raw_string_tensor(&self) -> Result<(Vec, Vec)> { - let dtype = self.dtype(); - match dtype { - ValueType::Tensor { ty, dimensions } => { + pub fn try_extract_raw_string_tensor(&self) -> Result<(&[i64], Vec)> { + match self.dtype() { + ValueType::Tensor { ty, dimensions, .. } => { let mem = self.memory_info(); if !mem.is_cpu_accessible() { return Err(Error::new(format!("Cannot extract from value on device `{}`, which is not CPU accessible", mem.allocation_device().as_str()))); } - if ty == TensorElementType::String { - let len = calculate_tensor_size(&dimensions); + if *ty == TensorElementType::String { + let len = calculate_tensor_size(dimensions); // Total length of string data, not including \0 suffix let mut total_length = 0; @@ -512,7 +506,7 @@ impl Tensor { /// # Ok(()) /// # } /// ``` - pub fn extract_raw_tensor(&self) -> (Vec, &[T]) { + pub fn extract_raw_tensor(&self) -> (&[i64], &[T]) { self.try_extract_raw_tensor().expect("Failed to extract tensor") } @@ -531,7 +525,7 @@ impl Tensor { /// # Ok(()) /// # } /// ``` - pub fn extract_raw_tensor_mut(&mut self) -> (Vec, &mut [T]) { + pub fn extract_raw_tensor_mut(&mut self) -> (&[i64], &mut [T]) { self.try_extract_raw_tensor_mut().expect("Failed to extract tensor") } } diff --git a/src/value/impl_tensor/mod.rs b/src/value/impl_tensor/mod.rs index fb2be71f..34375b0d 100644 --- a/src/value/impl_tensor/mod.rs +++ b/src/value/impl_tensor/mod.rs @@ -5,7 +5,6 @@ use std::{ fmt::Debug, marker::PhantomData, ops::{Index, IndexMut}, - ptr::NonNull, sync::Arc }; @@ -137,11 +136,8 @@ impl Value { /// # Ok(()) /// # } /// ``` - pub fn memory_info(&self) -> MemoryInfo { - let mut memory_info_ptr: *const ort_sys::OrtMemoryInfo = std::ptr::null_mut(); - // infallible, and `memory_info_ptr` will never be null - ortsys![unsafe GetTensorMemoryInfo(self.ptr(), &mut memory_info_ptr)]; - MemoryInfo::from_raw(unsafe { NonNull::new_unchecked(memory_info_ptr.cast_mut()) }, false) + pub fn memory_info(&self) -> &MemoryInfo { + unsafe { self.inner.memory_info.as_ref().unwrap_unchecked() } } } @@ -282,11 +278,11 @@ mod tests { fn test_tensor_value() -> crate::Result<()> { let v: Vec = vec![1., 2., 3., 4., 5.]; let value = Tensor::from_array(Array1::from_vec(v.clone()))?; - assert!(value.is_tensor()?); assert_eq!(value.dtype().tensor_type(), Some(TensorElementType::Float32)); - assert_eq!(value.dtype(), ValueType::Tensor { + assert_eq!(value.dtype(), &ValueType::Tensor { ty: TensorElementType::Float32, - dimensions: vec![v.len() as i64] + dimensions: vec![v.len() as i64], + dimension_symbols: vec![None] }); let (shape, data) = value.extract_raw_tensor(); diff --git a/src/value/mod.rs b/src/value/mod.rs index 6be11f9a..e10ab591 100644 --- a/src/value/mod.rs +++ b/src/value/mod.rs @@ -18,8 +18,9 @@ use std::{ any::Any, - fmt::{self, Debug}, + fmt::Debug, marker::PhantomData, + mem::transmute, ops::{Deref, DerefMut}, ptr::NonNull, sync::Arc @@ -28,240 +29,46 @@ use std::{ mod impl_map; mod impl_sequence; mod impl_tensor; +mod r#type; pub use self::{ impl_map::{DynMap, DynMapRef, DynMapRefMut, DynMapValueType, Map, MapRef, MapRefMut, MapValueType, MapValueTypeMarker}, impl_sequence::{ DynSequence, DynSequenceRef, DynSequenceRefMut, DynSequenceValueType, Sequence, SequenceRef, SequenceRefMut, SequenceValueType, SequenceValueTypeMarker }, - impl_tensor::{DynTensor, DynTensorRef, DynTensorRefMut, DynTensorValueType, Tensor, TensorRef, TensorRefMut, TensorValueType, TensorValueTypeMarker} + impl_tensor::{DynTensor, DynTensorRef, DynTensorRefMut, DynTensorValueType, Tensor, TensorRef, TensorRefMut, TensorValueType, TensorValueTypeMarker}, + r#type::ValueType }; use crate::{ AsPointer, error::{Error, ErrorCode, Result}, memory::MemoryInfo, ortsys, - session::SharedSessionInner, - tensor::TensorElementType + session::SharedSessionInner }; -/// The type of a [`Value`], or a session input/output. -/// -/// ``` -/// # use std::sync::Arc; -/// # use ort::{session::Session, value::{ValueType, Tensor}, tensor::TensorElementType}; -/// # fn main() -> ort::Result<()> { -/// # let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; -/// // `ValueType`s can be obtained from session inputs/outputs: -/// let input = &session.inputs[0]; -/// assert_eq!(input.input_type, ValueType::Tensor { -/// ty: TensorElementType::Float32, -/// // Our model has 3 dynamic dimensions, represented by -1 -/// dimensions: vec![-1, -1, -1, 3] -/// }); -/// -/// // Or by `Value`s created in Rust or output by a session. -/// let value = Tensor::from_array(([5usize], vec![1_i64, 2, 3, 4, 5].into_boxed_slice()))?; -/// assert_eq!(value.dtype(), ValueType::Tensor { -/// ty: TensorElementType::Int64, -/// dimensions: vec![5] -/// }); -/// # Ok(()) -/// # } -/// ``` -#[derive(Debug, PartialEq, Eq, Clone)] -pub enum ValueType { - /// Value is a tensor/multi-dimensional array. - Tensor { - /// Element type of the tensor. - ty: TensorElementType, - /// Dimensions of the tensor. If an exact dimension is not known (i.e. a dynamic dimension as part of an - /// [`Input`]/[`Output`]), the dimension will be `-1`. - /// - /// Actual tensor values, which have a known dimension, will always have positive (>1) dimensions. - /// - /// [`Input`]: crate::session::Input - /// [`Output`]: crate::session::Output - dimensions: Vec - }, - /// A sequence (vector) of other `Value`s. - /// - /// [Per ONNX spec](https://onnx.ai/onnx/intro/concepts.html#other-types), only sequences of tensors and maps are allowed. - Sequence(Box), - /// A map/dictionary from one element type to another. - Map { - /// The map key type. Allowed types are: - /// - [`TensorElementType::Int8`] - /// - [`TensorElementType::Int16`] - /// - [`TensorElementType::Int32`] - /// - [`TensorElementType::Int64`] - /// - [`TensorElementType::Uint8`] - /// - [`TensorElementType::Uint16`] - /// - [`TensorElementType::Uint32`] - /// - [`TensorElementType::Uint64`] - /// - [`TensorElementType::String`] - key: TensorElementType, - /// The map value type. - value: TensorElementType - }, - /// An optional value, which may or may not contain a [`Value`]. - Optional(Box) -} - -impl ValueType { - pub(crate) fn from_type_info(typeinfo_ptr: *mut ort_sys::OrtTypeInfo) -> Self { - let mut ty: ort_sys::ONNXType = ort_sys::ONNXType::ONNX_TYPE_UNKNOWN; - ortsys![unsafe GetOnnxTypeFromTypeInfo(typeinfo_ptr, &mut ty)]; // infallible - let io_type = match ty { - ort_sys::ONNXType::ONNX_TYPE_TENSOR | ort_sys::ONNXType::ONNX_TYPE_SPARSETENSOR => { - let mut info_ptr: *const ort_sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut(); - ortsys![unsafe CastTypeInfoToTensorInfo(typeinfo_ptr, &mut info_ptr)]; // infallible - unsafe { extract_data_type_from_tensor_info(info_ptr) } - } - ort_sys::ONNXType::ONNX_TYPE_SEQUENCE => { - let mut info_ptr: *const ort_sys::OrtSequenceTypeInfo = std::ptr::null_mut(); - ortsys![unsafe CastTypeInfoToSequenceTypeInfo(typeinfo_ptr, &mut info_ptr)]; // infallible - unsafe { extract_data_type_from_sequence_info(info_ptr) } - } - ort_sys::ONNXType::ONNX_TYPE_MAP => { - let mut info_ptr: *const ort_sys::OrtMapTypeInfo = std::ptr::null_mut(); - ortsys![unsafe CastTypeInfoToMapTypeInfo(typeinfo_ptr, &mut info_ptr)]; // infallible - unsafe { extract_data_type_from_map_info(info_ptr) } - } - ort_sys::ONNXType::ONNX_TYPE_OPTIONAL => { - let mut info_ptr: *const ort_sys::OrtOptionalTypeInfo = std::ptr::null_mut(); - ortsys![unsafe CastTypeInfoToOptionalTypeInfo(typeinfo_ptr, &mut info_ptr)]; // infallible - - let mut contained_type: *mut ort_sys::OrtTypeInfo = std::ptr::null_mut(); - ortsys![unsafe GetOptionalContainedTypeInfo(info_ptr, &mut contained_type)]; // infallible - - ValueType::Optional(Box::new(ValueType::from_type_info(contained_type))) - } - _ => unreachable!() - }; - ortsys![unsafe ReleaseTypeInfo(typeinfo_ptr)]; - io_type - } - /// Returns the dimensions of this value type if it is a tensor, or `None` if it is a sequence or map. - /// - /// ``` - /// # use ort::value::Tensor; - /// # fn main() -> ort::Result<()> { - /// let value = Tensor::from_array(([5usize], vec![1_i64, 2, 3, 4, 5].into_boxed_slice()))?; - /// assert_eq!(value.dtype().tensor_dimensions(), Some(&vec![5])); - /// # Ok(()) - /// # } - /// ``` - #[must_use] - pub fn tensor_dimensions(&self) -> Option<&Vec> { - match self { - ValueType::Tensor { dimensions, .. } => Some(dimensions), - _ => None - } - } - - /// Returns the element type of this value type if it is a tensor, or `None` if it is a sequence or map. - /// - /// ``` - /// # use ort::{tensor::TensorElementType, value::Tensor}; - /// # fn main() -> ort::Result<()> { - /// let value = Tensor::from_array(([5usize], vec![1_i64, 2, 3, 4, 5].into_boxed_slice()))?; - /// assert_eq!(value.dtype().tensor_type(), Some(TensorElementType::Int64)); - /// # Ok(()) - /// # } - /// ``` - #[must_use] - pub fn tensor_type(&self) -> Option { - match self { - ValueType::Tensor { ty, .. } => Some(*ty), - _ => None - } - } - - /// Returns `true` if this value type is a tensor. - #[inline] - #[must_use] - pub fn is_tensor(&self) -> bool { - matches!(self, ValueType::Tensor { .. }) - } - - /// Returns `true` if this value type is a sequence. - #[inline] - #[must_use] - pub fn is_sequence(&self) -> bool { - matches!(self, ValueType::Sequence { .. }) - } - - /// Returns `true` if this value type is a map. - #[inline] - #[must_use] - pub fn is_map(&self) -> bool { - matches!(self, ValueType::Map { .. }) - } -} - -impl fmt::Display for ValueType { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - ValueType::Tensor { ty, dimensions } => { - write!( - f, - "Tensor<{ty}>({})", - dimensions - .iter() - .map(|c| if *c == -1 { "dyn".to_string() } else { c.to_string() }) - .collect::>() - .join(", ") - ) - } - ValueType::Map { key, value } => write!(f, "Map<{key}, {value}>"), - ValueType::Sequence(inner) => write!(f, "Sequence<{inner}>"), - ValueType::Optional(inner) => write!(f, "Option<{inner}>") - } - } -} - #[derive(Debug)] -pub(crate) enum ValueInner { - RustOwned { - ptr: NonNull, - _array: Box, - /// Hold onto the `MemoryInfo` that we create in `Value::from_array`. - _memory_info: Option - }, - CppOwned { - ptr: NonNull, - /// Whether to release the value pointer on drop. - drop: bool, - /// Hold [`SharedSessionInner`] to ensure that the value can stay alive after the main session is dropped. - /// - /// This may be `None` if the value is created outside of a session or if the value does not need to hold onto - /// the session reference. In the case of sequence/map values, we forego this because: - /// - a map value can be created independently of a session, and thus we wouldn't have anything to hold on to; - /// - this is only ever used by `ValueRef`s, whos owner value (which *is* holding the session Arc) will outlive - /// it. - _session: Option> - } +pub(crate) struct ValueInner { + pub(crate) ptr: NonNull, + pub(crate) dtype: ValueType, + pub(crate) memory_info: Option, + pub(crate) drop: bool, + pub(crate) _backing: Option> } impl AsPointer for ValueInner { type Sys = ort_sys::OrtValue; fn ptr(&self) -> *const Self::Sys { - match self { - ValueInner::CppOwned { ptr, .. } | ValueInner::RustOwned { ptr, .. } => ptr.as_ptr() - } + self.ptr.as_ptr() } } impl Drop for ValueInner { fn drop(&mut self) { let ptr = self.ptr_mut(); - tracing::trace!("dropping {} value at {ptr:p}", match self { - ValueInner::RustOwned { .. } => "rust-owned", - ValueInner::CppOwned { .. } => "cpp-owned" - }); - if !matches!(self, ValueInner::CppOwned { drop: false, .. }) { + tracing::trace!("dropping value at {ptr:p}"); + if self.drop { ortsys![unsafe ReleaseValue(ptr)]; } } @@ -284,7 +91,7 @@ impl<'v, Type: ValueTypeMarker + ?Sized> ValueRef<'v, Type> { #[inline] pub fn downcast(self) -> Result> { let dt = self.dtype(); - if OtherType::can_downcast(&dt) { + if OtherType::can_downcast(dt) { Ok(unsafe { std::mem::transmute::, ValueRef<'v, OtherType>>(self) }) } else { Err(Error::new_with_code(ErrorCode::InvalidArgument, format!("Cannot downcast &{dt} to &{}", OtherType::format()))) @@ -295,10 +102,7 @@ impl<'v, Type: ValueTypeMarker + ?Sized> ValueRef<'v, Type> { pub fn try_upgrade(self) -> Result, Self> { // We cannot upgade a value which we cannot drop, i.e. `ValueRef`s used in operator kernels. Those only last for the // duration of the kernel, allowing an upgrade would allow a UAF. - if match &*self.inner.inner { - ValueInner::CppOwned { drop, .. } => !drop, - _ => false - } { + if !self.inner.inner.drop { return Err(self); } @@ -335,7 +139,7 @@ impl<'v, Type: ValueTypeMarker + ?Sized> ValueRefMut<'v, Type> { #[inline] pub fn downcast(self) -> Result> { let dt = self.dtype(); - if OtherType::can_downcast(&dt) { + if OtherType::can_downcast(dt) { Ok(unsafe { std::mem::transmute::, ValueRefMut<'v, OtherType>>(self) }) } else { Err(Error::new_with_code(ErrorCode::InvalidArgument, format!("Cannot downcast &mut {dt} to &mut {}", OtherType::format()))) @@ -346,10 +150,7 @@ impl<'v, Type: ValueTypeMarker + ?Sized> ValueRefMut<'v, Type> { pub fn try_upgrade(self) -> Result, Self> { // We cannot upgade a value which we cannot drop, i.e. `ValueRef`s used in operator kernels. Those only last for the // duration of the kernel, allowing an upgrade would allow a UAF. - if match &*self.inner.inner { - ValueInner::CppOwned { drop, .. } => !drop, - _ => false - } { + if !self.inner.inner.drop { return Err(self); } @@ -478,14 +279,8 @@ unsafe impl Sync for Value {} impl Value { /// Returns the data type of this [`Value`]. - pub fn dtype(&self) -> ValueType { - let mut typeinfo_ptr: *mut ort_sys::OrtTypeInfo = std::ptr::null_mut(); - ortsys![unsafe GetTypeInfo(self.ptr(), &mut typeinfo_ptr)]; // infallible - // `typeinfo_ptr` may be null in exceptionally rare cases - if typeinfo_ptr.is_null() { - panic!("unexpected UNKNOWN value type info"); - } - ValueType::from_type_info(typeinfo_ptr) + pub fn dtype(&self) -> &ValueType { + &self.inner.dtype } /// Construct a [`Value`] from a C++ [`ort_sys::OrtValue`] pointer. @@ -501,8 +296,16 @@ impl Value { /// - `session` must be `Some` for values returned from a session. #[must_use] pub unsafe fn from_ptr(ptr: NonNull, session: Option>) -> Value { + let mut typeinfo_ptr = std::ptr::null_mut(); + ortsys![unsafe GetTypeInfo(ptr.as_ptr(), &mut typeinfo_ptr)]; Value { - inner: Arc::new(ValueInner::CppOwned { ptr, drop: true, _session: session }), + inner: Arc::new(ValueInner { + ptr, + memory_info: MemoryInfo::from_value(ptr.as_ptr()), + dtype: ValueType::from_type_info(typeinfo_ptr), + drop: true, + _backing: session.map(|v| Box::new(v) as Box) + }), _markers: PhantomData } } @@ -511,58 +314,67 @@ impl Value { /// contexts. #[must_use] pub(crate) unsafe fn from_ptr_nodrop(ptr: NonNull, session: Option>) -> Value { + let mut typeinfo_ptr = std::ptr::null_mut(); + ortsys![unsafe GetTypeInfo(ptr.as_ptr(), &mut typeinfo_ptr)]; Value { - inner: Arc::new(ValueInner::CppOwned { ptr, drop: false, _session: session }), + inner: Arc::new(ValueInner { + ptr, + memory_info: MemoryInfo::from_value(ptr.as_ptr()), + dtype: ValueType::from_type_info(typeinfo_ptr), + drop: false, + _backing: session.map(|v| Box::new(v) as Box) + }), _markers: PhantomData } } /// Create a view of this value's data. pub fn view(&self) -> ValueRef<'_, Type> { - ValueRef::new(Value { - inner: Arc::clone(&self.inner), - _markers: PhantomData - }) + ValueRef::new(Value::clone_of(self)) } /// Create a mutable view of this value's data. pub fn view_mut(&mut self) -> ValueRefMut<'_, Type> { - ValueRefMut::new(Value { - inner: Arc::clone(&self.inner), + ValueRefMut::new(Value::clone_of(self)) + } + + /// Converts this value into a type-erased [`DynValue`]. + pub fn into_dyn(self) -> DynValue { + unsafe { std::mem::transmute(self) } + } + + pub(crate) fn clone_of(value: &Self) -> Self { + Self { + inner: Arc::clone(&value.inner), _markers: PhantomData - }) + } } +} +impl Value { /// Returns `true` if this value is a tensor, or `false` if it is another type (sequence, map). /// /// ``` /// # use ort::value::Tensor; /// # fn main() -> ort::Result<()> { - /// // Create a tensor from a raw data vector /// let tensor_value = Tensor::from_array(([3usize], vec![1.0_f32, 2.0, 3.0].into_boxed_slice()))?; - /// assert!(tensor_value.is_tensor()?); + /// let dyn_value = tensor_value.into_dyn(); + /// assert!(dyn_value.is_tensor()); /// # Ok(()) /// # } /// ``` - pub fn is_tensor(&self) -> Result { + pub fn is_tensor(&self) -> bool { let mut result = 0; - ortsys![unsafe IsTensor(self.ptr(), &mut result)?]; - Ok(result == 1) + ortsys![unsafe IsTensor(self.ptr(), &mut result)]; // infallible + result == 1 } - /// Converts this value into a type-erased [`DynValue`]. - pub fn into_dyn(self) -> DynValue { - unsafe { std::mem::transmute(self) } - } -} - -impl Value { /// Attempts to downcast a dynamic value (like [`DynValue`] or [`DynTensor`]) to a more strongly typed variant, /// like [`Tensor`]. #[inline] pub fn downcast(self) -> Result> { let dt = self.dtype(); - if OtherType::can_downcast(&dt) { + if OtherType::can_downcast(dt) { Ok(unsafe { std::mem::transmute::, Value>(self) }) } else { Err(Error::new_with_code(ErrorCode::InvalidArgument, format!("Cannot downcast {dt} to {}", OtherType::format()))) @@ -574,11 +386,8 @@ impl Value { #[inline] pub fn downcast_ref(&self) -> Result> { let dt = self.dtype(); - if OtherType::can_downcast(&dt) { - Ok(ValueRef::new(Value { - inner: Arc::clone(&self.inner), - _markers: PhantomData - })) + if OtherType::can_downcast(dt) { + Ok(ValueRef::new(unsafe { transmute::>(Value::clone_of(self)) })) } else { Err(Error::new_with_code(ErrorCode::InvalidArgument, format!("Cannot downcast &{dt} to &{}", OtherType::format()))) } @@ -589,11 +398,8 @@ impl Value { #[inline] pub fn downcast_mut(&mut self) -> Result> { let dt = self.dtype(); - if OtherType::can_downcast(&dt) { - Ok(ValueRefMut::new(Value { - inner: Arc::clone(&self.inner), - _markers: PhantomData - })) + if OtherType::can_downcast(dt) { + Ok(ValueRefMut::new(unsafe { transmute::>(Value::clone_of(self)) })) } else { Err(Error::new_with_code(ErrorCode::InvalidArgument, format!("Cannot downcast &mut {dt} to &mut {}", OtherType::format()))) } @@ -608,66 +414,6 @@ impl AsPointer for Value { } } -pub(crate) unsafe fn extract_data_type_from_tensor_info(info_ptr: *const ort_sys::OrtTensorTypeAndShapeInfo) -> ValueType { - let mut type_sys = ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - ortsys![GetTensorElementType(info_ptr, &mut type_sys)]; - assert_ne!(type_sys, ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); - // This transmute should be safe since its value is read from GetTensorElementType, which we must trust - let mut num_dims = 0; - ortsys![GetDimensionsCount(info_ptr, &mut num_dims)]; - - let mut node_dims: Vec = vec![0; num_dims]; - ortsys![GetDimensions(info_ptr, node_dims.as_mut_ptr(), num_dims)]; - - ValueType::Tensor { - ty: type_sys.into(), - dimensions: node_dims - } -} - -pub(crate) unsafe fn extract_data_type_from_sequence_info(info_ptr: *const ort_sys::OrtSequenceTypeInfo) -> ValueType { - let mut element_type_info: *mut ort_sys::OrtTypeInfo = std::ptr::null_mut(); - ortsys![GetSequenceElementType(info_ptr, &mut element_type_info)]; // infallible - - let mut ty: ort_sys::ONNXType = ort_sys::ONNXType::ONNX_TYPE_UNKNOWN; - ortsys![GetOnnxTypeFromTypeInfo(element_type_info, &mut ty)]; // infallible - - match ty { - ort_sys::ONNXType::ONNX_TYPE_TENSOR => { - let mut info_ptr: *const ort_sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut(); - ortsys![CastTypeInfoToTensorInfo(element_type_info, &mut info_ptr)]; // infallible - let ty = extract_data_type_from_tensor_info(info_ptr); - ValueType::Sequence(Box::new(ty)) - } - ort_sys::ONNXType::ONNX_TYPE_MAP => { - let mut info_ptr: *const ort_sys::OrtMapTypeInfo = std::ptr::null_mut(); - ortsys![CastTypeInfoToMapTypeInfo(element_type_info, &mut info_ptr)]; // infallible - let ty = extract_data_type_from_map_info(info_ptr); - ValueType::Sequence(Box::new(ty)) - } - _ => unreachable!() - } -} - -pub(crate) unsafe fn extract_data_type_from_map_info(info_ptr: *const ort_sys::OrtMapTypeInfo) -> ValueType { - let mut key_type_sys = ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - ortsys![GetMapKeyType(info_ptr, &mut key_type_sys)]; // infallible - assert_ne!(key_type_sys, ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); - - let mut value_type_info: *mut ort_sys::OrtTypeInfo = std::ptr::null_mut(); - ortsys![GetMapValueType(info_ptr, &mut value_type_info)]; // infallible - let mut value_info_ptr: *const ort_sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut(); - ortsys![unsafe CastTypeInfoToTensorInfo(value_type_info, &mut value_info_ptr)]; // infallible - let mut value_type_sys = ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - ortsys![GetTensorElementType(value_info_ptr, &mut value_type_sys)]; // infallible - assert_ne!(value_type_sys, ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); - - ValueType::Map { - key: key_type_sys.into(), - value: value_type_sys.into() - } -} - #[cfg(test)] mod tests { use super::{DynTensorValueType, Map, Sequence, Tensor, TensorRef, TensorRefMut, TensorValueType}; diff --git a/src/value/type.rs b/src/value/type.rs new file mode 100644 index 00000000..ef79b16c --- /dev/null +++ b/src/value/type.rs @@ -0,0 +1,254 @@ +use std::{ + ffi::{CStr, c_char}, + fmt, ptr +}; + +use crate::{ortsys, tensor::TensorElementType}; + +/// The type of a [`Value`], or a session input/output. +/// +/// ``` +/// # use std::sync::Arc; +/// # use ort::{session::Session, value::{ValueType, Tensor}, tensor::TensorElementType}; +/// # fn main() -> ort::Result<()> { +/// # let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; +/// // `ValueType`s can be obtained from session inputs/outputs: +/// let input = &session.inputs[0]; +/// assert_eq!(input.input_type, ValueType::Tensor { +/// ty: TensorElementType::Float32, +/// // Our model has 3 dynamic dimensions, represented by -1 +/// dimensions: vec![-1, -1, -1, 3] +/// }); +/// +/// // Or by `Value`s created in Rust or output by a session. +/// let value = Tensor::from_array(([5usize], vec![1_i64, 2, 3, 4, 5].into_boxed_slice()))?; +/// assert_eq!(value.dtype(), ValueType::Tensor { +/// ty: TensorElementType::Int64, +/// dimensions: vec![5] +/// }); +/// # Ok(()) +/// # } +/// ``` +#[derive(Debug, PartialEq, Eq, Clone)] +pub enum ValueType { + /// Value is a tensor/multi-dimensional array. + Tensor { + /// Element type of the tensor. + ty: TensorElementType, + /// Dimensions of the tensor. If an exact dimension is not known (i.e. a dynamic dimension as part of an + /// [`Input`]/[`Output`]), the dimension will be `-1`. + /// + /// Actual tensor values, which have a known dimension, will always have positive (>1) dimensions. + /// + /// [`Input`]: crate::session::Input + /// [`Output`]: crate::session::Output + dimensions: Vec, + dimension_symbols: Vec> + }, + /// A sequence (vector) of other `Value`s. + /// + /// [Per ONNX spec](https://onnx.ai/onnx/intro/concepts.html#other-types), only sequences of tensors and maps are allowed. + Sequence(Box), + /// A map/dictionary from one element type to another. + Map { + /// The map key type. Allowed types are: + /// - [`TensorElementType::Int8`] + /// - [`TensorElementType::Int16`] + /// - [`TensorElementType::Int32`] + /// - [`TensorElementType::Int64`] + /// - [`TensorElementType::Uint8`] + /// - [`TensorElementType::Uint16`] + /// - [`TensorElementType::Uint32`] + /// - [`TensorElementType::Uint64`] + /// - [`TensorElementType::String`] + key: TensorElementType, + /// The map value type. + value: TensorElementType + }, + /// An optional value, which may or may not contain a [`Value`]. + Optional(Box) +} + +impl ValueType { + pub(crate) fn from_type_info(typeinfo_ptr: *mut ort_sys::OrtTypeInfo) -> Self { + let mut ty: ort_sys::ONNXType = ort_sys::ONNXType::ONNX_TYPE_UNKNOWN; + ortsys![unsafe GetOnnxTypeFromTypeInfo(typeinfo_ptr, &mut ty)]; // infallible + let io_type = match ty { + ort_sys::ONNXType::ONNX_TYPE_TENSOR | ort_sys::ONNXType::ONNX_TYPE_SPARSETENSOR => { + let mut info_ptr: *const ort_sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut(); + ortsys![unsafe CastTypeInfoToTensorInfo(typeinfo_ptr, &mut info_ptr)]; // infallible + unsafe { extract_data_type_from_tensor_info(info_ptr) } + } + ort_sys::ONNXType::ONNX_TYPE_SEQUENCE => { + let mut info_ptr: *const ort_sys::OrtSequenceTypeInfo = std::ptr::null_mut(); + ortsys![unsafe CastTypeInfoToSequenceTypeInfo(typeinfo_ptr, &mut info_ptr)]; // infallible + + let mut element_type_info: *mut ort_sys::OrtTypeInfo = std::ptr::null_mut(); + ortsys![unsafe GetSequenceElementType(info_ptr, &mut element_type_info)]; // infallible + + let mut ty: ort_sys::ONNXType = ort_sys::ONNXType::ONNX_TYPE_UNKNOWN; + ortsys![unsafe GetOnnxTypeFromTypeInfo(element_type_info, &mut ty)]; // infallible + + match ty { + ort_sys::ONNXType::ONNX_TYPE_TENSOR => { + let mut info_ptr: *const ort_sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut(); + ortsys![unsafe CastTypeInfoToTensorInfo(element_type_info, &mut info_ptr)]; // infallible + let ty = unsafe { extract_data_type_from_tensor_info(info_ptr) }; + ValueType::Sequence(Box::new(ty)) + } + ort_sys::ONNXType::ONNX_TYPE_MAP => { + let mut info_ptr: *const ort_sys::OrtMapTypeInfo = std::ptr::null_mut(); + ortsys![unsafe CastTypeInfoToMapTypeInfo(element_type_info, &mut info_ptr)]; // infallible + let ty = unsafe { extract_data_type_from_map_info(info_ptr) }; + ValueType::Sequence(Box::new(ty)) + } + _ => unreachable!() + } + } + ort_sys::ONNXType::ONNX_TYPE_MAP => { + let mut info_ptr: *const ort_sys::OrtMapTypeInfo = std::ptr::null_mut(); + ortsys![unsafe CastTypeInfoToMapTypeInfo(typeinfo_ptr, &mut info_ptr)]; // infallible + unsafe { extract_data_type_from_map_info(info_ptr) } + } + ort_sys::ONNXType::ONNX_TYPE_OPTIONAL => { + let mut info_ptr: *const ort_sys::OrtOptionalTypeInfo = std::ptr::null_mut(); + ortsys![unsafe CastTypeInfoToOptionalTypeInfo(typeinfo_ptr, &mut info_ptr)]; // infallible + + let mut contained_type: *mut ort_sys::OrtTypeInfo = std::ptr::null_mut(); + ortsys![unsafe GetOptionalContainedTypeInfo(info_ptr, &mut contained_type)]; // infallible + + ValueType::Optional(Box::new(ValueType::from_type_info(contained_type))) + } + _ => unreachable!() + }; + ortsys![unsafe ReleaseTypeInfo(typeinfo_ptr)]; + io_type + } + /// Returns the dimensions of this value type if it is a tensor, or `None` if it is a sequence or map. + /// + /// ``` + /// # use ort::value::Tensor; + /// # fn main() -> ort::Result<()> { + /// let value = Tensor::from_array(([5usize], vec![1_i64, 2, 3, 4, 5].into_boxed_slice()))?; + /// assert_eq!(value.dtype().tensor_dimensions(), Some(&vec![5])); + /// # Ok(()) + /// # } + /// ``` + #[must_use] + pub fn tensor_dimensions(&self) -> Option<&Vec> { + match self { + ValueType::Tensor { dimensions, .. } => Some(dimensions), + _ => None + } + } + + /// Returns the element type of this value type if it is a tensor, or `None` if it is a sequence or map. + /// + /// ``` + /// # use ort::{tensor::TensorElementType, value::Tensor}; + /// # fn main() -> ort::Result<()> { + /// let value = Tensor::from_array(([5usize], vec![1_i64, 2, 3, 4, 5].into_boxed_slice()))?; + /// assert_eq!(value.dtype().tensor_type(), Some(TensorElementType::Int64)); + /// # Ok(()) + /// # } + /// ``` + #[must_use] + pub fn tensor_type(&self) -> Option { + match self { + ValueType::Tensor { ty, .. } => Some(*ty), + _ => None + } + } + + /// Returns `true` if this value type is a tensor. + #[inline] + #[must_use] + pub fn is_tensor(&self) -> bool { + matches!(self, ValueType::Tensor { .. }) + } + + /// Returns `true` if this value type is a sequence. + #[inline] + #[must_use] + pub fn is_sequence(&self) -> bool { + matches!(self, ValueType::Sequence { .. }) + } + + /// Returns `true` if this value type is a map. + #[inline] + #[must_use] + pub fn is_map(&self) -> bool { + matches!(self, ValueType::Map { .. }) + } +} + +impl fmt::Display for ValueType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + ValueType::Tensor { ty, dimensions, dimension_symbols } => { + write!( + f, + "Tensor<{ty}>({})", + dimensions + .iter() + .enumerate() + .map(|(i, c)| if *c == -1 { + dimension_symbols[i].clone().unwrap_or_else(|| String::from("dyn")) + } else { + c.to_string() + }) + .collect::>() + .join(", ") + ) + } + ValueType::Map { key, value } => write!(f, "Map<{key}, {value}>"), + ValueType::Sequence(inner) => write!(f, "Sequence<{inner}>"), + ValueType::Optional(inner) => write!(f, "Option<{inner}>") + } + } +} + +unsafe fn extract_data_type_from_tensor_info(info_ptr: *const ort_sys::OrtTensorTypeAndShapeInfo) -> ValueType { + let mut type_sys = ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + ortsys![GetTensorElementType(info_ptr, &mut type_sys)]; + assert_ne!(type_sys, ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); + // This transmute should be safe since its value is read from GetTensorElementType, which we must trust + let mut num_dims = 0; + ortsys![GetDimensionsCount(info_ptr, &mut num_dims)]; + + let mut node_dims: Vec = vec![0; num_dims]; + ortsys![GetDimensions(info_ptr, node_dims.as_mut_ptr(), num_dims)]; + + let mut symbolic_dims: Vec<*const c_char> = vec![ptr::null(); num_dims]; + ortsys![GetSymbolicDimensions(info_ptr, symbolic_dims.as_mut_ptr(), num_dims)]; + + let dimension_symbols = symbolic_dims + .into_iter() + .map(|c| if !c.is_null() { CStr::from_ptr(c).to_str().ok().map(str::to_string) } else { None }) + .collect(); + + ValueType::Tensor { + ty: type_sys.into(), + dimensions: node_dims, + dimension_symbols + } +} + +unsafe fn extract_data_type_from_map_info(info_ptr: *const ort_sys::OrtMapTypeInfo) -> ValueType { + let mut key_type_sys = ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + ortsys![GetMapKeyType(info_ptr, &mut key_type_sys)]; // infallible + assert_ne!(key_type_sys, ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); + + let mut value_type_info: *mut ort_sys::OrtTypeInfo = std::ptr::null_mut(); + ortsys![GetMapValueType(info_ptr, &mut value_type_info)]; // infallible + let mut value_info_ptr: *const ort_sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut(); + ortsys![unsafe CastTypeInfoToTensorInfo(value_type_info, &mut value_info_ptr)]; // infallible + let mut value_type_sys = ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + ortsys![GetTensorElementType(value_info_ptr, &mut value_type_sys)]; // infallible + assert_ne!(value_type_sys, ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); + + ValueType::Map { + key: key_type_sys.into(), + value: value_type_sys.into() + } +}