diff --git a/src/operator/tests.rs b/src/operator/tests.rs index 3ec78a75..ed683bd5 100644 --- a/src/operator/tests.rs +++ b/src/operator/tests.rs @@ -102,3 +102,95 @@ fn test_custom_ops() -> crate::Result<()> { Ok(()) } + +struct AttrTesterIntFloat; + +impl Operator for AttrTesterIntFloat { + fn name(&self) -> &str { + "AttrTesterIntFloat" + } + + fn inputs(&self) -> Vec { + vec![OperatorInput::required(TensorElementType::Float32)] + } + + fn outputs(&self) -> Vec { + vec![OperatorOutput::required(TensorElementType::Float32)] + } + + fn infer_shape(&self, ctx: &mut super::ShapeInferenceContext) -> crate::Result<()> { + assert!(matches!(ctx.attr("a_int"), Ok(1_i64))); + assert!(matches!(ctx.attr("a_float"), Ok(2.0_f32))); + assert!(matches!(ctx.attr::>("ints").as_deref(), Ok(&[3, 4, 5]))); + assert!(matches!(ctx.attr::>("floats").as_deref(), Ok(&[6., 7., 8.]))); + + ctx.set_output(0, &ctx.inputs()[0])?; + + Ok(()) + } + + fn create_kernel(&self, _: &KernelAttributes) -> crate::Result> { + Ok(Box::new(|ctx: &KernelContext| { + let x = ctx.input(0)?.ok_or_else(|| crate::Error::new("missing input"))?; + let (x_shape, x) = x.try_extract_raw_tensor::()?; + let mut z = ctx.output(0, x_shape.to_vec())?.ok_or_else(|| crate::Error::new("missing input"))?; + let (_, z_ref) = z.try_extract_raw_tensor_mut::()?; + for i in 0..x_shape.iter().copied().reduce(|acc, e| acc * e).unwrap_or(0) as usize { + z_ref[i] = x[i] * 2.; + } + Ok(()) + })) + } +} + +struct AttrTesterString; + +impl Operator for AttrTesterString { + fn name(&self) -> &str { + "AttrTesterString" + } + + fn inputs(&self) -> Vec { + vec![OperatorInput::required(TensorElementType::Float32)] + } + + fn outputs(&self) -> Vec { + vec![OperatorOutput::required(TensorElementType::Float32)] + } + + fn infer_shape(&self, ctx: &mut super::ShapeInferenceContext) -> crate::Result<()> { + assert!(matches!(ctx.attr::("a_string").as_deref(), Ok("iamastring"))); + + ctx.set_output(0, &ctx.inputs()[0])?; + + Ok(()) + } + + fn create_kernel(&self, _: &KernelAttributes) -> crate::Result> { + Ok(Box::new(|ctx: &KernelContext| { + let x = ctx.input(0)?.ok_or_else(|| crate::Error::new("missing input"))?; + let (x_shape, x) = x.try_extract_raw_tensor::()?; + let mut z = ctx.output(0, x_shape.to_vec())?.ok_or_else(|| crate::Error::new("missing input"))?; + let (_, z_ref) = z.try_extract_raw_tensor_mut::()?; + for i in 0..x_shape.iter().copied().reduce(|acc, e| acc * e).unwrap_or(0) as usize { + z_ref[i] = x[i] * 3.; + } + Ok(()) + })) + } +} + +#[test] +fn test_op_attrs() -> crate::Result<()> { + let model = std::fs::read("tests/data/attr_tester.onnx").expect(""); + let mut session = Session::builder()? + .with_operators(OperatorDomain::new("test.customop")?.add(AttrTesterIntFloat)?.add(AttrTesterString)?)? + .commit_from_memory(&model)?; + + let value1 = Tensor::from_array(([5], vec![1.0_f32, 2.0, 3.0, 4.0, 5.0]))?; + + let values = session.run(crate::inputs!["input_0" => &value1])?; + assert_eq!(values[0].try_extract_raw_tensor::()?.1, [6.0, 12.0, 18.0, 24.0, 30.0]); + + Ok(()) +} diff --git a/tests/data/attr_tester.onnx b/tests/data/attr_tester.onnx new file mode 100644 index 00000000..c99b4116 Binary files /dev/null and b/tests/data/attr_tester.onnx differ