-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathnn_test.py
48 lines (42 loc) · 1.4 KB
/
nn_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from Transformer.nn_component import Lane_Encoder
import unittest
import torch
class TestNeuralNetwork(unittest.TestCase):
"""Testing class for nn_component"""
def setUp(self):
self.test_item = []
self.dictionary = {
"lane_encoder": {
"VEHICLE": {
"layers": 1,
"embedding_size": 2,
"output_channels": 16,
"output_size": 32,
"kernel_size": 4,
"strides": 2,
"dropout": 0.5,
}
}
}
def test_shape(self):
"""Test the output shape of nn_component"""
me_params = self.dictionary["lane_encoder"]["VEHICLE"]
test = Lane_Encoder(
me_params["layers"],
me_params["embedding_size"],
me_params["output_channels"],
me_params["output_size"],
me_params["kernel_size"],
me_params["strides"],
)
# parameterize
batch_size = 256
max_lane_num = 3
input_dim = 10
input_tensor = torch.randn(
256, 3, 2, input_dim).view(256 * 3, 2, input_dim)
output = test(input_tensor).view(256, 3, 32)
assert output.size() == (batch_size, max_lane_num,
me_params["output_size"])
if __name__ == "__main__":
unittest.main()