Skip to content

Commit 832141a

Browse files
committed
example back to origin
1 parent 5819ce4 commit 832141a

File tree

7 files changed

+81
-140
lines changed

7 files changed

+81
-140
lines changed

example/layers/attention.py

Lines changed: 21 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,8 @@
11
from typing import Optional
22
import torch
33
import bmtrain as bmt
4-
from bmtrain.nn import (
5-
Linear,
6-
ColumnParallelLinear,
7-
RowParallelLinear,
8-
)
4+
from bmtrain.nn import Linear
95
import math
10-
from bmtrain.global_var import config
11-
from bmtrain.distributed import all_gather
126

137
class Attention(bmt.DistributedModule):
148
def __init__(self,
@@ -18,21 +12,14 @@ def __init__(self,
1812
) -> None:
1913
super().__init__()
2014

21-
if config['tp_size'] > 1:
22-
self.project_q = ColumnParallelLinear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype, gather_input=False)
23-
self.project_k = ColumnParallelLinear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype, gather_input=False)
24-
self.project_v = ColumnParallelLinear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype, gather_input=False)
25-
self.project_out = RowParallelLinear(dim_head * num_heads, dim_model, bias=bias, dtype=dtype)
26-
else:
27-
self.project_q = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype)
28-
self.project_k = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype)
29-
self.project_v = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype)
30-
self.project_out = Linear(dim_head * num_heads, dim_model, bias=bias, dtype=dtype)
15+
self.project_q = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype)
16+
self.project_k = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype)
17+
self.project_v = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype)
3118

19+
self.project_out = Linear(dim_head * num_heads, dim_model, bias=bias, dtype=dtype)
3220

3321
self.softmax = torch.nn.Softmax(dim=-1)
3422
self.num_heads = num_heads
35-
self.num_kv_heads = num_heads
3623
self.dim_head = dim_head
3724
self.dim_model = dim_model
3825

