Skip to content

Commit

Permalink
added onxx avgpool1d
Browse files Browse the repository at this point in the history
  • Loading branch information
Arjun31415 committed May 7, 2024
1 parent bd06b38 commit b9c086e
Show file tree
Hide file tree
Showing 12 changed files with 320 additions and 11 deletions.
4 changes: 2 additions & 2 deletions crates/burn-core/src/nn/pool/avg_pool1d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::tensor::Tensor;
use burn_tensor::module::avg_pool1d;

/// Configuration to create a [1D avg pooling](AvgPool1d) layer.
#[derive(Config)]
#[derive(Config, Debug)]
pub struct AvgPool1dConfig {
/// The size of the kernel.
pub kernel_size: usize,
Expand All @@ -20,7 +20,7 @@ pub struct AvgPool1dConfig {
pub padding: PaddingConfig1d,
/// If the padding is counted in the denominator when computing the average.
#[config(default = "true")]
count_include_pad: bool,
pub count_include_pad: bool,
}

/// Applies a 1D avg pooling over input tensors.
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-import/SUPPORTED-ONNX-OPS.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ represent the corresponding Burn Op.
| [Asinh][9] |||
| [Atan][10] |||
| [Atanh][11] |||
| [AveragePool1d][12] | ||
| [AveragePool1d][12] | ||
| [AveragePool2d][12] |||
| [BatchNormalization][14] |||
| [Bernoulli][15] |||
Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ fn main() {
ModelGen::new()
.input("tests/add/add_int.onnx")
.input("tests/add/add.onnx")
.input("tests/avg_pool1d/avg_pool1d.onnx")
.input("tests/avg_pool2d/avg_pool2d.onnx")
.input("tests/batch_norm/batch_norm.onnx")
.input("tests/cast/cast.onnx")
Expand Down
Binary file not shown.
58 changes: 58 additions & 0 deletions crates/burn-import/onnx-tests/tests/avg_pool1d/avg_pool1d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#!/usr/bin/env python3

# used to generate model: avg_pool1d.onnx

import torch
import torch.nn as nn


class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

self.pool1 = nn.AvgPool1d(4, stride=2)

self.pool2 = nn.AvgPool1d(4, stride=2, padding=2, count_include_pad=True)

self.pool3 = nn.AvgPool1d(4, stride=2, padding=2, count_include_pad=False)

def forward(self, x1, x2, x3):
y1 = self.pool1(x1)
y2 = self.pool2(x2)
y3 = self.pool3(x3)
return y1, y2, y3


def main():
# Set seed for reproducibility
torch.manual_seed(1)

# Print options
torch.set_printoptions(precision=3)

# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")

file_name = "avg_pool1d.onnx"
input1 = torch.randn(1, 5, 5, device=device)
torch.onnx.export(model, (input1, input1, input1), file_name,
verbose=False, opset_version=16)

print("Finished exporting model to {}".format(file_name))

# Output some test data for use in the test
print("Test input data shape: {}".format(input1.shape))
print("Test input data: {}".format(input1))
output1, output2, output3 = model.forward(input1, input1, input1)
print("Test output1 data shape: {}".format(output1.shape))
print("Test output2 data shape: {}".format(output2.shape))
print("Test output3 data shape: {}".format(output3.shape))
print("Test output1: {}".format(output1))
print("Test output2: {}".format(output2))
print("Test output3: {}".format(output3))


if __name__ == '__main__':
main()
48 changes: 48 additions & 0 deletions crates/burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ include_models!(
add_int,
add,
avg_pool2d,
avg_pool1d,
batch_norm,
cast,
clip_opset16,
Expand Down Expand Up @@ -498,6 +499,53 @@ mod tests {
assert_eq!(output.to_data(), expected);
}

#[test]
fn avg_pool1d() {
// Initialize the model without weights (because the exported file does not contain them)
let device = Default::default();
let model: avg_pool1d::Model<Backend> = avg_pool1d::Model::new(&device);

// Run the model
let input = Tensor::<Backend, 3>::from_floats(
[[
[-1.526, -0.750, -0.654, -1.609, -0.100],
[-0.609, -0.980, -1.609, -0.712, 1.171],
[1.767, -0.095, 0.139, -1.579, -0.321],
[-0.299, 1.879, 0.336, 0.275, 1.716],
[-0.056, 0.911, -1.392, 2.689, -0.111],
]],
&device,
);
let (output1, output2, output3) = model.forward(input.clone(), input.clone(), input);
let expected1 = Data::from([[[-1.135], [-0.978], [0.058], [0.548], [0.538]]]);
let expected2 = Data::from([[
[-0.569, -1.135, -0.591],
[-0.397, -0.978, -0.288],
[0.418, 0.058, -0.440],
[0.395, 0.548, 0.582],
[0.214, 0.538, 0.296],
]]);
let expected3 = Data::from([[
[-1.138, -1.135, -0.788],
[-0.794, -0.978, -0.383],
[0.836, 0.058, -0.587],
[0.790, 0.548, 0.776],
[0.427, 0.538, 0.395],
]]);

let expected_shape1 = Shape::from([1, 5, 1]);
let expected_shape2 = Shape::from([1, 5, 3]);
let expected_shape3 = Shape::from([1, 5, 3]);

assert_eq!(output1.shape(), expected_shape1);
assert_eq!(output2.shape(), expected_shape2);
assert_eq!(output3.shape(), expected_shape3);

output1.to_data().assert_approx_eq(&expected1, 3);
output2.to_data().assert_approx_eq(&expected2, 3);
output3.to_data().assert_approx_eq(&expected3, 3);
}

#[test]
fn avg_pool2d() {
// Initialize the model without weights (because the exported file does not contain them)
Expand Down
155 changes: 155 additions & 0 deletions crates/burn-import/src/burn/node/avg_pool1d.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
use proc_macro2::TokenStream;
use quote::quote;

use burn::{nn::pool::AvgPool1dConfig, record::PrecisionSettings};

use super::{Node, NodeCodegen};
use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type};

#[derive(Debug, Clone)]
pub struct AvgPool1dNode {
pub field: OtherType,
pub input: TensorType,
pub output: TensorType,
pub config: AvgPool1dConfig,
}

impl AvgPool1dNode {
pub fn new<S: AsRef<str>>(
name: S,
input: TensorType,
output: TensorType,
config: AvgPool1dConfig,
) -> Self {
Self {
field: OtherType::new(
name,
quote! {
AvgPool1d
},
),
input,
output,
config,
}
}
}

impl<PS: PrecisionSettings> NodeCodegen<PS> for AvgPool1dNode {
fn input_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.input.clone())]
}
fn output_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.output.clone())]
}
fn field_type(&self) -> Option<Type> {
Some(Type::Other(self.field.clone()))
}

