From 4680bb764edab459348fb851e0d62fe44f737b1a Mon Sep 17 00:00:00 2001 From: Glycogen W <109408857+XihWang@users.noreply.github.com> Date: Thu, 2 Jan 2025 14:58:32 +0800 Subject: [PATCH] Update transformer.md MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 在新版d2l代码中已经改用LazyLinear,这种线性层不用输入的通道数,具体可以见d2l源码网址:https://github.com/d2l-ai/d2l-en/blob/master/d2l/torch.py#L1149 若不更改,会报错参数过多 --- chapter_attention-mechanisms/transformer.md | 83 ++++++++++----------- 1 file changed, 40 insertions(+), 43 deletions(-) diff --git a/chapter_attention-mechanisms/transformer.md b/chapter_attention-mechanisms/transformer.md index 9c0411669..3d31a73c1 100644 --- a/chapter_attention-mechanisms/transformer.md +++ b/chapter_attention-mechanisms/transformer.md @@ -297,12 +297,12 @@ class EncoderBlock(nn.Block): #@save class EncoderBlock(nn.Module): """Transformer编码器块""" - def __init__(self, key_size, query_size, value_size, num_hiddens, + def __init__(self, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, dropout, use_bias=False, **kwargs): super(EncoderBlock, self).__init__(**kwargs) self.attention = d2l.MultiHeadAttention( - key_size, query_size, value_size, num_hiddens, num_heads, dropout, + num_hiddens, num_heads, dropout, use_bias) self.addnorm1 = AddNorm(norm_shape, dropout) self.ffn = PositionWiseFFN( @@ -319,10 +319,10 @@ class EncoderBlock(nn.Module): #@save class EncoderBlock(tf.keras.layers.Layer): """Transformer编码器块""" - def __init__(self, key_size, query_size, value_size, num_hiddens, + def __init__(self, num_hiddens, norm_shape, ffn_num_hiddens, num_heads, dropout, bias=False, **kwargs): super().__init__(**kwargs) - self.attention = d2l.MultiHeadAttention(key_size, query_size, value_size, num_hiddens, + self.attention = d2l.MultiHeadAttention( num_hiddens, num_heads, dropout, bias) self.addnorm1 = AddNorm(norm_shape, dropout) self.ffn = PositionWiseFFN(ffn_num_hiddens, num_hiddens) @@ -338,12 +338,12 @@ class EncoderBlock(tf.keras.layers.Layer): #@save class EncoderBlock(nn.Layer): """transformer编码器块""" - def __init__(self, key_size, query_size, value_size, num_hiddens, + def __init__(self, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, dropout, use_bias=False, **kwargs): super(EncoderBlock, self).__init__(**kwargs) self.attention = d2l.MultiHeadAttention( - key_size, query_size, value_size, num_hiddens, num_heads, dropout, + num_hiddens, num_heads, dropout, use_bias) self.addnorm1 = AddNorm(norm_shape, dropout) self.ffn = PositionWiseFFN( @@ -369,7 +369,7 @@ encoder_blk(X, valid_lens).shape #@tab pytorch, paddle X = d2l.ones((2, 100, 24)) valid_lens = d2l.tensor([3, 2]) -encoder_blk = EncoderBlock(24, 24, 24, 24, [100, 24], 24, 48, 8, 0.5) +encoder_blk = EncoderBlock( 24, [100, 24], 24, 48, 8, 0.5) encoder_blk.eval() encoder_blk(X, valid_lens).shape ``` @@ -379,7 +379,7 @@ encoder_blk(X, valid_lens).shape X = tf.ones((2, 100, 24)) valid_lens = tf.constant([3, 2]) norm_shape = [i for i in range(len(X.shape))][1:] -encoder_blk = EncoderBlock(24, 24, 24, 24, norm_shape, 48, 8, 0.5) +encoder_blk = EncoderBlock( 24, norm_shape, 48, 8, 0.5) encoder_blk(X, valid_lens, training=False).shape ``` @@ -419,7 +419,7 @@ class TransformerEncoder(d2l.Encoder): #@save class TransformerEncoder(d2l.Encoder): """Transformer编码器""" - def __init__(self, vocab_size, key_size, query_size, value_size, + def __init__(self, vocab_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout, use_bias=False, **kwargs): super(TransformerEncoder, self).__init__(**kwargs) @@ -429,7 +429,7 @@ class TransformerEncoder(d2l.Encoder): self.blks = nn.Sequential() for i in range(num_layers): self.blks.add_module("block"+str(i), - EncoderBlock(key_size, query_size, value_size, num_hiddens, + EncoderBlock(num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, dropout, use_bias)) @@ -451,7 +451,7 @@ class TransformerEncoder(d2l.Encoder): #@save class TransformerEncoder(d2l.Encoder): """Transformer编码器""" - def __init__(self, vocab_size, key_size, query_size, value_size, + def __init__(self, vocab_size, num_hiddens, norm_shape, ffn_num_hiddens, num_heads, num_layers, dropout, bias=False, **kwargs): super().__init__(**kwargs) @@ -459,7 +459,7 @@ class TransformerEncoder(d2l.Encoder): self.embedding = tf.keras.layers.Embedding(vocab_size, num_hiddens) self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout) self.blks = [EncoderBlock( - key_size, query_size, value_size, num_hiddens, norm_shape, + num_hiddens, norm_shape, ffn_num_hiddens, num_heads, dropout, bias) for _ in range( num_layers)] @@ -482,7 +482,7 @@ class TransformerEncoder(d2l.Encoder): #@save class TransformerEncoder(d2l.Encoder): """transformer编码器""" - def __init__(self, vocab_size, key_size, query_size, value_size, + def __init__(self, vocab_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout, use_bias=False, **kwargs): super(TransformerEncoder, self).__init__(**kwargs) @@ -492,7 +492,7 @@ class TransformerEncoder(d2l.Encoder): self.blks = nn.Sequential() for i in range(num_layers): self.blks.add_sublayer(str(i), - EncoderBlock(key_size, query_size, value_size, num_hiddens, + EncoderBlock(num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, dropout, use_bias)) @@ -521,21 +521,21 @@ encoder(np.ones((2, 100)), valid_lens).shape ```{.python .input} #@tab pytorch encoder = TransformerEncoder( - 200, 24, 24, 24, 24, [100, 24], 24, 48, 8, 2, 0.5) + 200, 24, [100, 24], 24, 48, 8, 2, 0.5) encoder.eval() encoder(d2l.ones((2, 100), dtype=torch.long), valid_lens).shape ``` ```{.python .input} #@tab tensorflow -encoder = TransformerEncoder(200, 24, 24, 24, 24, [1, 2], 48, 8, 2, 0.5) +encoder = TransformerEncoder(200, 24, [1, 2], 48, 8, 2, 0.5) encoder(tf.ones((2, 100)), valid_lens, training=False).shape ``` ```{.python .input} #@tab paddle encoder = TransformerEncoder( - 200, 24, 24, 24, 24, [100, 24], 24, 48, 8, 2, 0.5) + 200, 24, [100, 24], 24, 48, 8, 2, 0.5) encoder.eval() encoder(d2l.ones((2, 100), dtype=paddle.int64), valid_lens).shape ``` @@ -597,16 +597,16 @@ class DecoderBlock(nn.Block): #@tab pytorch class DecoderBlock(nn.Module): """解码器中第i个块""" - def __init__(self, key_size, query_size, value_size, num_hiddens, + def __init__(self, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, dropout, i, **kwargs): super(DecoderBlock, self).__init__(**kwargs) self.i = i self.attention1 = d2l.MultiHeadAttention( - key_size, query_size, value_size, num_hiddens, num_heads, dropout) + num_hiddens, num_heads, dropout) self.addnorm1 = AddNorm(norm_shape, dropout) self.attention2 = d2l.MultiHeadAttention( - key_size, query_size, value_size, num_hiddens, num_heads, dropout) + num_hiddens, num_heads, dropout) self.addnorm2 = AddNorm(norm_shape, dropout) self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens, num_hiddens) @@ -646,13 +646,13 @@ class DecoderBlock(nn.Module): #@tab tensorflow class DecoderBlock(tf.keras.layers.Layer): """解码器中第i个块""" - def __init__(self, key_size, query_size, value_size, num_hiddens, + def __init__(self, num_hiddens, norm_shape, ffn_num_hiddens, num_heads, dropout, i, **kwargs): super().__init__(**kwargs) self.i = i - self.attention1 = d2l.MultiHeadAttention(key_size, query_size, value_size, num_hiddens, num_heads, dropout) + self.attention1 = d2l.MultiHeadAttention(num_hiddens, num_heads, dropout) self.addnorm1 = AddNorm(norm_shape, dropout) - self.attention2 = d2l.MultiHeadAttention(key_size, query_size, value_size, num_hiddens, num_heads, dropout) + self.attention2 = d2l.MultiHeadAttention(num_hiddens, num_heads, dropout) self.addnorm2 = AddNorm(norm_shape, dropout) self.ffn = PositionWiseFFN(ffn_num_hiddens, num_hiddens) self.addnorm3 = AddNorm(norm_shape, dropout) @@ -692,16 +692,16 @@ class DecoderBlock(tf.keras.layers.Layer): #@tab paddle class DecoderBlock(nn.Layer): """解码器中第i个块""" - def __init__(self, key_size, query_size, value_size, num_hiddens, + def __init__(self, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, dropout, i, **kwargs): super(DecoderBlock, self).__init__(**kwargs) self.i = i self.attention1 = d2l.MultiHeadAttention( - key_size, query_size, value_size, num_hiddens, num_heads, dropout) + num_hiddens, num_heads, dropout) self.addnorm1 = AddNorm(norm_shape, dropout) self.attention2 = d2l.MultiHeadAttention( - key_size, query_size, value_size, num_hiddens, num_heads, dropout) + num_hiddens, num_heads, dropout) self.addnorm2 = AddNorm(norm_shape, dropout) self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens, num_hiddens) @@ -749,7 +749,7 @@ decoder_blk(X, state)[0].shape ```{.python .input} #@tab pytorch, paddle -decoder_blk = DecoderBlock(24, 24, 24, 24, [100, 24], 24, 48, 8, 0.5, 0) +decoder_blk = DecoderBlock(24, [100, 24], 24, 48, 8, 0.5, 0) decoder_blk.eval() X = d2l.ones((2, 100, 24)) state = [encoder_blk(X, valid_lens), valid_lens, [None]] @@ -758,7 +758,7 @@ decoder_blk(X, state)[0].shape ```{.python .input} #@tab tensorflow -decoder_blk = DecoderBlock(24, 24, 24, 24, [1, 2], 48, 8, 0.5, 0) +decoder_blk = DecoderBlock(24, [1, 2], 48, 8, 0.5, 0) X = tf.ones((2, 100, 24)) state = [encoder_blk(X, valid_lens), valid_lens, [None]] decoder_blk(X, state, training=False)[0].shape @@ -806,7 +806,7 @@ class TransformerDecoder(d2l.AttentionDecoder): ```{.python .input} #@tab pytorch class TransformerDecoder(d2l.AttentionDecoder): - def __init__(self, vocab_size, key_size, query_size, value_size, + def __init__(self, vocab_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout, **kwargs): super(TransformerDecoder, self).__init__(**kwargs) @@ -817,7 +817,7 @@ class TransformerDecoder(d2l.AttentionDecoder): self.blks = nn.Sequential() for i in range(num_layers): self.blks.add_module("block"+str(i), - DecoderBlock(key_size, query_size, value_size, num_hiddens, + DecoderBlock(num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, dropout, i)) self.dense = nn.Linear(num_hiddens, vocab_size) @@ -846,14 +846,14 @@ class TransformerDecoder(d2l.AttentionDecoder): ```{.python .input} #@tab tensorflow class TransformerDecoder(d2l.AttentionDecoder): - def __init__(self, vocab_size, key_size, query_size, value_size, + def __init__(self, vocab_size, num_hiddens, norm_shape, ffn_num_hidens, num_heads, num_layers, dropout, **kwargs): super().__init__(**kwargs) self.num_hiddens = num_hiddens self.num_layers = num_layers self.embedding = tf.keras.layers.Embedding(vocab_size, num_hiddens) self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout) - self.blks = [DecoderBlock(key_size, query_size, value_size, num_hiddens, norm_shape, + self.blks = [DecoderBlock(num_hiddens, norm_shape, ffn_num_hiddens, num_heads, dropout, i) for i in range(num_layers)] self.dense = tf.keras.layers.Dense(vocab_size) @@ -879,7 +879,7 @@ class TransformerDecoder(d2l.AttentionDecoder): ```{.python .input} #@tab paddle class TransformerDecoder(d2l.AttentionDecoder): - def __init__(self, vocab_size, key_size, query_size, value_size, + def __init__(self, vocab_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout, **kwargs): super(TransformerDecoder, self).__init__(**kwargs) @@ -890,7 +890,7 @@ class TransformerDecoder(d2l.AttentionDecoder): self.blks = nn.Sequential() for i in range(num_layers): self.blks.add_sublayer(str(i), - DecoderBlock(key_size, query_size, value_size, num_hiddens, + DecoderBlock(num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, dropout, i)) self.dense = nn.Linear(num_hiddens, vocab_size) @@ -942,17 +942,16 @@ d2l.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device) num_hiddens, num_layers, dropout, batch_size, num_steps = 32, 2, 0.1, 64, 10 lr, num_epochs, device = 0.005, 200, d2l.try_gpu() ffn_num_input, ffn_num_hiddens, num_heads = 32, 64, 4 -key_size, query_size, value_size = 32, 32, 32 norm_shape = [32] train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps) encoder = TransformerEncoder( - len(src_vocab), key_size, query_size, value_size, num_hiddens, + len(src_vocab), num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout) decoder = TransformerDecoder( - len(tgt_vocab), key_size, query_size, value_size, num_hiddens, + len(tgt_vocab), num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout) net = d2l.EncoderDecoder(encoder, decoder) @@ -964,15 +963,14 @@ d2l.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device) num_hiddens, num_layers, dropout, batch_size, num_steps = 32, 2, 0.1, 64, 10 lr, num_epochs, device = 0.005, 200, d2l.try_gpu() ffn_num_hiddens, num_heads = 64, 4 -key_size, query_size, value_size = 32, 32, 32 norm_shape = [2] train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps) encoder = TransformerEncoder( - len(src_vocab), key_size, query_size, value_size, num_hiddens, norm_shape, + len(src_vocab), num_hiddens, norm_shape, ffn_num_hiddens, num_heads, num_layers, dropout) decoder = TransformerDecoder( - len(tgt_vocab), key_size, query_size, value_size, num_hiddens, norm_shape, + len(tgt_vocab), num_hiddens, norm_shape, ffn_num_hiddens, num_heads, num_layers, dropout) net = d2l.EncoderDecoder(encoder, decoder) d2l.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device) @@ -983,17 +981,16 @@ d2l.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device) num_hiddens, num_layers, dropout, batch_size, num_steps = 32, 2, 0.1, 64, 10 lr, num_epochs, device = 0.005, 200, d2l.try_gpu() ffn_num_input, ffn_num_hiddens, num_heads = 32, 64, 4 -key_size, query_size, value_size = 32, 32, 32 norm_shape = [32] train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps) encoder = TransformerEncoder( - len(src_vocab), key_size, query_size, value_size, num_hiddens, + len(src_vocab), num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout) decoder = TransformerDecoder( - len(tgt_vocab), key_size, query_size, value_size, num_hiddens, + len(tgt_vocab), num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout) net = d2l.EncoderDecoder(encoder, decoder)