Skip to content

Commit

Permalink
change seeed
Browse files Browse the repository at this point in the history
  • Loading branch information
Arjun31415 committed May 7, 2024
1 parent 9f7c227 commit 52b3d51
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def forward(self, x1, x2, x3):

def main():
# Set seed for reproducibility
torch.manual_seed(42)
torch.manual_seed(1)

# Print options
torch.set_printoptions(precision=3)
Expand Down
32 changes: 16 additions & 16 deletions crates/burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -508,29 +508,29 @@ mod tests {
// Run the model
let input = Tensor::<Backend, 3>::from_floats(
[[
[1.927, 1.487, 0.901, -2.106, 0.678],
[-1.235, -0.043, -1.605, -0.752, -0.687],
[-0.493, 0.241, -1.111, 0.092, -2.317],
[-0.217, -1.385, -0.396, 0.803, -0.622],
[-0.592, -0.063, -0.829, 0.331, -1.558],
[-1.526, -0.750, -0.654, -1.609, -0.100],
[-0.609, -0.980, -1.609, -0.712, 1.171],
[1.767, -0.095, 0.139, -1.579, -0.321],
[-0.299, 1.879, 0.336, 0.275, 1.716],
[-0.056, 0.911, -1.392, 2.689, -0.111],
]],
&device,
);
let (output1, output2, output3) = model.forward(input.clone(), input.clone(), input);
let expected1 = Data::from([[[0.552], [-0.909], [-0.318], [-0.298], [-0.288]]]);
let expected1 = Data::from([[[-1.135], [-0.978], [0.058], [0.548], [0.538]]]);
let expected2 = Data::from([[
[0.854, 0.552, -0.132],
[-0.319, -0.909, -0.761],
[-0.063, -0.318, -0.834],
[-0.400, -0.298, -0.053],
[-0.164, -0.288, -0.514],
[-0.569, -1.135, -0.591],
[-0.397, -0.978, -0.288],
[0.418, 0.058, -0.440],
[0.395, 0.548, 0.582],
[0.214, 0.538, 0.296],
]]);
let expected3 = Data::from([[
[1.707, 0.552, -0.175],
[-0.639, -0.909, -1.014],
[-0.126, -0.318, -1.112],
[-0.801, -0.298, -0.071],
[-0.328, -0.288, -0.685],
[-1.138, -1.135, -0.788],
[-0.794, -0.978, -0.383],
[0.836, 0.058, -0.587],
[0.790, 0.548, 0.776],
[0.427, 0.538, 0.395],
]]);

let expected_shape1 = Shape::from([1, 5, 1]);
Expand Down

0 comments on commit 52b3d51

Please sign in to comment.