1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ jax torch tensorflow测试通过
4
+ 理论上roformerV2也能用,但我没roformerV2的unilm模型,所以没测
5
+ """
6
+ import os
7
+ os .environ ["KERAS_BACKEND" ] = "jax"
8
+
9
+ from bert4keras3 .backend import keras , K
10
+ from bert4keras3 .models import *
11
+ from bert4keras3 .tokenizers import Tokenizer
12
+ from bert4keras3 .snippets import sequence_padding
13
+ import numpy as np
14
+
15
+ base_path = 'D:\ea下载\chinese_roformer-sim-char_L-12_H-768_A-12/'
16
+ config_path = base_path + 'bert_config.json'
17
+ checkpoint_path = base_path + 'bert_model.ckpt'
18
+ dict_path = base_path + 'vocab.txt'
19
+
20
+ tokenizer = Tokenizer (dict_path , do_lower_case = True ) # 建立分词器
21
+ end_token = tokenizer ._token_end_id
22
+ # 建立加载模型
23
+ self = bert = build_transformer_model (
24
+ config_path ,
25
+ checkpoint_path ,
26
+ model = 'roformer' ,
27
+ application = 'unilm' ,
28
+ with_mlm = True ,
29
+ return_keras_model = False ,
30
+ )
31
+ #用没有cache的模型做greedy search
32
+ tokens ,segments = tokenizer .encode ('广东省的省会是广州' )
33
+ l = len (tokens )
34
+ tokens = tokens + [tokenizer ._token_start_id ]
35
+ segments = segments + [1 ]
36
+ #search
37
+ while tokens [- 1 ]!= end_token :
38
+ inputs = [np .expand_dims (tokens ,0 ),np .expand_dims (segments ,0 )]
39
+ pred = bert .model .predict (inputs ,verbose = 3 )
40
+ pred = pred .argmax (- 1 )[0 ][- 1 ]
41
+ tokens .append (pred )
42
+ segments .append (1 )
43
+ #展示结果
44
+ s2 = segments
45
+ outs = tokens
46
+ print (tokenizer .decode (outs [l :]))
47
+
48
+
49
+ #cache模型做greedy saerch
50
+ max_len = 32
51
+ input_lengths = [max_len ,max_len ]#segment和tokens的maxlen是一样
52
+ #构建输入
53
+ tokens ,segments = tokenizer .encode ('广东省的省会是广州' )
54
+ tokens = np .expand_dims (tokens + [tokenizer ._token_start_id ] + [0 ]* (max_len - len (tokens )- 1 ),0 )
55
+ segments = np .expand_dims (segments + [1 ]* (max_len - len (segments )),0 )
56
+ inputs = [tokens ,segments ]
57
+ #构建cache模型
58
+ cache_model = bert .build_cache_model (input_lengths ,end_token = end_token ,progress_print = True )
59
+ #输出并展示结果
60
+ o1 = cache_model .predict ([tokens ,segments ])
61
+ print (tokenizer .decode (o1 [0 ][l :]))
0 commit comments