Skip to content

Commit 786695a

Browse files
committed
bug fix to lila weights
1 parent 8d74f50 commit 786695a

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

joeynmt/encoders.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,8 @@ def __init__(self,
181181
self.emb_size = emb_size
182182
self.lila1 = nn.Linear(emb_size, hidden_size)
183183
self.lila2 = nn.Linear(hidden_size, hidden_size)
184+
self.lila3 = nn.Linear(hidden_size, hidden_size)
185+
self.lila4 = nn.Linear(hidden_size, hidden_size)
184186
self.activation = activation
185187
self.last_activation = last_activation
186188
self.conv1 = nn.Sequential(
@@ -272,9 +274,9 @@ def forward(self, embed_src: Tensor, src_length: Tensor, mask: Tensor, \
272274
conv_out1 = self.norm1(conv_out1)
273275

274276
if self.activation == "tanh":
275-
lila_out3 = torch.tanh(self.lila2(conv_out1))
277+
lila_out3 = torch.tanh(self.lila3(conv_out1))
276278
else:
277-
lila_out3 = torch.relu(self.lila2(conv_out1))
279+
lila_out3 = torch.relu(self.lila3(conv_out1))
278280
lila_out3 = lila_out3.transpose(1,2)
279281

280282
conv_out2 = self.conv2(lila_out3)
@@ -285,9 +287,9 @@ def forward(self, embed_src: Tensor, src_length: Tensor, mask: Tensor, \
285287
conv_out2 = self.norm2(conv_out2)
286288

287289
if self.activation == "tanh":
288-
lila_out4 = torch.tanh(self.lila2(conv_out2))
290+
lila_out4 = torch.tanh(self.lila4(conv_out2))
289291
else:
290-
lila_out4 = torch.relu(self.lila2(conv_out2))
292+
lila_out4 = torch.relu(self.lila4(conv_out2))
291293

292294
# apply dropout to the rnn input
293295
lila_do = self.rnn_input_dropout(lila_out4)

0 commit comments

Comments
 (0)