diff --git a/d2l/torch.py b/d2l/torch.py index 77e5f3577..79c1681c8 100644 --- a/d2l/torch.py +++ b/d2l/torch.py @@ -2276,17 +2276,6 @@ def forward(self, X, pred_positions): mlm_Y_hat = self.mlp(masked_X) return mlm_Y_hat -class NextSentencePred(nn.Module): - """BERT的下一句预测任务 - - Defined in :numref:`subsec_mlm`""" - def __init__(self, num_inputs, **kwargs): - super(NextSentencePred, self).__init__(**kwargs) - self.output = nn.Linear(num_inputs, 2) - - def forward(self, X): - # X的形状:(batchsize,num_hiddens) - return self.output(X) class BERTModel(nn.Module): """BERT模型 @@ -2295,17 +2284,18 @@ class BERTModel(nn.Module): def __init__(self, vocab_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout, max_len=1000, key_size=768, query_size=768, value_size=768, - hid_in_features=768, mlm_in_features=768, - nsp_in_features=768): + hid_in_features=768, mlm_in_features=768 + ): super(BERTModel, self).__init__() self.encoder = BERTEncoder(vocab_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout, max_len=max_len, key_size=key_size, query_size=query_size, value_size=value_size) self.hidden = nn.Sequential(nn.Linear(hid_in_features, num_hiddens), - nn.Tanh()) + nn.Tanh(), + nn.Linear(num_hiddens, 2)) self.mlm = MaskLM(vocab_size, num_hiddens, mlm_in_features) - self.nsp = NextSentencePred(nsp_in_features) + def forward(self, tokens, segments, valid_lens=None, pred_positions=None): @@ -2315,7 +2305,7 @@ def forward(self, tokens, segments, valid_lens=None, else: mlm_Y_hat = None # 用于下一句预测的多层感知机分类器的隐藏层,0是“”标记的索引 - nsp_Y_hat = self.nsp(self.hidden(encoded_X[:, 0, :])) + nsp_Y_hat = self.hidden(encoded_X[:, 0, :]) return encoded_X, mlm_Y_hat, nsp_Y_hat d2l.DATA_HUB['wikitext-2'] = (