Skip to content

Commit 8bd6475

Browse files
committed
delete test file
1 parent 832141a commit 8bd6475

File tree

9 files changed

+85
-1137
lines changed

9 files changed

+85
-1137
lines changed

example/layers/attention.py

Lines changed: 48 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
from typing import Optional
22
import torch
33
import bmtrain as bmt
4-
from bmtrain.nn import Linear
4+
from bmtrain.nn import (
5+
Linear,
6+
ColumnParallelLinear,
7+
RowParallelLinear,
8+
)
59
import math
10+
from bmtrain.global_var import config
11+
from bmtrain.distributed import all_gather
612

713
class Attention(bmt.DistributedModule):
814
def __init__(self,
@@ -12,11 +18,17 @@ def __init__(self,
1218
) -> None:
1319
super().__init__()
1420

15-
self.project_q = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype)
16-
self.project_k = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype)
17-
self.project_v = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype)
21+
if config['tp_size'] > 1:
22+
self.project_q = ColumnParallelLinear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype, gather_input=False)
23+
self.project_k = ColumnParallelLinear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype, gather_input=False)
24+
self.project_v = ColumnParallelLinear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype, gather_input=False)
25+
self.project_out = RowParallelLinear(dim_head * num_heads, dim_model, bias=bias, dtype=dtype)
26+
else:
27+
self.project_q = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype)
28+
self.project_k = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype)
29+
self.project_v = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype)
30+
self.project_out = Linear(dim_head * num_heads, dim_model, bias=bias, dtype=dtype)
1831

19-
self.project_out = Linear(dim_head * num_heads, dim_model, bias=bias, dtype=dtype)
2032

