Skip to content

Commit

Permalink
fix(candle,tract): new OrtStatus changes
Browse files Browse the repository at this point in the history
  • Loading branch information
decahedron1 committed Feb 23, 2025
1 parent de932b5 commit 0539cd5
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 105 deletions.
94 changes: 47 additions & 47 deletions backends/candle/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use crate::{
tensor::TypeInfo
};

unsafe extern "system" fn CreateStatus(code: OrtErrorCode, msg: *const ::std::os::raw::c_char) -> *mut OrtStatus {
unsafe extern "system" fn CreateStatus(code: OrtErrorCode, msg: *const ::std::os::raw::c_char) -> OrtStatusPtr {
let msg = CString::from_raw(msg.cast_mut());
Error::new_sys(code, msg.to_string_lossy())
}
Expand All @@ -40,7 +40,7 @@ unsafe extern "system" fn GetErrorMessage(status: *const OrtStatus) -> *const ::

unsafe extern "system" fn CreateEnv(log_severity_level: OrtLoggingLevel, logid: *const ::std::os::raw::c_char, out: *mut *mut OrtEnv) -> OrtStatusPtr {
*out = Environment::new_sys();
ptr::null_mut()
OrtStatusPtr::default()
}

unsafe extern "system" fn CreateEnvWithCustomLogger(
Expand All @@ -51,15 +51,15 @@ unsafe extern "system" fn CreateEnvWithCustomLogger(
out: *mut *mut OrtEnv
) -> OrtStatusPtr {
*out = Environment::new_sys();
ptr::null_mut()
OrtStatusPtr::default()
}

unsafe extern "system" fn EnableTelemetryEvents(env: *const OrtEnv) -> OrtStatusPtr {
ptr::null_mut()
OrtStatusPtr::default()
}

unsafe extern "system" fn DisableTelemetryEvents(env: *const OrtEnv) -> OrtStatusPtr {
ptr::null_mut()
OrtStatusPtr::default()
}

unsafe extern "system" fn CreateSession(
Expand Down Expand Up @@ -88,7 +88,7 @@ unsafe extern "system" fn CreateSession(
match Session::from_buffer(options, &buf) {
Ok(session) => {
*out = (Box::leak(Box::new(session)) as *mut Session).cast();
ptr::null_mut()
OrtStatusPtr::default()
}
Err(e) => Error::new_sys(OrtErrorCode::ORT_FAIL, format!("Failed to parse model: {e}"))
}
Expand All @@ -108,7 +108,7 @@ unsafe extern "system" fn CreateSessionFromArray(
match Session::from_buffer(options, buf) {
Ok(session) => {
*out = (Box::leak(Box::new(session)) as *mut Session).cast();
ptr::null_mut()
OrtStatusPtr::default()
}
Err(e) => Error::new_sys(OrtErrorCode::ORT_FAIL, format!("Failed to parse model: {e}"))
}
Expand Down Expand Up @@ -154,15 +154,15 @@ unsafe extern "system" fn Run(
}
}

ptr::null_mut()
OrtStatusPtr::default()
}
Err(e) => Error::new_sys(OrtErrorCode::ORT_FAIL, format!("Failed to run session: {e}"))
}
}

unsafe extern "system" fn CreateSessionOptions(options: *mut *mut OrtSessionOptions) -> OrtStatusPtr {
*options = (Box::leak(Box::new(SessionOptions)) as *mut SessionOptions).cast();
ptr::null_mut()
OrtStatusPtr::default()
}

unsafe extern "system" fn SetOptimizedModelFilePath(options: *mut OrtSessionOptions, optimized_model_filepath: *const ortchar) -> OrtStatusPtr {
Expand All @@ -172,7 +172,7 @@ unsafe extern "system" fn SetOptimizedModelFilePath(options: *mut OrtSessionOpti
unsafe extern "system" fn CloneSessionOptions(in_options: *const OrtSessionOptions, out_options: *mut *mut OrtSessionOptions) -> OrtStatusPtr {
let options = unsafe { &*in_options.cast::<SessionOptions>() };
*out_options = (Box::leak(Box::new(options.clone())) as *mut SessionOptions).cast();
ptr::null_mut()
OrtStatusPtr::default()
}

unsafe extern "system" fn SetSessionExecutionMode(options: *mut OrtSessionOptions, execution_mode: ExecutionMode) -> OrtStatusPtr {
Expand Down Expand Up @@ -252,7 +252,7 @@ unsafe extern "system" fn SessionGetInputCount(session: *const OrtSession, out:
match session.model.graph.as_ref() {
Some(graph) => {
*out = graph.input.len();
ptr::null_mut()
OrtStatusPtr::default()
}
None => Error::new_sys(OrtErrorCode::ORT_NO_MODEL, "Graph is missing")
}
Expand All @@ -263,15 +263,15 @@ unsafe extern "system" fn SessionGetOutputCount(session: *const OrtSession, out:
match session.model.graph.as_ref() {
Some(graph) => {
*out = graph.output.len();
ptr::null_mut()
OrtStatusPtr::default()
}
None => Error::new_sys(OrtErrorCode::ORT_NO_MODEL, "Graph is missing")
}
}

unsafe extern "system" fn SessionGetOverridableInitializerCount(session: *const OrtSession, out: *mut usize) -> OrtStatusPtr {
*out = 0;
ptr::null_mut()
OrtStatusPtr::default()
}

unsafe extern "system" fn SessionGetInputTypeInfo(session: *const OrtSession, index: usize, type_info: *mut *mut OrtTypeInfo) -> OrtStatusPtr {
Expand Down Expand Up @@ -301,7 +301,7 @@ unsafe extern "system" fn SessionGetInputTypeInfo(session: *const OrtSession, in
}

*type_info = TypeInfo::new_sys(dtype, shape_out);
ptr::null_mut()
OrtStatusPtr::default()
}
_ => Error::new_sys(OrtErrorCode::ORT_FAIL, "Invalid type; only tensors are supported")
}
Expand Down Expand Up @@ -337,7 +337,7 @@ unsafe extern "system" fn SessionGetOutputTypeInfo(session: *const OrtSession, i
}

*type_info = TypeInfo::new_sys(dtype, shape_out);
ptr::null_mut()
OrtStatusPtr::default()
}
_ => Error::new_sys(OrtErrorCode::ORT_FAIL, "Invalid type; only tensors are supported")
}
Expand All @@ -361,7 +361,7 @@ unsafe extern "system" fn SessionGetInputName(
Some(graph) => {
let name = CString::new(&*graph.input[index].name).unwrap();
*value = name.into_raw();
ptr::null_mut()
OrtStatusPtr::default()
}
None => Error::new_sys(OrtErrorCode::ORT_NO_MODEL, "Graph is missing")
}
Expand All @@ -378,7 +378,7 @@ unsafe extern "system" fn SessionGetOutputName(
Some(graph) => {
let name = CString::new(&*graph.output[index].name).unwrap();
*value = name.into_raw();
ptr::null_mut()
OrtStatusPtr::default()
}
None => Error::new_sys(OrtErrorCode::ORT_NO_MODEL, "Graph is missing")
}
Expand Down Expand Up @@ -450,7 +450,7 @@ unsafe extern "system" fn CreateTensorAsOrtValue(
match Tensor::zeros(shape, dtype, mem_info.device()) {
Ok(tensor) => {
*out = (Box::leak(Box::new(tensor)) as *mut Tensor).cast();
ptr::null_mut()
OrtStatusPtr::default()
}
Err(e) => Error::new_sys(OrtErrorCode::ORT_EP_FAIL, format!("Failed to create tensor: {e}"))
}
Expand Down Expand Up @@ -479,15 +479,15 @@ unsafe extern "system" fn CreateTensorWithDataAsOrtValue(
match Tensor::from_raw_buffer(data_slice, dtype, &shape, mem_info.device()) {
Ok(tensor) => {
*out = (Box::leak(Box::new(tensor)) as *mut Tensor).cast();
ptr::null_mut()
OrtStatusPtr::default()
}
Err(e) => Error::new_sys(OrtErrorCode::ORT_EP_FAIL, format!("Failed to create tensor: {e}"))
}
}

unsafe extern "system" fn IsTensor(value: *const OrtValue, out: *mut ::std::os::raw::c_int) -> OrtStatusPtr {
*out = 1;
ptr::null_mut()
OrtStatusPtr::default()
}

unsafe extern "system" fn GetTensorMutableData(value: *mut OrtValue, out: *mut *mut ::std::os::raw::c_void) -> OrtStatusPtr {
Expand All @@ -504,7 +504,7 @@ unsafe extern "system" fn GetTensorMutableData(value: *mut OrtValue, out: *mut *
CpuStorage::F16(v) => v.as_ptr() as *mut _,
CpuStorage::BF16(v) => v.as_ptr() as *mut _
};
ptr::null_mut()
OrtStatusPtr::default()
}
_ => Error::new_sys(OrtErrorCode::ORT_NOT_IMPLEMENTED, "Unimplemented")
}
Expand All @@ -530,25 +530,25 @@ unsafe extern "system" fn GetStringTensorContent(

unsafe extern "system" fn CastTypeInfoToTensorInfo(type_info: *const OrtTypeInfo, out: *mut *const OrtTensorTypeAndShapeInfo) -> OrtStatusPtr {
*out = type_info.cast();
ptr::null_mut()
OrtStatusPtr::default()
}

unsafe extern "system" fn GetOnnxTypeFromTypeInfo(type_info: *const OrtTypeInfo, out: *mut ONNXType) -> OrtStatusPtr {
*out = ONNXType::ONNX_TYPE_TENSOR;
ptr::null_mut()
OrtStatusPtr::default()
}

unsafe extern "system" fn CreateTensorTypeAndShapeInfo(out: *mut *mut OrtTensorTypeAndShapeInfo) -> OrtStatusPtr {
*out = TypeInfo::new_sys(DType::F32, Vec::new()).cast();
ptr::null_mut()
OrtStatusPtr::default()
}

unsafe extern "system" fn SetTensorElementType(info: *mut OrtTensorTypeAndShapeInfo, type_: ONNXTensorElementDataType) -> OrtStatusPtr {
let info = unsafe { &mut *info.cast::<TypeInfo>() };
match convert_sys_to_dtype(type_) {
Ok(dtype) => {
info.dtype = dtype;
ptr::null_mut()
OrtStatusPtr::default()
}
Err(e) => e.into_sys()
}
Expand All @@ -557,27 +557,27 @@ unsafe extern "system" fn SetTensorElementType(info: *mut OrtTensorTypeAndShapeI
unsafe extern "system" fn SetDimensions(info: *mut OrtTensorTypeAndShapeInfo, dim_values: *const i64, dim_count: usize) -> OrtStatusPtr {
let info = unsafe { &mut *info.cast::<TypeInfo>() };
info.shape = unsafe { std::slice::from_raw_parts(dim_values.cast(), dim_count) }.to_vec();
ptr::null_mut()
OrtStatusPtr::default()
}

unsafe extern "system" fn GetTensorElementType(info: *const OrtTensorTypeAndShapeInfo, out: *mut ONNXTensorElementDataType) -> OrtStatusPtr {
let info = unsafe { &*info.cast::<TypeInfo>() };
*out = convert_dtype_to_sys(info.dtype);
ptr::null_mut()
OrtStatusPtr::default()
}

unsafe extern "system" fn GetDimensionsCount(info: *const OrtTensorTypeAndShapeInfo, out: *mut usize) -> OrtStatusPtr {
let info = unsafe { &*info.cast::<TypeInfo>() };
*out = info.shape.len();
ptr::null_mut()
OrtStatusPtr::default()
}

unsafe extern "system" fn GetDimensions(info: *const OrtTensorTypeAndShapeInfo, dim_values: *mut i64, dim_values_length: usize) -> OrtStatusPtr {
let info = unsafe { &*info.cast::<TypeInfo>() };
for (i, dim) in info.shape.iter().enumerate().take(dim_values_length) {
*dim_values.add(i) = *dim as _;
}
ptr::null_mut()
OrtStatusPtr::default()
}

unsafe extern "system" fn GetSymbolicDimensions(
Expand All @@ -588,7 +588,7 @@ unsafe extern "system" fn GetSymbolicDimensions(
for i in 0..dim_params_length {
*dim_params.add(i) = ptr::null();
}
ptr::null_mut()
OrtStatusPtr::default()
}

unsafe extern "system" fn GetTensorShapeElementCount(info: *const OrtTensorTypeAndShapeInfo, out: *mut usize) -> OrtStatusPtr {
Expand All @@ -598,24 +598,24 @@ unsafe extern "system" fn GetTensorShapeElementCount(info: *const OrtTensorTypeA
size *= *dim as usize;
}
*out = size;
ptr::null_mut()
OrtStatusPtr::default()
}

unsafe extern "system" fn GetTensorTypeAndShape(value: *const OrtValue, out: *mut *mut OrtTensorTypeAndShapeInfo) -> OrtStatusPtr {
let tensor = unsafe { &*value.cast::<Tensor>() };
*out = TypeInfo::new_sys(tensor.dtype(), tensor.shape().dims().iter().map(|c| *c as i64).collect()).cast();
ptr::null_mut()
OrtStatusPtr::default()
}

unsafe extern "system" fn GetTypeInfo(value: *const OrtValue, out: *mut *mut OrtTypeInfo) -> OrtStatusPtr {
let tensor = unsafe { &*value.cast::<Tensor>() };
*out = TypeInfo::new_sys(tensor.dtype(), tensor.shape().dims().iter().map(|c| *c as i64).collect());
ptr::null_mut()
OrtStatusPtr::default()
}

unsafe extern "system" fn GetValueType(value: *const OrtValue, out: *mut ONNXType) -> OrtStatusPtr {
*out = ONNXType::ONNX_TYPE_TENSOR;
ptr::null_mut()
OrtStatusPtr::default()
}

unsafe extern "system" fn CreateMemoryInfo(
Expand All @@ -629,7 +629,7 @@ unsafe extern "system" fn CreateMemoryInfo(
match MemoryInfo::new(device_name.to_string_lossy(), id as _, mem_type) {
Ok(inf) => {
unsafe { *out = (Box::leak(Box::new(inf)) as *mut MemoryInfo).cast() };
ptr::null_mut()
OrtStatusPtr::default()
}
Err(e) => e.into_sys()
}
Expand All @@ -639,7 +639,7 @@ unsafe extern "system" fn CreateCpuMemoryInfo(type_: OrtAllocatorType, mem_type:
match MemoryInfo::new("Cpu", 0, mem_type) {
Ok(inf) => {
unsafe { *out = (Box::leak(Box::new(inf)) as *mut MemoryInfo).cast() };
ptr::null_mut()
OrtStatusPtr::default()
}
Err(e) => e.into_sys()
}
Expand All @@ -649,53 +649,53 @@ unsafe extern "system" fn CompareMemoryInfo(info1: *const OrtMemoryInfo, info2:
let info1 = unsafe { &*info1.cast::<MemoryInfo>() };
let info2 = unsafe { &*info2.cast::<MemoryInfo>() };
*out = if info1 == info2 { 0 } else { -1 };
ptr::null_mut()
OrtStatusPtr::default()
}

unsafe extern "system" fn MemoryInfoGetName(ptr: *const OrtMemoryInfo, out: *mut *const ::std::os::raw::c_char) -> OrtStatusPtr {
let info = unsafe { &*ptr.cast::<MemoryInfo>() };
*out = info.device_name_sys().as_ptr().cast();
ptr::null_mut()
OrtStatusPtr::default()
}

unsafe extern "system" fn MemoryInfoGetId(ptr: *const OrtMemoryInfo, out: *mut ::std::os::raw::c_int) -> OrtStatusPtr {
let info = unsafe { &*ptr.cast::<MemoryInfo>() };
*out = info.device_id() as _;
ptr::null_mut()
OrtStatusPtr::default()
}

unsafe extern "system" fn MemoryInfoGetMemType(ptr: *const OrtMemoryInfo, out: *mut OrtMemType) -> OrtStatusPtr {
let info = unsafe { &*ptr.cast::<MemoryInfo>() };
*out = info.memory_type() as _;
ptr::null_mut()
OrtStatusPtr::default()
}

unsafe extern "system" fn MemoryInfoGetType(ptr: *const OrtMemoryInfo, out: *mut OrtAllocatorType) -> OrtStatusPtr {
*out = OrtAllocatorType::OrtDeviceAllocator;
ptr::null_mut()
OrtStatusPtr::default()
}

unsafe extern "system" fn AllocatorAlloc(ort_allocator: *mut OrtAllocator, size: usize, out: *mut *mut ::std::os::raw::c_void) -> OrtStatusPtr {
*out = unsafe { &*ort_allocator }.Alloc.unwrap()(ort_allocator, size);
if unsafe { *out }.is_null() {
return Error::new_sys(OrtErrorCode::ORT_RUNTIME_EXCEPTION, "Allocation failed");
}
ptr::null_mut()
OrtStatusPtr::default()
}

unsafe extern "system" fn AllocatorFree(ort_allocator: *mut OrtAllocator, p: *mut ::std::os::raw::c_void) -> OrtStatusPtr {
unsafe { &*ort_allocator }.Free.unwrap()(ort_allocator, p);
ptr::null_mut()
OrtStatusPtr::default()
}

unsafe extern "system" fn AllocatorGetInfo(ort_allocator: *const OrtAllocator, out: *mut *const OrtMemoryInfo) -> OrtStatusPtr {
*out = unsafe { &*ort_allocator }.Info.unwrap()(ort_allocator);
ptr::null_mut()
OrtStatusPtr::default()
}

unsafe extern "system" fn GetAllocatorWithDefaultOptions(out: *mut *mut OrtAllocator) -> OrtStatusPtr {
*out = (&crate::memory::DEFAULT_CPU_ALLOCATOR as *const Allocator).cast_mut().cast();
ptr::null_mut()
OrtStatusPtr::default()
}

unsafe extern "system" fn AddFreeDimensionOverride(
Expand Down Expand Up @@ -972,7 +972,7 @@ unsafe extern "system" fn AddSessionConfigEntry(
unsafe extern "system" fn CreateAllocator(session: *const OrtSession, mem_info: *const OrtMemoryInfo, out: *mut *mut OrtAllocator) -> OrtStatusPtr {
let mem_info = unsafe { &*mem_info.cast::<MemoryInfo>() };
*out = (Box::leak(Box::new(Allocator::new(mem_info))) as *mut Allocator).cast();
ptr::null_mut()
OrtStatusPtr::default()
}

unsafe extern "system" fn ReleaseAllocator(input: *mut OrtAllocator) {
Expand Down Expand Up @@ -1370,7 +1370,7 @@ unsafe extern "system" fn GetTensorMemoryInfo(value: *const OrtValue, mem_info:
let tensor = unsafe { &*value.cast::<Tensor>() };
// `MemoryInfo` is #[repr(transparent)], so &MemoryInfo is &Device.
*mem_info = (tensor.device() as *const Device).cast();
ptr::null_mut()
OrtStatusPtr::default()
}

unsafe extern "system" fn GetExecutionProviderApi(
Expand Down
Loading

0 comments on commit 0539cd5

Please sign in to comment.