@@ -181,6 +181,8 @@ def __init__(self,
181
181
self .emb_size = emb_size
182
182
self .lila1 = nn .Linear (emb_size , hidden_size )
183
183
self .lila2 = nn .Linear (hidden_size , hidden_size )
184
+ self .lila3 = nn .Linear (hidden_size , hidden_size )
185
+ self .lila4 = nn .Linear (hidden_size , hidden_size )
184
186
self .activation = activation
185
187
self .last_activation = last_activation
186
188
self .conv1 = nn .Sequential (
@@ -272,9 +274,9 @@ def forward(self, embed_src: Tensor, src_length: Tensor, mask: Tensor, \
272
274
conv_out1 = self .norm1 (conv_out1 )
273
275
274
276
if self .activation == "tanh" :
275
- lila_out3 = torch .tanh (self .lila2 (conv_out1 ))
277
+ lila_out3 = torch .tanh (self .lila3 (conv_out1 ))
276
278
else :
277
- lila_out3 = torch .relu (self .lila2 (conv_out1 ))
279
+ lila_out3 = torch .relu (self .lila3 (conv_out1 ))
278
280
lila_out3 = lila_out3 .transpose (1 ,2 )
279
281
280
282
conv_out2 = self .conv2 (lila_out3 )
@@ -285,9 +287,9 @@ def forward(self, embed_src: Tensor, src_length: Tensor, mask: Tensor, \
285
287
conv_out2 = self .norm2 (conv_out2 )
286
288
287
289
if self .activation == "tanh" :
288
- lila_out4 = torch .tanh (self .lila2 (conv_out2 ))
290
+ lila_out4 = torch .tanh (self .lila4 (conv_out2 ))
289
291
else :
290
- lila_out4 = torch .relu (self .lila2 (conv_out2 ))
292
+ lila_out4 = torch .relu (self .lila4 (conv_out2 ))
291
293
292
294
# apply dropout to the rnn input
293
295
lila_do = self .rnn_input_dropout (lila_out4 )
0 commit comments