Skip to content

Commit

Permalink
feat: allow creating DynTensor with arbitrary element type
Browse files Browse the repository at this point in the history
  • Loading branch information
decahedron1 committed Feb 22, 2025
1 parent 45f75ac commit cbeeb77
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 47 deletions.
23 changes: 23 additions & 0 deletions src/tensor/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,29 @@ pub enum TensorElementType {
Bfloat16
}

impl TensorElementType {
pub fn size(&self) -> usize {
match self {
TensorElementType::Bool => 1,
#[cfg(feature = "half")]
TensorElementType::Bfloat16 => 2,
#[cfg(feature = "half")]
TensorElementType::Float16 => 2,
TensorElementType::Float32 => 4,
TensorElementType::Float64 => 8,
TensorElementType::Int16 => 2,
TensorElementType::Int32 => 4,
TensorElementType::Int64 => 8,
TensorElementType::Int8 => 1,
TensorElementType::String => 0,
TensorElementType::Uint16 => 2,
TensorElementType::Uint32 => 4,
TensorElementType::Uint64 => 8,
TensorElementType::Uint8 => 1
}
}
}

impl fmt::Display for TensorElementType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self {
Expand Down
45 changes: 3 additions & 42 deletions src/value/impl_tensor/create.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use core::{
#[cfg(feature = "ndarray")]
use ndarray::{ArcArray, Array, ArrayView, ArrayViewMut, CowArray, Dimension};

use super::{Tensor, TensorRef, TensorRefMut, calculate_tensor_size};
use super::{DynTensor, Tensor, TensorRef, TensorRefMut, calculate_tensor_size};
use crate::{
AsPointer,
error::{Error, ErrorCode, Result, assert_non_null_pointer},
Expand Down Expand Up @@ -109,47 +109,8 @@ impl<T: PrimitiveTensorElementType + Debug> Tensor<T> {
/// # }
/// ```
pub fn new(allocator: &Allocator, shape: impl ToDimensions) -> Result<Tensor<T>> {
let mut value_ptr: *mut ort_sys::OrtValue = ptr::null_mut();

let shape = shape.to_dimensions(None)?;
let shape_ptr: *const i64 = shape.as_ptr();
let shape_len = shape.len();

ortsys![
unsafe CreateTensorAsOrtValue(
allocator.ptr().cast_mut(),
shape_ptr,
shape_len,
T::into_tensor_element_type().into(),
&mut value_ptr
)?;
nonNull(value_ptr)
];

// `CreateTensorAsOrtValue` actually does not guarantee that the data allocated is zero'd out, so if we can, we should
// do it manually.
let memory_info = MemoryInfo::from_value(value_ptr).expect("CreateTensorAsOrtValue returned non-tensor");
if memory_info.is_cpu_accessible() {
let mut buffer_ptr: *mut ort_sys::c_void = ptr::null_mut();
ortsys![unsafe GetTensorMutableData(value_ptr, &mut buffer_ptr)?; nonNull(buffer_ptr)];

unsafe { buffer_ptr.write_bytes(0, calculate_tensor_size(&shape) * size_of::<T>()) };
}

Ok(Value {
inner: Arc::new(ValueInner {
ptr: unsafe { NonNull::new_unchecked(value_ptr) },
dtype: ValueType::Tensor {
ty: T::into_tensor_element_type(),
dimensions: shape,
dimension_symbols: vec![None; shape_len]
},
drop: true,
memory_info: MemoryInfo::from_value(value_ptr),
_backing: None
}),
_markers: PhantomData
})
let tensor = DynTensor::new(allocator, T::into_tensor_element_type(), shape)?;
Ok(unsafe { core::mem::transmute::<DynTensor, Tensor<T>>(tensor) })
}

/// Construct an owned tensor from an array of data.
Expand Down
78 changes: 74 additions & 4 deletions src/value/impl_tensor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,18 @@ use core::{
marker::PhantomData,
mem,
ops::{Index, IndexMut},
ptr
ptr::{self, NonNull}
};

pub use self::create::{OwnedTensorArrayData, TensorArrayData, TensorArrayDataMut, TensorArrayDataParts};
use super::{DowncastableTarget, DynValue, Value, ValueRef, ValueRefMut, ValueType, ValueTypeMarker};
use crate::{AsPointer, error::Result, memory::MemoryInfo, ortsys, tensor::IntoTensorElementType};
pub use self::create::{OwnedTensorArrayData, TensorArrayData, TensorArrayDataMut, TensorArrayDataParts, ToDimensions};
use super::{DowncastableTarget, DynValue, Value, ValueInner, ValueRef, ValueRefMut, ValueType, ValueTypeMarker};
use crate::{
AsPointer,
error::Result,
memory::{Allocator, MemoryInfo},
ortsys,
tensor::{IntoTensorElementType, TensorElementType}
};

pub trait TensorValueTypeMarker: ValueTypeMarker {
private_trait!();
Expand Down Expand Up @@ -70,6 +76,70 @@ impl DowncastableTarget for DynTensorValueType {
private_impl!();
}

impl DynTensor {
/// Construct a tensor via a given allocator with a given shape and datatype. The data in the tensor will be
/// **uninitialized**.
///
/// This can be used to create a tensor with data on a certain device. For example, to create a tensor with pinned
/// (CPU) memory for use with CUDA:
/// ```no_run
/// # use ort::{memory::{Allocator, MemoryInfo, MemoryType, AllocationDevice, AllocatorType}, session::Session, tensor::TensorElementType, value::DynTensor};
/// # fn main() -> ort::Result<()> {
/// # let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?;
/// let allocator = Allocator::new(
/// &session,
/// MemoryInfo::new(AllocationDevice::CUDA_PINNED, 0, AllocatorType::Device, MemoryType::CPUInput)?
/// )?;
///
/// let mut img_input = DynTensor::new(&allocator, TensorElementType::Float32, [1, 128, 128, 3])?;
/// # Ok(())
/// # }
/// ```
pub fn new(allocator: &Allocator, data_type: TensorElementType, shape: impl ToDimensions) -> Result<DynTensor> {
let mut value_ptr: *mut ort_sys::OrtValue = ptr::null_mut();

let shape = shape.to_dimensions(None)?;
let shape_ptr: *const i64 = shape.as_ptr();
let shape_len = shape.len();

ortsys![
unsafe CreateTensorAsOrtValue(
allocator.ptr().cast_mut(),
shape_ptr,
shape_len,
data_type.into(),
&mut value_ptr
)?;
nonNull(value_ptr)
];

// `CreateTensorAsOrtValue` actually does not guarantee that the data allocated is zero'd out, so if we can, we should
// do it manually.
let memory_info = MemoryInfo::from_value(value_ptr).expect("CreateTensorAsOrtValue returned non-tensor");
if memory_info.is_cpu_accessible() && data_type != TensorElementType::String {
let mut buffer_ptr: *mut ort_sys::c_void = ptr::null_mut();
ortsys![unsafe GetTensorMutableData(value_ptr, &mut buffer_ptr)?; nonNull(buffer_ptr)];

unsafe { buffer_ptr.write_bytes(0, calculate_tensor_size(&shape) * data_type.size()) };
}

Ok(Value {
inner: Arc::new(ValueInner {
ptr: unsafe { NonNull::new_unchecked(value_ptr) },
dtype: ValueType::Tensor {
ty: data_type,
dimensions: shape,
dimension_symbols: vec![None; shape_len]
},
drop: true,
memory_info: MemoryInfo::from_value(value_ptr),
_backing: None
}),
_markers: PhantomData
})
}
}

impl<Type: TensorValueTypeMarker + ?Sized> Value<Type> {
/// Returns a mutable pointer to the tensor's data.
///
Expand Down
2 changes: 1 addition & 1 deletion src/value/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ pub use self::{
},
impl_tensor::{
DynTensor, DynTensorRef, DynTensorRefMut, DynTensorValueType, OwnedTensorArrayData, Tensor, TensorArrayData, TensorArrayDataMut, TensorArrayDataParts,
TensorRef, TensorRefMut, TensorValueType, TensorValueTypeMarker
TensorRef, TensorRefMut, TensorValueType, TensorValueTypeMarker, ToDimensions
},
r#type::ValueType
};
Expand Down

0 comments on commit cbeeb77

Please sign in to comment.