Skip to content

Commit

Permalink
feat: added expand to import
Browse files Browse the repository at this point in the history
  • Loading branch information
JachymPutta committed May 31, 2024
1 parent da8a522 commit 10cda31
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 8 deletions.
11 changes: 5 additions & 6 deletions crates/burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1124,13 +1124,12 @@ mod tests {
let device = Default::default();
let model: expand::Model<Backend> = expand::Model::new(&device);

let input1 = Tensor::<Backend, 4>::from_floats([[[[-1.0, 1.0, 42.0, 3.0]]]], &device);
let input2 = Tensor::<Backend, 1, Int>::from_ints([3, 2], &device);
let input1 = Tensor::<Backend, 2>::from_floats([[-1.0], [1.0]], &device);

// let output = model.forward(input1, input2);
// let expected_shape = Shape::from([3, 2]);
//
// assert_eq!(output.shape(), expected_shape);
let output = model.forward(input1);

Check failure on line 1129 in crates/burn-import/onnx-tests/tests/onnx_tests.rs

View workflow job for this annotation

GitHub Actions / tests (ubuntu-22.04, stable, std)

[clippy] reported by reviewdog 🐶 error[E0308]: mismatched types --> crates/burn-import/onnx-tests/tests/onnx_tests.rs:1129:36 | 1129 | let output = model.forward(input1); | ------- ^^^^^^ expected `1`, found `2` | | | arguments to this method are incorrect | = note: expected struct `burn::tensor::Tensor<_, _, 1>` found struct `burn::tensor::Tensor<_, _, 2>` note: method defined here --> /home/runner/work/burn/burn/target/debug/build/onnx-tests-e235b20d449f04aa/out/model/expand.rs:42:12 | 42 | pub fn forward(&self, input1: Tensor<B, 1>) -> Tensor<B, 2> { | ^^^^^^^ -------------------- Raw Output: crates/burn-import/onnx-tests/tests/onnx_tests.rs:1129:36:e:error[E0308]: mismatched types --> crates/burn-import/onnx-tests/tests/onnx_tests.rs:1129:36 | 1129 | let output = model.forward(input1); | ------- ^^^^^^ expected `1`, found `2` | | | arguments to this method are incorrect | = note: expected struct `burn::tensor::Tensor<_, _, 1>` found struct `burn::tensor::Tensor<_, _, 2>` note: method defined here --> /home/runner/work/burn/burn/target/debug/build/onnx-tests-e235b20d449f04aa/out/model/expand.rs:42:12 | 42 | pub fn forward(&self, input1: Tensor<B, 1>) -> Tensor<B, 2> { | ^^^^^^^ -------------------- __END__
let expected_shape = Shape::from([2, 2]);

assert_eq!(output.shape(), expected_shape);
}

#[test]
Expand Down
3 changes: 2 additions & 1 deletion crates/burn-import/src/onnx/from_onnx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@ use super::ir::{ArgType, Argument, Node, NodeType};

use protobuf::Message;

const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 9] = [
const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 10] = [
NodeType::BatchNormalization,
NodeType::Clip,
NodeType::Conv1d,
NodeType::Conv2d,
NodeType::Dropout,
NodeType::Expand,
NodeType::Reshape,
NodeType::Unsqueeze,
NodeType::ReduceSum,
Expand Down
1 change: 0 additions & 1 deletion crates/burn-import/src/onnx/to_burn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -878,7 +878,6 @@ impl OnnxGraph {
fn expand_conversion(node: Node) -> ExpandNode {
let input = node.inputs.first().unwrap().to_tensor_type();
let output = node.outputs.first().unwrap().to_tensor_type();
println!("{:?}", node);
let shape = expand_config(&node);

ExpandNode::new(input, output, shape)
Expand Down

0 comments on commit 10cda31

Please sign in to comment.