fn field_init(&self) -> Option<TokenStream> {
let name = &self.field.name;
let kernel_size = self.config.kernel_size.to_tokens();
let strides = self.config.stride.to_tokens();
let padding = self.config.padding.to_tokens();
let count_include_pad = self.config.count_include_pad;

let tokens = quote! {
let #name = AvgPool1dConfig::new(#kernel_size)
.with_stride(#strides)
.with_padding(#padding)
.with_count_include_pad(#count_include_pad)
.init();
};

Some(tokens)
}

fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
let input = scope.tensor_use_owned(&self.input, node_position);
let output = &self.output.name;
let field = &self.field.name;

quote! {
let #output = self.#field.forward(#input);
}
}

fn register_imports(&self, imports: &mut BurnImports) {
imports.register("burn::nn::PaddingConfig1d");
imports.register("burn::nn::pool::AvgPool1d");
imports.register("burn::nn::pool::AvgPool1dConfig");
}

fn into_node(self) -> Node<PS> {
Node::AvgPool1d(self)
}

fn field_serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
S::serialize_none(serializer)
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::burn::{graph::BurnGraph, node::test::assert_tokens, TensorType};
use burn::{nn::PaddingConfig1d, record::FullPrecisionSettings};

#[test]
fn test_codegen() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();

graph.register(AvgPool1dNode::new(
"avg_pool1d",
TensorType::new_float("input", 3),
TensorType::new_float("output", 3),
AvgPool1dConfig::new(3)
.with_stride(1)
.with_padding(PaddingConfig1d::Valid),
));

graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]);

let expected = quote! {
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};
use burn::nn::PaddingConfig1d;
use burn::nn::pool::AvgPool1d;
use burn::nn::pool::AvgPool1dConfig;

#[derive(Module, Debug)]
pub struct Model <B: Backend> {
avg_pool1d: AvgPool1d,
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}

