diff --git a/src/tensor/types.rs b/src/tensor/types.rs index ae27eaa1..73b8e6bc 100644 --- a/src/tensor/types.rs +++ b/src/tensor/types.rs @@ -27,9 +27,7 @@ pub enum TensorElementType { String, /// Boolean, equivalent to Rust's `bool`. Bool, - /// 16-bit floating point number, equivalent to [`half::f16`] (requires the `half` feature). - #[cfg(feature = "half")] - #[cfg_attr(docsrs, doc(cfg(feature = "half")))] + /// 16-bit floating point number, equivalent to [`half::f16`] (with the `half` feature). Float16, /// 64-bit floating point number, equivalent to Rust's `f64`. Also known as `double`. Float64, @@ -37,31 +35,45 @@ pub enum TensorElementType { Uint32, /// Unsigned 64-bit integer, equivalent to Rust's `u64`. Uint64, - /// Brain 16-bit floating point number, equivalent to [`half::bf16`] (requires the `half` feature). - #[cfg(feature = "half")] - #[cfg_attr(docsrs, doc(cfg(feature = "half")))] - Bfloat16 + /// Brain 16-bit floating point number, equivalent to [`half::bf16`] (with the `half` feature). + Bfloat16, + Complex64, + Complex128, + /// 8-bit floating point number with 4 exponent bits and 3 mantissa bits, with only NaN values and no infinite + /// values. + Float8E4M3FN, + /// 8-bit floating point number with 4 exponent bits and 3 mantissa bits, with only NaN values, no infinite + /// values, and no negative zero. + Float8E4M3FNUZ, + /// 8-bit floating point number with 5 exponent bits and 2 mantissa bits. + Float8E5M2, + /// 8-bit floating point number with 5 exponent bits and 2 mantissa bits, with only NaN values, no infinite + /// values, and no negative zero. + Float8E5M2FNUZ, + /// 4-bit unsigned integer. + Uint4, + /// 4-bit signed integer. + Int4 } impl TensorElementType { - pub fn size(&self) -> usize { + /// Returns the size in bytes that a container of this type occupies according to its total capacity. + pub fn byte_size(&self, container_capacity: usize) -> usize { match self { - TensorElementType::Bool => 1, - #[cfg(feature = "half")] - TensorElementType::Bfloat16 => 2, - #[cfg(feature = "half")] - TensorElementType::Float16 => 2, - TensorElementType::Float32 => 4, - TensorElementType::Float64 => 8, - TensorElementType::Int16 => 2, - TensorElementType::Int32 => 4, - TensorElementType::Int64 => 8, - TensorElementType::Int8 => 1, - TensorElementType::String => 0, - TensorElementType::Uint16 => 2, - TensorElementType::Uint32 => 4, - TensorElementType::Uint64 => 8, - TensorElementType::Uint8 => 1 + TensorElementType::Uint4 | TensorElementType::Int4 => container_capacity / 2, + TensorElementType::Bool | TensorElementType::Int8 | TensorElementType::Uint8 => container_capacity, + TensorElementType::Int16 | TensorElementType::Uint16 => container_capacity * 2, + TensorElementType::Int32 | TensorElementType::Uint32 => container_capacity * 4, + TensorElementType::Int64 | TensorElementType::Uint64 => container_capacity * 8, + TensorElementType::String => 0, // unsure what to do about this... + TensorElementType::Float8E4M3FN | TensorElementType::Float8E4M3FNUZ | TensorElementType::Float8E5M2 | TensorElementType::Float8E5M2FNUZ => { + container_capacity * 4 + } + TensorElementType::Float16 | TensorElementType::Bfloat16 => container_capacity * 2, + TensorElementType::Float32 => container_capacity * 4, + TensorElementType::Float64 => container_capacity * 8, + TensorElementType::Complex64 => container_capacity * 8, + TensorElementType::Complex128 => container_capacity * 16 } } } @@ -70,9 +82,7 @@ impl fmt::Display for TensorElementType { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.write_str(match self { TensorElementType::Bool => "bool", - #[cfg(feature = "half")] TensorElementType::Bfloat16 => "bf16", - #[cfg(feature = "half")] TensorElementType::Float16 => "f16", TensorElementType::Float32 => "f32", TensorElementType::Float64 => "f64", @@ -80,11 +90,20 @@ impl fmt::Display for TensorElementType { TensorElementType::Int32 => "i32", TensorElementType::Int64 => "i64", TensorElementType::Int8 => "i8", + TensorElementType::Int4 => "i4", TensorElementType::String => "String", TensorElementType::Uint16 => "u16", TensorElementType::Uint32 => "u32", TensorElementType::Uint64 => "u64", - TensorElementType::Uint8 => "u8" + TensorElementType::Uint8 => "u8", + TensorElementType::Uint4 => "u4", + TensorElementType::Complex64 => "c64", + TensorElementType::Complex128 => "c128", + // these really need more memorable (and easier to type) names. like Gerald or perhaps Alexa + TensorElementType::Float8E4M3FN => "f8_e4m3fn", + TensorElementType::Float8E4M3FNUZ => "f8_e4m3fnuz", + TensorElementType::Float8E5M2 => "f8_e5m2", + TensorElementType::Float8E5M2FNUZ => "f8_e5m2fnuz" }) } } @@ -101,19 +120,26 @@ impl From for ort_sys::ONNXTensorElementDataType { TensorElementType::Int64 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, TensorElementType::String => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING, TensorElementType::Bool => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, - #[cfg(feature = "half")] TensorElementType::Float16 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, TensorElementType::Float64 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, TensorElementType::Uint32 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32, TensorElementType::Uint64 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64, - #[cfg(feature = "half")] - TensorElementType::Bfloat16 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 + TensorElementType::Bfloat16 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16, + TensorElementType::Int4 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4, + TensorElementType::Uint4 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4, + TensorElementType::Complex64 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64, + TensorElementType::Complex128 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128, + TensorElementType::Float8E4M3FN => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN, + TensorElementType::Float8E4M3FNUZ => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ, + TensorElementType::Float8E5M2 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2, + TensorElementType::Float8E5M2FNUZ => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ } } } impl From for TensorElementType { fn from(val: ort_sys::ONNXTensorElementDataType) -> Self { match val { + ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED => panic!("Invalid ONNX tensor element data type"), ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT => TensorElementType::Float32, ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8 => TensorElementType::Uint8, ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 => TensorElementType::Int8, @@ -123,14 +149,19 @@ impl From for TensorElementType { ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 => TensorElementType::Int64, ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING => TensorElementType::String, ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL => TensorElementType::Bool, - #[cfg(feature = "half")] ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 => TensorElementType::Float16, ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE => TensorElementType::Float64, ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32 => TensorElementType::Uint32, ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 => TensorElementType::Uint64, - #[cfg(feature = "half")] ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 => TensorElementType::Bfloat16, - _ => panic!("Invalid ONNXTensorElementDataType value") + ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4 => TensorElementType::Int4, + ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4 => TensorElementType::Uint4, + ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64 => TensorElementType::Complex64, + ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128 => TensorElementType::Complex128, + ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN => TensorElementType::Float8E4M3FN, + ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ => TensorElementType::Float8E4M3FNUZ, + ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2 => TensorElementType::Float8E5M2, + ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ => TensorElementType::Float8E5M2FNUZ } } } diff --git a/src/value/impl_tensor/mod.rs b/src/value/impl_tensor/mod.rs index fe3c03bc..12f4a5cb 100644 --- a/src/value/impl_tensor/mod.rs +++ b/src/value/impl_tensor/mod.rs @@ -121,7 +121,7 @@ impl DynTensor { let mut buffer_ptr: *mut ort_sys::c_void = ptr::null_mut(); ortsys![unsafe GetTensorMutableData(value_ptr, &mut buffer_ptr)?; nonNull(buffer_ptr)]; - unsafe { buffer_ptr.write_bytes(0, calculate_tensor_size(&shape) * data_type.size()) }; + unsafe { buffer_ptr.write_bytes(0, data_type.byte_size(calculate_tensor_size(&shape))) }; } Ok(Value {