Skip to content

Commit 30cd283

Browse files
committed
extended to two weight options
1 parent 786695a commit 30cd283

File tree

2 files changed

+29
-11
lines changed

2 files changed

+29
-11
lines changed

configs/testme.yaml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,15 +63,16 @@ model: # specify your model architecture here
6363
embedding_dim: 20 # size of embeddings
6464
scale: False # scale the embeddings by sqrt of their size, default: False
6565
freeze: True # if True, embeddings are not updated during training
66-
hidden_size: 50 # size of RNN
66+
hidden_size: 256 # size of RNN
6767
bidirectional: True # use a bi-directional encoder, default: True
6868
dropout: 0.2 # apply dropout to the inputs to the RNN, default: 0.0
69-
num_layers: 1 # stack this many layers of equal size, default: 1
69+
num_layers: 3 # stack this many layers of equal size, default: 1
7070
freeze: False # if True, encoder parameters are not updated during training (does not include embedding parameters)
7171
activation: "tanh" # activation type for 2 layers following the src embeddings (only for speech), default: "relu", other options: "tanh"
7272
last_activation: "relu" # non-linear activation after RNNs in speech encoder, default: "None", other options: "tanh", "relu"
73-
layer_norm: False # layer normalization layers for 2 CNNs and RNN layer, default: False
73+
layer_norm: True # layer normalization layers for 2 CNNs and RNN layer, default: False
7474
emb_norm: False # layer normalization layers for embeddings, default: False
75+
same_weights: True # use same weights for linear layers, default: False
7576
decoder:
7677
rnn_type: "gru"
7778
embeddings:

joeynmt/encoders.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ def __init__(self,
160160
last_activation: str = "None",
161161
layer_norm: bool = False,
162162
emb_norm: bool = False,
163+
same_weights: bool = False,
163164
**kwargs) -> None:
164165
"""
165166
Create a new recurrent encoder.
@@ -181,8 +182,10 @@ def __init__(self,
181182
self.emb_size = emb_size
182183
self.lila1 = nn.Linear(emb_size, hidden_size)
183184
self.lila2 = nn.Linear(hidden_size, hidden_size)
184-
self.lila3 = nn.Linear(hidden_size, hidden_size)
185-
self.lila4 = nn.Linear(hidden_size, hidden_size)
185+
self.same_weights = same_weights
186+
if not self.same_weights:
187+
self.lila3 = nn.Linear(hidden_size, hidden_size)
188+
self.lila4 = nn.Linear(hidden_size, hidden_size)
186189
self.activation = activation
187190
self.last_activation = last_activation
188191
self.conv1 = nn.Sequential(
@@ -201,6 +204,7 @@ def __init__(self,
201204
self.norm_out = nn.LayerNorm(2 * hidden_size if bidirectional else hidden_size)
202205
if self.emb_norm:
203206
self.norm_emb = nn.LayerNorm(emb_size)
207+
self.same_weights = same_weights
204208

205209
rnn = nn.GRU if rnn_type == "gru" else nn.LSTM
206210

@@ -273,10 +277,17 @@ def forward(self, embed_src: Tensor, src_length: Tensor, mask: Tensor, \
273277
if self.layer_norm:
274278
conv_out1 = self.norm1(conv_out1)
275279

276-
if self.activation == "tanh":
277-
lila_out3 = torch.tanh(self.lila3(conv_out1))
280+
if not self.same_weights:
281+
if self.activation == "tanh":
282+
lila_out3 = torch.tanh(self.lila3(conv_out1))
283+
else:
284+
lila_out3 = torch.relu(self.lila3(conv_out1))
278285
else:
279-
lila_out3 = torch.relu(self.lila3(conv_out1))
286+
if self.activation == "tanh":
287+
lila_out3 = torch.tanh(self.lila2(conv_out1))
288+
else:
289+
lila_out3 = torch.relu(self.lila2(conv_out1))
290+
280291
lila_out3 = lila_out3.transpose(1,2)
281292

282293
conv_out2 = self.conv2(lila_out3)
@@ -286,10 +297,16 @@ def forward(self, embed_src: Tensor, src_length: Tensor, mask: Tensor, \
286297
if self.layer_norm:
287298
conv_out2 = self.norm2(conv_out2)
288299

289-
if self.activation == "tanh":
290-
lila_out4 = torch.tanh(self.lila4(conv_out2))
300+
if not self.same_weights:
301+
if self.activation == "tanh":
302+
lila_out4 = torch.tanh(self.lila4(conv_out1))
303+
else:
304+
lila_out4 = torch.relu(self.lila4(conv_out1))
291305
else:
292-
lila_out4 = torch.relu(self.lila4(conv_out2))
306+
if self.activation == "tanh":
307+
lila_out4 = torch.tanh(self.lila2(conv_out1))
308+
else:
309+
lila_out4 = torch.relu(self.lila2(conv_out1))
293310

294311
# apply dropout to the rnn input
295312
lila_do = self.rnn_input_dropout(lila_out4)

0 commit comments

Comments
 (0)