Skip to content

Commit

Permalink
feat: support all tensor element types
Browse files Browse the repository at this point in the history
Now that we have `DynTensor` we don't need the tensor element types to strictly map to Rust primitives anymore.

+ might want to look into `num-complex` integration like we do with `half` (this might be especially useful for custom ops in audio applications, VV?)
  • Loading branch information
decahedron1 committed Mar 5, 2025
1 parent b4699a2 commit 51d9bf7
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 34 deletions.
97 changes: 64 additions & 33 deletions src/tensor/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,41 +27,53 @@ pub enum TensorElementType {
String,
/// Boolean, equivalent to Rust's `bool`.
Bool,
/// 16-bit floating point number, equivalent to [`half::f16`] (requires the `half` feature).
#[cfg(feature = "half")]
#[cfg_attr(docsrs, doc(cfg(feature = "half")))]
/// 16-bit floating point number, equivalent to [`half::f16`] (with the `half` feature).
Float16,
/// 64-bit floating point number, equivalent to Rust's `f64`. Also known as `double`.
Float64,
/// Unsigned 32-bit integer, equivalent to Rust's `u32`.
Uint32,
/// Unsigned 64-bit integer, equivalent to Rust's `u64`.
Uint64,
/// Brain 16-bit floating point number, equivalent to [`half::bf16`] (requires the `half` feature).
#[cfg(feature = "half")]
#[cfg_attr(docsrs, doc(cfg(feature = "half")))]
Bfloat16
/// Brain 16-bit floating point number, equivalent to [`half::bf16`] (with the `half` feature).
Bfloat16,
Complex64,
Complex128,
/// 8-bit floating point number with 4 exponent bits and 3 mantissa bits, with only NaN values and no infinite
/// values.
Float8E4M3FN,
/// 8-bit floating point number with 4 exponent bits and 3 mantissa bits, with only NaN values, no infinite
/// values, and no negative zero.
Float8E4M3FNUZ,
/// 8-bit floating point number with 5 exponent bits and 2 mantissa bits.
Float8E5M2,
/// 8-bit floating point number with 5 exponent bits and 2 mantissa bits, with only NaN values, no infinite
/// values, and no negative zero.
Float8E5M2FNUZ,
/// 4-bit unsigned integer.
Uint4,
/// 4-bit signed integer.
Int4
}

