-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathnn_blocks.py
More file actions
131 lines (105 loc) · 4.93 KB
/
nn_blocks.py
File metadata and controls
131 lines (105 loc) · 4.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import gensim
class DAEncoder(nn.Module):
def __init__(self, da_input_size, da_embed_size,da_hidden):
super(DAEncoder, self).__init__()
self.hidden_size = da_hidden
self.xe = nn.Embedding(da_input_size, da_embed_size)
self.eh = nn.Linear(da_embed_size, da_hidden)
def forward(self, DA):
embedding = torch.tanh(self.eh(self.xe(DA))) # (batch_size, 1) -> (batch_size, 1, hidden_size)
return embedding
def initHidden(self, batch_size, device):
return torch.zeros(batch_size, self.hidden_size).to(device)
class DAContextEncoder(nn.Module):
def __init__(self, da_hidden):
super(DAContextEncoder, self).__init__()
self.hidden_size = da_hidden
self.hh = nn.GRU(da_hidden, da_hidden, batch_first=True)
def forward(self, input_hidden, prev_hidden):
output = input_hidden
output, hidden = self.hh(output, prev_hidden)
return output, hidden
def initHidden(self, batch_size, device):
# h_0 = (num_layers * num_directions, batch_size, hidden_size)
return torch.zeros(1, batch_size, self.hidden_size).to(device)
class DADecoder(nn.Module):
def __init__(self, da_input_size, da_embed_size, da_hidden):
super(DADecoder, self).__init__()
self.he = nn.Linear(da_hidden, da_embed_size)
self.ey = nn.Linear(da_embed_size, da_input_size)
def forward(self, hidden):
pred = self.ey(F.tanh(self.he(hidden)))
return pred
class UtteranceEncoder(nn.Module):
def __init__(self, utt_input_size, embed_size, utterance_hidden, padding_idx, fine_tuning):
super(UtteranceEncoder, self).__init__()
self.hidden_size = utterance_hidden
self.padding_idx = padding_idx
self.xe = nn.Embedding(utt_input_size, embed_size)
self.xe.weight.requires_grad = False if fine_tuning else True
self.eh = nn.Linear(embed_size, utterance_hidden)
self.hh = nn.GRU(utterance_hidden, utterance_hidden, num_layers=1, batch_first=True, bidirectional=True)
def forward(self, X, hidden):
lengths = (X != self.padding_idx).sum(dim=1)
seq_len, sort_idx = lengths.sort(descending=True)
_, unsort_idx = sort_idx.sort(descending=False)
# sorting
X = F.tanh(self.eh(self.xe(X))) # (batch_size, 1, seq_len, embed_size)
sorted_X = X[sort_idx]
# padding
packed_tensor = pack_padded_sequence(sorted_X, seq_len, batch_first=True)
output, hidden = self.hh(packed_tensor, hidden)
# unpacking
output, _ = pad_packed_sequence(output, batch_first=True)
# extract last timestep output
idx = (lengths - 1).view(-1, 1).expand(output.size(0), output.size(2)).unsqueeze(1)
output = output.gather(1, idx)
# unsorting
output = output[unsort_idx]
hidden = hidden[:, unsort_idx]
return output, hidden
def initHidden(self, batch_size, device):
return torch.zeros(2, batch_size, self.hidden_size).to(device)
class UtteranceContextEncoder(nn.Module):
def __init__(self, utterance_hidden_size):
super(UtteranceContextEncoder, self).__init__()
self.hidden_size = utterance_hidden_size
self.hh = nn.GRU(utterance_hidden_size, utterance_hidden_size, batch_first=True)
def forward(self, input_hidden, prev_hidden):
output = input_hidden
output, hidden = self.hh(output, prev_hidden)
return output, hidden
def initHidden(self, batch_size, device):
return torch.zeros(1, batch_size, self.hidden_size).to(device)
class UtteranceDecoder(nn.Module):
def __init__(self, utterance_hidden_size, utt_embed_size, utt_vocab_size):
super(UtteranceDecoder, self).__init__()
self.hidden_size = utterance_hidden_size
self.embed_size = utt_embed_size
self.vocab_size = utt_vocab_size
self.ye = nn.Embedding(self.vocab_size, self.embed_size)
self.eh = nn.Linear(self.embed_size, self.hidden_size)
self.hh = nn.GRU(self.hidden_size, self.hidden_size, batch_first=True)
self.he = nn.Linear(self.hidden_size, self.embed_size)
self.ey = nn.Linear(self.embed_size, self.vocab_size)
def forward(self, Y, hidden):
h = F.tanh(self.eh(self.ye(Y)))
output, hidden = self.hh(h, hidden)
y_dist = self.ey(torch.tanh(self.he(output.squeeze(1))))
return y_dist, hidden, output
class BeamNode(object):
def __init__(self, hidden, previousNode, wordId, logProb, length):
self.hidden = hidden
self.prevNode = previousNode
self.wordid = wordId
self.logp = logProb
self.length = length
def eval(self, alpha=1.0):
reward = 0
return self.logp / float(self.length - 1 + 1e-6) + alpha * reward
def __lt__(self, other):
return self.eval() < other.eval()