-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathencoder.py
More file actions
22 lines (18 loc) · 1.12 KB
/
encoder.py
File metadata and controls
22 lines (18 loc) · 1.12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch
import torch.nn as nn
from generalRnn import BaseCoder
class Encoder(BaseCoder):
def __init__(self,vocab_size, hidden_size, embedding_size, input_dropout=0.0,output_dropout=0.0, n_layers=1, bidirectional=True,rnn="lstm", vocab=None, embeddings=None):
super(Encoder, self).__init__(vocab_size, hidden_size, embedding_size, input_dropout,output_dropout, n_layers, rnn, vocab, embeddings)
# TODO: add pretrained embeddings
self.rnn = self.baseModel(input_size=embedding_size, hidden_size=hidden_size, num_layers=n_layers,
batch_first=True, bidirectional=bidirectional, dropout=input_dropout)
for weight in self.rnn.parameters():
nn.init.uniform_(weight,-0.1, 0.1)
def forward(self, input_seq, input_lengths=None):
embedded = self.embedding(input_seq)
#embedded = self.input_dropout(embedded)
embedded = nn.utils.rnn.pack_padded_sequence(embedded, input_lengths, batch_first=True)
output, hidden = self.rnn(embedded)
output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=True)
return output, hidden