Skip to content

Commit

Permalink
feat: shape inference attributes, Clone for KernelAttributes
Browse files Browse the repository at this point in the history
  • Loading branch information
decahedron1 committed Mar 5, 2025
1 parent d7cebd6 commit b4699a2
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 19 deletions.
7 changes: 5 additions & 2 deletions src/operator/bound.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use alloc::{boxed::Box, ffi::CString, vec::Vec};
use core::ptr;
use core::ptr::{self, NonNull};

use super::{
Operator, ShapeInferenceContext,
Expand Down Expand Up @@ -74,7 +74,10 @@ impl BoundOperator {
kernel_ptr: *mut *mut ort_sys::c_void
) -> ort_sys::OrtStatusPtr {
let safe = Self::safe(op);
let kernel = match safe.operator.create_kernel(&KernelAttributes::new(info)) {
let kernel = match safe
.operator
.create_kernel(&KernelAttributes::from_ptr(NonNull::new(info.cast_mut()).expect("infallible"), false))
{
Ok(kernel) => kernel,
e => return e.into_status()
};
Expand Down
150 changes: 134 additions & 16 deletions src/operator/kernel.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use alloc::{boxed::Box, ffi::CString, string::String, vec, vec::Vec};
use core::{
ffi::{c_char, c_void},
mem::size_of,
ops::{Deref, DerefMut},
ptr::{self, NonNull},
slice
Expand Down Expand Up @@ -28,34 +29,37 @@ where
}
}

pub struct KernelAttributes(NonNull<ort_sys::OrtKernelInfo>);
pub struct KernelAttributes {
ptr: NonNull<ort_sys::OrtKernelInfo>,
should_release: bool
}

impl KernelAttributes {
pub(crate) fn new(info: *const ort_sys::OrtKernelInfo) -> Self {
Self(NonNull::from(unsafe { &*info }))
pub(crate) fn from_ptr(ptr: NonNull<ort_sys::OrtKernelInfo>, should_release: bool) -> Self {
Self { ptr, should_release }
}

pub fn get<'s, T: GetKernelAttribute<'s>>(&'s self, name: impl AsRef<str>) -> Option<T> {
let name = CString::new(name.as_ref()).ok()?;
unsafe { T::get_from(self.0.as_ptr(), name.as_ptr()) }
unsafe { T::get_from(self.ptr.as_ptr(), name.as_ptr()) }
}

pub fn inputs(&self) -> Result<Vec<Input>> {
let mut num_inputs = 0;
ortsys![unsafe KernelInfo_GetInputCount(self.0.as_ptr(), &mut num_inputs)?];
ortsys![unsafe KernelInfo_GetInputCount(self.ptr.as_ptr(), &mut num_inputs)?];

let mut inputs = Vec::with_capacity(num_inputs);
for idx in 0..num_inputs {
let mut name_len = 0;
ortsys![unsafe KernelInfo_GetInputName(self.0.as_ptr(), idx, ptr::null_mut(), &mut name_len)?];
ortsys![unsafe KernelInfo_GetInputName(self.ptr.as_ptr(), idx, ptr::null_mut(), &mut name_len)?];
let mut name = vec![0u8; name_len];
ortsys![unsafe KernelInfo_GetInputName(self.0.as_ptr(), idx, name.as_mut_ptr().cast::<c_char>(), &mut name_len)?];
ortsys![unsafe KernelInfo_GetInputName(self.ptr.as_ptr(), idx, name.as_mut_ptr().cast::<c_char>(), &mut name_len)?];
let name = CString::from_vec_with_nul(name)
.map_err(Error::wrap)?
.into_string()
.map_err(Error::wrap)?;
let mut type_info = ptr::null_mut();
ortsys![unsafe KernelInfo_GetInputTypeInfo(self.0.as_ptr(), idx, &mut type_info)?; nonNull(type_info)];
ortsys![unsafe KernelInfo_GetInputTypeInfo(self.ptr.as_ptr(), idx, &mut type_info)?; nonNull(type_info)];
let input_type = ValueType::from_type_info(type_info);
inputs.push(Input { name, input_type })
}
Expand All @@ -64,56 +68,112 @@ impl KernelAttributes {

pub fn outputs(&self) -> Result<Vec<Output>> {
let mut num_outputs = 0;
ortsys![unsafe KernelInfo_GetOutputCount(self.0.as_ptr(), &mut num_outputs)?];
ortsys![unsafe KernelInfo_GetOutputCount(self.ptr.as_ptr(), &mut num_outputs)?];

let mut outputs = Vec::with_capacity(num_outputs);
for idx in 0..num_outputs {
let mut name_len = 0;
ortsys![unsafe KernelInfo_GetOutputName(self.0.as_ptr(), idx, ptr::null_mut(), &mut name_len)?];
ortsys![unsafe KernelInfo_GetOutputName(self.ptr.as_ptr(), idx, ptr::null_mut(), &mut name_len)?];
let mut name = vec![0u8; name_len];
ortsys![unsafe KernelInfo_GetOutputName(self.0.as_ptr(), idx, name.as_mut_ptr().cast::<c_char>(), &mut name_len)?];
ortsys![unsafe KernelInfo_GetOutputName(self.ptr.as_ptr(), idx, name.as_mut_ptr().cast::<c_char>(), &mut name_len)?];
let name = CString::from_vec_with_nul(name)
.map_err(Error::wrap)?
.into_string()
.map_err(Error::wrap)?;
let mut type_info = ptr::null_mut();
ortsys![unsafe KernelInfo_GetOutputTypeInfo(self.0.as_ptr(), idx, &mut type_info)?; nonNull(type_info)];
ortsys![unsafe KernelInfo_GetOutputTypeInfo(self.ptr.as_ptr(), idx, &mut type_info)?; nonNull(type_info)];
let output_type = ValueType::from_type_info(type_info);
outputs.push(Output { name, output_type })
}
Ok(outputs)
}

pub fn constant_input<T: DowncastableTarget>(&self, idx: usize) -> Result<ValueRef<'_, T>> {
let mut value_ptr: *const ort_sys::OrtValue = ptr::null();
let mut is_constant = 0;
ortsys![unsafe KernelInfoGetConstantInput_tensor(self.ptr.as_ptr(), idx, &mut is_constant, &mut value_ptr)?];
if is_constant == 0 || value_ptr.is_null() {
return Err(Error::new("input index out of bounds or input is not constant"));
}

unsafe { ValueRef::new(DynValue::from_ptr_nodrop(NonNull::new_unchecked(value_ptr.cast_mut()), None)) }.downcast()
}

pub fn node_name(&self) -> Result<String> {
let mut name_len = 0;
ortsys![unsafe KernelInfo_GetNodeName(self.0.as_ptr(), ptr::null_mut(), &mut name_len)?];
ortsys![unsafe KernelInfo_GetNodeName(self.ptr.as_ptr(), ptr::null_mut(), &mut name_len)?];
let mut name = vec![0u8; name_len];
ortsys![unsafe KernelInfo_GetNodeName(self.0.as_ptr(), name.as_mut_ptr().cast::<c_char>(), &mut name_len)?];
ortsys![unsafe KernelInfo_GetNodeName(self.ptr.as_ptr(), name.as_mut_ptr().cast::<c_char>(), &mut name_len)?];
CString::from_vec_with_nul(name).map_err(Error::wrap)?.into_string().map_err(Error::wrap)
}

pub fn allocator(&self, mem_type: MemoryType) -> Result<Allocator> {
let mut ptr: *mut ort_sys::OrtAllocator = ptr::null_mut();
ortsys![unsafe KernelInfoGetAllocator(self.0.as_ptr(), mem_type.into(), &mut ptr)?];
ortsys![unsafe KernelInfoGetAllocator(self.ptr.as_ptr(), mem_type.into(), &mut ptr)?];
Ok(unsafe { Allocator::from_raw_unchecked(ptr) })
}
}

impl Clone for KernelAttributes {
fn clone(&self) -> Self {
let mut out = ptr::null_mut();
ortsys![unsafe CopyKernelInfo(self.ptr.as_ptr(), &mut out).expect("failed to clone KernelAttributes")];
Self {
ptr: NonNull::new(out).expect("failed to clone KernelAttributes"),
should_release: true
}
}
}

impl AsPointer for KernelAttributes {
type Sys = ort_sys::OrtKernelInfo;

fn ptr(&self) -> *const Self::Sys {
self.0.as_ptr()
self.ptr.as_ptr()
}
}

impl Drop for KernelAttributes {
fn drop(&mut self) {
if self.should_release {
ortsys![unsafe ReleaseKernelInfo(self.ptr.as_ptr())];
}
}
}

pub trait GetKernelAttribute<'s> {
fn attr_type() -> Option<ort_sys::OrtOpAttrType> {
None
}

unsafe fn from_read_op(attr: *const ort_sys::OrtOpAttr, len: usize) -> Result<Self>
where
Self: Sized
{
let _ = (attr, len);
Err(Error::new("not implemented"))
}

unsafe fn get_from(info: *mut ort_sys::OrtKernelInfo, name: *const ort_sys::c_char) -> Option<Self>
where
Self: Sized;
}

impl GetKernelAttribute<'_> for f32 {
fn attr_type() -> Option<ort_sys::OrtOpAttrType> {
Some(ort_sys::OrtOpAttrType::ORT_OP_ATTR_FLOAT)
}

unsafe fn from_read_op(attr: *const ort_sys::OrtOpAttr, mut len: usize) -> Result<Self>
where
Self: Sized
{
let mut out = 0.0_f32;
ortsys![unsafe ReadOpAttr(attr, ort_sys::OrtOpAttrType::ORT_OP_ATTR_FLOAT, (&mut out as *mut f32).cast(), size_of::<f32>(), &mut len)?];
assert_eq!(len, size_of::<f32>());
Ok(out)
}

unsafe fn get_from(info: *mut ort_sys::OrtKernelInfo, name: *const ort_sys::c_char) -> Option<Self>
where
Self: Sized
Expand All @@ -126,6 +186,20 @@ impl GetKernelAttribute<'_> for f32 {
}

impl GetKernelAttribute<'_> for i64 {
fn attr_type() -> Option<ort_sys::OrtOpAttrType> {
Some(ort_sys::OrtOpAttrType::ORT_OP_ATTR_INT)
}

unsafe fn from_read_op(attr: *const ort_sys::OrtOpAttr, mut len: usize) -> Result<Self>
where
Self: Sized
{
let mut out = 0_i64;
ortsys![unsafe ReadOpAttr(attr, ort_sys::OrtOpAttrType::ORT_OP_ATTR_INT, (&mut out as *mut i64).cast(), size_of::<i64>(), &mut len)?];
assert_eq!(len, size_of::<i64>());
Ok(out)
}

unsafe fn get_from(info: *mut ort_sys::OrtKernelInfo, name: *const ort_sys::c_char) -> Option<Self>
where
Self: Sized
Expand All @@ -138,6 +212,22 @@ impl GetKernelAttribute<'_> for i64 {
}

impl GetKernelAttribute<'_> for String {
fn attr_type() -> Option<ort_sys::OrtOpAttrType> {
Some(ort_sys::OrtOpAttrType::ORT_OP_ATTR_STRING)
}

unsafe fn from_read_op(attr: *const ort_sys::OrtOpAttr, mut len: usize) -> Result<Self>
where
Self: Sized
{
let mut out = vec![0_u8; len / size_of::<u8>()];
ortsys![unsafe ReadOpAttr(attr, ort_sys::OrtOpAttrType::ORT_OP_ATTR_STRING, out.as_mut_ptr().cast(), len, &mut len)?];
assert_eq!(out.len(), len / size_of::<u8>());
CString::from_vec_with_nul(out)
.map_err(|_| Error::new("invalid string"))
.and_then(|f| f.into_string().map_err(|_| Error::new("invalid string")))
}

unsafe fn get_from(info: *mut ort_sys::OrtKernelInfo, name: *const ort_sys::c_char) -> Option<Self>
where
Self: Sized
Expand All @@ -153,6 +243,20 @@ impl GetKernelAttribute<'_> for String {
}

