Skip to content

Commit

Permalink
chore: mark OrtStatusPtr as #[must_use]
Browse files Browse the repository at this point in the history
luckily most cases where a status "may" leak were instances where an error state would never be returned to begin with, but this will be useful to guard from future Sleepy 2am Deca forgetting to add a ?
  • Loading branch information
decahedron1 committed Feb 23, 2025
1 parent ca0ef4e commit 9b0d1ea
Show file tree
Hide file tree
Showing 10 changed files with 102 additions and 55 deletions.
7 changes: 5 additions & 2 deletions ort-sys/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,10 @@ pub struct OrtShapeInferContext {
pub struct OrtLoraAdapter {
_unused: [u8; 0]
}
pub type OrtStatusPtr = *mut OrtStatus;
#[repr(transparent)]
#[derive(Debug, Copy, Clone)]
#[must_use = "statuses must be freed with `OrtApi::ReleaseStatus` if they are not null"]
pub struct OrtStatusPtr(pub *mut OrtStatus);
#[doc = " \\brief Memory allocation interface\n\n Structure of function pointers that defines a memory allocator. This can be created and filled in by the user for custom allocators.\n\n When an allocator is passed to any function, be sure that the allocator object is not destroyed until the last allocated object using it is freed."]
#[repr(C)]
#[derive(Debug, Copy, Clone)]
Expand Down Expand Up @@ -628,7 +631,7 @@ pub type RunAsyncCallbackFn =
#[derive(Debug, Copy, Clone)]
pub struct OrtApi {
#[doc = " \\brief Create an OrtStatus from a null terminated string\n\n \\param[in] code\n \\param[in] msg A null-terminated string. Its contents will be copied.\n \\return A new OrtStatus object, must be destroyed with OrtApi::ReleaseStatus"]
pub CreateStatus: unsafe extern "system" fn(code: OrtErrorCode, msg: *const core::ffi::c_char) -> *mut OrtStatus,
pub CreateStatus: unsafe extern "system" fn(code: OrtErrorCode, msg: *const core::ffi::c_char) -> OrtStatusPtr,
#[doc = " \\brief Get OrtErrorCode from OrtStatus\n\n \\param[in] status\n \\return OrtErrorCode that \\p status was created with"]
pub GetErrorCode: unsafe extern "system" fn(status: *const OrtStatus) -> OrtErrorCode,
#[doc = " \\brief Get error string from OrtStatus\n\n \\param[in] status\n \\return The error message inside the `status`. Do not free the returned value."]
Expand Down
2 changes: 1 addition & 1 deletion src/environment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ pub struct GlobalThreadPoolOptions {
impl Default for GlobalThreadPoolOptions {
fn default() -> Self {
let mut ptr = ptr::null_mut();
ortsys![unsafe CreateThreadingOptions(&mut ptr)];
ortsys![unsafe CreateThreadingOptions(&mut ptr).expect("failed to create threading options")];
Self { ptr, thread_manager: None }
}
}
Expand Down
9 changes: 5 additions & 4 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ use crate::{char_p_to_string, ortsys};
pub type Result<T, E = Error> = core::result::Result<T, E>;

pub(crate) trait IntoStatus {
fn into_status(self) -> *mut ort_sys::OrtStatus;
fn into_status(self) -> ort_sys::OrtStatusPtr;
}

impl<T> IntoStatus for Result<T, Error> {
fn into_status(self) -> *mut ort_sys::OrtStatus {
fn into_status(self) -> ort_sys::OrtStatusPtr {
let (code, message) = match &self {
Ok(_) => return ptr::null_mut(),
Ok(_) => return ort_sys::OrtStatusPtr(ptr::null_mut()),
Err(e) => (ort_sys::OrtErrorCode::ORT_FAIL, Some(e.to_string()))
};
let message = message.map(|c| CString::new(c).unwrap_or_else(|_| unreachable!()));
Expand Down Expand Up @@ -177,7 +177,8 @@ pub(crate) fn assert_non_null_pointer<T>(ptr: *const T, name: &'static str) -> R
/// Converts an [`ort_sys::OrtStatus`] to a [`Result`].
///
/// Note that this frees `status`!
pub(crate) unsafe fn status_to_result(status: *mut ort_sys::OrtStatus) -> Result<(), Error> {
pub(crate) unsafe fn status_to_result(status: ort_sys::OrtStatusPtr) -> Result<(), Error> {
let status = status.0;
if status.is_null() {
Ok(())
} else {
Expand Down
14 changes: 7 additions & 7 deletions src/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -403,11 +403,11 @@ impl MemoryInfo {

pub(crate) fn from_value(value_ptr: *mut ort_sys::OrtValue) -> Option<Self> {
let mut is_tensor = 0;
ortsys![unsafe IsTensor(value_ptr, &mut is_tensor)]; // infallible
ortsys![unsafe IsTensor(value_ptr, &mut is_tensor).expect("infallible")];
if is_tensor != 0 {
let mut memory_info_ptr: *const ort_sys::OrtMemoryInfo = ptr::null_mut();
// infallible, and `memory_info_ptr` will never be null
ortsys![unsafe GetTensorMemoryInfo(value_ptr, &mut memory_info_ptr)];
ortsys![unsafe GetTensorMemoryInfo(value_ptr, &mut memory_info_ptr).expect("infallible")];
Some(Self::from_raw(unsafe { NonNull::new_unchecked(memory_info_ptr.cast_mut()) }, false))
} else {
None
Expand All @@ -433,7 +433,7 @@ impl MemoryInfo {
/// ```
pub fn memory_type(&self) -> MemoryType {
let mut raw_type: ort_sys::OrtMemType = ort_sys::OrtMemType::OrtMemTypeDefault;
ortsys![unsafe MemoryInfoGetMemType(self.ptr.as_ptr(), &mut raw_type)];
ortsys![unsafe MemoryInfoGetMemType(self.ptr.as_ptr(), &mut raw_type).expect("infallible")];
MemoryType::from(raw_type)
}

Expand All @@ -448,7 +448,7 @@ impl MemoryInfo {
/// ```
pub fn allocator_type(&self) -> AllocatorType {
let mut raw_type: ort_sys::OrtAllocatorType = ort_sys::OrtAllocatorType::OrtInvalidAllocator;
ortsys![unsafe MemoryInfoGetType(self.ptr.as_ptr(), &mut raw_type)];
ortsys![unsafe MemoryInfoGetType(self.ptr.as_ptr(), &mut raw_type).expect("infallible")];
match raw_type {
ort_sys::OrtAllocatorType::OrtArenaAllocator => AllocatorType::Arena,
ort_sys::OrtAllocatorType::OrtDeviceAllocator => AllocatorType::Device,
Expand All @@ -467,7 +467,7 @@ impl MemoryInfo {
/// ```
pub fn allocation_device(&self) -> AllocationDevice {
let mut name_ptr: *const c_char = ptr::null_mut();
ortsys![unsafe MemoryInfoGetName(self.ptr.as_ptr(), &mut name_ptr)];
ortsys![unsafe MemoryInfoGetName(self.ptr.as_ptr(), &mut name_ptr).expect("infallible")];

// SAFETY: `name_ptr` can never be null - `CreateMemoryInfo` internally checks against builtin device names, erroring
// if a non-builtin device is passed
Expand All @@ -494,7 +494,7 @@ impl MemoryInfo {
/// ```
pub fn device_id(&self) -> i32 {
let mut raw: ort_sys::c_int = 0;
ortsys![unsafe MemoryInfoGetId(self.ptr.as_ptr(), &mut raw)];
ortsys![unsafe MemoryInfoGetId(self.ptr.as_ptr(), &mut raw).expect("infallible")];
raw as _
}

Expand All @@ -521,7 +521,7 @@ impl Clone for MemoryInfo {
impl PartialEq<MemoryInfo> for MemoryInfo {
fn eq(&self, other: &MemoryInfo) -> bool {
let mut out = 0;
ortsys![unsafe CompareMemoryInfo(self.ptr.as_ptr(), other.ptr.as_ptr(), &mut out)]; // implementation always returns ok status
ortsys![unsafe CompareMemoryInfo(self.ptr.as_ptr(), other.ptr.as_ptr(), &mut out).expect("infallible")]; // implementation always returns ok status
out == 0
}
}
Expand Down
46 changes: 42 additions & 4 deletions src/metadata.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use alloc::{ffi::CString, string::String, vec::Vec};
use core::{
ffi::c_char,
marker::PhantomData,
ptr::{self, NonNull},
slice
};
Expand All @@ -10,12 +11,17 @@ use crate::{AsPointer, char_p_to_string, error::Result, memory::Allocator, ortsy
/// Container for model metadata, including name & producer information.
pub struct ModelMetadata<'s> {
metadata_ptr: NonNull<ort_sys::OrtModelMetadata>,
allocator: &'s Allocator
allocator: Allocator,
_p: PhantomData<&'s ()>
}

impl<'s> ModelMetadata<'s> {
pub(crate) fn new(metadata_ptr: NonNull<ort_sys::OrtModelMetadata>, allocator: &'s Allocator) -> Self {
ModelMetadata { metadata_ptr, allocator }
impl ModelMetadata<'_> {
pub(crate) fn new(metadata_ptr: NonNull<ort_sys::OrtModelMetadata>) -> Self {
ModelMetadata {
metadata_ptr,
allocator: Allocator::default(),
_p: PhantomData
}
}

/// Gets the model description, returning an error if no description is present.
Expand All @@ -34,6 +40,22 @@ impl<'s> ModelMetadata<'s> {
Ok(value)
}

/// Gets the description of the graph.
pub fn graph_description(&self) -> Result<String> {
let mut str_bytes: *mut c_char = ptr::null_mut();
ortsys![unsafe ModelMetadataGetGraphDescription(self.metadata_ptr.as_ptr(), self.allocator.ptr().cast_mut(), &mut str_bytes)?; nonNull(str_bytes)];

let value = match char_p_to_string(str_bytes) {
Ok(value) => value,
Err(e) => {
unsafe { self.allocator.free(str_bytes) };
return Err(e);
}
};
unsafe { self.allocator.free(str_bytes) };
Ok(value)
}

/// Gets the model producer name, returning an error if no producer name is present.
pub fn producer(&self) -> Result<String> {
let mut str_bytes: *mut c_char = ptr::null_mut();
Expand Down Expand Up @@ -66,6 +88,22 @@ impl<'s> ModelMetadata<'s> {
Ok(value)
}

/// Returns the model's domain, returning an error if no name is present.
pub fn domain(&self) -> Result<String> {
let mut str_bytes: *mut c_char = ptr::null_mut();
ortsys![unsafe ModelMetadataGetDomain(self.metadata_ptr.as_ptr(), self.allocator.ptr().cast_mut(), &mut str_bytes)?; nonNull(str_bytes)];

let value = match char_p_to_string(str_bytes) {
Ok(value) => value,
Err(e) => {
unsafe { self.allocator.free(str_bytes) };
return Err(e);
}
};
unsafe { self.allocator.free(str_bytes) };
Ok(value)
}

/// Gets the model version, returning an error if no version is present.
pub fn version(&self) -> Result<i64> {
let mut ver = 0i64;
Expand Down
6 changes: 3 additions & 3 deletions src/operator/bound.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ impl BoundOperator {
_: *const ort_sys::OrtApi,
info: *const ort_sys::OrtKernelInfo,
kernel_ptr: *mut *mut ort_sys::c_void
) -> *mut ort_sys::OrtStatus {
) -> ort_sys::OrtStatusPtr {
let safe = Self::safe(op);
let kernel = match safe.operator.create_kernel(&KernelAttributes::new(info)) {
Ok(kernel) => kernel,
Expand All @@ -82,7 +82,7 @@ impl BoundOperator {
Ok(()).into_status()
}

pub(crate) extern "system" fn compute_kernel(kernel_ptr: *mut ort_sys::c_void, context: *mut ort_sys::OrtKernelContext) -> *mut ort_sys::OrtStatus {
pub(crate) extern "system" fn compute_kernel(kernel_ptr: *mut ort_sys::c_void, context: *mut ort_sys::OrtKernelContext) -> ort_sys::OrtStatusPtr {
let context = KernelContext::new(context);
unsafe { &mut *kernel_ptr.cast::<Box<dyn Kernel>>() }.compute(&context).into_status()
}
Expand Down Expand Up @@ -194,7 +194,7 @@ impl BoundOperator {
.into()
}

pub(crate) extern "system" fn infer_output_shape(op: *const ort_sys::OrtCustomOp, ctx: *mut ort_sys::OrtShapeInferContext) -> *mut ort_sys::OrtStatus {
pub(crate) extern "system" fn infer_output_shape(op: *const ort_sys::OrtCustomOp, ctx: *mut ort_sys::OrtShapeInferContext) -> ort_sys::OrtStatusPtr {
let safe = Self::safe(op);
let mut ctx = ShapeInferenceContext { ptr: ctx };
safe.operator.infer_shape(&mut ctx).into_status()
Expand Down
6 changes: 2 additions & 4 deletions src/session/async.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use alloc::{ffi::CString, sync::Arc};
use core::{
cell::UnsafeCell,
ffi::c_char,
ffi::{c_char, c_void},
future::Future,
marker::PhantomData,
ops::Deref,
Expand All @@ -11,8 +11,6 @@ use core::{
};
use std::sync::Mutex;

use ort_sys::{OrtStatus, c_void};

use crate::{
error::Result,
session::{RunOptions, SelectedOutputMarker, SessionInputValue, SessionOutputs, SharedSessionInner},
Expand Down Expand Up @@ -137,7 +135,7 @@ pub(crate) struct AsyncInferenceContext<'r, 's, 'v> {
pub(crate) output_value_ptrs: Vec<*mut ort_sys::OrtValue>
}

pub(crate) extern "system" fn async_callback(user_data: *mut c_void, _: *mut *mut ort_sys::OrtValue, _: usize, status: *mut OrtStatus) {
pub(crate) extern "system" fn async_callback(user_data: *mut c_void, _: *mut *mut ort_sys::OrtValue, _: usize, status: ort_sys::OrtStatusPtr) {
let ctx = unsafe { Box::from_raw(user_data.cast::<AsyncInferenceContext<'_, '_, '_>>()) };

// Reconvert name ptrs to CString so drop impl is called and memory is freed
Expand Down
17 changes: 12 additions & 5 deletions src/session/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,14 @@ impl Session {
pub fn metadata(&self) -> Result<ModelMetadata<'_>> {
let mut metadata_ptr: *mut ort_sys::OrtModelMetadata = ptr::null_mut();
ortsys![unsafe SessionGetModelMetadata(self.inner.session_ptr.as_ptr(), &mut metadata_ptr)?; nonNull(metadata_ptr)];
Ok(ModelMetadata::new(unsafe { NonNull::new_unchecked(metadata_ptr) }, &self.inner.allocator))
Ok(ModelMetadata::new(unsafe { NonNull::new_unchecked(metadata_ptr) }))
}

/// Returns the time that profiling was started, in nanoseconds.
pub fn profiling_start_ns(&self) -> Result<u64> {
let mut out = 0;
ortsys![unsafe SessionGetProfilingStartTimeNs(self.inner.session_ptr.as_ptr(), &mut out)?];
Ok(out)
}

/// Ends profiling for this session.
Expand All @@ -529,7 +536,7 @@ impl Session {
pub fn end_profiling(&mut self) -> Result<String> {
let mut profiling_name: *mut c_char = ptr::null_mut();

ortsys![unsafe SessionEndProfiling(self.inner.session_ptr.as_ptr(), self.inner.allocator.ptr().cast_mut(), &mut profiling_name)];
ortsys![unsafe SessionEndProfiling(self.inner.session_ptr.as_ptr(), self.inner.allocator.ptr().cast_mut(), &mut profiling_name)?];
assert_non_null_pointer(profiling_name, "ProfilingName")?;
dangerous::raw_pointer_to_string(&self.inner.allocator, profiling_name)
}
Expand Down Expand Up @@ -619,7 +626,7 @@ mod dangerous {
}

fn extract_io_count(
f: unsafe extern "system" fn(*const ort_sys::OrtSession, *mut usize) -> *mut ort_sys::OrtStatus,
f: unsafe extern "system" fn(*const ort_sys::OrtSession, *mut usize) -> ort_sys::OrtStatusPtr,
session_ptr: NonNull<ort_sys::OrtSession>
) -> Result<usize> {
let mut num_nodes = 0;
Expand Down Expand Up @@ -651,7 +658,7 @@ mod dangerous {
}

fn extract_io_name(
f: unsafe extern "system" fn(*const ort_sys::OrtSession, usize, *mut ort_sys::OrtAllocator, *mut *mut c_char) -> *mut ort_sys::OrtStatus,
f: unsafe extern "system" fn(*const ort_sys::OrtSession, usize, *mut ort_sys::OrtAllocator, *mut *mut c_char) -> ort_sys::OrtStatusPtr,
session_ptr: NonNull<ort_sys::OrtSession>,
allocator: &Allocator,
i: usize
Expand Down Expand Up @@ -680,7 +687,7 @@ mod dangerous {
}

fn extract_io(
f: unsafe extern "system" fn(*const ort_sys::OrtSession, usize, *mut *mut ort_sys::OrtTypeInfo) -> *mut ort_sys::OrtStatus,
f: unsafe extern "system" fn(*const ort_sys::OrtSession, usize, *mut *mut ort_sys::OrtTypeInfo) -> ort_sys::OrtStatusPtr,
session_ptr: NonNull<ort_sys::OrtSession>,
i: usize
) -> Result<ValueType> {
Expand Down
6 changes: 3 additions & 3 deletions src/value/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ impl<Type: ValueTypeMarker + ?Sized> Value<Type> {
#[must_use]
pub unsafe fn from_ptr(ptr: NonNull<ort_sys::OrtValue>, session: Option<Arc<SharedSessionInner>>) -> Value<Type> {
let mut typeinfo_ptr = ptr::null_mut();
ortsys![unsafe GetTypeInfo(ptr.as_ptr(), &mut typeinfo_ptr)];
ortsys![unsafe GetTypeInfo(ptr.as_ptr(), &mut typeinfo_ptr).expect("infallible")];
Value {
inner: Arc::new(ValueInner {
ptr,
Expand All @@ -333,7 +333,7 @@ impl<Type: ValueTypeMarker + ?Sized> Value<Type> {
#[must_use]
pub(crate) unsafe fn from_ptr_nodrop(ptr: NonNull<ort_sys::OrtValue>, session: Option<Arc<SharedSessionInner>>) -> Value<Type> {
let mut typeinfo_ptr = ptr::null_mut();
ortsys![unsafe GetTypeInfo(ptr.as_ptr(), &mut typeinfo_ptr)];
ortsys![unsafe GetTypeInfo(ptr.as_ptr(), &mut typeinfo_ptr).expect("infallible")];
Value {
inner: Arc::new(ValueInner {
ptr,
Expand Down Expand Up @@ -383,7 +383,7 @@ impl Value<DynValueTypeMarker> {
/// ```
pub fn is_tensor(&self) -> bool {
let mut result = 0;
ortsys![unsafe IsTensor(self.ptr(), &mut result)]; // infallible
ortsys![unsafe IsTensor(self.ptr(), &mut result).expect("infallible")];
result == 1
}

Expand Down
Loading

0 comments on commit 9b0d1ea

Please sign in to comment.