forked from bytedance/lightseq
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathls_transformer_export.py
233 lines (201 loc) · 6.75 KB
/
ls_transformer_export.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
"""
Export LightSeq Transformer models to protobuf format.
Refer to the `examples/training/custom` directory for more training details.
"""
import argparse
import time
import numpy as np
import torch
from transformers import BertTokenizer
from proto.transformer_pb2 import Transformer
from lightseq.training import (
export_ls_config,
export_ls_embedding,
export_ls_encoder,
export_ls_decoder,
LSTransformer,
)
import lightseq.inference as lsi
def _extract_weight(state_dict):
encoder_state_dict = {}
decoder_state_dict = {}
for k in state_dict:
if k.startswith("encoder."):
encoder_state_dict[k] = state_dict[k]
if k.startswith("decoder."):
decoder_state_dict[k] = state_dict[k]
return encoder_state_dict, decoder_state_dict
def export_other_weights(ls_infer_model, state_dict, vocab_size):
enc_norm_w = state_dict["encoder.layer_norm.weight"].flatten().tolist()
enc_norm_b = state_dict["encoder.layer_norm.bias"].flatten().tolist()
dec_norm_w = state_dict["decoder.layer_norm.weight"].flatten().tolist()
dec_norm_b = state_dict["decoder.layer_norm.bias"].flatten().tolist()
dec_shared_b = torch.zeros(vocab_size).flatten().tolist()
ls_infer_model.src_embedding.norm_scale[:] = enc_norm_w
ls_infer_model.src_embedding.norm_bias[:] = enc_norm_b
ls_infer_model.trg_embedding.norm_scale[:] = dec_norm_w
ls_infer_model.trg_embedding.norm_bias[:] = dec_norm_b
ls_infer_model.trg_embedding.shared_bias[:] = dec_shared_b
def export_pb(state_dict, pb_path, pad_id, start_id, end_id, config):
encoder_state_dict, decoder_state_dict = _extract_weight(state_dict)
ls_infer_model = Transformer()
export_ls_embedding(
ls_infer_model, encoder_state_dict, config.max_seq_len, config.hidden_size, True
)
export_ls_embedding(
ls_infer_model,
decoder_state_dict,
config.max_seq_len,
config.hidden_size,
False,
)
export_ls_encoder(
ls_infer_model,
encoder_state_dict,
config.hidden_size,
config.intermediate_size,
)
export_ls_decoder(
ls_infer_model,
decoder_state_dict,
config.hidden_size,
config.intermediate_size,
config.num_decoder_layer,
)
export_other_weights(ls_infer_model, state_dict, config.vocab_size)
export_ls_config(
ls_infer_model,
config.nhead,
pad_id,
start_id,
end_id,
config.num_encoder_layer,
config.num_decoder_layer,
beam_size=1,
)
with open(pb_path, "wb") as fout:
fout.write(ls_infer_model.SerializeToString())
def create_data():
# create Hugging Face tokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
vocab_size = tokenizer.vocab_size
pad_id = tokenizer.encode(
tokenizer.special_tokens_map["pad_token"], add_special_tokens=False
)[0]
start_id = tokenizer.encode(
tokenizer.special_tokens_map["cls_token"], add_special_tokens=False
)[0]
end_id = tokenizer.encode(
tokenizer.special_tokens_map["sep_token"], add_special_tokens=False
)[0]
# source text to id
src_text = [
"What is the fastest library in the world?",
"You are so pretty!",
"What do you love me for?",
"The sparrow outside the window hovering on the telephone pole.",
]
src_tokens = tokenizer.batch_encode_plus(
src_text, padding=True, return_tensors="pt"
)
src_tokens = src_tokens["input_ids"].to(torch.device("cuda:0"))
batch_size, src_seq_len = src_tokens.size(0), src_tokens.size(1)
# target text to id
trg_text = [
"I guess it must be LightSeq, because ByteDance is the fastest.",
"Thanks very much and you are pretty too.",
"Love your beauty, smart, virtuous and kind.",
"You said all this is very summery.",
]
trg_tokens = tokenizer.batch_encode_plus(
trg_text, padding=True, return_tensors="pt"
)
trg_tokens = trg_tokens["input_ids"].to(torch.device("cuda:0"))
trg_seq_len = trg_tokens.size(1)
# left shift 1 token as the target output
target = trg_tokens.clone()[:, 1:]
trg_tokens = trg_tokens[:, :-1]
return (
tokenizer,
src_text,
src_tokens,
trg_text,
trg_tokens,
target,
pad_id,
start_id,
end_id,
vocab_size,
batch_size,
src_seq_len,
trg_seq_len,
)
def create_config(vocab_size):
transformer_config = LSTransformer.get_config(
model="transformer-base",
max_batch_tokens=2048,
max_seq_len=512,
vocab_size=vocab_size,
padding_idx=0,
num_encoder_layer=6,
num_decoder_layer=6,
fp16=True,
local_rank=0,
)
return transformer_config
def ls_predict(ls_infer_model, src_tokens):
torch.cuda.synchronize()
start_time = time.perf_counter()
ls_output = ls_infer_model.infer(src_tokens)
torch.cuda.synchronize()
end_time = time.perf_counter()
return ls_output, end_time - start_time
def parse_args():
parser = argparse.ArgumentParser(description="export LightSeq checkpoint", usage="")
parser.add_argument(
"--model",
"-m",
type=str,
default="checkpoint_best.pt",
help="path of LightSeq checkpoint",
)
args = parser.parse_args()
return args
if __name__ == "__main__":
(
tokenizer,
src_text,
src_tokens,
trg_text,
trg_tokens,
target,
pad_id,
start_id,
end_id,
vocab_size,
batch_size,
src_seq_len,
trg_seq_len,
) = create_data()
args = parse_args()
model_name = ".".join(args.model.split(".")[:-1])
pb_path = f"{model_name}.pb"
with open(args.model, "rb") as fin:
state_dict = torch.load(fin, map_location=torch.device("cpu"))
config = create_config(vocab_size)
export_pb(state_dict, pb_path, pad_id, start_id, end_id, config)
ls_infer_model = lsi.Transformer(pb_path, 8)
src_tokens_np = np.array(src_tokens.cpu())
print("========================WARM UP========================")
ls_predict(ls_infer_model, src_tokens_np)
print("========================LIGHTSEQ TEST========================")
ls_output, ls_time = ls_predict(ls_infer_model, src_tokens_np)
ls_output = [ids[0] for ids in ls_output[0]]
ls_predict_text = tokenizer.batch_decode(ls_output, skip_special_tokens=True)
print(">>>>> source text")
print("\n".join(src_text))
print(">>>>> target text")
print("\n".join(trg_text))
print(">>>>> lightseq (infer) predict text")
print("\n".join(ls_predict_text))
print("lightseq (infer) predict time: {}ms".format(ls_time * 1000))