@@ -160,6 +160,7 @@ def __init__(self,
160
160
last_activation : str = "None" ,
161
161
layer_norm : bool = False ,
162
162
emb_norm : bool = False ,
163
+ same_weights : bool = False ,
163
164
** kwargs ) -> None :
164
165
"""
165
166
Create a new recurrent encoder.
@@ -181,8 +182,10 @@ def __init__(self,
181
182
self .emb_size = emb_size
182
183
self .lila1 = nn .Linear (emb_size , hidden_size )
183
184
self .lila2 = nn .Linear (hidden_size , hidden_size )
184
- self .lila3 = nn .Linear (hidden_size , hidden_size )
185
- self .lila4 = nn .Linear (hidden_size , hidden_size )
185
+ self .same_weights = same_weights
186
+ if not self .same_weights :
187
+ self .lila3 = nn .Linear (hidden_size , hidden_size )
188
+ self .lila4 = nn .Linear (hidden_size , hidden_size )
186
189
self .activation = activation
187
190
self .last_activation = last_activation
188
191
self .conv1 = nn .Sequential (
@@ -201,6 +204,7 @@ def __init__(self,
201
204
self .norm_out = nn .LayerNorm (2 * hidden_size if bidirectional else hidden_size )
202
205
if self .emb_norm :
203
206
self .norm_emb = nn .LayerNorm (emb_size )
207
+ self .same_weights = same_weights
204
208
205
209
rnn = nn .GRU if rnn_type == "gru" else nn .LSTM
206
210
@@ -273,10 +277,17 @@ def forward(self, embed_src: Tensor, src_length: Tensor, mask: Tensor, \
273
277
if self .layer_norm :
274
278
conv_out1 = self .norm1 (conv_out1 )
275
279
276
- if self .activation == "tanh" :
277
- lila_out3 = torch .tanh (self .lila3 (conv_out1 ))
280
+ if not self .same_weights :
281
+ if self .activation == "tanh" :
282
+ lila_out3 = torch .tanh (self .lila3 (conv_out1 ))
283
+ else :
284
+ lila_out3 = torch .relu (self .lila3 (conv_out1 ))
278
285
else :
279
- lila_out3 = torch .relu (self .lila3 (conv_out1 ))
286
+ if self .activation == "tanh" :
287
+ lila_out3 = torch .tanh (self .lila2 (conv_out1 ))
288
+ else :
289
+ lila_out3 = torch .relu (self .lila2 (conv_out1 ))
290
+
280
291
lila_out3 = lila_out3 .transpose (1 ,2 )
281
292
282
293
conv_out2 = self .conv2 (lila_out3 )
@@ -286,10 +297,16 @@ def forward(self, embed_src: Tensor, src_length: Tensor, mask: Tensor, \
286
297
if self .layer_norm :
287
298
conv_out2 = self .norm2 (conv_out2 )
288
299
289
- if self .activation == "tanh" :
290
- lila_out4 = torch .tanh (self .lila4 (conv_out2 ))
300
+ if not self .same_weights :
301
+ if self .activation == "tanh" :
302
+ lila_out4 = torch .tanh (self .lila4 (conv_out1 ))
303
+ else :
304
+ lila_out4 = torch .relu (self .lila4 (conv_out1 ))
291
305
else :
292
- lila_out4 = torch .relu (self .lila4 (conv_out2 ))
306
+ if self .activation == "tanh" :
307
+ lila_out4 = torch .tanh (self .lila2 (conv_out1 ))
308
+ else :
309
+ lila_out4 = torch .relu (self .lila2 (conv_out1 ))
293
310
294
311
# apply dropout to the rnn input
295
312
lila_do = self .rnn_input_dropout (lila_out4 )
0 commit comments