diff --git a/ort-sys/src/lib.rs b/ort-sys/src/lib.rs index 0ca32e2b..74a277d2 100644 --- a/ort-sys/src/lib.rs +++ b/ort-sys/src/lib.rs @@ -256,7 +256,10 @@ pub struct OrtShapeInferContext { pub struct OrtLoraAdapter { _unused: [u8; 0] } -pub type OrtStatusPtr = *mut OrtStatus; +#[repr(transparent)] +#[derive(Debug, Copy, Clone)] +#[must_use = "statuses must be freed with `OrtApi::ReleaseStatus` if they are not null"] +pub struct OrtStatusPtr(pub *mut OrtStatus); #[doc = " \\brief Memory allocation interface\n\n Structure of function pointers that defines a memory allocator. This can be created and filled in by the user for custom allocators.\n\n When an allocator is passed to any function, be sure that the allocator object is not destroyed until the last allocated object using it is freed."] #[repr(C)] #[derive(Debug, Copy, Clone)] @@ -628,7 +631,7 @@ pub type RunAsyncCallbackFn = #[derive(Debug, Copy, Clone)] pub struct OrtApi { #[doc = " \\brief Create an OrtStatus from a null terminated string\n\n \\param[in] code\n \\param[in] msg A null-terminated string. Its contents will be copied.\n \\return A new OrtStatus object, must be destroyed with OrtApi::ReleaseStatus"] - pub CreateStatus: unsafe extern "system" fn(code: OrtErrorCode, msg: *const core::ffi::c_char) -> *mut OrtStatus, + pub CreateStatus: unsafe extern "system" fn(code: OrtErrorCode, msg: *const core::ffi::c_char) -> OrtStatusPtr, #[doc = " \\brief Get OrtErrorCode from OrtStatus\n\n \\param[in] status\n \\return OrtErrorCode that \\p status was created with"] pub GetErrorCode: unsafe extern "system" fn(status: *const OrtStatus) -> OrtErrorCode, #[doc = " \\brief Get error string from OrtStatus\n\n \\param[in] status\n \\return The error message inside the `status`. Do not free the returned value."] diff --git a/src/environment.rs b/src/environment.rs index 761e24da..25ec47ed 100644 --- a/src/environment.rs +++ b/src/environment.rs @@ -77,7 +77,7 @@ pub struct GlobalThreadPoolOptions { impl Default for GlobalThreadPoolOptions { fn default() -> Self { let mut ptr = ptr::null_mut(); - ortsys![unsafe CreateThreadingOptions(&mut ptr)]; + ortsys![unsafe CreateThreadingOptions(&mut ptr).expect("failed to create threading options")]; Self { ptr, thread_manager: None } } } diff --git a/src/error.rs b/src/error.rs index 41f627a2..939a7682 100644 --- a/src/error.rs +++ b/src/error.rs @@ -11,13 +11,13 @@ use crate::{char_p_to_string, ortsys}; pub type Result = core::result::Result; pub(crate) trait IntoStatus { - fn into_status(self) -> *mut ort_sys::OrtStatus; + fn into_status(self) -> ort_sys::OrtStatusPtr; } impl IntoStatus for Result { - fn into_status(self) -> *mut ort_sys::OrtStatus { + fn into_status(self) -> ort_sys::OrtStatusPtr { let (code, message) = match &self { - Ok(_) => return ptr::null_mut(), + Ok(_) => return ort_sys::OrtStatusPtr(ptr::null_mut()), Err(e) => (ort_sys::OrtErrorCode::ORT_FAIL, Some(e.to_string())) }; let message = message.map(|c| CString::new(c).unwrap_or_else(|_| unreachable!())); @@ -177,7 +177,8 @@ pub(crate) fn assert_non_null_pointer(ptr: *const T, name: &'static str) -> R /// Converts an [`ort_sys::OrtStatus`] to a [`Result`]. /// /// Note that this frees `status`! -pub(crate) unsafe fn status_to_result(status: *mut ort_sys::OrtStatus) -> Result<(), Error> { +pub(crate) unsafe fn status_to_result(status: ort_sys::OrtStatusPtr) -> Result<(), Error> { + let status = status.0; if status.is_null() { Ok(()) } else { diff --git a/src/memory.rs b/src/memory.rs index 27386770..d4b6aba9 100644 --- a/src/memory.rs +++ b/src/memory.rs @@ -403,11 +403,11 @@ impl MemoryInfo { pub(crate) fn from_value(value_ptr: *mut ort_sys::OrtValue) -> Option { let mut is_tensor = 0; - ortsys![unsafe IsTensor(value_ptr, &mut is_tensor)]; // infallible + ortsys![unsafe IsTensor(value_ptr, &mut is_tensor).expect("infallible")]; if is_tensor != 0 { let mut memory_info_ptr: *const ort_sys::OrtMemoryInfo = ptr::null_mut(); // infallible, and `memory_info_ptr` will never be null - ortsys![unsafe GetTensorMemoryInfo(value_ptr, &mut memory_info_ptr)]; + ortsys![unsafe GetTensorMemoryInfo(value_ptr, &mut memory_info_ptr).expect("infallible")]; Some(Self::from_raw(unsafe { NonNull::new_unchecked(memory_info_ptr.cast_mut()) }, false)) } else { None @@ -433,7 +433,7 @@ impl MemoryInfo { /// ``` pub fn memory_type(&self) -> MemoryType { let mut raw_type: ort_sys::OrtMemType = ort_sys::OrtMemType::OrtMemTypeDefault; - ortsys![unsafe MemoryInfoGetMemType(self.ptr.as_ptr(), &mut raw_type)]; + ortsys![unsafe MemoryInfoGetMemType(self.ptr.as_ptr(), &mut raw_type).expect("infallible")]; MemoryType::from(raw_type) } @@ -448,7 +448,7 @@ impl MemoryInfo { /// ``` pub fn allocator_type(&self) -> AllocatorType { let mut raw_type: ort_sys::OrtAllocatorType = ort_sys::OrtAllocatorType::OrtInvalidAllocator; - ortsys![unsafe MemoryInfoGetType(self.ptr.as_ptr(), &mut raw_type)]; + ortsys![unsafe MemoryInfoGetType(self.ptr.as_ptr(), &mut raw_type).expect("infallible")]; match raw_type { ort_sys::OrtAllocatorType::OrtArenaAllocator => AllocatorType::Arena, ort_sys::OrtAllocatorType::OrtDeviceAllocator => AllocatorType::Device, @@ -467,7 +467,7 @@ impl MemoryInfo { /// ``` pub fn allocation_device(&self) -> AllocationDevice { let mut name_ptr: *const c_char = ptr::null_mut(); - ortsys![unsafe MemoryInfoGetName(self.ptr.as_ptr(), &mut name_ptr)]; + ortsys![unsafe MemoryInfoGetName(self.ptr.as_ptr(), &mut name_ptr).expect("infallible")]; // SAFETY: `name_ptr` can never be null - `CreateMemoryInfo` internally checks against builtin device names, erroring // if a non-builtin device is passed @@ -494,7 +494,7 @@ impl MemoryInfo { /// ``` pub fn device_id(&self) -> i32 { let mut raw: ort_sys::c_int = 0; - ortsys![unsafe MemoryInfoGetId(self.ptr.as_ptr(), &mut raw)]; + ortsys![unsafe MemoryInfoGetId(self.ptr.as_ptr(), &mut raw).expect("infallible")]; raw as _ } @@ -521,7 +521,7 @@ impl Clone for MemoryInfo { impl PartialEq for MemoryInfo { fn eq(&self, other: &MemoryInfo) -> bool { let mut out = 0; - ortsys![unsafe CompareMemoryInfo(self.ptr.as_ptr(), other.ptr.as_ptr(), &mut out)]; // implementation always returns ok status + ortsys![unsafe CompareMemoryInfo(self.ptr.as_ptr(), other.ptr.as_ptr(), &mut out).expect("infallible")]; // implementation always returns ok status out == 0 } } diff --git a/src/metadata.rs b/src/metadata.rs index 142fc9bb..7225305a 100644 --- a/src/metadata.rs +++ b/src/metadata.rs @@ -1,6 +1,7 @@ use alloc::{ffi::CString, string::String, vec::Vec}; use core::{ ffi::c_char, + marker::PhantomData, ptr::{self, NonNull}, slice }; @@ -10,12 +11,17 @@ use crate::{AsPointer, char_p_to_string, error::Result, memory::Allocator, ortsy /// Container for model metadata, including name & producer information. pub struct ModelMetadata<'s> { metadata_ptr: NonNull, - allocator: &'s Allocator + allocator: Allocator, + _p: PhantomData<&'s ()> } -impl<'s> ModelMetadata<'s> { - pub(crate) fn new(metadata_ptr: NonNull, allocator: &'s Allocator) -> Self { - ModelMetadata { metadata_ptr, allocator } +impl ModelMetadata<'_> { + pub(crate) fn new(metadata_ptr: NonNull) -> Self { + ModelMetadata { + metadata_ptr, + allocator: Allocator::default(), + _p: PhantomData + } } /// Gets the model description, returning an error if no description is present. @@ -34,6 +40,22 @@ impl<'s> ModelMetadata<'s> { Ok(value) } + /// Gets the description of the graph. + pub fn graph_description(&self) -> Result { + let mut str_bytes: *mut c_char = ptr::null_mut(); + ortsys![unsafe ModelMetadataGetGraphDescription(self.metadata_ptr.as_ptr(), self.allocator.ptr().cast_mut(), &mut str_bytes)?; nonNull(str_bytes)]; + + let value = match char_p_to_string(str_bytes) { + Ok(value) => value, + Err(e) => { + unsafe { self.allocator.free(str_bytes) }; + return Err(e); + } + }; + unsafe { self.allocator.free(str_bytes) }; + Ok(value) + } + /// Gets the model producer name, returning an error if no producer name is present. pub fn producer(&self) -> Result { let mut str_bytes: *mut c_char = ptr::null_mut(); @@ -66,6 +88,22 @@ impl<'s> ModelMetadata<'s> { Ok(value) } + /// Returns the model's domain, returning an error if no name is present. + pub fn domain(&self) -> Result { + let mut str_bytes: *mut c_char = ptr::null_mut(); + ortsys![unsafe ModelMetadataGetDomain(self.metadata_ptr.as_ptr(), self.allocator.ptr().cast_mut(), &mut str_bytes)?; nonNull(str_bytes)]; + + let value = match char_p_to_string(str_bytes) { + Ok(value) => value, + Err(e) => { + unsafe { self.allocator.free(str_bytes) }; + return Err(e); + } + }; + unsafe { self.allocator.free(str_bytes) }; + Ok(value) + } + /// Gets the model version, returning an error if no version is present. pub fn version(&self) -> Result { let mut ver = 0i64; diff --git a/src/operator/bound.rs b/src/operator/bound.rs index db1473ad..802ccc48 100644 --- a/src/operator/bound.rs +++ b/src/operator/bound.rs @@ -72,7 +72,7 @@ impl BoundOperator { _: *const ort_sys::OrtApi, info: *const ort_sys::OrtKernelInfo, kernel_ptr: *mut *mut ort_sys::c_void - ) -> *mut ort_sys::OrtStatus { + ) -> ort_sys::OrtStatusPtr { let safe = Self::safe(op); let kernel = match safe.operator.create_kernel(&KernelAttributes::new(info)) { Ok(kernel) => kernel, @@ -82,7 +82,7 @@ impl BoundOperator { Ok(()).into_status() } - pub(crate) extern "system" fn compute_kernel(kernel_ptr: *mut ort_sys::c_void, context: *mut ort_sys::OrtKernelContext) -> *mut ort_sys::OrtStatus { + pub(crate) extern "system" fn compute_kernel(kernel_ptr: *mut ort_sys::c_void, context: *mut ort_sys::OrtKernelContext) -> ort_sys::OrtStatusPtr { let context = KernelContext::new(context); unsafe { &mut *kernel_ptr.cast::>() }.compute(&context).into_status() } @@ -194,7 +194,7 @@ impl BoundOperator { .into() } - pub(crate) extern "system" fn infer_output_shape(op: *const ort_sys::OrtCustomOp, ctx: *mut ort_sys::OrtShapeInferContext) -> *mut ort_sys::OrtStatus { + pub(crate) extern "system" fn infer_output_shape(op: *const ort_sys::OrtCustomOp, ctx: *mut ort_sys::OrtShapeInferContext) -> ort_sys::OrtStatusPtr { let safe = Self::safe(op); let mut ctx = ShapeInferenceContext { ptr: ctx }; safe.operator.infer_shape(&mut ctx).into_status() diff --git a/src/session/async.rs b/src/session/async.rs index 374a13c8..97443800 100644 --- a/src/session/async.rs +++ b/src/session/async.rs @@ -1,7 +1,7 @@ use alloc::{ffi::CString, sync::Arc}; use core::{ cell::UnsafeCell, - ffi::c_char, + ffi::{c_char, c_void}, future::Future, marker::PhantomData, ops::Deref, @@ -11,8 +11,6 @@ use core::{ }; use std::sync::Mutex; -use ort_sys::{OrtStatus, c_void}; - use crate::{ error::Result, session::{RunOptions, SelectedOutputMarker, SessionInputValue, SessionOutputs, SharedSessionInner}, @@ -137,7 +135,7 @@ pub(crate) struct AsyncInferenceContext<'r, 's, 'v> { pub(crate) output_value_ptrs: Vec<*mut ort_sys::OrtValue> } -pub(crate) extern "system" fn async_callback(user_data: *mut c_void, _: *mut *mut ort_sys::OrtValue, _: usize, status: *mut OrtStatus) { +pub(crate) extern "system" fn async_callback(user_data: *mut c_void, _: *mut *mut ort_sys::OrtValue, _: usize, status: ort_sys::OrtStatusPtr) { let ctx = unsafe { Box::from_raw(user_data.cast::>()) }; // Reconvert name ptrs to CString so drop impl is called and memory is freed diff --git a/src/session/mod.rs b/src/session/mod.rs index 887bf354..3a78dd7b 100644 --- a/src/session/mod.rs +++ b/src/session/mod.rs @@ -520,7 +520,14 @@ impl Session { pub fn metadata(&self) -> Result> { let mut metadata_ptr: *mut ort_sys::OrtModelMetadata = ptr::null_mut(); ortsys![unsafe SessionGetModelMetadata(self.inner.session_ptr.as_ptr(), &mut metadata_ptr)?; nonNull(metadata_ptr)]; - Ok(ModelMetadata::new(unsafe { NonNull::new_unchecked(metadata_ptr) }, &self.inner.allocator)) + Ok(ModelMetadata::new(unsafe { NonNull::new_unchecked(metadata_ptr) })) + } + + /// Returns the time that profiling was started, in nanoseconds. + pub fn profiling_start_ns(&self) -> Result { + let mut out = 0; + ortsys![unsafe SessionGetProfilingStartTimeNs(self.inner.session_ptr.as_ptr(), &mut out)?]; + Ok(out) } /// Ends profiling for this session. @@ -529,7 +536,7 @@ impl Session { pub fn end_profiling(&mut self) -> Result { let mut profiling_name: *mut c_char = ptr::null_mut(); - ortsys![unsafe SessionEndProfiling(self.inner.session_ptr.as_ptr(), self.inner.allocator.ptr().cast_mut(), &mut profiling_name)]; + ortsys![unsafe SessionEndProfiling(self.inner.session_ptr.as_ptr(), self.inner.allocator.ptr().cast_mut(), &mut profiling_name)?]; assert_non_null_pointer(profiling_name, "ProfilingName")?; dangerous::raw_pointer_to_string(&self.inner.allocator, profiling_name) } @@ -619,7 +626,7 @@ mod dangerous { } fn extract_io_count( - f: unsafe extern "system" fn(*const ort_sys::OrtSession, *mut usize) -> *mut ort_sys::OrtStatus, + f: unsafe extern "system" fn(*const ort_sys::OrtSession, *mut usize) -> ort_sys::OrtStatusPtr, session_ptr: NonNull ) -> Result { let mut num_nodes = 0; @@ -651,7 +658,7 @@ mod dangerous { } fn extract_io_name( - f: unsafe extern "system" fn(*const ort_sys::OrtSession, usize, *mut ort_sys::OrtAllocator, *mut *mut c_char) -> *mut ort_sys::OrtStatus, + f: unsafe extern "system" fn(*const ort_sys::OrtSession, usize, *mut ort_sys::OrtAllocator, *mut *mut c_char) -> ort_sys::OrtStatusPtr, session_ptr: NonNull, allocator: &Allocator, i: usize @@ -680,7 +687,7 @@ mod dangerous { } fn extract_io( - f: unsafe extern "system" fn(*const ort_sys::OrtSession, usize, *mut *mut ort_sys::OrtTypeInfo) -> *mut ort_sys::OrtStatus, + f: unsafe extern "system" fn(*const ort_sys::OrtSession, usize, *mut *mut ort_sys::OrtTypeInfo) -> ort_sys::OrtStatusPtr, session_ptr: NonNull, i: usize ) -> Result { diff --git a/src/value/mod.rs b/src/value/mod.rs index d88c5a1a..0cfe6995 100644 --- a/src/value/mod.rs +++ b/src/value/mod.rs @@ -315,7 +315,7 @@ impl Value { #[must_use] pub unsafe fn from_ptr(ptr: NonNull, session: Option>) -> Value { let mut typeinfo_ptr = ptr::null_mut(); - ortsys![unsafe GetTypeInfo(ptr.as_ptr(), &mut typeinfo_ptr)]; + ortsys![unsafe GetTypeInfo(ptr.as_ptr(), &mut typeinfo_ptr).expect("infallible")]; Value { inner: Arc::new(ValueInner { ptr, @@ -333,7 +333,7 @@ impl Value { #[must_use] pub(crate) unsafe fn from_ptr_nodrop(ptr: NonNull, session: Option>) -> Value { let mut typeinfo_ptr = ptr::null_mut(); - ortsys![unsafe GetTypeInfo(ptr.as_ptr(), &mut typeinfo_ptr)]; + ortsys![unsafe GetTypeInfo(ptr.as_ptr(), &mut typeinfo_ptr).expect("infallible")]; Value { inner: Arc::new(ValueInner { ptr, @@ -383,7 +383,7 @@ impl Value { /// ``` pub fn is_tensor(&self) -> bool { let mut result = 0; - ortsys![unsafe IsTensor(self.ptr(), &mut result)]; // infallible + ortsys![unsafe IsTensor(self.ptr(), &mut result).expect("infallible")]; result == 1 } diff --git a/src/value/type.rs b/src/value/type.rs index 4afc8a8e..5ad099ad 100644 --- a/src/value/type.rs +++ b/src/value/type.rs @@ -87,33 +87,33 @@ pub enum ValueType { impl ValueType { pub(crate) fn from_type_info(typeinfo_ptr: *mut ort_sys::OrtTypeInfo) -> Self { let mut ty: ort_sys::ONNXType = ort_sys::ONNXType::ONNX_TYPE_UNKNOWN; - ortsys![unsafe GetOnnxTypeFromTypeInfo(typeinfo_ptr, &mut ty)]; // infallible + ortsys![unsafe GetOnnxTypeFromTypeInfo(typeinfo_ptr, &mut ty).expect("infallible")]; let io_type = match ty { ort_sys::ONNXType::ONNX_TYPE_TENSOR | ort_sys::ONNXType::ONNX_TYPE_SPARSETENSOR => { let mut info_ptr: *const ort_sys::OrtTensorTypeAndShapeInfo = ptr::null_mut(); - ortsys![unsafe CastTypeInfoToTensorInfo(typeinfo_ptr, &mut info_ptr)]; // infallible + ortsys![unsafe CastTypeInfoToTensorInfo(typeinfo_ptr, &mut info_ptr).expect("infallible")]; unsafe { extract_data_type_from_tensor_info(info_ptr) } } ort_sys::ONNXType::ONNX_TYPE_SEQUENCE => { let mut info_ptr: *const ort_sys::OrtSequenceTypeInfo = ptr::null_mut(); - ortsys![unsafe CastTypeInfoToSequenceTypeInfo(typeinfo_ptr, &mut info_ptr)]; // infallible + ortsys![unsafe CastTypeInfoToSequenceTypeInfo(typeinfo_ptr, &mut info_ptr).expect("infallible")]; let mut element_type_info: *mut ort_sys::OrtTypeInfo = ptr::null_mut(); - ortsys![unsafe GetSequenceElementType(info_ptr, &mut element_type_info)]; // infallible + ortsys![unsafe GetSequenceElementType(info_ptr, &mut element_type_info).expect("infallible")]; let mut ty: ort_sys::ONNXType = ort_sys::ONNXType::ONNX_TYPE_UNKNOWN; - ortsys![unsafe GetOnnxTypeFromTypeInfo(element_type_info, &mut ty)]; // infallible + ortsys![unsafe GetOnnxTypeFromTypeInfo(element_type_info, &mut ty).expect("infallible")]; match ty { ort_sys::ONNXType::ONNX_TYPE_TENSOR => { let mut info_ptr: *const ort_sys::OrtTensorTypeAndShapeInfo = ptr::null_mut(); - ortsys![unsafe CastTypeInfoToTensorInfo(element_type_info, &mut info_ptr)]; // infallible + ortsys![unsafe CastTypeInfoToTensorInfo(element_type_info, &mut info_ptr).expect("infallible")]; let ty = unsafe { extract_data_type_from_tensor_info(info_ptr) }; ValueType::Sequence(Box::new(ty)) } ort_sys::ONNXType::ONNX_TYPE_MAP => { let mut info_ptr: *const ort_sys::OrtMapTypeInfo = ptr::null_mut(); - ortsys![unsafe CastTypeInfoToMapTypeInfo(element_type_info, &mut info_ptr)]; // infallible + ortsys![unsafe CastTypeInfoToMapTypeInfo(element_type_info, &mut info_ptr).expect("infallible")]; let ty = unsafe { extract_data_type_from_map_info(info_ptr) }; ValueType::Sequence(Box::new(ty)) } @@ -122,15 +122,15 @@ impl ValueType { } ort_sys::ONNXType::ONNX_TYPE_MAP => { let mut info_ptr: *const ort_sys::OrtMapTypeInfo = ptr::null_mut(); - ortsys![unsafe CastTypeInfoToMapTypeInfo(typeinfo_ptr, &mut info_ptr)]; // infallible + ortsys![unsafe CastTypeInfoToMapTypeInfo(typeinfo_ptr, &mut info_ptr).expect("infallible")]; unsafe { extract_data_type_from_map_info(info_ptr) } } ort_sys::ONNXType::ONNX_TYPE_OPTIONAL => { let mut info_ptr: *const ort_sys::OrtOptionalTypeInfo = ptr::null_mut(); - ortsys![unsafe CastTypeInfoToOptionalTypeInfo(typeinfo_ptr, &mut info_ptr)]; // infallible + ortsys![unsafe CastTypeInfoToOptionalTypeInfo(typeinfo_ptr, &mut info_ptr).expect("infallible")]; let mut contained_type: *mut ort_sys::OrtTypeInfo = ptr::null_mut(); - ortsys![unsafe GetOptionalContainedTypeInfo(info_ptr, &mut contained_type)]; // infallible + ortsys![unsafe GetOptionalContainedTypeInfo(info_ptr, &mut contained_type).expect("infallible")]; ValueType::Optional(Box::new(ValueType::from_type_info(contained_type))) } @@ -144,16 +144,16 @@ impl ValueType { match self { Self::Tensor { ty, dimensions, dimension_symbols } => { let mut info_ptr = ptr::null_mut(); - ortsys![unsafe CreateTensorTypeAndShapeInfo(&mut info_ptr)]; - ortsys![unsafe SetTensorElementType(info_ptr, (*ty).into())]; - ortsys![unsafe SetDimensions(info_ptr, dimensions.as_ptr(), dimensions.len())]; + ortsys![unsafe CreateTensorTypeAndShapeInfo(&mut info_ptr).expect("infallible")]; + ortsys![unsafe SetTensorElementType(info_ptr, (*ty).into()).expect("infallible")]; + ortsys![unsafe SetDimensions(info_ptr, dimensions.as_ptr(), dimensions.len()).expect("infallible")]; let dimension_symbols: Vec<*const c_char> = dimension_symbols .iter() .cloned() .map(|s| CString::new(s.unwrap_or_default())) .map(|s| s.map_or(ptr::null(), |s| s.into_raw().cast_const())) .collect(); - ortsys![unsafe SetSymbolicDimensions(info_ptr, dimension_symbols.as_ptr().cast_mut(), dimension_symbols.len())]; + ortsys![unsafe SetSymbolicDimensions(info_ptr, dimension_symbols.as_ptr().cast_mut(), dimension_symbols.len()).expect("infallible")]; for p in dimension_symbols { if !p.is_null() { drop(unsafe { CString::from_raw(p.cast_mut().cast()) }); @@ -251,17 +251,17 @@ impl fmt::Display for ValueType { pub(crate) unsafe fn extract_data_type_from_tensor_info(info_ptr: *const ort_sys::OrtTensorTypeAndShapeInfo) -> ValueType { let mut type_sys = ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - ortsys![unsafe GetTensorElementType(info_ptr, &mut type_sys)]; + ortsys![unsafe GetTensorElementType(info_ptr, &mut type_sys).expect("infallible")]; assert_ne!(type_sys, ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); // This transmute should be safe since its value is read from GetTensorElementType, which we must trust let mut num_dims = 0; - ortsys![unsafe GetDimensionsCount(info_ptr, &mut num_dims)]; + ortsys![unsafe GetDimensionsCount(info_ptr, &mut num_dims).expect("infallible")]; let mut node_dims: Vec = vec![0; num_dims]; - ortsys![unsafe GetDimensions(info_ptr, node_dims.as_mut_ptr(), num_dims)]; + ortsys![unsafe GetDimensions(info_ptr, node_dims.as_mut_ptr(), num_dims).expect("infallible")]; let mut symbolic_dims: Vec<*const c_char> = vec![ptr::null(); num_dims]; - ortsys![unsafe GetSymbolicDimensions(info_ptr, symbolic_dims.as_mut_ptr(), num_dims)]; + ortsys![unsafe GetSymbolicDimensions(info_ptr, symbolic_dims.as_mut_ptr(), num_dims).expect("infallible")]; let dimension_symbols = symbolic_dims .into_iter() @@ -283,15 +283,15 @@ pub(crate) unsafe fn extract_data_type_from_tensor_info(info_ptr: *const ort_sys unsafe fn extract_data_type_from_map_info(info_ptr: *const ort_sys::OrtMapTypeInfo) -> ValueType { let mut key_type_sys = ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - ortsys![unsafe GetMapKeyType(info_ptr, &mut key_type_sys)]; // infallible + ortsys![unsafe GetMapKeyType(info_ptr, &mut key_type_sys).expect("infallible")]; assert_ne!(key_type_sys, ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); let mut value_type_info: *mut ort_sys::OrtTypeInfo = ptr::null_mut(); - ortsys![unsafe GetMapValueType(info_ptr, &mut value_type_info)]; // infallible + ortsys![unsafe GetMapValueType(info_ptr, &mut value_type_info).expect("infallible")]; let mut value_info_ptr: *const ort_sys::OrtTensorTypeAndShapeInfo = ptr::null_mut(); - ortsys![unsafe CastTypeInfoToTensorInfo(value_type_info, &mut value_info_ptr)]; // infallible + ortsys![unsafe CastTypeInfoToTensorInfo(value_type_info, &mut value_info_ptr).expect("infallible")]; let mut value_type_sys = ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - ortsys![unsafe GetTensorElementType(value_info_ptr, &mut value_type_sys)]; // infallible + ortsys![unsafe GetTensorElementType(value_info_ptr, &mut value_type_sys).expect("infallible")]; assert_ne!(value_type_sys, ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); ValueType::Map {