Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
JC committed Jul 6, 2024
1 parent 2c4ca7c commit 6181b4c
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 5 deletions.
5 changes: 4 additions & 1 deletion crates/burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -496,11 +496,14 @@ mod tests {
[
[1., 2., 3., 4., 5., 6., 7., 8., 9., 10.],
[11., 12., 13., 14., 15., 16., 17., 18., 19., 20.],
[21., 22., 23., 24., 25., 26., 27., 28., 29., 30.],
[31., 32., 33., 34., 35., 36., 37., 38., 39., 40.],
[41., 42., 43., 44., 45., 46., 47., 48., 49., 50.],
],
&device,
);
let output = model.forward(input);
let expected = TensorData::from([[1f32, 2., 3., 4., 5.]]);
let expected = TensorData::from([[1f32, 2., 3., 4., 5.], [11f32, 12., 13., 14., 15.]]);

output.to_data().assert_eq(&expected, true);
}
Expand Down
Binary file modified crates/burn-import/onnx-tests/tests/slice/slice.onnx
Binary file not shown.
8 changes: 4 additions & 4 deletions crates/burn-import/onnx-tests/tests/slice/slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

def main() -> None:
# Starts
starts_val = [-2,0] # Example shape value
starts_val = [-5,0] # Example shape value
starts_tensor = helper.make_tensor(
name="starts",
data_type=TensorProto.INT64,
Expand All @@ -23,7 +23,7 @@ def main() -> None:
)

# Ends
ends_val = [-1,5] # Example shape value
ends_val = [-3,-5] # Example shape value
ends_tensor = helper.make_tensor(
name="ends",
data_type=TensorProto.INT64,
Expand Down Expand Up @@ -83,10 +83,10 @@ def main() -> None:
nodes=[starts_node, ends_node, axes_node, steps_node, slice_node],
name="SliceGraph",
inputs=[
helper.make_tensor_value_info("input_tensor", TensorProto.FLOAT, [2, 10]),
helper.make_tensor_value_info("input_tensor", TensorProto.FLOAT, [5, 10]),
],
outputs=[
helper.make_tensor_value_info("output", TensorProto.FLOAT, [1, 5])
helper.make_tensor_value_info("output", TensorProto.FLOAT, [2, 5])
],
)

Expand Down

0 comments on commit 6181b4c

Please sign in to comment.