@@ -59,50 +59,34 @@ def test_transformer_encoder_forward(self):
59
59
torch .Size ([batch_size , time_dim , self .hidden_size ]))
60
60
self .assertEqual (hidden , None )
61
61
62
+ # yapf: disable
62
63
output_target = torch .Tensor ([
63
- [[
64
- 1.9728e-01 , - 1.2042e-01 , 8.0998e-02 , 1.3411e-03 , - 3.5960e-01 ,
65
- - 5.2988e-01 , - 5.6056e-01 , - 3.5297e-01 , 2.6680e-01 , 2.8343e-01 ,
66
- - 3.7342e-01 , - 5.9113e-03
67
- ],
68
- [
69
- 8.9687e-02 , - 1.2491e-01 , 7.7809e-02 , - 1.3499e-03 , - 2.7002e-01 ,
70
- - 4.7312e-01 , - 5.7981e-01 , - 4.1998e-01 , 1.0457e-01 , 2.9726e-01 ,
71
- - 3.9461e-01 , 8.1598e-02
72
- ],
73
- [
74
- 3.4988e-02 , - 1.3020e-01 , 6.0043e-02 , 2.7782e-02 , - 3.1483e-01 ,
75
- - 3.8940e-01 , - 5.5557e-01 , - 5.9540e-01 , - 2.9808e-02 , 3.1468e-01 ,
76
- - 4.5809e-01 , 4.3312e-03
77
- ],
78
- [
79
- 1.2234e-01 , - 1.3285e-01 , 6.3068e-02 , - 2.3343e-02 , - 2.3519e-01 ,
80
- - 4.0794e-01 , - 5.6063e-01 ,
81
- - 5.5484e-01 , - 1.1272e-01 ,
82
- 3.0103e-01 , - 4.0983e-01 , 3.3038e-02
83
- ]],
84
- [[
85
- 9.8597e-02 , - 1.2121e-01 , 1.0718e-01 , - 2.2644e-02 , - 4.0282e-01 ,
86
- - 4.2646e-01 , - 5.9981e-01 ,
87
- - 3.7200e-01 , 1.9538e-01 , 2.7036e-01 , - 3.4072e-01 , - 1.7965e-03
88
- ],
89
- [
90
- 8.8470e-02 , - 1.2618e-01 , 5.3351e-02 , - 1.8531e-02 , - 3.3834e-01 ,
91
- - 4.9047e-01 , - 5.7063e-01 , - 4.9790e-01 , 2.2070e-01 , 3.3964e-01 ,
92
- - 4.1604e-01 , 2.3519e-02
93
- ],
94
- [
95
- 5.8373e-02 , - 1.2706e-01 , 1.0598e-01 , 9.3256e-05 , - 3.0493e-01 ,
96
- - 4.4406e-01 , - 5.4723e-01 , - 5.2214e-01 , 8.0374e-02 , 2.6307e-01 ,
97
- - 4.4571e-01 , 8.7052e-02
98
- ],
99
- [
100
- 7.9567e-02 , - 1.2977e-01 , 1.1731e-01 , 2.6198e-02 , - 2.4024e-01 ,
101
- - 4.2161e-01 , - 5.7604e-01 , - 7.3298e-01 , 1.6698e-01 , 3.1454e-01 ,
102
- - 4.9189e-01 , 2.4027e-02
103
- ]]
64
+ [[1.9728e-01 , - 1.2042e-01 , 8.0998e-02 , 1.3411e-03 , - 3.5960e-01 ,
65
+ - 5.2988e-01 , - 5.6056e-01 , - 3.5297e-01 , 2.6680e-01 , 2.8343e-01 ,
66
+ - 3.7342e-01 , - 5.9112e-03 ],
67
+ [8.9687e-02 , - 1.2491e-01 , 7.7809e-02 , - 1.3500e-03 , - 2.7002e-01 ,
68
+ - 4.7312e-01 , - 5.7981e-01 , - 4.1998e-01 , 1.0457e-01 , 2.9726e-01 ,
69
+ - 3.9461e-01 , 8.1598e-02 ],
70
+ [3.4988e-02 , - 1.3020e-01 , 6.0043e-02 , 2.7782e-02 , - 3.1483e-01 ,
71
+ - 3.8940e-01 , - 5.5557e-01 , - 5.9540e-01 , - 2.9808e-02 , 3.1468e-01 ,
72
+ - 4.5809e-01 , 4.3313e-03 ],
73
+ [1.2234e-01 , - 1.3285e-01 , 6.3068e-02 , - 2.3343e-02 , - 2.3519e-01 ,
74
+ - 4.0794e-01 , - 5.6063e-01 , - 5.5484e-01 , - 1.1272e-01 , 3.0103e-01 ,
75
+ - 4.0983e-01 , 3.3038e-02 ]],
76
+ [[9.8597e-02 , - 1.2121e-01 , 1.0718e-01 , - 2.2644e-02 , - 4.0282e-01 ,
77
+ - 4.2646e-01 , - 5.9981e-01 , - 3.7200e-01 , 1.9538e-01 , 2.7036e-01 ,
78
+ - 3.4072e-01 , - 1.7965e-03 ],
79
+ [8.8470e-02 , - 1.2618e-01 , 5.3351e-02 , - 1.8531e-02 , - 3.3834e-01 ,
80
+ - 4.9047e-01 , - 5.7063e-01 , - 4.9790e-01 , 2.2070e-01 , 3.3964e-01 ,
81
+ - 4.1604e-01 , 2.3519e-02 ],
82
+ [5.8373e-02 , - 1.2706e-01 , 1.0598e-01 , 9.3255e-05 , - 3.0493e-01 ,
83
+ - 4.4406e-01 , - 5.4723e-01 , - 5.2214e-01 , 8.0374e-02 , 2.6307e-01 ,
84
+ - 4.4571e-01 , 8.7052e-02 ],
85
+ [7.9567e-02 , - 1.2977e-01 , 1.1731e-01 , 2.6198e-02 , - 2.4024e-01 ,
86
+ - 4.2161e-01 , - 5.7604e-01 , - 7.3298e-01 , 1.6698e-01 , 3.1454e-01 ,
87
+ - 4.9189e-01 , 2.4027e-02 ]],
104
88
])
105
- torch .testing .assert_close (output , output_target )
89
+ torch .testing .assert_close (output , output_target , rtol = 1e-4 , atol = 1e-4 )
106
90
107
91
for layer in encoder .layers :
108
92
self .assertTrue (isinstance (layer , TransformerEncoderLayer ))
@@ -118,7 +102,7 @@ def test_transformer_encoder_forward(self):
118
102
self .assertEqual (layer ._layer_norm_position , self .layer_norm )
119
103
120
104
121
- class TestSubsampler (TensorTestCase ):
105
+ class TestSubsampler (unittest . TestCase ):
122
106
123
107
def setUp (self ):
124
108
self .hidden_size = 12
@@ -149,32 +133,20 @@ def test_subsampler_forward(self):
149
133
# x shape [batch_size, seq_len, emb_dim]: [2, 9, 10] -> [2, 3, 12]
150
134
self .assertEqual (x .size (), torch .Size ([batch_size , 3 , self .hidden_size ]))
151
135
152
- x_target = torch .tensor ([[[
153
- - 0.4831 , - 0.0188 , - 0.0643 , 0.2323 , 0.1843 , - 0.0599 , 0.0333 , - 0.0295 , 0.0926 ,
154
- 0.0629 , 0.4416 , - 0.3737
155
- ],
156
- [
157
- - 0.0230 , 0.0513 , - 0.2007 , - 0.2211 , 0.7072 , 0.0523 ,
158
- - 0.0546 , 0.0382 , - 0.0606 , - 0.8240 , - 0.3379 ,
159
- - 0.7052
160
- ],
161
- [
162
- 0.0229 , 0.1770 , - 0.2644 , - 0.5954 , 0.8251 , - 0.0118 ,
163
- - 0.0228 , - 0.2697 , 0.1242 , 0.1570 , - 0.2263 , - 0.9022
164
- ]],
165
- [[
166
- - 0.4647 , 0.0986 , - 0.1160 , 0.0453 , 0.2717 , - 0.0112 ,
167
- 0.0018 , 0.0935 , 0.2077 , - 0.2647 , 0.3621 , - 0.4435
168
- ],
169
- [
170
- 0.0116 , - 0.1874 , - 0.0305 , - 0.5209 , 0.7063 ,
171
- - 0.0522 , 0.0577 , 0.4307 , 0.1027 , - 0.1947 , 0.0964 ,
172
- - 0.8076
173
- ],
174
- [
175
- - 0.2909 , - 0.0827 , - 0.1345 , - 0.4011 , 0.4482 ,
176
- 0.4247 , 0.2187 , - 0.2467 , 0.0096 , - 0.2841 , 0.0799 ,
177
- - 1.2243
178
- ]]])
179
- self .assertTensorAlmostEqual (x , x_target )
180
- self .assertTensorAlmostEqual (x_length , torch .tensor ([3 , 3 ]))
136
+ # yapf: disable
137
+ x_target = torch .tensor ([
138
+ [[- 0.4831 , - 0.0188 , - 0.0643 , 0.2323 , 0.1843 , - 0.0599 , 0.0333 ,
139
+ - 0.0295 , 0.0926 , 0.0629 , 0.4416 , - 0.3737 ],
140
+ [- 0.0230 , 0.0513 , - 0.2007 , - 0.2211 , 0.7072 , 0.0523 , - 0.0546 ,
141
+ 0.0382 , - 0.0606 , - 0.8240 , - 0.3379 , - 0.7052 ],
142
+ [0.0229 , 0.1770 , - 0.2644 , - 0.5954 , 0.8251 , - 0.0118 , - 0.0228 ,
143
+ - 0.2697 , 0.1242 , 0.1570 , - 0.2263 , - 0.9022 ]],
144
+ [[- 0.4647 , 0.0986 , - 0.1160 , 0.0453 , 0.2717 , - 0.0112 , 0.0018 ,
145
+ 0.0935 , 0.2077 , - 0.2647 , 0.3621 , - 0.4435 ],
146
+ [0.0116 , - 0.1874 , - 0.0305 , - 0.5209 , 0.7063 , - 0.0522 , 0.0577 ,
147
+ 0.4307 , 0.1027 , - 0.1947 , 0.0964 , - 0.8076 ],
148
+ [- 0.2909 , - 0.0827 , - 0.1345 , - 0.4011 , 0.4482 , 0.4247 , 0.2187 ,
149
+ - 0.2467 , 0.0096 , - 0.2841 , 0.0799 , - 1.2243 ]],
150
+ ])
151
+ torch .testing .assert_close (x , x_target , rtol = 1e-4 , atol = 1e-4 )
152
+ torch .testing .assert_close (x_length , torch .tensor ([3 , 3 ]))
0 commit comments