diff --git a/examples/custom-ops/examples/custom-ops.rs b/examples/custom-ops/examples/custom-ops.rs index 1afdc73d..0b8c986a 100644 --- a/examples/custom-ops/examples/custom-ops.rs +++ b/examples/custom-ops/examples/custom-ops.rs @@ -39,9 +39,9 @@ impl Kernel for CustomOpOneKernel { let (x_shape, x) = x.try_extract_raw_tensor::()?; let (y_shape, y) = y.try_extract_raw_tensor::()?; - let mut z = ctx.output(0, x_shape)?.unwrap(); + let mut z = ctx.output(0, x_shape.to_vec())?.unwrap(); let (_, z_ref) = z.try_extract_raw_tensor_mut::()?; - for i in 0..y_shape.into_iter().reduce(|acc, e| acc * e).unwrap() as usize { + for i in 0..y_shape.iter().copied().reduce(|acc, e| acc * e).unwrap() as usize { if i % 2 == 0 { z_ref[i] = x[i]; } else { @@ -79,9 +79,9 @@ impl Kernel for CustomOpTwoKernel { fn compute(&mut self, ctx: &KernelContext) -> ort::Result<()> { let x = ctx.input(0)?.unwrap(); let (x_shape, x) = x.try_extract_raw_tensor::()?; - let mut z = ctx.output(0, x_shape.clone())?.unwrap(); + let mut z = ctx.output(0, x_shape.to_vec())?.unwrap(); let (_, z_ref) = z.try_extract_raw_tensor_mut::()?; - for i in 0..x_shape.into_iter().reduce(|acc, e| acc * e).unwrap() as usize { + for i in 0..x_shape.iter().copied().reduce(|acc, e| acc * e).unwrap() as usize { z_ref[i] = (x[i] * i as f32) as i32; } Ok(()) diff --git a/examples/model-info/examples/model-info.rs b/examples/model-info/examples/model-info.rs index 08c068db..8bceac49 100644 --- a/examples/model-info/examples/model-info.rs +++ b/examples/model-info/examples/model-info.rs @@ -1,44 +1,6 @@ use std::{env, process}; -use ort::{session::Session, tensor::TensorElementType, value::ValueType}; - -fn display_element_type(t: TensorElementType) -> &'static str { - match t { - TensorElementType::Bfloat16 => "bf16", - TensorElementType::Bool => "bool", - TensorElementType::Float16 => "f16", - TensorElementType::Float32 => "f32", - TensorElementType::Float64 => "f64", - TensorElementType::Int16 => "i16", - TensorElementType::Int32 => "i32", - TensorElementType::Int64 => "i64", - TensorElementType::Int8 => "i8", - TensorElementType::String => "str", - TensorElementType::Uint16 => "u16", - TensorElementType::Uint32 => "u32", - TensorElementType::Uint64 => "u64", - TensorElementType::Uint8 => "u8" - } -} - -fn display_value_type(value: &ValueType) -> String { - match value { - ValueType::Tensor { ty, dimensions } => { - format!( - "Tensor<{}>({})", - display_element_type(*ty), - dimensions - .iter() - .map(|c| if *c == -1 { "dyn".to_string() } else { c.to_string() }) - .collect::>() - .join(", ") - ) - } - ValueType::Map { key, value } => format!("Map<{}, {}>", display_element_type(*key), display_element_type(*value)), - ValueType::Sequence(inner) => format!("Sequence<{}>", display_value_type(inner)), - ValueType::Optional(inner) => format!("Option<{}>", display_value_type(inner)) - } -} +use ort::session::Session; fn main() -> ort::Result<()> { let Some(path) = env::args().nth(1) else { @@ -61,11 +23,11 @@ fn main() -> ort::Result<()> { println!("Inputs:"); for (i, input) in session.inputs.iter().enumerate() { - println!(" {i} {}: {}", input.name, display_value_type(&input.input_type)); + println!(" {i} {}: {}", input.name, input.input_type); } println!("Outputs:"); for (i, output) in session.outputs.iter().enumerate() { - println!(" {i} {}: {}", output.name, display_value_type(&output.output_type)); + println!(" {i} {}: {}", output.name, output.output_type); } Ok(()) diff --git a/src/io_binding.rs b/src/io_binding.rs index 96995cec..c4a627f1 100644 --- a/src/io_binding.rs +++ b/src/io_binding.rs @@ -4,7 +4,6 @@ use std::{ collections::HashMap, ffi::CString, fmt::Debug, - marker::PhantomData, ptr::{self, NonNull}, sync::Arc }; @@ -214,7 +213,7 @@ impl IoBinding { let run_options_ptr = if let Some(run_options) = run_options { run_options.ptr() } else { std::ptr::null() }; ortsys![unsafe RunWithBinding(self.session.ptr().cast_mut(), run_options_ptr, self.ptr())?]; - let owned_ptrs: HashMap<*mut ort_sys::OrtValue, &Arc> = self.output_values.values().map(|c| (c.ptr().cast_mut(), &c.inner)).collect(); + let owned_ptrs: HashMap<*mut ort_sys::OrtValue, &Value> = self.output_values.values().map(|c| (c.ptr().cast_mut(), c)).collect(); let mut count = self.output_names.len(); if count > 0 { let mut output_values_ptr: *mut *mut ort_sys::OrtValue = ptr::null_mut(); @@ -223,11 +222,8 @@ impl IoBinding { let output_values = unsafe { std::slice::from_raw_parts(output_values_ptr, count).to_vec() } .into_iter() .map(|v| unsafe { - if let Some(inner) = owned_ptrs.get(&v) { - DynValue { - inner: Arc::clone(*inner), - _markers: PhantomData - } + if let Some(value) = owned_ptrs.get(&v) { + DynValue::clone_of(value) } else { DynValue::from_ptr( NonNull::new(v).expect("OrtValue ptrs returned by GetBoundOutputValues should not be null"), diff --git a/src/memory.rs b/src/memory.rs index 809a3e27..984abcb3 100644 --- a/src/memory.rs +++ b/src/memory.rs @@ -400,6 +400,19 @@ impl MemoryInfo { }) } + pub(crate) fn from_value(value_ptr: *mut ort_sys::OrtValue) -> Option { + let mut is_tensor = 0; + ortsys![unsafe IsTensor(value_ptr, &mut is_tensor)]; // infallible + if is_tensor != 0 { + let mut memory_info_ptr: *const ort_sys::OrtMemoryInfo = std::ptr::null_mut(); + // infallible, and `memory_info_ptr` will never be null + ortsys![unsafe GetTensorMemoryInfo(value_ptr, &mut memory_info_ptr)]; + Some(Self::from_raw(unsafe { NonNull::new_unchecked(memory_info_ptr.cast_mut()) }, false)) + } else { + None + } + } + pub(crate) fn from_raw(ptr: NonNull, should_release: bool) -> Self { MemoryInfo { ptr, should_release } } diff --git a/src/operator/kernel.rs b/src/operator/kernel.rs index 8b5a665e..c2cbdd70 100644 --- a/src/operator/kernel.rs +++ b/src/operator/kernel.rs @@ -7,7 +7,7 @@ use std::{ use crate::{ AsPointer, error::{Error, Result, status_to_result}, - memory::{Allocator, MemoryInfo}, + memory::{Allocator, MemoryInfo, MemoryType}, ortsys, session::{Input, Output}, value::{DowncastableTarget, DynValue, Value, ValueRef, ValueRefMut, ValueType} @@ -89,6 +89,12 @@ impl KernelAttributes { ortsys![unsafe KernelInfo_GetNodeName(self.0.as_ptr(), name.as_mut_ptr().cast::(), &mut name_len)?]; CString::from_vec_with_nul(name).map_err(Error::wrap)?.into_string().map_err(Error::wrap) } + + pub fn allocator(&self, mem_type: MemoryType) -> Result { + let mut ptr: *mut ort_sys::OrtAllocator = ptr::null_mut(); + ortsys![unsafe KernelInfoGetAllocator(self.0.as_ptr(), mem_type.into(), &mut ptr)?]; + Ok(unsafe { Allocator::from_raw_unchecked(ptr) }) + } } impl AsPointer for KernelAttributes { diff --git a/src/operator/tests.rs b/src/operator/tests.rs index 23a29258..4c10cd73 100644 --- a/src/operator/tests.rs +++ b/src/operator/tests.rs @@ -41,9 +41,9 @@ impl Kernel for CustomOpOneKernel { let (x_shape, x) = x.try_extract_raw_tensor::()?; let (y_shape, y) = y.try_extract_raw_tensor::()?; - let mut z = ctx.output(0, x_shape)?.ok_or_else(|| crate::Error::new("missing input"))?; + let mut z = ctx.output(0, x_shape.to_vec())?.ok_or_else(|| crate::Error::new("missing input"))?; let (_, z_ref) = z.try_extract_raw_tensor_mut::()?; - for i in 0..y_shape.into_iter().reduce(|acc, e| acc * e).unwrap_or(0) as usize { + for i in 0..y_shape.iter().copied().reduce(|acc, e| acc * e).unwrap_or(0) as usize { if i % 2 == 0 { z_ref[i] = x[i]; } else { @@ -81,9 +81,9 @@ impl Kernel for CustomOpTwoKernel { fn compute(&mut self, ctx: &KernelContext) -> crate::Result<()> { let x = ctx.input(0)?.ok_or_else(|| crate::Error::new("missing input"))?; let (x_shape, x) = x.try_extract_raw_tensor::()?; - let mut z = ctx.output(0, x_shape.clone())?.ok_or_else(|| crate::Error::new("missing input"))?; + let mut z = ctx.output(0, x_shape.to_vec())?.ok_or_else(|| crate::Error::new("missing input"))?; let (_, z_ref) = z.try_extract_raw_tensor_mut::()?; - for i in 0..x_shape.into_iter().reduce(|acc, e| acc * e).unwrap_or(0) as usize { + for i in 0..x_shape.iter().copied().reduce(|acc, e| acc * e).unwrap_or(0) as usize { z_ref[i] = (x[i] * i as f32) as i32; } Ok(()) diff --git a/src/session/output.rs b/src/session/output.rs index 41714585..6dc9784b 100644 --- a/src/session/output.rs +++ b/src/session/output.rs @@ -1,11 +1,9 @@ use std::{ ffi::c_void, iter::FusedIterator, - marker::PhantomData, mem::ManuallyDrop, ops::{Index, IndexMut}, - ptr, - sync::Arc + ptr }; use crate::{ @@ -113,10 +111,7 @@ impl<'r, 's> SessionOutputs<'r, 's> { if &key == k { *k = ""; self.effective_len -= 1; - return Some(DynValue { - inner: Arc::clone(&self.values[i].inner), - _markers: PhantomData - }); + return Some(DynValue::clone_of(&self.values[i])); } } None diff --git a/src/session/run_options.rs b/src/session/run_options.rs index 93569aab..0c176b1b 100644 --- a/src/session/run_options.rs +++ b/src/session/run_options.rs @@ -116,15 +116,7 @@ impl OutputSelector { .map(|o| &o.name) .filter(|n| !self.default_blocklist.contains(n)) .chain(self.allowlist.iter()) - .map(|n| { - ( - n.as_str(), - self.preallocated_outputs.get(n).map(|v| DynValue { - inner: Arc::clone(&v.inner), - _markers: PhantomData - }) - ) - }) + .map(|n| (n.as_str(), self.preallocated_outputs.get(n).map(DynValue::clone_of))) .unzip() } } diff --git a/src/value/impl_map.rs b/src/value/impl_map.rs index 82f3d506..5507430a 100644 --- a/src/value/impl_map.rs +++ b/src/value/impl_map.rs @@ -81,11 +81,11 @@ impl Value { match self.dtype() { ValueType::Map { key, value } => { let k_type = K::into_tensor_element_type(); - if k_type != key { + if k_type != *key { return Err(Error::new_with_code(ErrorCode::InvalidArgument, format!("Cannot extract Map<{:?}, _> (value has K type {:?})", k_type, key))); } let v_type = V::into_tensor_element_type(); - if v_type != value { + if v_type != *value { return Err(Error::new_with_code( ErrorCode::InvalidArgument, format!("Cannot extract Map<{}, {}> from Map<{}, {}>", K::into_tensor_element_type(), V::into_tensor_element_type(), k_type, v_type) @@ -100,7 +100,7 @@ impl Value { if K::into_tensor_element_type() != TensorElementType::String { let dtype = key_value.dtype(); let (key_tensor_shape, key_tensor) = match dtype { - ValueType::Tensor { ty, dimensions } => { + ValueType::Tensor { ty, dimensions, .. } => { let mem = key_value.memory_info(); if !mem.is_cpu_accessible() { return Err(Error::new(format!( @@ -109,13 +109,13 @@ impl Value { ))); } - if ty == K::into_tensor_element_type() { + if *ty == K::into_tensor_element_type() { let mut output_array_ptr: *mut K = ptr::null_mut(); let output_array_ptr_ptr: *mut *mut K = &mut output_array_ptr; let output_array_ptr_ptr_void: *mut *mut std::ffi::c_void = output_array_ptr_ptr.cast(); ortsys![unsafe GetTensorMutableData(key_tensor_ptr, output_array_ptr_ptr_void)?; nonNull(output_array_ptr)]; - let len = calculate_tensor_size(&dimensions); + let len = calculate_tensor_size(dimensions); (dimensions, unsafe { std::slice::from_raw_parts(output_array_ptr, len) }) } else { return Err(Error::new_with_code( @@ -251,10 +251,15 @@ impl Value { let value = unsafe { Value::from_ptr(NonNull::new_unchecked(value_ptr), None) }; let value_type = value.dtype(); - if !OtherType::can_downcast(&value.dtype()) { + if !OtherType::can_downcast(value.dtype()) { return Err(Error::new_with_code( ErrorCode::InvalidArgument, format!("Cannot extract Sequence<{}> from {value_type:?}", OtherType::format()) @@ -134,10 +134,14 @@ impl Value { @@ -76,10 +76,16 @@ impl Tensor { ortsys![unsafe FillStringTensor(value_ptr, string_pointers.as_ptr(), string_pointers.len())?]; Ok(Value { - inner: Arc::new(ValueInner::RustOwned { + inner: Arc::new(ValueInner { ptr: unsafe { NonNull::new_unchecked(value_ptr) }, - _array: Box::new(()), - _memory_info: None + dtype: ValueType::Tensor { + ty: TensorElementType::String, + dimensions: shape, + dimension_symbols: vec![None; shape_len] + }, + memory_info: MemoryInfo::from_value(value_ptr), + drop: true, + _backing: None }), _markers: PhantomData }) @@ -124,10 +130,16 @@ impl Tensor { ]; Ok(Value { - inner: Arc::new(ValueInner::RustOwned { + inner: Arc::new(ValueInner { ptr: unsafe { NonNull::new_unchecked(value_ptr) }, - _array: Box::new(()), - _memory_info: None + dtype: ValueType::Tensor { + ty: T::into_tensor_element_type(), + dimensions: shape, + dimension_symbols: vec![None; shape_len] + }, + drop: true, + memory_info: MemoryInfo::from_value(value_ptr), + _backing: None }), _markers: PhantomData }) @@ -195,10 +207,16 @@ impl Tensor { ]; Ok(Value { - inner: Arc::new(ValueInner::RustOwned { + inner: Arc::new(ValueInner { ptr: unsafe { NonNull::new_unchecked(value_ptr) }, - _array: guard, - _memory_info: Some(memory_info) + dtype: ValueType::Tensor { + ty: T::into_tensor_element_type(), + dimensions: shape, + dimension_symbols: vec![None; shape_len] + }, + drop: true, + memory_info: Some(memory_info), + _backing: Some(guard) }), _markers: PhantomData }) @@ -252,10 +270,16 @@ impl<'a, T: PrimitiveTensorElementType + Debug> TensorRefMut<'a, T> { ]; Ok(TensorRefMut::new(Value { - inner: Arc::new(ValueInner::CppOwned { + inner: Arc::new(ValueInner { ptr: unsafe { NonNull::new_unchecked(value_ptr) }, + dtype: ValueType::Tensor { + ty: T::into_tensor_element_type(), + dimensions: shape, + dimension_symbols: vec![None; shape_len] + }, drop: true, - _session: None + memory_info: Some(info), + _backing: None }), _markers: PhantomData })) diff --git a/src/value/impl_tensor/extract.rs b/src/value/impl_tensor/extract.rs index 02106a39..3026f2e6 100644 --- a/src/value/impl_tensor/extract.rs +++ b/src/value/impl_tensor/extract.rs @@ -48,15 +48,14 @@ impl Value { pub fn try_extract_tensor(&self) -> Result> { use crate::AsPointer; - let dtype = self.dtype(); - match dtype { - ValueType::Tensor { ty, dimensions } => { + match self.dtype() { + ValueType::Tensor { ty, dimensions, .. } => { let mem = self.memory_info(); if !mem.is_cpu_accessible() { return Err(Error::new(format!("Cannot extract from value on device `{}`, which is not CPU accessible", mem.allocation_device().as_str()))); } - if ty == T::into_tensor_element_type() { + if *ty == T::into_tensor_element_type() { Ok(extract_primitive_array(IxDyn(&dimensions.iter().map(|&n| n as usize).collect::>()), self.ptr())?) } else { Err(Error::new_with_code( @@ -93,15 +92,14 @@ impl Value { /// /// [`DynValue`]: crate::value::DynValue pub fn try_extract_scalar(&self) -> Result { - let dtype = self.dtype(); - match dtype { - ValueType::Tensor { ty, dimensions } => { + match self.dtype() { + ValueType::Tensor { ty, dimensions, .. } => { let mem = self.memory_info(); if !mem.is_cpu_accessible() { return Err(Error::new(format!("Cannot extract from value on device `{}`, which is not CPU accessible", mem.allocation_device().as_str()))); } - if ty == T::into_tensor_element_type() { + if *ty == T::into_tensor_element_type() { if !dimensions.is_empty() { return Err(Error::new_with_code( ErrorCode::InvalidArgument, @@ -158,15 +156,14 @@ impl Value { #[cfg(feature = "ndarray")] #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] pub fn try_extract_tensor_mut(&mut self) -> Result> { - let dtype = self.dtype(); - match dtype { - ValueType::Tensor { ty, dimensions } => { + match self.dtype() { + ValueType::Tensor { ty, dimensions, .. } => { let mem = self.memory_info(); if !mem.is_cpu_accessible() { return Err(Error::new(format!("Cannot extract from value on device `{}`, which is not CPU accessible", mem.allocation_device().as_str()))); } - if ty == T::into_tensor_element_type() { + if *ty == T::into_tensor_element_type() { Ok(extract_primitive_array_mut(IxDyn(&dimensions.iter().map(|&n| n as usize).collect::>()), self.ptr_mut())?) } else { Err(Error::new_with_code( @@ -207,22 +204,21 @@ impl Value { /// - The provided type `T` does not match the tensor's element type. /// /// [`DynValue`]: crate::value::DynValue - pub fn try_extract_raw_tensor(&self) -> Result<(Vec, &[T])> { - let dtype = self.dtype(); - match dtype { - ValueType::Tensor { ty, dimensions } => { + pub fn try_extract_raw_tensor(&self) -> Result<(&[i64], &[T])> { + match self.dtype() { + ValueType::Tensor { ty, dimensions, .. } => { let mem = self.memory_info(); if !mem.is_cpu_accessible() { return Err(Error::new(format!("Cannot extract from value on device `{}`, which is not CPU accessible", mem.allocation_device().as_str()))); } - if ty == T::into_tensor_element_type() { + if *ty == T::into_tensor_element_type() { let mut output_array_ptr: *mut T = ptr::null_mut(); let output_array_ptr_ptr: *mut *mut T = &mut output_array_ptr; let output_array_ptr_ptr_void: *mut *mut std::ffi::c_void = output_array_ptr_ptr.cast(); ortsys![unsafe GetTensorMutableData(self.ptr().cast_mut(), output_array_ptr_ptr_void)?; nonNull(output_array_ptr)]; - let len = calculate_tensor_size(&dimensions); + let len = calculate_tensor_size(dimensions); Ok((dimensions, unsafe { std::slice::from_raw_parts(output_array_ptr, len) })) } else { Err(Error::new_with_code( @@ -260,22 +256,22 @@ impl Value { /// - The provided type `T` does not match the tensor's element type. /// /// [`DynValue`]: crate::value::DynValue - pub fn try_extract_raw_tensor_mut(&mut self) -> Result<(Vec, &mut [T])> { + pub fn try_extract_raw_tensor_mut(&mut self) -> Result<(&[i64], &mut [T])> { let dtype = self.dtype(); match dtype { - ValueType::Tensor { ty, dimensions } => { + ValueType::Tensor { ty, dimensions, .. } => { let mem = self.memory_info(); if !mem.is_cpu_accessible() { return Err(Error::new(format!("Cannot extract from value on device `{}`, which is not CPU accessible", mem.allocation_device().as_str()))); } - if ty == T::into_tensor_element_type() { + if *ty == T::into_tensor_element_type() { let mut output_array_ptr: *mut T = ptr::null_mut(); let output_array_ptr_ptr: *mut *mut T = &mut output_array_ptr; let output_array_ptr_ptr_void: *mut *mut std::ffi::c_void = output_array_ptr_ptr.cast(); - ortsys![unsafe GetTensorMutableData(self.ptr_mut(), output_array_ptr_ptr_void)?; nonNull(output_array_ptr)]; + ortsys![unsafe GetTensorMutableData(self.ptr().cast_mut(), output_array_ptr_ptr_void)?; nonNull(output_array_ptr)]; - let len = calculate_tensor_size(&dimensions); + let len = calculate_tensor_size(dimensions); Ok((dimensions, unsafe { std::slice::from_raw_parts_mut(output_array_ptr, len) })) } else { Err(Error::new_with_code( @@ -304,16 +300,15 @@ impl Value { #[cfg(feature = "ndarray")] #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] pub fn try_extract_string_tensor(&self) -> Result> { - let dtype = self.dtype(); - match dtype { - ValueType::Tensor { ty, dimensions } => { + match self.dtype() { + ValueType::Tensor { ty, dimensions, .. } => { let mem = self.memory_info(); if !mem.is_cpu_accessible() { return Err(Error::new(format!("Cannot extract from value on device `{}`, which is not CPU accessible", mem.allocation_device().as_str()))); } - if ty == TensorElementType::String { - let len = calculate_tensor_size(&dimensions); + if *ty == TensorElementType::String { + let len = calculate_tensor_size(dimensions); // Total length of string data, not including \0 suffix let mut total_length = 0; @@ -370,17 +365,16 @@ impl Value { /// # Ok(()) /// # } /// ``` - pub fn try_extract_raw_string_tensor(&self) -> Result<(Vec, Vec)> { - let dtype = self.dtype(); - match dtype { - ValueType::Tensor { ty, dimensions } => { + pub fn try_extract_raw_string_tensor(&self) -> Result<(&[i64], Vec)> { + match self.dtype() { + ValueType::Tensor { ty, dimensions, .. } => { let mem = self.memory_info(); if !mem.is_cpu_accessible() { return Err(Error::new(format!("Cannot extract from value on device `{}`, which is not CPU accessible", mem.allocation_device().as_str()))); } - if ty == TensorElementType::String { - let len = calculate_tensor_size(&dimensions); + if *ty == TensorElementType::String { + let len = calculate_tensor_size(dimensions); // Total length of string data, not including \0 suffix let mut total_length = 0; @@ -512,7 +506,7 @@ impl Tensor { /// # Ok(()) /// # } /// ``` - pub fn extract_raw_tensor(&self) -> (Vec, &[T]) { + pub fn extract_raw_tensor(&self) -> (&[i64], &[T]) { self.try_extract_raw_tensor().expect("Failed to extract tensor") } @@ -531,7 +525,7 @@ impl Tensor { /// # Ok(()) /// # } /// ``` - pub fn extract_raw_tensor_mut(&mut self) -> (Vec, &mut [T]) { + pub fn extract_raw_tensor_mut(&mut self) -> (&[i64], &mut [T]) { self.try_extract_raw_tensor_mut().expect("Failed to extract tensor") } } diff --git a/src/value/impl_tensor/mod.rs b/src/value/impl_tensor/mod.rs index fb2be71f..34375b0d 100644 --- a/src/value/impl_tensor/mod.rs +++ b/src/value/impl_tensor/mod.rs @@ -5,7 +5,6 @@ use std::{ fmt::Debug, marker::PhantomData, ops::{Index, IndexMut}, - ptr::NonNull, sync::Arc }; @@ -137,11 +136,8 @@ impl Value { /// # Ok(()) /// # } /// ``` - pub fn memory_info(&self) -> MemoryInfo { - let mut memory_info_ptr: *const ort_sys::OrtMemoryInfo = std::ptr::null_mut(); - // infallible, and `memory_info_ptr` will never be null - ortsys![unsafe GetTensorMemoryInfo(self.ptr(), &mut memory_info_ptr)]; - MemoryInfo::from_raw(unsafe { NonNull::new_unchecked(memory_info_ptr.cast_mut()) }, false) + pub fn memory_info(&self) -> &MemoryInfo { + unsafe { self.inner.memory_info.as_ref().unwrap_unchecked() } } } @@ -282,11 +278,11 @@ mod tests { fn test_tensor_value() -> crate::Result<()> { let v: Vec = vec![1., 2., 3., 4., 5.]; let value = Tensor::from_array(Array1::from_vec(v.clone()))?; - assert!(value.is_tensor()?); assert_eq!(value.dtype().tensor_type(), Some(TensorElementType::Float32)); - assert_eq!(value.dtype(), ValueType::Tensor { + assert_eq!(value.dtype(), &ValueType::Tensor { ty: TensorElementType::Float32, - dimensions: vec![v.len() as i64] + dimensions: vec![v.len() as i64], + dimension_symbols: vec![None] }); let (shape, data) = value.extract_raw_tensor(); diff --git a/src/value/mod.rs b/src/value/mod.rs index 6be11f9a..e10ab591 100644 --- a/src/value/mod.rs +++ b/src/value/mod.rs @@ -18,8 +18,9 @@ use std::{ any::Any, - fmt::{self, Debug}, + fmt::Debug, marker::PhantomData, + mem::transmute, ops::{Deref, DerefMut}, ptr::NonNull, sync::Arc @@ -28,240 +29,46 @@ use std::{ mod impl_map; mod impl_sequence; mod impl_tensor; +mod r#type; pub use self::{ impl_map::{DynMap, DynMapRef, DynMapRefMut, DynMapValueType, Map, MapRef, MapRefMut, MapValueType, MapValueTypeMarker}, impl_sequence::{ DynSequence, DynSequenceRef, DynSequenceRefMut, DynSequenceValueType, Sequence, SequenceRef, SequenceRefMut, SequenceValueType, SequenceValueTypeMarker }, - impl_tensor::{DynTensor, DynTensorRef, DynTensorRefMut, DynTensorValueType, Tensor, TensorRef, TensorRefMut, TensorValueType, TensorValueTypeMarker} + impl_tensor::{DynTensor, DynTensorRef, DynTensorRefMut, DynTensorValueType, Tensor, TensorRef, TensorRefMut, TensorValueType, TensorValueTypeMarker}, + r#type::ValueType }; use crate::{ AsPointer, error::{Error, ErrorCode, Result}, memory::MemoryInfo, ortsys, - session::SharedSessionInner, - tensor::TensorElementType + session::SharedSessionInner }; -/// The type of a [`Value`], or a session input/output. -/// -/// ``` -/// # use std::sync::Arc; -/// # use ort::{session::Session, value::{ValueType, Tensor}, tensor::TensorElementType}; -/// # fn main() -> ort::Result<()> { -/// # let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; -/// // `ValueType`s can be obtained from session inputs/outputs: -/// let input = &session.inputs[0]; -/// assert_eq!(input.input_type, ValueType::Tensor { -/// ty: TensorElementType::Float32, -/// // Our model has 3 dynamic dimensions, represented by -1 -/// dimensions: vec![-1, -1, -1, 3] -/// }); -/// -/// // Or by `Value`s created in Rust or output by a session. -/// let value = Tensor::from_array(([5usize], vec![1_i64, 2, 3, 4, 5].into_boxed_slice()))?; -/// assert_eq!(value.dtype(), ValueType::Tensor { -/// ty: TensorElementType::Int64, -/// dimensions: vec![5] -/// }); -/// # Ok(()) -/// # } -/// ``` -#[derive(Debug, PartialEq, Eq, Clone)] -pub enum ValueType { - /// Value is a tensor/multi-dimensional array. - Tensor { - /// Element type of the tensor. - ty: TensorElementType, - /// Dimensions of the tensor. If an exact dimension is not known (i.e. a dynamic dimension as part of an - /// [`Input`]/[`Output`]), the dimension will be `-1`. - /// - /// Actual tensor values, which have a known dimension, will always have positive (>1) dimensions. - /// - /// [`Input`]: crate::session::Input - /// [`Output`]: crate::session::Output - dimensions: Vec - }, - /// A sequence (vector) of other `Value`s. - /// - /// [Per ONNX spec](https://onnx.ai/onnx/intro/concepts.html#other-types), only sequences of tensors and maps are allowed. - Sequence(Box), - /// A map/dictionary from one element type to another. - Map { - /// The map key type. Allowed types are: - /// - [`TensorElementType::Int8`] - /// - [`TensorElementType::Int16`] - /// - [`TensorElementType::Int32`] - /// - [`TensorElementType::Int64`] - /// - [`TensorElementType::Uint8`] - /// - [`TensorElementType::Uint16`] - /// - [`TensorElementType::Uint32`] - /// - [`TensorElementType::Uint64`] - /// - [`TensorElementType::String`] - key: TensorElementType, - /// The map value type. - value: TensorElementType - }, - /// An optional value, which may or may not contain a [`Value`]. - Optional(Box) -} - -impl ValueType { - pub(crate) fn from_type_info(typeinfo_ptr: *mut ort_sys::OrtTypeInfo) -> Self { - let mut ty: ort_sys::ONNXType = ort_sys::ONNXType::ONNX_TYPE_UNKNOWN; - ortsys![unsafe GetOnnxTypeFromTypeInfo(typeinfo_ptr, &mut ty)]; // infallible - let io_type = match ty { - ort_sys::ONNXType::ONNX_TYPE_TENSOR | ort_sys::ONNXType::ONNX_TYPE_SPARSETENSOR => { - let mut info_ptr: *const ort_sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut(); - ortsys![unsafe CastTypeInfoToTensorInfo(typeinfo_ptr, &mut info_ptr)]; // infallible - unsafe { extract_data_type_from_tensor_info(info_ptr) } - } - ort_sys::ONNXType::ONNX_TYPE_SEQUENCE => { - let mut info_ptr: *const ort_sys::OrtSequenceTypeInfo = std::ptr::null_mut(); - ortsys![unsafe CastTypeInfoToSequenceTypeInfo(typeinfo_ptr, &mut info_ptr)]; // infallible - unsafe { extract_data_type_from_sequence_info(info_ptr) } - } - ort_sys::ONNXType::ONNX_TYPE_MAP => { - let mut info_ptr: *const ort_sys::OrtMapTypeInfo = std::ptr::null_mut(); - ortsys![unsafe CastTypeInfoToMapTypeInfo(typeinfo_ptr, &mut info_ptr)]; // infallible - unsafe { extract_data_type_from_map_info(info_ptr) } - } - ort_sys::ONNXType::ONNX_TYPE_OPTIONAL => { - let mut info_ptr: *const ort_sys::OrtOptionalTypeInfo = std::ptr::null_mut(); - ortsys![unsafe CastTypeInfoToOptionalTypeInfo(typeinfo_ptr, &mut info_ptr)]; // infallible - - let mut contained_type: *mut ort_sys::OrtTypeInfo = std::ptr::null_mut(); - ortsys![unsafe GetOptionalContainedTypeInfo(info_ptr, &mut contained_type)]; // infallible - - ValueType::Optional(Box::new(ValueType::from_type_info(contained_type))) - } - _ => unreachable!() - }; - ortsys![unsafe ReleaseTypeInfo(typeinfo_ptr)]; - io_type - } - /// Returns the dimensions of this value type if it is a tensor, or `None` if it is a sequence or map. - /// - /// ``` - /// # use ort::value::Tensor; - /// # fn main() -> ort::Result<()> { - /// let value = Tensor::from_array(([5usize], vec![1_i64, 2, 3, 4, 5].into_boxed_slice()))?; - /// assert_eq!(value.dtype().tensor_dimensions(), Some(&vec![5])); - /// # Ok(()) - /// # } - /// ``` - #[must_use] - pub fn tensor_dimensions(&self) -> Option<&Vec> { - match self { - ValueType::Tensor { dimensions, .. } => Some(dimensions), - _ => None - } - } - - /// Returns the element type of this value type if it is a tensor, or `None` if it is a sequence or map. - /// - /// ``` - /// # use ort::{tensor::TensorElementType, value::Tensor}; - /// # fn main() -> ort::Result<()> { - /// let value = Tensor::from_array(([5usize], vec![1_i64, 2, 3, 4, 5].into_boxed_slice()))?; - /// assert_eq!(value.dtype().tensor_type(), Some(TensorElementType::Int64)); - /// # Ok(()) - /// # } - /// ``` - #[must_use] - pub fn tensor_type(&self) -> Option { - match self { - ValueType::Tensor { ty, .. } => Some(*ty), - _ => None - } - } - - /// Returns `true` if this value type is a tensor. - #[inline] - #[must_use] - pub fn is_tensor(&self) -> bool { - matches!(self, ValueType::Tensor { .. }) - } - - /// Returns `true` if this value type is a sequence. - #[inline] - #[must_use] - pub fn is_sequence(&self) -> bool { - matches!(self, ValueType::Sequence { .. }) - } - - /// Returns `true` if this value type is a map. - #[inline] - #[must_use] - pub fn is_map(&self) -> bool { - matches!(self, ValueType::Map { .. }) - } -} - -impl fmt::Display for ValueType { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - ValueType::Tensor { ty, dimensions } => { - write!( - f, - "Tensor<{ty}>({})", - dimensions - .iter() - .map(|c| if *c == -1 { "dyn".to_string() } else { c.to_string() }) - .collect::>() - .join(", ") - ) - } - ValueType::Map { key, value } => write!(f, "Map<{key}, {value}>"), - ValueType::Sequence(inner) => write!(f, "Sequence<{inner}>"), - ValueType::Optional(inner) => write!(f, "Option<{inner}>") - } - } -} - #[derive(Debug)] -pub(crate) enum ValueInner { - RustOwned { - ptr: NonNull, - _array: Box, - /// Hold onto the `MemoryInfo` that we create in `Value::from_array`. - _memory_info: Option - }, - CppOwned { - ptr: NonNull, - /// Whether to release the value pointer on drop. - drop: bool, - /// Hold [`SharedSessionInner`] to ensure that the value can stay alive after the main session is dropped. - /// - /// This may be `None` if the value is created outside of a session or if the value does not need to hold onto - /// the session reference. In the case of sequence/map values, we forego this because: - /// - a map value can be created independently of a session, and thus we wouldn't have anything to hold on to; - /// - this is only ever used by `ValueRef`s, whos owner value (which *is* holding the session Arc) will outlive - /// it. - _session: Option> - } +pub(crate) struct ValueInner { + pub(crate) ptr: NonNull, + pub(crate) dtype: ValueType, + pub(crate) memory_info: Option, + pub(crate) drop: bool, + pub(crate) _backing: Option> } impl AsPointer for ValueInner { type Sys = ort_sys::OrtValue; fn ptr(&self) -> *const Self::Sys { - match self { - ValueInner::CppOwned { ptr, .. } | ValueInner::RustOwned { ptr, .. } => ptr.as_ptr() - } + self.ptr.as_ptr() } } impl Drop for ValueInner { fn drop(&mut self) { let ptr = self.ptr_mut(); - tracing::trace!("dropping {} value at {ptr:p}", match self { - ValueInner::RustOwned { .. } => "rust-owned", - ValueInner::CppOwned { .. } => "cpp-owned" - }); - if !matches!(self, ValueInner::CppOwned { drop: false, .. }) { + tracing::trace!("dropping value at {ptr:p}"); + if self.drop { ortsys![unsafe ReleaseValue(ptr)]; } } @@ -284,7 +91,7 @@ impl<'v, Type: ValueTypeMarker + ?Sized> ValueRef<'v, Type> { #[inline] pub fn downcast(self) -> Result> { let dt = self.dtype(); - if OtherType::can_downcast(&dt) { + if OtherType::can_downcast(dt) { Ok(unsafe { std::mem::transmute::, ValueRef<'v, OtherType>>(self) }) } else { Err(Error::new_with_code(ErrorCode::InvalidArgument, format!("Cannot downcast &{dt} to &{}", OtherType::format()))) @@ -295,10 +102,7 @@ impl<'v, Type: ValueTypeMarker + ?Sized> ValueRef<'v, Type> { pub fn try_upgrade(self) -> Result, Self> { // We cannot upgade a value which we cannot drop, i.e. `ValueRef`s used in operator kernels. Those only last for the // duration of the kernel, allowing an upgrade would allow a UAF. - if match &*self.inner.inner { - ValueInner::CppOwned { drop, .. } => !drop, - _ => false - } { + if !self.inner.inner.drop { return Err(self); } @@ -335,7 +139,7 @@ impl<'v, Type: ValueTypeMarker + ?Sized> ValueRefMut<'v, Type> { #[inline] pub fn downcast(self) -> Result> { let dt = self.dtype(); - if OtherType::can_downcast(&dt) { + if OtherType::can_downcast(dt) { Ok(unsafe { std::mem::transmute::, ValueRefMut<'v, OtherType>>(self) }) } else { Err(Error::new_with_code(ErrorCode::InvalidArgument, format!("Cannot downcast &mut {dt} to &mut {}", OtherType::format()))) @@ -346,10 +150,7 @@ impl<'v, Type: ValueTypeMarker + ?Sized> ValueRefMut<'v, Type> { pub fn try_upgrade(self) -> Result, Self> { // We cannot upgade a value which we cannot drop, i.e. `ValueRef`s used in operator kernels. Those only last for the // duration of the kernel, allowing an upgrade would allow a UAF. - if match &*self.inner.inner { - ValueInner::CppOwned { drop, .. } => !drop, - _ => false - } { + if !self.inner.inner.drop { return Err(self); } @@ -478,14 +279,8 @@ unsafe impl Sync for Value {} impl Value { /// Returns the data type of this [`Value`]. - pub fn dtype(&self) -> ValueType { - let mut typeinfo_ptr: *mut ort_sys::OrtTypeInfo = std::ptr::null_mut(); - ortsys![unsafe GetTypeInfo(self.ptr(), &mut typeinfo_ptr)]; // infallible - // `typeinfo_ptr` may be null in exceptionally rare cases - if typeinfo_ptr.is_null() { - panic!("unexpected UNKNOWN value type info"); - } - ValueType::from_type_info(typeinfo_ptr) + pub fn dtype(&self) -> &ValueType { + &self.inner.dtype } /// Construct a [`Value`] from a C++ [`ort_sys::OrtValue`] pointer. @@ -501,8 +296,16 @@ impl Value { /// - `session` must be `Some` for values returned from a session. #[must_use] pub unsafe fn from_ptr(ptr: NonNull, session: Option>) -> Value { + let mut typeinfo_ptr = std::ptr::null_mut(); + ortsys![unsafe GetTypeInfo(ptr.as_ptr(), &mut typeinfo_ptr)]; Value { - inner: Arc::new(ValueInner::CppOwned { ptr, drop: true, _session: session }), + inner: Arc::new(ValueInner { + ptr, + memory_info: MemoryInfo::from_value(ptr.as_ptr()), + dtype: ValueType::from_type_info(typeinfo_ptr), + drop: true, + _backing: session.map(|v| Box::new(v) as Box) + }), _markers: PhantomData } } @@ -511,58 +314,67 @@ impl Value { /// contexts. #[must_use] pub(crate) unsafe fn from_ptr_nodrop(ptr: NonNull, session: Option>) -> Value { + let mut typeinfo_ptr = std::ptr::null_mut(); + ortsys![unsafe GetTypeInfo(ptr.as_ptr(), &mut typeinfo_ptr)]; Value { - inner: Arc::new(ValueInner::CppOwned { ptr, drop: false, _session: session }), + inner: Arc::new(ValueInner { + ptr, + memory_info: MemoryInfo::from_value(ptr.as_ptr()), + dtype: ValueType::from_type_info(typeinfo_ptr), + drop: false, + _backing: session.map(|v| Box::new(v) as Box) + }), _markers: PhantomData } } /// Create a view of this value's data. pub fn view(&self) -> ValueRef<'_, Type> { - ValueRef::new(Value { - inner: Arc::clone(&self.inner), - _markers: PhantomData - }) + ValueRef::new(Value::clone_of(self)) } /// Create a mutable view of this value's data. pub fn view_mut(&mut self) -> ValueRefMut<'_, Type> { - ValueRefMut::new(Value { - inner: Arc::clone(&self.inner), + ValueRefMut::new(Value::clone_of(self)) + } + + /// Converts this value into a type-erased [`DynValue`]. + pub fn into_dyn(self) -> DynValue { + unsafe { std::mem::transmute(self) } + } + + pub(crate) fn clone_of(value: &Self) -> Self { + Self { + inner: Arc::clone(&value.inner), _markers: PhantomData - }) + } } +} +impl Value { /// Returns `true` if this value is a tensor, or `false` if it is another type (sequence, map). /// /// ``` /// # use ort::value::Tensor; /// # fn main() -> ort::Result<()> { - /// // Create a tensor from a raw data vector /// let tensor_value = Tensor::from_array(([3usize], vec![1.0_f32, 2.0, 3.0].into_boxed_slice()))?; - /// assert!(tensor_value.is_tensor()?); + /// let dyn_value = tensor_value.into_dyn(); + /// assert!(dyn_value.is_tensor()); /// # Ok(()) /// # } /// ``` - pub fn is_tensor(&self) -> Result { + pub fn is_tensor(&self) -> bool { let mut result = 0; - ortsys![unsafe IsTensor(self.ptr(), &mut result)?]; - Ok(result == 1) + ortsys![unsafe IsTensor(self.ptr(), &mut result)]; // infallible + result == 1 } - /// Converts this value into a type-erased [`DynValue`]. - pub fn into_dyn(self) -> DynValue { - unsafe { std::mem::transmute(self) } - } -} - -impl Value { /// Attempts to downcast a dynamic value (like [`DynValue`] or [`DynTensor`]) to a more strongly typed variant, /// like [`Tensor`]. #[inline] pub fn downcast(self) -> Result> { let dt = self.dtype(); - if OtherType::can_downcast(&dt) { + if OtherType::can_downcast(dt) { Ok(unsafe { std::mem::transmute::, Value>(self) }) } else { Err(Error::new_with_code(ErrorCode::InvalidArgument, format!("Cannot downcast {dt} to {}", OtherType::format()))) @@ -574,11 +386,8 @@ impl Value { #[inline] pub fn downcast_ref(&self) -> Result> { let dt = self.dtype(); - if OtherType::can_downcast(&dt) { - Ok(ValueRef::new(Value { - inner: Arc::clone(&self.inner), - _markers: PhantomData - })) + if OtherType::can_downcast(dt) { + Ok(ValueRef::new(unsafe { transmute::>(Value::clone_of(self)) })) } else { Err(Error::new_with_code(ErrorCode::InvalidArgument, format!("Cannot downcast &{dt} to &{}", OtherType::format()))) } @@ -589,11 +398,8 @@ impl Value { #[inline] pub fn downcast_mut(&mut self) -> Result> { let dt = self.dtype(); - if OtherType::can_downcast(&dt) { - Ok(ValueRefMut::new(Value { - inner: Arc::clone(&self.inner), - _markers: PhantomData - })) + if OtherType::can_downcast(dt) { + Ok(ValueRefMut::new(unsafe { transmute::>(Value::clone_of(self)) })) } else { Err(Error::new_with_code(ErrorCode::InvalidArgument, format!("Cannot downcast &mut {dt} to &mut {}", OtherType::format()))) } @@ -608,66 +414,6 @@ impl AsPointer for Value { } } -pub(crate) unsafe fn extract_data_type_from_tensor_info(info_ptr: *const ort_sys::OrtTensorTypeAndShapeInfo) -> ValueType { - let mut type_sys = ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - ortsys![GetTensorElementType(info_ptr, &mut type_sys)]; - assert_ne!(type_sys, ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); - // This transmute should be safe since its value is read from GetTensorElementType, which we must trust - let mut num_dims = 0; - ortsys![GetDimensionsCount(info_ptr, &mut num_dims)]; - - let mut node_dims: Vec = vec![0; num_dims]; - ortsys![GetDimensions(info_ptr, node_dims.as_mut_ptr(), num_dims)]; - - ValueType::Tensor { - ty: type_sys.into(), - dimensions: node_dims - } -} - -pub(crate) unsafe fn extract_data_type_from_sequence_info(info_ptr: *const ort_sys::OrtSequenceTypeInfo) -> ValueType { - let mut element_type_info: *mut ort_sys::OrtTypeInfo = std::ptr::null_mut(); - ortsys![GetSequenceElementType(info_ptr, &mut element_type_info)]; // infallible - - let mut ty: ort_sys::ONNXType = ort_sys::ONNXType::ONNX_TYPE_UNKNOWN; - ortsys![GetOnnxTypeFromTypeInfo(element_type_info, &mut ty)]; // infallible - - match ty { - ort_sys::ONNXType::ONNX_TYPE_TENSOR => { - let mut info_ptr: *const ort_sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut(); - ortsys![CastTypeInfoToTensorInfo(element_type_info, &mut info_ptr)]; // infallible - let ty = extract_data_type_from_tensor_info(info_ptr); - ValueType::Sequence(Box::new(ty)) - } - ort_sys::ONNXType::ONNX_TYPE_MAP => { - let mut info_ptr: *const ort_sys::OrtMapTypeInfo = std::ptr::null_mut(); - ortsys![CastTypeInfoToMapTypeInfo(element_type_info, &mut info_ptr)]; // infallible - let ty = extract_data_type_from_map_info(info_ptr); - ValueType::Sequence(Box::new(ty)) - } - _ => unreachable!() - } -} - -pub(crate) unsafe fn extract_data_type_from_map_info(info_ptr: *const ort_sys::OrtMapTypeInfo) -> ValueType { - let mut key_type_sys = ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - ortsys![GetMapKeyType(info_ptr, &mut key_type_sys)]; // infallible - assert_ne!(key_type_sys, ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); - - let mut value_type_info: *mut ort_sys::OrtTypeInfo = std::ptr::null_mut(); - ortsys![GetMapValueType(info_ptr, &mut value_type_info)]; // infallible - let mut value_info_ptr: *const ort_sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut(); - ortsys![unsafe CastTypeInfoToTensorInfo(value_type_info, &mut value_info_ptr)]; // infallible - let mut value_type_sys = ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - ortsys![GetTensorElementType(value_info_ptr, &mut value_type_sys)]; // infallible - assert_ne!(value_type_sys, ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); - - ValueType::Map { - key: key_type_sys.into(), - value: value_type_sys.into() - } -} - #[cfg(test)] mod tests { use super::{DynTensorValueType, Map, Sequence, Tensor, TensorRef, TensorRefMut, TensorValueType}; diff --git a/src/value/type.rs b/src/value/type.rs new file mode 100644 index 00000000..ef79b16c --- /dev/null +++ b/src/value/type.rs @@ -0,0 +1,254 @@ +use std::{ + ffi::{CStr, c_char}, + fmt, ptr +}; + +use crate::{ortsys, tensor::TensorElementType}; + +/// The type of a [`Value`], or a session input/output. +/// +/// ``` +/// # use std::sync::Arc; +/// # use ort::{session::Session, value::{ValueType, Tensor}, tensor::TensorElementType}; +/// # fn main() -> ort::Result<()> { +/// # let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; +/// // `ValueType`s can be obtained from session inputs/outputs: +/// let input = &session.inputs[0]; +/// assert_eq!(input.input_type, ValueType::Tensor { +/// ty: TensorElementType::Float32, +/// // Our model has 3 dynamic dimensions, represented by -1 +/// dimensions: vec![-1, -1, -1, 3] +/// }); +/// +/// // Or by `Value`s created in Rust or output by a session. +/// let value = Tensor::from_array(([5usize], vec![1_i64, 2, 3, 4, 5].into_boxed_slice()))?; +/// assert_eq!(value.dtype(), ValueType::Tensor { +/// ty: TensorElementType::Int64, +/// dimensions: vec![5] +/// }); +/// # Ok(()) +/// # } +/// ``` +#[derive(Debug, PartialEq, Eq, Clone)] +pub enum ValueType { + /// Value is a tensor/multi-dimensional array. + Tensor { + /// Element type of the tensor. + ty: TensorElementType, + /// Dimensions of the tensor. If an exact dimension is not known (i.e. a dynamic dimension as part of an + /// [`Input`]/[`Output`]), the dimension will be `-1`. + /// + /// Actual tensor values, which have a known dimension, will always have positive (>1) dimensions. + /// + /// [`Input`]: crate::session::Input + /// [`Output`]: crate::session::Output + dimensions: Vec, + dimension_symbols: Vec> + }, + /// A sequence (vector) of other `Value`s. + /// + /// [Per ONNX spec](https://onnx.ai/onnx/intro/concepts.html#other-types), only sequences of tensors and maps are allowed. + Sequence(Box), + /// A map/dictionary from one element type to another. + Map { + /// The map key type. Allowed types are: + /// - [`TensorElementType::Int8`] + /// - [`TensorElementType::Int16`] + /// - [`TensorElementType::Int32`] + /// - [`TensorElementType::Int64`] + /// - [`TensorElementType::Uint8`] + /// - [`TensorElementType::Uint16`] + /// - [`TensorElementType::Uint32`] + /// - [`TensorElementType::Uint64`] + /// - [`TensorElementType::String`] + key: TensorElementType, + /// The map value type. + value: TensorElementType + }, + /// An optional value, which may or may not contain a [`Value`]. + Optional(Box) +} + +impl ValueType { + pub(crate) fn from_type_info(typeinfo_ptr: *mut ort_sys::OrtTypeInfo) -> Self { + let mut ty: ort_sys::ONNXType = ort_sys::ONNXType::ONNX_TYPE_UNKNOWN; + ortsys![unsafe GetOnnxTypeFromTypeInfo(typeinfo_ptr, &mut ty)]; // infallible + let io_type = match ty { + ort_sys::ONNXType::ONNX_TYPE_TENSOR | ort_sys::ONNXType::ONNX_TYPE_SPARSETENSOR => { + let mut info_ptr: *const ort_sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut(); + ortsys![unsafe CastTypeInfoToTensorInfo(typeinfo_ptr, &mut info_ptr)]; // infallible + unsafe { extract_data_type_from_tensor_info(info_ptr) } + } + ort_sys::ONNXType::ONNX_TYPE_SEQUENCE => { + let mut info_ptr: *const ort_sys::OrtSequenceTypeInfo = std::ptr::null_mut(); + ortsys![unsafe CastTypeInfoToSequenceTypeInfo(typeinfo_ptr, &mut info_ptr)]; // infallible + + let mut element_type_info: *mut ort_sys::OrtTypeInfo = std::ptr::null_mut(); + ortsys![unsafe GetSequenceElementType(info_ptr, &mut element_type_info)]; // infallible + + let mut ty: ort_sys::ONNXType = ort_sys::ONNXType::ONNX_TYPE_UNKNOWN; + ortsys![unsafe GetOnnxTypeFromTypeInfo(element_type_info, &mut ty)]; // infallible + + match ty { + ort_sys::ONNXType::ONNX_TYPE_TENSOR => { + let mut info_ptr: *const ort_sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut(); + ortsys![unsafe CastTypeInfoToTensorInfo(element_type_info, &mut info_ptr)]; // infallible + let ty = unsafe { extract_data_type_from_tensor_info(info_ptr) }; + ValueType::Sequence(Box::new(ty)) + } + ort_sys::ONNXType::ONNX_TYPE_MAP => { + let mut info_ptr: *const ort_sys::OrtMapTypeInfo = std::ptr::null_mut(); + ortsys![unsafe CastTypeInfoToMapTypeInfo(element_type_info, &mut info_ptr)]; // infallible + let ty = unsafe { extract_data_type_from_map_info(info_ptr) }; + ValueType::Sequence(Box::new(ty)) + } + _ => unreachable!() + } + } + ort_sys::ONNXType::ONNX_TYPE_MAP => { + let mut info_ptr: *const ort_sys::OrtMapTypeInfo = std::ptr::null_mut(); + ortsys![unsafe CastTypeInfoToMapTypeInfo(typeinfo_ptr, &mut info_ptr)]; // infallible + unsafe { extract_data_type_from_map_info(info_ptr) } + } + ort_sys::ONNXType::ONNX_TYPE_OPTIONAL => { + let mut info_ptr: *const ort_sys::OrtOptionalTypeInfo = std::ptr::null_mut(); + ortsys![unsafe CastTypeInfoToOptionalTypeInfo(typeinfo_ptr, &mut info_ptr)]; // infallible + + let mut contained_type: *mut ort_sys::OrtTypeInfo = std::ptr::null_mut(); + ortsys![unsafe GetOptionalContainedTypeInfo(info_ptr, &mut contained_type)]; // infallible + + ValueType::Optional(Box::new(ValueType::from_type_info(contained_type))) + } + _ => unreachable!() + }; + ortsys![unsafe ReleaseTypeInfo(typeinfo_ptr)]; + io_type + } + /// Returns the dimensions of this value type if it is a tensor, or `None` if it is a sequence or map. + /// + /// ``` + /// # use ort::value::Tensor; + /// # fn main() -> ort::Result<()> { + /// let value = Tensor::from_array(([5usize], vec![1_i64, 2, 3, 4, 5].into_boxed_slice()))?; + /// assert_eq!(value.dtype().tensor_dimensions(), Some(&vec![5])); + /// # Ok(()) + /// # } + /// ``` + #[must_use] + pub fn tensor_dimensions(&self) -> Option<&Vec> { + match self { + ValueType::Tensor { dimensions, .. } => Some(dimensions), + _ => None + } + } + + /// Returns the element type of this value type if it is a tensor, or `None` if it is a sequence or map. + /// + /// ``` + /// # use ort::{tensor::TensorElementType, value::Tensor}; + /// # fn main() -> ort::Result<()> { + /// let value = Tensor::from_array(([5usize], vec![1_i64, 2, 3, 4, 5].into_boxed_slice()))?; + /// assert_eq!(value.dtype().tensor_type(), Some(TensorElementType::Int64)); + /// # Ok(()) + /// # } + /// ``` + #[must_use] + pub fn tensor_type(&self) -> Option { + match self { + ValueType::Tensor { ty, .. } => Some(*ty), + _ => None + } + } + + /// Returns `true` if this value type is a tensor. + #[inline] + #[must_use] + pub fn is_tensor(&self) -> bool { + matches!(self, ValueType::Tensor { .. }) + } + + /// Returns `true` if this value type is a sequence. + #[inline] + #[must_use] + pub fn is_sequence(&self) -> bool { + matches!(self, ValueType::Sequence { .. }) + } + + /// Returns `true` if this value type is a map. + #[inline] + #[must_use] + pub fn is_map(&self) -> bool { + matches!(self, ValueType::Map { .. }) + } +} + +impl fmt::Display for ValueType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + ValueType::Tensor { ty, dimensions, dimension_symbols } => { + write!( + f, + "Tensor<{ty}>({})", + dimensions + .iter() + .enumerate() + .map(|(i, c)| if *c == -1 { + dimension_symbols[i].clone().unwrap_or_else(|| String::from("dyn")) + } else { + c.to_string() + }) + .collect::>() + .join(", ") + ) + } + ValueType::Map { key, value } => write!(f, "Map<{key}, {value}>"), + ValueType::Sequence(inner) => write!(f, "Sequence<{inner}>"), + ValueType::Optional(inner) => write!(f, "Option<{inner}>") + } + } +} + +unsafe fn extract_data_type_from_tensor_info(info_ptr: *const ort_sys::OrtTensorTypeAndShapeInfo) -> ValueType { + let mut type_sys = ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + ortsys![GetTensorElementType(info_ptr, &mut type_sys)]; + assert_ne!(type_sys, ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); + // This transmute should be safe since its value is read from GetTensorElementType, which we must trust + let mut num_dims = 0; + ortsys![GetDimensionsCount(info_ptr, &mut num_dims)]; + + let mut node_dims: Vec = vec![0; num_dims]; + ortsys![GetDimensions(info_ptr, node_dims.as_mut_ptr(), num_dims)]; + + let mut symbolic_dims: Vec<*const c_char> = vec![ptr::null(); num_dims]; + ortsys![GetSymbolicDimensions(info_ptr, symbolic_dims.as_mut_ptr(), num_dims)]; + + let dimension_symbols = symbolic_dims + .into_iter() + .map(|c| if !c.is_null() { CStr::from_ptr(c).to_str().ok().map(str::to_string) } else { None }) + .collect(); + + ValueType::Tensor { + ty: type_sys.into(), + dimensions: node_dims, + dimension_symbols + } +} + +unsafe fn extract_data_type_from_map_info(info_ptr: *const ort_sys::OrtMapTypeInfo) -> ValueType { + let mut key_type_sys = ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + ortsys![GetMapKeyType(info_ptr, &mut key_type_sys)]; // infallible + assert_ne!(key_type_sys, ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); + + let mut value_type_info: *mut ort_sys::OrtTypeInfo = std::ptr::null_mut(); + ortsys![GetMapValueType(info_ptr, &mut value_type_info)]; // infallible + let mut value_info_ptr: *const ort_sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut(); + ortsys![unsafe CastTypeInfoToTensorInfo(value_type_info, &mut value_info_ptr)]; // infallible + let mut value_type_sys = ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + ortsys![GetTensorElementType(value_info_ptr, &mut value_type_sys)]; // infallible + assert_ne!(value_type_sys, ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); + + ValueType::Map { + key: key_type_sys.into(), + value: value_type_sys.into() + } +}