impl GetKernelAttribute<'_> for Vec<f32> {
fn attr_type() -> Option<ort_sys::OrtOpAttrType> {
Some(ort_sys::OrtOpAttrType::ORT_OP_ATTR_FLOATS)
}

unsafe fn from_read_op(attr: *const ort_sys::OrtOpAttr, mut len: usize) -> Result<Self>
where
Self: Sized
{
let mut out = vec![0.0_f32; len / size_of::<f32>()];
ortsys![unsafe ReadOpAttr(attr, ort_sys::OrtOpAttrType::ORT_OP_ATTR_FLOATS, out.as_mut_ptr().cast(), len, &mut len)?];
assert_eq!(out.len(), len / size_of::<f32>());
Ok(out)
}

unsafe fn get_from(info: *mut ort_sys::OrtKernelInfo, name: *const ort_sys::c_char) -> Option<Self>
where
Self: Sized
Expand All @@ -168,6 +272,20 @@ impl GetKernelAttribute<'_> for Vec<f32> {
}

impl GetKernelAttribute<'_> for Vec<i64> {
fn attr_type() -> Option<ort_sys::OrtOpAttrType> {
Some(ort_sys::OrtOpAttrType::ORT_OP_ATTR_INTS)
}

unsafe fn from_read_op(attr: *const ort_sys::OrtOpAttr, mut len: usize) -> Result<Self>
where
Self: Sized
{
let mut out = vec![0_i64; len / size_of::<i64>()];
ortsys![unsafe ReadOpAttr(attr, ort_sys::OrtOpAttrType::ORT_OP_ATTR_INTS, out.as_mut_ptr().cast(), len, &mut len)?];
assert_eq!(out.len(), len / size_of::<i64>());
Ok(out)
}

