-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtransformer.py
115 lines (90 loc) · 3.82 KB
/
transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class CausalAttention(nn.Module):
def __init__(self, n_head, dim, max_seqlen=64):
super(CausalAttention, self).__init__()
assert dim % n_head == 0, f'{dim=}应该是{n_head=}的整数倍'
self.n_head = n_head
self.dim = dim
self.max_seqlen = max_seqlen
self.Wq = nn.Linear(self.dim, self.dim)
self.Wk = nn.Linear(self.dim, self.dim)
self.Wv = nn.Linear(self.dim, self.dim)
self.Wo = nn.Linear(self.dim, self.dim)
self.register_buffer("bias", torch.tril(torch.ones(self.max_seqlen+10, self.max_seqlen+10))
.view(1, 1, self.max_seqlen+10, self.max_seqlen+10))
def forward(self, x):
# [1, 16, 512]
B, T, C = x.shape
q = self.Wq(x)
k = self.Wk(x)
v = self.Wv(x)
q = q.reshape(B, T, self.n_head, C // self.n_head).transpose(1, 2) # -> [B, n_head, T, C // self.n_head]
k = k.reshape(B, T, self.n_head, C // self.n_head).transpose(1, 2)
v = v.reshape(B, T, self.n_head, C // self.n_head).transpose(1, 2)
att = (q @ k.transpose(-2, -1)) / math.sqrt(q.shape[-1])
# 进行casual mask
att = att.masked_fill(self.bias[:, :, T, T]==0, float('-inf'))
y = F.softmax(att, dim=-1) @ v
y = y.reshape(B, T, C)
return self.Wo(y)
class Block(nn.Module):
def __init__(self, n_head, dim, max_seqlen=64):
super(Block, self).__init__()
self.n_head = n_head
self.dim = dim
self.max_seqlen = max_seqlen
self.ln_1 = nn.LayerNorm(self.dim)
self.attn = CausalAttention(n_head=self.n_head, dim=self.dim, max_seqlen=self.max_seqlen)
self.ln_2 = nn.LayerNorm(self.dim)
self.mlp = nn.Sequential(
nn.Linear(self.dim, 4 * self.dim),
nn.GELU(),
nn.Linear(4 * self.dim, self.dim)
)
def forward(self, x):
x = x + self.attn(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
class Transformer(nn.Module):
def __init__(self, n_layer, out_dim, n_head, dim, max_seqlen=64):
super(Transformer, self).__init__()
self.n_head = n_head
self.dim = dim
self.max_seqlen = max_seqlen
self.n_layer = n_layer
self.out_dim = out_dim
self.transformer = nn.ModuleDict(dict(
wpe = nn.Embedding(self.max_seqlen, self.dim),
h = nn.ModuleList([Block(n_head=self.n_head, dim=self.dim, max_seqlen=self.max_seqlen) for _ in range(self.n_layer)]),
ln_f = nn.LayerNorm(self.dim)
))
self.lm_head = nn.Linear(self.dim, self.out_dim)
# 计算参数量,如太大请减少n_layer
n_params = sum(p.numel() for p in self.transformer.parameters())
print("Transformer 参数量: %.2fM" % (n_params/1e6,))
def forward(self, x):
B, T, C = x.shape
device = x.device
pos = torch.arange(0, T, dtype=torch.long, device=device).unsqueeze(0)
pos_emb = self.transformer.wpe(pos)
x = x + pos_emb
for block in self.transformer.h:
x = block(x)
x = self.transformer.ln_f(x)
return self.lm_head(x)
class RLTransformer(nn.Module):
def __init__(self, n_layer, state_dim, out_dim, n_head, dim, max_seqlen=64):
super(RLTransformer, self).__init__()
self.expand_layer = nn.Linear(state_dim, dim)
self.transformer = Transformer(n_layer=n_layer, out_dim=out_dim, n_head=n_head, dim=dim, max_seqlen=max_seqlen)
def forward(self, x):
x = x.unsqueeze(0)
# [batch, step, dim]
x = self.expand_layer(x)
x = self.transformer(x)
# 最后一个step的输出有全部前向信息
x = x.squeeze(0)
return x[-1, :]