impl<B: Backend> Model <B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
let avg_pool1d = AvgPool1dConfig::new(3)
.with_stride(1)
.with_padding(PaddingConfig1d::Valid)
.with_count_include_pad(true)
.init();

Self {
avg_pool1d,
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
let output = self.avg_pool1d.forward(input);

output
}
}
};

assert_tokens(graph.codegen(), expected);
}
}
17 changes: 10 additions & 7 deletions crates/burn-import/src/burn/node/base.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use super::{
avg_pool2d::AvgPool2dNode, batch_norm::BatchNormNode, binary::BinaryNode, clip::ClipNode,
concat::ConcatNode, constant::ConstantNode, conv1d::Conv1dNode, conv2d::Conv2dNode,
conv_transpose_2d::ConvTranspose2dNode, dropout::DropoutNode, gather::GatherNode,
global_avg_pool::GlobalAvgPoolNode, layer_norm::LayerNormNode, linear::LinearNode,
mask_where::WhereNode, matmul::MatmulNode, max_pool1d::MaxPool1dNode,
max_pool2d::MaxPool2dNode, prelu::PReluNode, reshape::ReshapeNode, unary::UnaryNode,
unsqueeze::UnsqueezeNode,
avg_pool1d::AvgPool1dNode, avg_pool2d::AvgPool2dNode, batch_norm::BatchNormNode,
binary::BinaryNode, clip::ClipNode, concat::ConcatNode, constant::ConstantNode,
conv1d::Conv1dNode, conv2d::Conv2dNode, conv_transpose_2d::ConvTranspose2dNode,
dropout::DropoutNode, gather::GatherNode, global_avg_pool::GlobalAvgPoolNode,
layer_norm::LayerNormNode, linear::LinearNode, mask_where::WhereNode, matmul::MatmulNode,
max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, prelu::PReluNode, reshape::ReshapeNode,
unary::UnaryNode, unsqueeze::UnsqueezeNode,
};
use crate::burn::{BurnImports, Scope, Type};
use burn::backend::NdArray;
Expand Down Expand Up @@ -75,6 +75,7 @@ pub trait NodeCodegen<PS: PrecisionSettings>: std::fmt::Debug {

#[derive(Debug, Clone)]
pub enum Node<PS: PrecisionSettings> {
AvgPool1d(AvgPool1dNode),
AvgPool2d(AvgPool2dNode),
BatchNorm(BatchNormNode<PS>),
Binary(BinaryNode),
Expand Down Expand Up @@ -103,6 +104,7 @@ macro_rules! match_all {
($self:expr, $func:expr) => {{
#[allow(clippy::redundant_closure_call)]
match $self {
Node::AvgPool1d(node) => $func(node),
Node::AvgPool2d(node) => $func(node),
Node::BatchNorm(node) => $func(node),
Node::Binary(node) => $func(node),
Expand Down Expand Up @@ -141,6 +143,7 @@ impl<PS: PrecisionSettings> Serialize for Node<PS> {
impl<PS: PrecisionSettings> Node<PS> {
pub fn name(&self) -> &str {
match self {
Node::AvgPool1d(_) => "avg_pool1d",
Node::AvgPool2d(_) => "avg_pool2d",
Node::BatchNorm(_) => "batch_norm",
Node::Binary(binary) => binary.binary_type.as_str(),
Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/src/burn/node/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
mod base;

pub(crate) mod avg_pool1d;
pub(crate) mod avg_pool2d;
pub(crate) mod batch_norm;
pub(crate) mod binary;
Expand Down
2 changes: 2 additions & 0 deletions crates/burn-import/src/onnx/dim_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use super::{
pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) {
match node.node_type {
NodeType::Add => same_as_input(node),
NodeType::AveragePool1d => same_as_input(node),
NodeType::AveragePool2d => same_as_input(node),
NodeType::BatchNormalization => same_as_input(node),
NodeType::Cast => cast_update_outputs(node),
Expand All @@ -38,6 +39,7 @@ pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) {
NodeType::Log => same_as_input(node),
NodeType::LogSoftmax => same_as_input(node),
NodeType::MatMul => matmul_update_outputs(node),
NodeType::MaxPool1d => same_as_input(node),
NodeType::MaxPool2d => same_as_input(node),
NodeType::Mul => same_as_input(node),
NodeType::Neg => same_as_input(node),
Expand Down
Loading

0 comments on commit b9c086e

Please sign in to comment.