15
15
# See the License for the specific language governing permissions and
16
16
# limitations under the License.
17
17
import os
18
- from transformers import AutoConfig , AutoTokenizer
19
- from intel_extension_for_transformers .llm .runtime .graph .scripts .convert import convert_model
18
+
20
19
import torch
20
+ from intel_extension_for_transformers .llm .runtime .graph .scripts .convert import convert_model
21
+ from transformers import AutoConfig , AutoTokenizer
22
+
21
23
model_maps = {"gpt_neox" : "gptneox" , "gpt_bigcode" : "starcoder" }
22
24
25
+
23
26
class Model :
24
27
def __init__ (self ):
25
28
self .module = None
@@ -28,55 +31,68 @@ def __init__(self):
28
31
self .bin_file = None
29
32
self .generate_round = 0
30
33
31
- def __import_package (self , model_name ):
34
+ def __import_package (self , model_type ):
32
35
if self .module :
33
36
return
34
- if model_name == "gptj" :
37
+ if model_type == "gptj" :
35
38
import intel_extension_for_transformers .llm .runtime .graph .gptj_cpp as cpp_model
36
- elif model_name == "falcon" :
39
+ elif model_type == "falcon" :
37
40
import intel_extension_for_transformers .llm .runtime .graph .falcon_cpp as cpp_model
38
- elif model_name == "gptneox" :
41
+ elif model_type == "gptneox" :
39
42
import intel_extension_for_transformers .llm .runtime .graph .gptneox_cpp as cpp_model
40
- elif model_name == "dolly" :
43
+ elif model_type == "dolly" :
41
44
import intel_extension_for_transformers .llm .runtime .graph .dolly_cpp as cpp_model
42
- elif model_name == "llama" or model_name == "llama2" :
45
+ elif model_type == "llama" or model_type == "llama2" :
43
46
import intel_extension_for_transformers .llm .runtime .graph .llama_cpp as cpp_model
44
- elif model_name == "mpt" :
47
+ elif model_type == "mpt" :
45
48
import intel_extension_for_transformers .llm .runtime .graph .mpt_cpp as cpp_model
46
- elif model_name == "gpt_bigcode" or model_name == "starcoder" :
49
+ elif model_type == "gpt_bigcode" or model_type == "starcoder" :
47
50
import intel_extension_for_transformers .llm .runtime .graph .starcoder_cpp as cpp_model
48
- elif model_name == "opt" :
51
+ elif model_type == "opt" :
49
52
import intel_extension_for_transformers .llm .runtime .graph .opt_cpp as cpp_model
50
- elif model_name == "bloom" :
53
+ elif model_type == "bloom" :
51
54
import intel_extension_for_transformers .llm .runtime .graph .bloom_cpp as cpp_model
52
- elif model_name == "chatglm" :
55
+ elif model_type == "chatglm" :
53
56
import intel_extension_for_transformers .llm .runtime .graph .chatglm_cpp as cpp_model
54
- elif model_name == "chatglm2" :
57
+ elif model_type == "chatglm2" :
55
58
import intel_extension_for_transformers .llm .runtime .graph .chatglm2_cpp as cpp_model
56
- elif model_name == "baichuan" :
59
+ elif model_type == "baichuan" :
57
60
import intel_extension_for_transformers .llm .runtime .graph .baichuan_cpp as cpp_model
58
- elif model_name == "polyglot" :
61
+ elif model_type == "polyglot" :
59
62
import intel_extension_for_transformers .llm .runtime .graph .polyglot_cpp as cpp_model
60
- elif model_name == "mistral" :
63
+ elif model_type == "mistral" :
61
64
import intel_extension_for_transformers .llm .runtime .graph .mistral_cpp as cpp_model
62
65
else :
63
- raise TypeError ("Unspported model type {}!" .format (model_name ))
66
+ raise TypeError ("Unspported model type {}!" .format (model_type ))
64
67
self .module = cpp_model
65
68
69
+ @staticmethod
70
+ def get_model_type (model_config ):
71
+ model_type = model_maps .get (model_config .model_type , model_config .model_type )
72
+ if model_type == "chatglm" and "chatglm2" in model_config ._name_or_path :
73
+ model_type = "chatglm2"
74
+ return model_type
75
+
66
76
def init (self , model_name , not_quant = False , use_cache = False , ** quant_kwargs ):
67
77
self .config = AutoConfig .from_pretrained (model_name , trust_remote_code = True )
68
78
self .tokenizer = AutoTokenizer .from_pretrained (model_name , trust_remote_code = True )
69
- model_type = model_maps .get (self .config .model_type , self .config .model_type )
70
- if model_type == "chatglm" and "chatglm2" in self .config ._name_or_path :
71
- model_type = "chatglm2"
79
+ model_type = Model .get_model_type (self .config )
72
80
self .__import_package (model_type )
73
81
74
82
# check cache and quantization
75
83
output_path = "runtime_outs"
76
- if not os .path .exists (output_path ):
77
- os .makedirs (output_path )
84
+ os .makedirs (output_path , exist_ok = True )
78
85
fp32_bin = "{}/ne_{}_f32.bin" .format (output_path , model_type )
79
- quant_bin = "{}/ne_{}_q.bin" .format (output_path , model_type )
86
+ quant_desc = quant_kwargs ['weight_dtype' ]
87
+ if quant_kwargs ['use_ggml' ]:
88
+ quant_desc += "_ggml"
89
+ else :
90
+ quant_desc += "_jblas_c" + quant_kwargs ['compute_dtype' ]
91
+ if quant_kwargs ['group_size' ] == - 1 :
92
+ quant_desc += "_pc"
93
+ else :
94
+ quant_desc += "_g{}" .format (quant_kwargs ['group_size' ])
95
+ quant_bin = "{}/ne_{}_q_{}.bin" .format (output_path , model_type , quant_desc )
80
96
81
97
if not_quant :
82
98
self .bin_file = fp32_bin
@@ -85,20 +101,22 @@ def init(self, model_name, not_quant=False, use_cache=False, **quant_kwargs):
85
101
if use_cache and os .path .exists (self .bin_file ):
86
102
return
87
103
88
- convert_model (model_name , fp32_bin , "f32" )
89
- assert os .path .exists (fp32_bin ), "Fail to convert pytorch model"
104
+ if not use_cache or not os .path .exists (fp32_bin ):
105
+ convert_model (model_name , fp32_bin , "f32" )
106
+ assert os .path .exists (fp32_bin ), "Fail to convert pytorch model"
90
107
91
108
if not_quant :
92
109
print ("FP32 model will be used." )
93
110
return
94
- self .module .Model .quant_model (model_path = fp32_bin , out_path = quant_bin , ** quant_kwargs )
111
+ self .module .Model .quant_model (model_path = fp32_bin , out_path = quant_bin , ** quant_kwargs )
95
112
assert os .path .exists (quant_bin ), "Fail to quantize model"
96
-
113
+
97
114
# clean
98
- os .remove (fp32_bin )
115
+ if not use_cache :
116
+ os .remove (fp32_bin )
99
117
100
- def init_from_bin (self , model_name , model_path , ** generate_kwargs ):
101
- self .__import_package (model_name )
118
+ def init_from_bin (self , model_type , model_path , ** generate_kwargs ):
119
+ self .__import_package (model_type )
102
120
self .model = self .module .Model ()
103
121
if "threads" not in generate_kwargs :
104
122
threads = os .getenv ("OMP_NUM_THREADS" )
@@ -108,11 +126,9 @@ def init_from_bin(self, model_name, model_path, **generate_kwargs):
108
126
generate_kwargs ["threads" ] = int (threads )
109
127
self .model .init_model (model_path , ** generate_kwargs )
110
128
111
- def quant_model (self , model_name , model_path , out_path , ** quant_kwargs ):
112
- self .__import_package (model_name )
113
- self .module .Model .quant_model (model_path = model_path ,
114
- out_path = out_path , ** quant_kwargs )
115
-
129
+ def quant_model (self , model_type , model_path , out_path , ** quant_kwargs ):
130
+ self .__import_package (model_type )
131
+ self .module .Model .quant_model (model_path = model_path , out_path = out_path , ** quant_kwargs )
116
132
117
133
def generate (self , input_ids , streamer = None , interactive = False , ignore_prompt = False , stopping_criteria = None , ** generate_kwargs ):
118
134
max_new_tokens = generate_kwargs .get ("max_new_tokens" , - 1 )
@@ -129,8 +145,7 @@ def generate(self, input_ids, streamer=None, interactive=False, ignore_prompt=Fa
129
145
ret = input_ids .tolist ()
130
146
131
147
beam_search = False
132
- if ("num_beams" in generate_kwargs and generate_kwargs ["num_beams" ] > 1 ) and not \
133
- generate_kwargs .get ("do_sample" , False ):
148
+ if (generate_kwargs .get ("num_beams" , 1 ) > 1 ) and not generate_kwargs .get ("do_sample" , False ):
134
149
beam_search = True
135
150
if not beam_search :
136
151
# TODO support multi batch
@@ -142,30 +157,43 @@ def generate(self, input_ids, streamer=None, interactive=False, ignore_prompt=Fa
142
157
Make sure that `num_beams` is set to 1."
143
158
if self .generate_round == 0 and not ignore_prompt :
144
159
streamer .put (input_ids )
145
-
160
+
146
161
if interactive :
147
162
self .model .reset_token_end ()
148
163
out_count = 0
164
+ input_list = input_ids .tolist ()
149
165
while True :
150
- response = self .model .generate (input_ids = input_ids .tolist ())
166
+ response = self .model .generate (input_ids = input_list )
167
+ input_list = [] # next-token stage will use previous output
151
168
if len (response ) == 0 :
152
169
break
153
170
if streamer :
154
171
streamer .put (torch .tensor ([response [0 ]]))
155
172
for i in range (len (response )):
156
173
ret [i ].extend (response [i ])
174
+ if beam_search :
175
+ break
157
176
if stopping_criteria is not None :
158
177
if stopping_criteria (torch .tensor (ret ), None ):
159
178
break
160
179
elif ret [0 ][- 1 ] == self .tokenizer .eos_token_id or \
161
- (max_new_tokens != - 1 and out_count > max_new_tokens ):
180
+ (max_new_tokens != - 1 and out_count > max_new_tokens ):
162
181
break
163
182
out_count += 1
164
183
if streamer :
165
184
streamer .end ()
166
-
185
+
167
186
self .generate_round += 1
168
187
return ret
169
188
170
189
def is_token_end (self ):
171
190
return self .model .is_token_end ()
191
+
192
+ def __call__ (self , input_ids , reinit = False , ** kwargs ):
193
+ if self .model is None :
194
+ self .init_from_bin (self .model_type , self .bin_file , ** kwargs )
195
+ self .generate_round = 0
196
+ elif reinit :
197
+ self .model .reinit ()
198
+ self .generate_round = 0
199
+ return self .model .evaluate (input_ids .tolist ())
0 commit comments