Skip to content

Commit

Permalink
custom op attrs test
Browse files Browse the repository at this point in the history
  • Loading branch information
decahedron1 committed Mar 6, 2025
1 parent e46689f commit aaed913
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 0 deletions.
92 changes: 92 additions & 0 deletions src/operator/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,95 @@ fn test_custom_ops() -> crate::Result<()> {

Ok(())
}

struct AttrTesterIntFloat;

impl Operator for AttrTesterIntFloat {
fn name(&self) -> &str {
"AttrTesterIntFloat"
}

fn inputs(&self) -> Vec<OperatorInput> {
vec![OperatorInput::required(TensorElementType::Float32)]
}

fn outputs(&self) -> Vec<OperatorOutput> {
vec![OperatorOutput::required(TensorElementType::Float32)]
}

fn infer_shape(&self, ctx: &mut super::ShapeInferenceContext) -> crate::Result<()> {
assert!(matches!(ctx.attr("a_int"), Ok(1_i64)));
assert!(matches!(ctx.attr("a_float"), Ok(2.0_f32)));
assert!(matches!(ctx.attr::<Vec<i64>>("ints").as_deref(), Ok(&[3, 4, 5])));
assert!(matches!(ctx.attr::<Vec<f32>>("floats").as_deref(), Ok(&[6., 7., 8.])));

ctx.set_output(0, &ctx.inputs()[0])?;

Ok(())
}

fn create_kernel(&self, _: &KernelAttributes) -> crate::Result<Box<dyn Kernel>> {
Ok(Box::new(|ctx: &KernelContext| {
let x = ctx.input(0)?.ok_or_else(|| crate::Error::new("missing input"))?;
let (x_shape, x) = x.try_extract_raw_tensor::<f32>()?;
let mut z = ctx.output(0, x_shape.to_vec())?.ok_or_else(|| crate::Error::new("missing input"))?;
let (_, z_ref) = z.try_extract_raw_tensor_mut::<f32>()?;
for i in 0..x_shape.iter().copied().reduce(|acc, e| acc * e).unwrap_or(0) as usize {
z_ref[i] = x[i] * 2.;
}
Ok(())
}))
}
}

struct AttrTesterString;

impl Operator for AttrTesterString {
fn name(&self) -> &str {
"AttrTesterString"
}

fn inputs(&self) -> Vec<OperatorInput> {
vec![OperatorInput::required(TensorElementType::Float32)]
}

fn outputs(&self) -> Vec<OperatorOutput> {
vec![OperatorOutput::required(TensorElementType::Float32)]
}

fn infer_shape(&self, ctx: &mut super::ShapeInferenceContext) -> crate::Result<()> {
assert!(matches!(ctx.attr::<String>("a_string").as_deref(), Ok("iamastring")));

ctx.set_output(0, &ctx.inputs()[0])?;

Ok(())
}

fn create_kernel(&self, _: &KernelAttributes) -> crate::Result<Box<dyn Kernel>> {
Ok(Box::new(|ctx: &KernelContext| {
let x = ctx.input(0)?.ok_or_else(|| crate::Error::new("missing input"))?;
let (x_shape, x) = x.try_extract_raw_tensor::<f32>()?;
let mut z = ctx.output(0, x_shape.to_vec())?.ok_or_else(|| crate::Error::new("missing input"))?;
let (_, z_ref) = z.try_extract_raw_tensor_mut::<f32>()?;
for i in 0..x_shape.iter().copied().reduce(|acc, e| acc * e).unwrap_or(0) as usize {
z_ref[i] = x[i] * 3.;
}
Ok(())
}))
}
}

#[test]
fn test_op_attrs() -> crate::Result<()> {
let model = std::fs::read("tests/data/attr_tester.onnx").expect("");
let mut session = Session::builder()?
.with_operators(OperatorDomain::new("test.customop")?.add(AttrTesterIntFloat)?.add(AttrTesterString)?)?
.commit_from_memory(&model)?;

let value1 = Tensor::from_array(([5], vec![1.0_f32, 2.0, 3.0, 4.0, 5.0]))?;

let values = session.run(crate::inputs!["input_0" => &value1])?;
assert_eq!(values[0].try_extract_raw_tensor::<f32>()?.1, [6.0, 12.0, 18.0, 24.0, 30.0]);

Ok(())
}
Binary file added tests/data/attr_tester.onnx
Binary file not shown.

0 comments on commit aaed913

Please sign in to comment.