-
Notifications
You must be signed in to change notification settings - Fork 19
/
Copy pathmodels.py
126 lines (102 loc) · 4.92 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from embedding import *
from collections import OrderedDict
import torch
class RelationMetaLearner(nn.Module):
def __init__(self, few, embed_size=100, num_hidden1=500, num_hidden2=200, out_size=100, dropout_p=0.5):
super(RelationMetaLearner, self).__init__()
self.embed_size = embed_size
self.few = few
self.out_size = out_size
self.rel_fc1 = nn.Sequential(OrderedDict([
('fc', nn.Linear(2*embed_size, num_hidden1)),
('bn', nn.BatchNorm1d(few)),
('relu', nn.LeakyReLU()),
('drop', nn.Dropout(p=dropout_p)),
]))
self.rel_fc2 = nn.Sequential(OrderedDict([
('fc', nn.Linear(num_hidden1, num_hidden2)),
('bn', nn.BatchNorm1d(few)),
('relu', nn.LeakyReLU()),
('drop', nn.Dropout(p=dropout_p)),
]))
self.rel_fc3 = nn.Sequential(OrderedDict([
('fc', nn.Linear(num_hidden2, out_size)),
('bn', nn.BatchNorm1d(few)),
]))
nn.init.xavier_normal_(self.rel_fc1.fc.weight)
nn.init.xavier_normal_(self.rel_fc2.fc.weight)
nn.init.xavier_normal_(self.rel_fc3.fc.weight)
def forward(self, inputs):
size = inputs.shape
x = inputs.contiguous().view(size[0], size[1], -1)
x = self.rel_fc1(x)
x = self.rel_fc2(x)
x = self.rel_fc3(x)
x = torch.mean(x, 1)
return x.view(size[0], 1, 1, self.out_size)
class EmbeddingLearner(nn.Module):
def __init__(self):
super(EmbeddingLearner, self).__init__()
def forward(self, h, t, r, pos_num):
score = -torch.norm(h + r - t, 2, -1).squeeze(2)
p_score = score[:, :pos_num]
n_score = score[:, pos_num:]
return p_score, n_score
class MetaR(nn.Module):
def __init__(self, dataset, parameter):
super(MetaR, self).__init__()
self.device = parameter['device']
self.beta = parameter['beta']
self.dropout_p = parameter['dropout_p']
self.embed_dim = parameter['embed_dim']
self.margin = parameter['margin']
self.abla = parameter['ablation']
self.embedding = Embedding(dataset, parameter)
if parameter['dataset'] == 'Wiki-One':
self.relation_learner = RelationMetaLearner(parameter['few'], embed_size=50, num_hidden1=250,
num_hidden2=100, out_size=50, dropout_p=self.dropout_p)
elif parameter['dataset'] == 'NELL-One':
self.relation_learner = RelationMetaLearner(parameter['few'], embed_size=100, num_hidden1=500,
num_hidden2=200, out_size=100, dropout_p=self.dropout_p)
self.embedding_learner = EmbeddingLearner()
self.loss_func = nn.MarginRankingLoss(self.margin)
self.rel_q_sharing = dict()
def split_concat(self, positive, negative):
pos_neg_e1 = torch.cat([positive[:, :, 0, :],
negative[:, :, 0, :]], 1).unsqueeze(2)
pos_neg_e2 = torch.cat([positive[:, :, 1, :],
negative[:, :, 1, :]], 1).unsqueeze(2)
return pos_neg_e1, pos_neg_e2
def forward(self, task, iseval=False, curr_rel=''):
# transfer task string into embedding
support, support_negative, query, negative = [self.embedding(t) for t in task]
few = support.shape[1] # num of few
num_sn = support_negative.shape[1] # num of support negative
num_q = query.shape[1] # num of query
num_n = negative.shape[1] # num of query negative
rel = self.relation_learner(support)
rel.retain_grad()
# relation for support
rel_s = rel.expand(-1, few+num_sn, -1, -1)
# because in test and dev step, same relation uses same support,
# so it's no need to repeat the step of relation-meta learning
if iseval and curr_rel != '' and curr_rel in self.rel_q_sharing.keys():
rel_q = self.rel_q_sharing[curr_rel]
else:
if not self.abla:
# split on e1/e2 and concat on pos/neg
sup_neg_e1, sup_neg_e2 = self.split_concat(support, support_negative)
p_score, n_score = self.embedding_learner(sup_neg_e1, sup_neg_e2, rel_s, few)
y = torch.Tensor([1]).to(self.device)
self.zero_grad()
loss = self.loss_func(p_score, n_score, y)
loss.backward(retain_graph=True)
grad_meta = rel.grad
rel_q = rel - self.beta*grad_meta
else:
rel_q = rel
self.rel_q_sharing[curr_rel] = rel_q
rel_q = rel_q.expand(-1, num_q + num_n, -1, -1)
que_neg_e1, que_neg_e2 = self.split_concat(query, negative) # [bs, nq+nn, 1, es]
p_score, n_score = self.embedding_learner(que_neg_e1, que_neg_e2, rel_q, num_q)
return p_score, n_score