diff --git a/src/memory.rs b/src/memory.rs index 8ebe0d0e..ca44e32a 100644 --- a/src/memory.rs +++ b/src/memory.rs @@ -1,7 +1,7 @@ //! Types for managing memory & device allocations. use std::{ - ffi::{CString, c_char, c_int, c_void}, + ffi::{c_char, c_int, c_void}, mem, ptr::NonNull, sync::Arc @@ -248,22 +248,22 @@ impl Drop for AllocatedBlock<'_> { pub struct AllocationDevice(&'static str); impl AllocationDevice { - pub const CPU: AllocationDevice = AllocationDevice("Cpu"); - pub const CUDA: AllocationDevice = AllocationDevice("Cuda"); - pub const CUDA_PINNED: AllocationDevice = AllocationDevice("CudaPinned"); - pub const CANN: AllocationDevice = AllocationDevice("Cann"); - pub const CANN_PINNED: AllocationDevice = AllocationDevice("CannPinned"); - pub const DIRECTML: AllocationDevice = AllocationDevice("DML"); - pub const DIRECTML_CPU: AllocationDevice = AllocationDevice("DML CPU"); - pub const HIP: AllocationDevice = AllocationDevice("Hip"); - pub const HIP_PINNED: AllocationDevice = AllocationDevice("HipPinned"); - pub const OPENVINO_CPU: AllocationDevice = AllocationDevice("OpenVINO_CPU"); - pub const OPENVINO_GPU: AllocationDevice = AllocationDevice("OpenVINO_GPU"); - pub const XNNPACK: AllocationDevice = AllocationDevice("XnnpackExecutionProvider"); - pub const TVM: AllocationDevice = AllocationDevice("TVM"); + pub const CPU: AllocationDevice = AllocationDevice("Cpu\0"); + pub const CUDA: AllocationDevice = AllocationDevice("Cuda\0"); + pub const CUDA_PINNED: AllocationDevice = AllocationDevice("CudaPinned\0"); + pub const CANN: AllocationDevice = AllocationDevice("Cann\0"); + pub const CANN_PINNED: AllocationDevice = AllocationDevice("CannPinned\0"); + pub const DIRECTML: AllocationDevice = AllocationDevice("DML\0"); + pub const DIRECTML_CPU: AllocationDevice = AllocationDevice("DML CPU\0"); + pub const HIP: AllocationDevice = AllocationDevice("Hip\0"); + pub const HIP_PINNED: AllocationDevice = AllocationDevice("HipPinned\0"); + pub const OPENVINO_CPU: AllocationDevice = AllocationDevice("OpenVINO_CPU\0"); + pub const OPENVINO_GPU: AllocationDevice = AllocationDevice("OpenVINO_GPU\0"); + pub const XNNPACK: AllocationDevice = AllocationDevice("XnnpackExecutionProvider\0"); + pub const TVM: AllocationDevice = AllocationDevice("TVM\0"); pub fn as_str(&self) -> &'static str { - self.0 + &self.0[..self.0.len() - 1] } } @@ -390,9 +390,8 @@ impl MemoryInfo { /// ``` pub fn new(allocation_device: AllocationDevice, device_id: c_int, allocator_type: AllocatorType, memory_type: MemoryType) -> Result { let mut memory_info_ptr: *mut ort_sys::OrtMemoryInfo = std::ptr::null_mut(); - let allocator_name = CString::new(allocation_device.as_str()).unwrap_or_else(|_| unreachable!()); ortsys![ - unsafe CreateMemoryInfo(allocator_name.as_ptr(), allocator_type.into(), device_id, memory_type.into(), &mut memory_info_ptr)?; + unsafe CreateMemoryInfo(allocation_device.as_str().as_ptr().cast(), allocator_type.into(), device_id, memory_type.into(), &mut memory_info_ptr)?; nonNull(memory_info_ptr) ]; Ok(Self { @@ -470,7 +469,7 @@ impl MemoryInfo { ortsys![unsafe MemoryInfoGetName(self.ptr.as_ptr(), &mut name_ptr)]; // SAFETY: `name_ptr` can never be null - `CreateMemoryInfo` internally checks against builtin device names, erroring - // if a non-builtin device is passed, and ONNX Runtime will never supply a pointer to the C++ constructor + // if a non-builtin device is passed let mut len = 0; while unsafe { *name_ptr.add(len) } != 0x00 { @@ -479,7 +478,7 @@ impl MemoryInfo { // SAFETY: ONNX Runtime internally only ever defines allocation device names as ASCII. can't wait for this to blow up // one day regardless - let name = unsafe { std::str::from_utf8_unchecked(std::slice::from_raw_parts(name_ptr.cast::(), len)) }; + let name = unsafe { std::str::from_utf8_unchecked(std::slice::from_raw_parts(name_ptr.cast::(), len + 1)) }; AllocationDevice(name) }