impl TensorElementType {
pub fn size(&self) -> usize {
/// Returns the size in bytes that a container of this type occupies according to its total capacity.
pub fn byte_size(&self, container_capacity: usize) -> usize {
match self {
TensorElementType::Bool => 1,
#[cfg(feature = "half")]
TensorElementType::Bfloat16 => 2,
#[cfg(feature = "half")]
TensorElementType::Float16 => 2,
TensorElementType::Float32 => 4,
TensorElementType::Float64 => 8,
TensorElementType::Int16 => 2,
TensorElementType::Int32 => 4,
TensorElementType::Int64 => 8,
TensorElementType::Int8 => 1,
TensorElementType::String => 0,
TensorElementType::Uint16 => 2,
TensorElementType::Uint32 => 4,
TensorElementType::Uint64 => 8,
TensorElementType::Uint8 => 1
TensorElementType::Uint4 | TensorElementType::Int4 => container_capacity / 2,
TensorElementType::Bool | TensorElementType::Int8 | TensorElementType::Uint8 => container_capacity,
TensorElementType::Int16 | TensorElementType::Uint16 => container_capacity * 2,
TensorElementType::Int32 | TensorElementType::Uint32 => container_capacity * 4,
TensorElementType::Int64 | TensorElementType::Uint64 => container_capacity * 8,
TensorElementType::String => 0, // unsure what to do about this...
TensorElementType::Float8E4M3FN | TensorElementType::Float8E4M3FNUZ | TensorElementType::Float8E5M2 | TensorElementType::Float8E5M2FNUZ => {
container_capacity * 4
}
TensorElementType::Float16 | TensorElementType::Bfloat16 => container_capacity * 2,
TensorElementType::Float32 => container_capacity * 4,
TensorElementType::Float64 => container_capacity * 8,
TensorElementType::Complex64 => container_capacity * 8,
TensorElementType::Complex128 => container_capacity * 16
}
}
}
Expand All @@ -70,21 +82,28 @@ impl fmt::Display for TensorElementType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self {
TensorElementType::Bool => "bool",
#[cfg(feature = "half")]
TensorElementType::Bfloat16 => "bf16",
#[cfg(feature = "half")]
TensorElementType::Float16 => "f16",
TensorElementType::Float32 => "f32",
TensorElementType::Float64 => "f64",
TensorElementType::Int16 => "i16",
TensorElementType::Int32 => "i32",
TensorElementType::Int64 => "i64",
TensorElementType::Int8 => "i8",
TensorElementType::Int4 => "i4",
TensorElementType::String => "String",
TensorElementType::Uint16 => "u16",
TensorElementType::Uint32 => "u32",
TensorElementType::Uint64 => "u64",
TensorElementType::Uint8 => "u8"
TensorElementType::Uint8 => "u8",
TensorElementType::Uint4 => "u4",
TensorElementType::Complex64 => "c64",
TensorElementType::Complex128 => "c128",
// these really need more memorable (and easier to type) names. like Gerald or perhaps Alexa
TensorElementType::Float8E4M3FN => "f8_e4m3fn",
TensorElementType::Float8E4M3FNUZ => "f8_e4m3fnuz",
TensorElementType::Float8E5M2 => "f8_e5m2",
TensorElementType::Float8E5M2FNUZ => "f8_e5m2fnuz"
})
}
}
Expand All @@ -101,19 +120,26 @@ impl From<TensorElementType> for ort_sys::ONNXTensorElementDataType {
TensorElementType::Int64 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64,
TensorElementType::String => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING,
TensorElementType::Bool => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL,
#[cfg(feature = "half")]
TensorElementType::Float16 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16,
TensorElementType::Float64 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE,
TensorElementType::Uint32 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32,
TensorElementType::Uint64 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64,
#[cfg(feature = "half")]
TensorElementType::Bfloat16 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16
TensorElementType::Bfloat16 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16,
TensorElementType::Int4 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4,
TensorElementType::Uint4 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4,
TensorElementType::Complex64 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64,
TensorElementType::Complex128 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128,
TensorElementType::Float8E4M3FN => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN,
TensorElementType::Float8E4M3FNUZ => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ,
TensorElementType::Float8E5M2 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2,
TensorElementType::Float8E5M2FNUZ => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ
}
}
}
impl From<ort_sys::ONNXTensorElementDataType> for TensorElementType {
fn from(val: ort_sys::ONNXTensorElementDataType) -> Self {
match val {
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED => panic!("Invalid ONNX tensor element data type"),
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT => TensorElementType::Float32,
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8 => TensorElementType::Uint8,
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 => TensorElementType::Int8,
Expand All @@ -123,14 +149,19 @@ impl From<ort_sys::ONNXTensorElementDataType> for TensorElementType {
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 => TensorElementType::Int64,
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING => TensorElementType::String,
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL => TensorElementType::Bool,
#[cfg(feature = "half")]
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 => TensorElementType::Float16,
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE => TensorElementType::Float64,
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32 => TensorElementType::Uint32,
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 => TensorElementType::Uint64,
#[cfg(feature = "half")]
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 => TensorElementType::Bfloat16,
_ => panic!("Invalid ONNXTensorElementDataType value")
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4 => TensorElementType::Int4,
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4 => TensorElementType::Uint4,
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64 => TensorElementType::Complex64,
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128 => TensorElementType::Complex128,
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN => TensorElementType::Float8E4M3FN,
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ => TensorElementType::Float8E4M3FNUZ,
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2 => TensorElementType::Float8E5M2,
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ => TensorElementType::Float8E5M2FNUZ
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/value/impl_tensor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ impl DynTensor {
let mut buffer_ptr: *mut ort_sys::c_void = ptr::null_mut();
ortsys![unsafe GetTensorMutableData(value_ptr, &mut buffer_ptr)?; nonNull(buffer_ptr)];

unsafe { buffer_ptr.write_bytes(0, calculate_tensor_size(&shape) * data_type.size()) };
unsafe { buffer_ptr.write_bytes(0, data_type.byte_size(calculate_tensor_size(&shape))) };
}

Ok(Value {
Expand Down

0 comments on commit 51d9bf7

Please sign in to comment.