unsafe fn get_from(info: *mut ort_sys::OrtKernelInfo, name: *const ort_sys::c_char) -> Option<Self>
where
Self: Sized
Expand Down
17 changes: 16 additions & 1 deletion src/operator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ mod tests;
use self::{
bound::BoundOperator,
io::{OperatorInput, OperatorOutput},
kernel::{Kernel, KernelAttributes}
kernel::{GetKernelAttribute, Kernel, KernelAttributes}
};
use crate::{
AsPointer, Error,
Expand Down Expand Up @@ -82,6 +82,21 @@ impl ShapeInferenceContext {
tys
}

pub fn attr<'s, T: GetKernelAttribute<'s>>(&'s self, name: impl AsRef<str>) -> Result<T> {
let Some(attr_type) = T::attr_type() else {
return Err(Error::new("type is not supported as a ShapeInferenceContext attribute"));
};

let mut attr = ptr::null();
let name = CString::new(name.as_ref())?;
ortsys![unsafe ShapeInferContext_GetAttribute(self.ptr(), name.as_ptr(), &mut attr)?];

let mut len = 0;
let _ = ortsys![unsafe ReadOpAttr(attr, attr_type, ptr::null_mut(), 0, &mut len)];

unsafe { T::from_read_op(attr, len) }
}

pub fn set_output(&mut self, idx: usize, ty: &ValueType) -> Result<()> {
match ty.to_tensor_type_info() {
Some(ty_ptr) => {
Expand Down

0 comments on commit b4699a2

Please sign in to comment.