diff --git a/src/tensor/types.rs b/src/tensor/types.rs index 5187e2ed..ae27eaa1 100644 --- a/src/tensor/types.rs +++ b/src/tensor/types.rs @@ -43,6 +43,29 @@ pub enum TensorElementType { Bfloat16 } +impl TensorElementType { + pub fn size(&self) -> usize { + match self { + TensorElementType::Bool => 1, + #[cfg(feature = "half")] + TensorElementType::Bfloat16 => 2, + #[cfg(feature = "half")] + TensorElementType::Float16 => 2, + TensorElementType::Float32 => 4, + TensorElementType::Float64 => 8, + TensorElementType::Int16 => 2, + TensorElementType::Int32 => 4, + TensorElementType::Int64 => 8, + TensorElementType::Int8 => 1, + TensorElementType::String => 0, + TensorElementType::Uint16 => 2, + TensorElementType::Uint32 => 4, + TensorElementType::Uint64 => 8, + TensorElementType::Uint8 => 1 + } + } +} + impl fmt::Display for TensorElementType { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.write_str(match self { diff --git a/src/value/impl_tensor/create.rs b/src/value/impl_tensor/create.rs index cd494806..3af448e7 100644 --- a/src/value/impl_tensor/create.rs +++ b/src/value/impl_tensor/create.rs @@ -11,7 +11,7 @@ use core::{ #[cfg(feature = "ndarray")] use ndarray::{ArcArray, Array, ArrayView, ArrayViewMut, CowArray, Dimension}; -use super::{Tensor, TensorRef, TensorRefMut, calculate_tensor_size}; +use super::{DynTensor, Tensor, TensorRef, TensorRefMut, calculate_tensor_size}; use crate::{ AsPointer, error::{Error, ErrorCode, Result, assert_non_null_pointer}, @@ -109,47 +109,8 @@ impl Tensor { /// # } /// ``` pub fn new(allocator: &Allocator, shape: impl ToDimensions) -> Result> { - let mut value_ptr: *mut ort_sys::OrtValue = ptr::null_mut(); - - let shape = shape.to_dimensions(None)?; - let shape_ptr: *const i64 = shape.as_ptr(); - let shape_len = shape.len(); - - ortsys![ - unsafe CreateTensorAsOrtValue( - allocator.ptr().cast_mut(), - shape_ptr, - shape_len, - T::into_tensor_element_type().into(), - &mut value_ptr - )?; - nonNull(value_ptr) - ]; - - // `CreateTensorAsOrtValue` actually does not guarantee that the data allocated is zero'd out, so if we can, we should - // do it manually. - let memory_info = MemoryInfo::from_value(value_ptr).expect("CreateTensorAsOrtValue returned non-tensor"); - if memory_info.is_cpu_accessible() { - let mut buffer_ptr: *mut ort_sys::c_void = ptr::null_mut(); - ortsys![unsafe GetTensorMutableData(value_ptr, &mut buffer_ptr)?; nonNull(buffer_ptr)]; - - unsafe { buffer_ptr.write_bytes(0, calculate_tensor_size(&shape) * size_of::()) }; - } - - Ok(Value { - inner: Arc::new(ValueInner { - ptr: unsafe { NonNull::new_unchecked(value_ptr) }, - dtype: ValueType::Tensor { - ty: T::into_tensor_element_type(), - dimensions: shape, - dimension_symbols: vec![None; shape_len] - }, - drop: true, - memory_info: MemoryInfo::from_value(value_ptr), - _backing: None - }), - _markers: PhantomData - }) + let tensor = DynTensor::new(allocator, T::into_tensor_element_type(), shape)?; + Ok(unsafe { core::mem::transmute::>(tensor) }) } /// Construct an owned tensor from an array of data. diff --git a/src/value/impl_tensor/mod.rs b/src/value/impl_tensor/mod.rs index 68077f02..2c8e94a4 100644 --- a/src/value/impl_tensor/mod.rs +++ b/src/value/impl_tensor/mod.rs @@ -11,12 +11,18 @@ use core::{ marker::PhantomData, mem, ops::{Index, IndexMut}, - ptr + ptr::{self, NonNull} }; -pub use self::create::{OwnedTensorArrayData, TensorArrayData, TensorArrayDataMut, TensorArrayDataParts}; -use super::{DowncastableTarget, DynValue, Value, ValueRef, ValueRefMut, ValueType, ValueTypeMarker}; -use crate::{AsPointer, error::Result, memory::MemoryInfo, ortsys, tensor::IntoTensorElementType}; +pub use self::create::{OwnedTensorArrayData, TensorArrayData, TensorArrayDataMut, TensorArrayDataParts, ToDimensions}; +use super::{DowncastableTarget, DynValue, Value, ValueInner, ValueRef, ValueRefMut, ValueType, ValueTypeMarker}; +use crate::{ + AsPointer, + error::Result, + memory::{Allocator, MemoryInfo}, + ortsys, + tensor::{IntoTensorElementType, TensorElementType} +}; pub trait TensorValueTypeMarker: ValueTypeMarker { private_trait!(); @@ -70,6 +76,70 @@ impl DowncastableTarget for DynTensorValueType { private_impl!(); } +impl DynTensor { + /// Construct a tensor via a given allocator with a given shape and datatype. The data in the tensor will be + /// **uninitialized**. + /// + /// This can be used to create a tensor with data on a certain device. For example, to create a tensor with pinned + /// (CPU) memory for use with CUDA: + /// ```no_run + /// # use ort::{memory::{Allocator, MemoryInfo, MemoryType, AllocationDevice, AllocatorType}, session::Session, tensor::TensorElementType, value::DynTensor}; + /// # fn main() -> ort::Result<()> { + /// # let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; + /// let allocator = Allocator::new( + /// &session, + /// MemoryInfo::new(AllocationDevice::CUDA_PINNED, 0, AllocatorType::Device, MemoryType::CPUInput)? + /// )?; + /// + /// let mut img_input = DynTensor::new(&allocator, TensorElementType::Float32, [1, 128, 128, 3])?; + /// # Ok(()) + /// # } + /// ``` + pub fn new(allocator: &Allocator, data_type: TensorElementType, shape: impl ToDimensions) -> Result { + let mut value_ptr: *mut ort_sys::OrtValue = ptr::null_mut(); + + let shape = shape.to_dimensions(None)?; + let shape_ptr: *const i64 = shape.as_ptr(); + let shape_len = shape.len(); + + ortsys![ + unsafe CreateTensorAsOrtValue( + allocator.ptr().cast_mut(), + shape_ptr, + shape_len, + data_type.into(), + &mut value_ptr + )?; + nonNull(value_ptr) + ]; + + // `CreateTensorAsOrtValue` actually does not guarantee that the data allocated is zero'd out, so if we can, we should + // do it manually. + let memory_info = MemoryInfo::from_value(value_ptr).expect("CreateTensorAsOrtValue returned non-tensor"); + if memory_info.is_cpu_accessible() && data_type != TensorElementType::String { + let mut buffer_ptr: *mut ort_sys::c_void = ptr::null_mut(); + ortsys![unsafe GetTensorMutableData(value_ptr, &mut buffer_ptr)?; nonNull(buffer_ptr)]; + + unsafe { buffer_ptr.write_bytes(0, calculate_tensor_size(&shape) * data_type.size()) }; + } + + Ok(Value { + inner: Arc::new(ValueInner { + ptr: unsafe { NonNull::new_unchecked(value_ptr) }, + dtype: ValueType::Tensor { + ty: data_type, + dimensions: shape, + dimension_symbols: vec![None; shape_len] + }, + drop: true, + memory_info: MemoryInfo::from_value(value_ptr), + _backing: None + }), + _markers: PhantomData + }) + } +} + impl Value { /// Returns a mutable pointer to the tensor's data. /// diff --git a/src/value/mod.rs b/src/value/mod.rs index a31b8fbf..d88c5a1a 100644 --- a/src/value/mod.rs +++ b/src/value/mod.rs @@ -43,7 +43,7 @@ pub use self::{ }, impl_tensor::{ DynTensor, DynTensorRef, DynTensorRefMut, DynTensorValueType, OwnedTensorArrayData, Tensor, TensorArrayData, TensorArrayDataMut, TensorArrayDataParts, - TensorRef, TensorRefMut, TensorValueType, TensorValueTypeMarker + TensorRef, TensorRefMut, TensorValueType, TensorValueTypeMarker, ToDimensions }, r#type::ValueType };