Skip to content

Commit 662494f

Browse files
authored
Add files via upload
1 parent 8d64bcd commit 662494f

File tree

1 file changed

+61
-0
lines changed

1 file changed

+61
-0
lines changed

examples/test_simroformer.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
jax torch tensorflow测试通过
4+
理论上roformerV2也能用,但我没roformerV2的unilm模型,所以没测
5+
"""
6+
import os
7+
os.environ["KERAS_BACKEND"] = "jax"
8+
9+
from bert4keras3.backend import keras, K
10+
from bert4keras3.models import *
11+
from bert4keras3.tokenizers import Tokenizer
12+
from bert4keras3.snippets import sequence_padding
13+
import numpy as np
14+
15+
base_path = 'D:\ea下载\chinese_roformer-sim-char_L-12_H-768_A-12/'
16+
config_path = base_path+'bert_config.json'
17+
checkpoint_path = base_path+'bert_model.ckpt'
18+
dict_path = base_path+'vocab.txt'
19+
20+
tokenizer = Tokenizer(dict_path, do_lower_case=True) # 建立分词器
21+
end_token=tokenizer._token_end_id
22+
# 建立加载模型
23+
self = bert = build_transformer_model(
24+
config_path,
25+
checkpoint_path,
26+
model='roformer',
27+
application='unilm',
28+
with_mlm=True,
29+
return_keras_model=False,
30+
)
31+
#用没有cache的模型做greedy search
32+
tokens,segments = tokenizer.encode('广东省的省会是广州')
33+
l = len(tokens)
34+
tokens = tokens+[tokenizer._token_start_id]
35+
segments = segments+[1]
36+
#search
37+
while tokens[-1]!=end_token:
38+
inputs = [np.expand_dims(tokens,0),np.expand_dims(segments,0)]
39+
pred=bert.model.predict(inputs,verbose=3)
40+
pred = pred.argmax(-1)[0][-1]
41+
tokens.append(pred)
42+
segments.append(1)
43+
#展示结果
44+
s2= segments
45+
outs = tokens
46+
print(tokenizer.decode(outs[l:]))
47+
48+
49+
#cache模型做greedy saerch
50+
max_len=32
51+
input_lengths=[max_len,max_len]#segment和tokens的maxlen是一样
52+
#构建输入
53+
tokens,segments = tokenizer.encode('广东省的省会是广州')
54+
tokens = np.expand_dims(tokens+[tokenizer._token_start_id] +[0]*(max_len-len(tokens)-1),0)
55+
segments = np.expand_dims(segments+[1]*(max_len-len(segments)),0)
56+
inputs = [tokens,segments]
57+
#构建cache模型
58+
cache_model=bert.build_cache_model(input_lengths,end_token=end_token,progress_print=True)
59+
#输出并展示结果
60+
o1 = cache_model.predict([tokens,segments])
61+
print(tokenizer.decode(o1[0][l:]))

0 commit comments

Comments
 (0)