Skip to content

Commit

Permalink
feat: added range onnx import
Browse files Browse the repository at this point in the history
  • Loading branch information
JachymPutta committed May 29, 2024
1 parent e61b026 commit 99f0f56
Show file tree
Hide file tree
Showing 10 changed files with 199 additions and 3 deletions.
2 changes: 1 addition & 1 deletion crates/burn-import/SUPPORTED-ONNX-OPS.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ represent the corresponding Burn Op.
| [RandomNormalLike][127] |||
| [RandomUniform][128] |||
| [RandomUniformLike][129] |||
| [Range][130] | ||
| [Range][130] | ||
| [Reciprocal][131] |||
| [ReduceL][132] |||
| [ReduceLogSum][133] |||
Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ fn main() {
.input("tests/squeeze/squeeze_opset13.onnx")
.input("tests/random_uniform/random_uniform.onnx")
.input("tests/random_normal/random_normal.onnx")
.input("tests/range/range.onnx")
.out_dir("model/")
.run_from_script();

Expand Down
16 changes: 16 additions & 0 deletions crates/burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ include_models!(
less,
less_or_equal,
prelu,
range,
recip,
reduce_max,
reduce_mean,
Expand Down Expand Up @@ -1055,6 +1056,21 @@ mod tests {
output.to_data().assert_approx_eq(&expected, 4);
}

#[test]
fn range() {
let device = Default::default();
let model: range::Model<Backend> = range::Model::new(&device);

// Run the model
let start = 0i64;
let limit = 10i64;
let delta = 2i64;
let output = model.forward(start, limit, delta);

let expected = Data::from([0, 2, 4, 6, 8]);
assert_eq!(output.to_data(), expected);
}

#[test]
fn recip() {
// Initialize the model
Expand Down
Binary file added crates/burn-import/onnx-tests/tests/range/range.onnx
Binary file not shown.
34 changes: 34 additions & 0 deletions crates/burn-import/onnx-tests/tests/range/range.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#!/usr/bin/env python3

# used to generate model: onnx-tests/tests/range/range.onnx

import onnx
from onnx import helper, TensorProto

def main():
node = onnx.helper.make_node(
'Range',
name='range',
inputs=['start', 'end', 'step'],
outputs=['output']
)

graph_def = helper.make_graph(
nodes=[node],
name='RangeGraph',
inputs=[
helper.make_tensor_value_info('start', TensorProto.INT64, []),
helper.make_tensor_value_info('end', TensorProto.INT64, []),
helper.make_tensor_value_info('step', TensorProto.INT64, [])
],
outputs=[
helper.make_tensor_value_info('output', TensorProto.INT64, [5])
],
)

model_def = helper.make_model(graph_def, producer_name='range')

onnx.save(model_def, 'range.onnx')

if __name__ == '__main__':
main()
7 changes: 5 additions & 2 deletions crates/burn-import/src/burn/node/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ use super::{
dropout::DropoutNode, gather::GatherNode, global_avg_pool::GlobalAvgPoolNode,
layer_norm::LayerNormNode, linear::LinearNode, mask_where::WhereNode, matmul::MatmulNode,
max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, prelu::PReluNode,
random_normal::RandomNormalNode, random_uniform::RandomUniformNode, reshape::ReshapeNode,
squeeze::SqueezeNode, unary::UnaryNode, unsqueeze::UnsqueezeNode,
random_normal::RandomNormalNode, random_uniform::RandomUniformNode, range::RangeNode,
reshape::ReshapeNode, squeeze::SqueezeNode, unary::UnaryNode, unsqueeze::UnsqueezeNode,
};
use crate::burn::{BurnImports, Scope, Type};
use burn::backend::NdArray;
Expand Down Expand Up @@ -95,6 +95,7 @@ pub enum Node<PS: PrecisionSettings> {
Matmul(MatmulNode),
MaxPool1d(MaxPool1dNode),
MaxPool2d(MaxPool2dNode),
Range(RangeNode),
Reshape(ReshapeNode),
Squeeze(SqueezeNode),
Unary(UnaryNode),
Expand Down Expand Up @@ -127,6 +128,7 @@ macro_rules! match_all {
Node::Matmul(node) => $func(node),
Node::MaxPool1d(node) => $func(node),
Node::MaxPool2d(node) => $func(node),
Node::Range(node) => $func(node),
Node::Reshape(node) => $func(node),
Node::Squeeze(node) => $func(node),
Node::Unary(node) => $func(node),
Expand Down Expand Up @@ -169,6 +171,7 @@ impl<PS: PrecisionSettings> Node<PS> {
Node::Matmul(_) => "matmul",
Node::MaxPool1d(_) => "max_pool1d",
Node::MaxPool2d(_) => "max_pool2d",
Node::Range(_) => "range",
Node::Reshape(_) => "reshape",
Node::Squeeze(_) => "squeeze",
Node::Unary(unary) => unary.kind.as_str(),
Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/src/burn/node/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ pub(crate) mod max_pool2d;
pub(crate) mod prelu;
pub(crate) mod random_normal;
pub(crate) mod random_uniform;
pub(crate) mod range;
pub(crate) mod reshape;
pub(crate) mod squeeze;
pub(crate) mod unary;
Expand Down
117 changes: 117 additions & 0 deletions crates/burn-import/src/burn/node/range.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
use super::{Node, NodeCodegen};
use crate::burn::{Scope, TensorType, Type};
use burn::record::PrecisionSettings;
use proc_macro2::TokenStream;
use quote::quote;

#[derive(Debug, Clone, new)]
pub struct RangeNode {
pub start: Type,
pub end: Type,
pub step: Type,
pub output: TensorType,
}

impl<PS: PrecisionSettings> NodeCodegen<PS> for RangeNode {
fn output_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.output.clone())]
}

fn input_types(&self) -> Vec<Type> {
vec![self.start.clone(), self.end.clone(), self.step.clone()]
}

fn forward(&self, _scope: &mut Scope, _node_position: usize) -> TokenStream {
let output = &self.output.name;

let start = match &self.start {
Type::Scalar(s) => {
let name = s.name.clone();
quote! { #name }
}
_ => panic!("Start must be a scalar"),
};

let end = match &self.end {
Type::Scalar(s) => {
let name = s.name.clone();
quote! { #name }
}
_ => panic!("End must be a scalar"),
};

let step = match &self.step {
Type::Scalar(s) => {
let name = s.name.clone();
quote! { #name }
}
_ => panic!("Step must be a scalar"),
};

quote! {
let #output = Tensor::arange_step(#start..#end, #step as usize, &*self.device);
}
}
fn into_node(self) -> Node<PS> {
Node::Range(self)
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::burn::graph::BurnGraph;
use crate::burn::node::test::assert_tokens;
use crate::burn::{ScalarKind, ScalarType};
use burn::record::FullPrecisionSettings;

#[test]
fn codegen_nodes_range() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();

graph.register(
RangeNode::new(
Type::Scalar(ScalarType::new("start", ScalarKind::Int64)),
Type::Scalar(ScalarType::new("end", ScalarKind::Int64)),
Type::Scalar(ScalarType::new("step", ScalarKind::Int64)),
TensorType::new_int("output", 1),
)
.into_node(),
);
graph.register_input_output(
vec!["start".to_string(), "end".to_string(), "step".to_string()],
vec!["output".to_string()],
);

let expected = quote! {
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};

#[derive(Module, Debug)]
pub struct Model<B: Backend> {
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}

impl<B: Backend> Model <B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
Self {
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, start: i64, end: i64, step: i64) -> Tensor<B, 1> {
let output = Tensor::arange_step(start..end, step as usize, &*self.device);

output
}
}
};

assert_tokens(graph.codegen(), expected);
}
}
13 changes: 13 additions & 0 deletions crates/burn-import/src/onnx/dim_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) {
NodeType::GreaterOrEqual => greater_or_equal_update_outputs(node),
NodeType::Less => less_update_outputs(node),
NodeType::LessOrEqual => less_or_equal_update_outputs(node),
NodeType::Range => range_update_outputs(node),
NodeType::Reciprocal => same_as_input(node),
NodeType::ReduceMax => reduce_max_update_outputs(node),
NodeType::ReduceMean => reduce_mean_update_outputs(node),
Expand Down Expand Up @@ -567,6 +568,18 @@ fn matmul_update_outputs(node: &mut Node) {
}
}

fn range_update_outputs(node: &mut Node) {
if node.inputs.len() != 3 {
panic!("Range: expected 3 inputs, found {}", node.inputs.len());
}

node.outputs[0].ty = ArgType::Tensor(TensorType {
elem_type: ElementType::Int64,
dim: 1,
shape: None,
});
}

/// Infers the shape of a ReduceMax node and replaces the shape of the output tensor.
fn reduce_max_update_outputs(node: &mut Node) {
if node.inputs.len() != 1 {
Expand Down
11 changes: 11 additions & 0 deletions crates/burn-import/src/onnx/to_burn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ use crate::{
prelu::PReluNode,
random_normal::RandomNormalNode,
random_uniform::RandomUniformNode,
range::RangeNode,
reshape::ReshapeNode,
squeeze::SqueezeNode,
unary::UnaryNode,
Expand Down Expand Up @@ -277,6 +278,7 @@ impl OnnxGraph {
NodeType::Tanh => graph.register(Self::tanh_conversion(node)),
NodeType::Constant => graph.register(Self::constant_conversion::<PS>(node)),
NodeType::Min => graph.register(Self::min_conversion(node)),
NodeType::Range => graph.register(Self::range_conversion(node)),
NodeType::ReduceMax => graph.register(Self::reduce_max_conversion(node)),
NodeType::ReduceMean => graph.register(Self::reduce_mean_conversion(node)),
NodeType::ReduceSum => graph.register(Self::reduce_sum_conversion(node)),
Expand Down Expand Up @@ -573,6 +575,15 @@ impl OnnxGraph {
BinaryNode::min_pair(lhs, rhs, output)
}

fn range_conversion(node: Node) -> RangeNode {
let output = node.outputs.first().unwrap().to_tensor_type();
let start = node.inputs.get(0).unwrap().to_type();

Check failure on line 580 in crates/burn-import/src/onnx/to_burn.rs

View workflow job for this annotation

GitHub Actions / tests (ubuntu-22.04, stable, std)

[clippy] reported by reviewdog 🐶 error: accessing first element with `node.inputs.get(0)` --> crates/burn-import/src/onnx/to_burn.rs:580:21 | 580 | let start = node.inputs.get(0).unwrap().to_type(); | ^^^^^^^^^^^^^^^^^^ help: try: `node.inputs.first()` | = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#get_first = note: `-D clippy::get-first` implied by `-D warnings` = help: to override `-D warnings` add `#[allow(clippy::get_first)]` Raw Output: crates/burn-import/src/onnx/to_burn.rs:580:21:e:error: accessing first element with `node.inputs.get(0)` --> crates/burn-import/src/onnx/to_burn.rs:580:21 | 580 | let start = node.inputs.get(0).unwrap().to_type(); | ^^^^^^^^^^^^^^^^^^ help: try: `node.inputs.first()` | = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#get_first = note: `-D clippy::get-first` implied by `-D warnings` = help: to override `-D warnings` add `#[allow(clippy::get_first)]` __END__
let end = node.inputs.get(1).unwrap().to_type();
let step = node.inputs.get(2).unwrap().to_type();

RangeNode::new(start, end, step, output)
}

fn reduce_max_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();
Expand Down

0 comments on commit 99f0f56

Please sign in to comment.