@@ -45,50 +32,32 @@ def forward(self,
4532
batch_size, seq_q, dim_model = hidden_q.size()
4633
seq_kv = hidden_kv.size(1)
4734

48-
if isinstance(self.project_q, ColumnParallelLinear):
49-
assert hidden_q.data_ptr() == hidden_kv.data_ptr()
50-
hidden_q = bmt.nn.OpParallelLinear.apply(
51-
hidden_q,
52-
torch.cat([self.project_q.weight, self.project_k.weight, self.project_v.weight], dim=0),
53-
torch.cat([self.project_q.bias, self.project_k.bias, self.project_v.bias], dim=0) if self.project_q.bias is not None else None,
54-
True, False,
55-
False, None
56-
)
57-
h_q, h_k, h_v = hidden_q.chunk(3, dim=-1)
58-
else:
59-
h_q : torch.Tensor = self.project_q(hidden_q)
60-
h_k : torch.Tensor = self.project_k(hidden_q)
61-
h_v : torch.Tensor = self.project_v(hidden_q)
62-
if config['tp_size'] > 1:
63-
#batch_size will changed in TensorParallel
64-
batch_size = h_v.shape[0]
65-
66-
h_q = h_q.view(batch_size, seq_q, -1, self.dim_head)
67-
h_k = h_k.view(batch_size, seq_kv, -1, self.dim_head)
68-
h_v = h_v.view(batch_size, seq_kv, -1, self.dim_head)
35+
h_q : torch.Tensor = self.project_q(hidden_q)
36+
h_k : torch.Tensor = self.project_k(hidden_kv)
37+
h_v : torch.Tensor = self.project_v(hidden_kv)
38+
39+
h_q = h_q.view(batch_size, seq_q, self.num_heads, self.dim_head)
40+
h_k = h_k.view(batch_size, seq_kv, self.num_heads, self.dim_head)
41+
h_v = h_v.view(batch_size, seq_kv, self.num_heads, self.dim_head)
6942

7043
h_q = h_q.permute(0, 2, 1, 3).contiguous()
7144
h_k = h_k.permute(0, 2, 1, 3).contiguous()
7245
h_v = h_v.permute(0, 2, 1, 3).contiguous()
7346

74-
h_q = h_q.view(-1, seq_q, self.dim_head)
75-
h_k = h_k.view(-1, seq_kv, self.dim_head)
76-
h_v = h_v.view(-1, seq_kv, self.dim_head)
47+
h_q = h_q.view(batch_size * self.num_heads, seq_q, self.dim_head)
48+
h_k = h_k.view(batch_size * self.num_heads, seq_kv, self.dim_head)
49+
h_v = h_v.view(batch_size * self.num_heads, seq_kv, self.dim_head)
7750

7851
score = torch.bmm(
7952
h_q, h_k.transpose(1, 2)
8053
)
8154
score = score / math.sqrt(self.dim_head)
8255

83-
score = score.view(batch_size, -1, seq_q, seq_kv)
56+
score = score.view(batch_size, self.num_heads, seq_q, seq_kv)
8457

8558
if position_bias is not None:
86-
score = score + position_bias.view(batch_size, -1, seq_q, seq_kv)
87-
88-
if config['tp_size'] > 1:
89-
with torch.no_grad():
90-
mask = all_gather(mask, config['tp_comm']).flatten(0,1)
91-
59+
score = score + position_bias.view(batch_size, self.num_heads, seq_q, seq_kv)
60+
9261
score = torch.where(
9362
mask.view(batch_size, 1, seq_q, seq_kv),
9463
score,
@@ -101,14 +70,14 @@ def forward(self,
10170
torch.scalar_tensor(0, device=score.device, dtype=score.dtype)
10271
)
10372

104-
score = score.view(-1, seq_q, seq_kv)
73+
score = score.view(batch_size * self.num_heads, seq_q, seq_kv)
10574

10675
h_out = torch.bmm(
10776
score, h_v
10877
)
109-
h_out = h_out.view(batch_size, -1, seq_q, self.dim_head)
78+
h_out = h_out.view(batch_size, self.num_heads, seq_q, self.dim_head)
11079
h_out = h_out.permute(0, 2, 1, 3).contiguous()
111-
h_out = h_out.view(batch_size, seq_q, -1)
80+
h_out = h_out.view(batch_size, seq_q, self.num_heads * self.dim_head)
11281

11382
attn_out = self.project_out(h_out)
11483
return attn_out

example/layers/embedding.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,13 +77,11 @@ def from_pretrained(cls, embeddings, freeze=True, padding_idx=None,
7777

7878
def forward(self, input: torch.Tensor, projection : bool = False) -> torch.Tensor:
7979
if not projection:
80-
out = F.embedding(
80+
return F.embedding(
8181
input, self.weight, self.padding_idx, self.max_norm,
8282
self.norm_type, self.scale_grad_by_freq, self.sparse)
83-
return out
8483
else:
85-
out = F.linear(input, self.weight)
86-
return out
84+
return F.linear(input, self.weight) / math.sqrt(self.embedding_dim)
8785

8886
def extra_repr(self) -> str:
8987
s = '{num_embeddings}, {embedding_dim}'
@@ -99,4 +97,4 @@ def extra_repr(self) -> str:
9997
s += ', sparse=True'
10098
return s.format(**self.__dict__)
10199

102-
100+

example/layers/feedforward.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,16 @@
11
import torch
22
import bmtrain as bmt
3-
from bmtrain.nn import (
4-
Linear,
5-
ColumnParallelLinear,
6-
RowParallelLinear)
7-
from bmtrain.global_var import config
3+
from bmtrain.nn import Linear
84

95
class Feedforward(bmt.DistributedModule):
106
def __init__(self, dim_model : int, dim_ff : int, bias : bool = True, dtype = None) -> None:
117
super().__init__()
128

13-
if config['tp_size'] > 1:
14-
self.w_in = ColumnParallelLinear(dim_model, dim_ff, bias = bias, dtype=dtype)
15-
self.w_out = RowParallelLinear(dim_ff, dim_model, bias = bias, dtype=dtype)
16-
else:
17-
self.w_in = Linear(dim_model, dim_ff, bias=bias, dtype=dtype)
18-
self.w_out = Linear(dim_ff, dim_model, bias=bias, dtype=dtype)
9+
self.w_in = Linear(dim_model, dim_ff, bias = bias, dtype=dtype)
10+
self.w_out = Linear(dim_ff, dim_model, bias = bias, dtype=dtype)
1911

2012
self.relu = torch.nn.ReLU()
2113

2214
def forward(self, input : torch.Tensor) -> torch.Tensor:
15+
2316
return self.w_out(self.relu(self.w_in(input)))

example/layers/transformer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@ def forward(self,
2020
hidden : torch.Tensor, # (batch, seq_len, dim_model)
2121
mask : torch.BoolTensor, # (batch, seq_len, dim_model)
2222
position_bias : Optional[torch.Tensor] = None, # (batch, num_head, seq_len, seq_len)
23-
):
24-
# bmt.inspect.record_tensor(hidden, "hidden")
23+
):
24+
bmt.inspect.record_tensor(hidden, "hidden")
2525
x = self.ln_attn(hidden)
26-
x = self.attn(x, x, mask)
26+
x = self.attn(x, x, mask, position_bias)
2727
hidden = hidden + x
2828

2929
x = self.ln_ff(hidden)

example/models/gpt.py

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,28 @@
11
import torch
22
import bmtrain as bmt
33
from layers import TransformerEncoder, Layernorm, Embedding, TransformerEncoder
4-
from bmtrain.global_var import config
54

65
class GPT(bmt.DistributedModule):
76
def __init__(self,
87
num_layers : int, vocab_size : int,
98
dim_model : int, dim_head : int, num_heads : int, dim_ff : int,
109
max_distance : int,
11-
bias : bool = True, dtype = None, offload = False, offload_level = 0
10+
bias : bool = True, dtype = None
1211
) -> None:
1312
super().__init__()
1413

1514
self.max_distance = max_distance
1615

17-
if config['tp_size'] > 1:
18-
self.word_emb = bmt.nn.ParallelEmbedding(vocab_size, dim_model, dtype=dtype)
19-
else:
20-
self.word_emb = Embedding(vocab_size, dim_model, dtype=dtype)
16+
self.word_emb = Embedding(vocab_size, dim_model, dtype=dtype)
2117
self.pos_emb = Embedding(max_distance, dim_model, dtype=dtype)
22-
if offload:
23-
offload_mask = [True if i%4 == 0 else False for i in range(num_layers)]
24-
ckpt_mask = [not offload_mask[i] for i in range(num_layers)]
25-
offload_level = offload_level
26-
else:
27-
ckpt_mask = [ True for i in range(num_layers) ]
28-
offload_mask = [ False for i in range(num_layers) ]
18+
2919
self.transformers = bmt.TransformerBlockList([
3020
bmt.CheckpointBlock(
3121
TransformerEncoder(
3222
dim_model, dim_head, num_heads, dim_ff, bias, dtype
33-
),use_checkpoint=ckpt_mask[i],use_offload=offload_mask[i],offload_level=offload_level
23+
)
3424
)
35-
for i in range(num_layers)
25+
for _ in range(num_layers)
3626
])
3727

3828
self.layernorm = Layernorm(dim_model, dtype=dtype)
@@ -52,10 +42,7 @@ def forward(self,
5242
out = self.transformers(out, mask_2d, None)
5343
out = self.layernorm(out)
5444

55-
if config['tp_size'] > 1:
56-
logits = self.word_emb.projection(out)
57-
else:
58-
logits = self.word_emb(out, projection=True)
45+
logits = self.word_emb(out, projection=True)
5946
bmt.inspect.record_tensor(logits, "logits")
6047

61-
return logits
48+
return logits

example/run.sh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
1-
torchrun --nnodes=1 --nproc_per_node=4 --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=localhost $1
1+
export NCCL_P2P_DISABLE=1
2+
export CUDA_LAUNCH_BLOCKING=1
3+
torchrun --nnodes=1 --nproc_per_node=4 --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=localhost train.py

example/train.py

Lines changed: 40 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -3,32 +3,28 @@
33
from models import GPT
44
import time
55
from bmtrain import optim
6-
from bmtrain.global_var import config
76
from bmtrain import inspect
87

98
def main():
109
bmt.init_distributed(
1110
seed=0,
12-
tp_size=2,
11+
zero_level=2,
1312
)
14-
offload = True
15-
seq_len = 4096
16-
offload_level = 0
13+
1714
model = GPT(
18-
num_layers=24,
19-
vocab_size=80000,
20-
dim_model=1024,
21-
dim_head=64,
22-
num_heads=16,
23-
dim_ff=4096,
24-
max_distance=seq_len,
25-
bias=False,
26-
dtype=torch.half,
27-
offload=offload,
28-
offload_level=offload_level
15+
num_layers=8,
16+
vocab_size=10240,
17+
dim_model=2560,
18+
dim_head=80,
19+
num_heads=32,
20+
dim_ff=8192,
21+
max_distance=1024,
22+
bias=True,
23+
dtype=torch.half
2924
)
3025

3126
bmt.init_parameters(model)
27+
# print_inspect(model, "*")
3228

3329
bmt.print_rank("Model memory")
3430
bmt.print_rank(torch.cuda.memory_summary())
@@ -37,7 +33,10 @@ def main():
3733
# data
3834
# generate dummy data for each rank
3935
torch.manual_seed(1234)
40-
batch_size = 4
36+
37+
batch_size = 2
38+
seq_len = 512
39+
4140
for i in range(bmt.world_size()):
4241
sent = torch.randint(0, 10240, (batch_size, seq_len + 1))
4342
enc_length = torch.randint(128, seq_len, (batch_size,)).long().cuda()
@@ -53,11 +52,7 @@ def main():
5352
if i == bmt.rank():
5453
break
5554

56-
if config['tp_size'] > 1:
57-
loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100, parallel=True)
58-
else:
59-
loss_func = torch.nn.CrossEntropyLoss(ignore_index=-100)
60-
55+
loss_func = torch.nn.CrossEntropyLoss(ignore_index=-100)
6156
optimizer = optim.AdamOffloadOptimizer(model.parameters(), weight_decay=1e-2)
6257
lr_scheduler = bmt.lr_scheduler.Noam(optimizer, start_lr=1e-3, warmup_iter=40, end_iter=1000, num_iter=0)
6358

@@ -69,43 +64,40 @@ def main():
6964
avg_time_recorder = bmt.utils.AverageRecorder()
7065
avg_loss_recorder = bmt.utils.AverageRecorder()
7166

72-
for iteration in range(30):
67+
for iteration in range(1000):
7368
# load data
7469
st = time.time()
7570

76-
# with bmt.inspect.inspect_tensor() as inspector:
77-
pos = torch.arange(enc_input.size(1)).long().cuda().repeat(enc_input.size(0), 1)
78-
logits = model(
79-
enc_input,
80-
pos,
81-
pos < enc_length[:, None]
82-
)
83-
batch, seq_len, vocab_out_size = logits.size()
71+
with inspect.inspect_tensor() as inspector:
72+
pos = torch.arange(enc_input.size(1)).long().cuda().repeat(enc_input.size(0), 1)
73+
logits = model(
74+
enc_input,
75+
pos,
76+
pos < enc_length[:, None]
77+
)
78+
batch, seq_len, vocab_out_size = logits.size()
8479

85-
if config['tp_size'] > 1:
86-
loss = loss_func(logits.view(batch * seq_len, vocab_out_size), targets)
87-
else:
88-
loss = loss_func(logits.float().view(batch * seq_len, vocab_out_size), targets.view(batch * seq_len))
80+
loss = loss_func(logits.view(batch * seq_len, vocab_out_size), targets.view(batch * seq_len))
8981

9082
global_loss = bmt.sum_loss(loss).item()
9183

92-
optim_manager.zero_grad()
84+
optim_manager.zero_grad()
9385

94-
optim_manager.backward(loss)
86+
optim_manager.backward(loss)
9587

9688
# print inspected tensors in the forward & backward pass
9789
# print parameters of the model
98-
# if iteration % 100 == 0:
99-
# bmt.print_rank(
100-
# bmt.inspect.format_summary(
101-
# inspector.get_summary()
102-
# )
103-
# )
104-
# bmt.print_rank(
105-
# bmt.inspect.format_summary(
106-
# bmt.inspect.inspect_model(model, "*")
107-
# )
108-
# )
90+
if iteration % 100 == 0:
91+
bmt.print_rank(
92+
inspect.format_summary(
93+
inspector.get_summary()
94+
)
95+
)
96+
bmt.print_rank(
97+
inspect.format_summary(
98+
inspect.inspect_model(model, "*")
99+
)
100+
)
109101

110102
optim_manager.step()
111103

0 commit comments

Comments
 (0)