2133
self.softmax = torch.nn.Softmax(dim=-1)
2234
self.num_heads = num_heads
@@ -32,32 +44,48 @@ def forward(self,
3244
batch_size, seq_q, dim_model = hidden_q.size()
3345
seq_kv = hidden_kv.size(1)
3446

35-
h_q : torch.Tensor = self.project_q(hidden_q)
36-
h_k : torch.Tensor = self.project_k(hidden_kv)
37-
h_v : torch.Tensor = self.project_v(hidden_kv)
47+
assert hidden_q.data_ptr() == hidden_kv.data_ptr()
3848

39-
h_q = h_q.view(batch_size, seq_q, self.num_heads, self.dim_head)
40-
h_k = h_k.view(batch_size, seq_kv, self.num_heads, self.dim_head)
41-
h_v = h_v.view(batch_size, seq_kv, self.num_heads, self.dim_head)
49+
hidden_q = bmt.nn.OpParallelLinear.apply(
50+
hidden_q,
51+
torch.cat([self.project_q.weight, self.project_k.weight, self.project_v.weight], dim=0),
52+
torch.cat([self.project_q.bias, self.project_k.bias, self.project_v.bias], dim=0),
53+
True, False,
54+
False, None
55+
)
56+
57+
h_q, h_k, h_v = hidden_q.chunk(3, dim=-1)
58+
59+
if config['tp_size'] > 1:
60+
#batch_size will changed in TensorParallel
61+
batch_size = h_v.shape[0]
62+
63+
h_q = h_q.view(batch_size, seq_q, -1, self.dim_head)
64+
h_k = h_k.view(batch_size, seq_kv, -1, self.dim_head)
65+
h_v = h_v.view(batch_size, seq_kv, -1, self.dim_head)
4266

4367
h_q = h_q.permute(0, 2, 1, 3).contiguous()
4468
h_k = h_k.permute(0, 2, 1, 3).contiguous()
4569
h_v = h_v.permute(0, 2, 1, 3).contiguous()
4670

47-
h_q = h_q.view(batch_size * self.num_heads, seq_q, self.dim_head)
48-
h_k = h_k.view(batch_size * self.num_heads, seq_kv, self.dim_head)
49-
h_v = h_v.view(batch_size * self.num_heads, seq_kv, self.dim_head)
71+
h_q = h_q.view(-1, seq_q, self.dim_head)
72+
h_k = h_k.view(-1, seq_kv, self.dim_head)
73+
h_v = h_v.view(-1, seq_kv, self.dim_head)
5074

5175
score = torch.bmm(
5276
h_q, h_k.transpose(1, 2)
5377
)
5478
score = score / math.sqrt(self.dim_head)
5579

56-
score = score.view(batch_size, self.num_heads, seq_q, seq_kv)
80+
score = score.view(batch_size, -1, seq_q, seq_kv)
5781

5882
if position_bias is not None:
59-
score = score + position_bias.view(batch_size, self.num_heads, seq_q, seq_kv)
60-
83+
score = score + position_bias.view(batch_size, -1, seq_q, seq_kv)
84+
85+
if config['tp_size'] > 1:
86+
with torch.no_grad():
87+
mask = all_gather(mask, config['tp_comm']).flatten(0,1)
88+
6189
score = torch.where(
6290
mask.view(batch_size, 1, seq_q, seq_kv),
6391
score,
@@ -70,14 +98,14 @@ def forward(self,
7098
torch.scalar_tensor(0, device=score.device, dtype=score.dtype)
7199
)
72100

73-
score = score.view(batch_size * self.num_heads, seq_q, seq_kv)
101+
score = score.view(-1, seq_q, seq_kv)
74102

75103
h_out = torch.bmm(
76104
score, h_v
77105
)
78-
h_out = h_out.view(batch_size, self.num_heads, seq_q, self.dim_head)
106+
h_out = h_out.view(batch_size, -1, seq_q, self.dim_head)
79107
h_out = h_out.permute(0, 2, 1, 3).contiguous()
80-
h_out = h_out.view(batch_size, seq_q, self.num_heads * self.dim_head)
108+
h_out = h_out.view(batch_size, seq_q, -1)
81109

82110
attn_out = self.project_out(h_out)
83111
return attn_out

example/layers/embedding.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,13 @@ def from_pretrained(cls, embeddings, freeze=True, padding_idx=None,
7777

7878
def forward(self, input: torch.Tensor, projection : bool = False) -> torch.Tensor:
7979
if not projection:
80-
return F.embedding(
80+
out = F.embedding(
8181
input, self.weight, self.padding_idx, self.max_norm,
8282
self.norm_type, self.scale_grad_by_freq, self.sparse)
83+
return out
8384
else:
84-
return F.linear(input, self.weight) / math.sqrt(self.embedding_dim)
85+
out = F.linear(input, self.weight)
86+
return out
8587

8688
def extra_repr(self) -> str:
8789
s = '{num_embeddings}, {embedding_dim}'
@@ -97,4 +99,4 @@ def extra_repr(self) -> str:
9799
s += ', sparse=True'
98100
return s.format(**self.__dict__)
99101

100-
102+

example/layers/feedforward.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,23 @@
11
import torch
22
import bmtrain as bmt
3-
from bmtrain.nn import Linear
3+
from bmtrain.nn import (
4+
Linear,
5+
ColumnParallelLinear,
6+
RowParallelLinear)
7+
from bmtrain.global_var import config
48

59
class Feedforward(bmt.DistributedModule):
610
def __init__(self, dim_model : int, dim_ff : int, bias : bool = True, dtype = None) -> None:
711
super().__init__()
812

9-
self.w_in = Linear(dim_model, dim_ff, bias = bias, dtype=dtype)
10-
self.w_out = Linear(dim_ff, dim_model, bias = bias, dtype=dtype)
13+
if config['tp_size'] > 1:
14+
self.w_in = ColumnParallelLinear(dim_model, dim_ff, bias = bias, dtype=dtype)
15+
self.w_out = RowParallelLinear(dim_ff, dim_model, bias = bias, dtype=dtype)
16+
else:
17+
self.w_in = Linear(dim_model, dim_ff, bias=bias, dtype=dtype)
18+
self.w_out = Linear(dim_ff, dim_model, bias=bias, dtype=dtype)
1119

1220
self.relu = torch.nn.ReLU()
1321

1422
def forward(self, input : torch.Tensor) -> torch.Tensor:
15-
1623
return self.w_out(self.relu(self.w_in(input)))

0 commit comments

Comments
 (0)